Browse Source

formatting: run precommit on all files

master
lanvent 1 year ago
parent
commit
618c94edb8
40 changed files with 228 additions and 646 deletions
  1. +1
    -0
      app.py
  2. +2
    -10
      bot/baidu/baidu_unit_bot.py
  3. +7
    -22
      bot/chatgpt/chat_gpt_bot.py
  4. +5
    -17
      bot/chatgpt/chat_gpt_session.py
  5. +7
    -21
      bot/openai/open_ai_bot.py
  6. +2
    -8
      bot/openai/open_ai_image.py
  7. +3
    -13
      bot/openai/open_ai_session.py
  8. +4
    -16
      bot/session_manager.py
  9. +1
    -3
      bridge/context.py
  10. +28
    -103
      channel/chat_channel.py
  11. +1
    -3
      channel/terminal/terminal_channel.py
  12. +8
    -31
      channel/wechat/wechat_channel.py
  13. +6
    -20
      channel/wechat/wechat_message.py
  14. +5
    -15
      channel/wechat/wechaty_channel.py
  15. +3
    -9
      channel/wechat/wechaty_message.py
  16. +8
    -17
      channel/wechatmp/active_reply.py
  17. +5
    -3
      channel/wechatmp/common.py
  18. +18
    -36
      channel/wechatmp/passive_reply.py
  19. +20
    -27
      channel/wechatmp/wechatmp_channel.py
  20. +10
    -11
      channel/wechatmp/wechatmp_client.py
  21. +5
    -14
      channel/wechatmp/wechatmp_message.py
  22. +3
    -11
      common/time_check.py
  23. +1
    -3
      config.py
  24. +3
    -9
      plugins/banwords/banwords.py
  25. +7
    -32
      plugins/bdunit/bdunit.py
  26. +2
    -8
      plugins/dungeon/dungeon.py
  27. +7
    -23
      plugins/godcmd/godcmd.py
  28. +2
    -6
      plugins/hello/hello.py
  29. +2
    -2
      plugins/keyword/config.json.template
  30. +1
    -3
      plugins/keyword/keyword.py
  31. +14
    -46
      plugins/plugin_manager.py
  32. +7
    -24
      plugins/role/role.py
  33. +1
    -1
      plugins/tool/README.md
  34. +4
    -10
      plugins/tool/tool.py
  35. +5
    -15
      voice/audio_convert.py
  36. +9
    -33
      voice/azure/azure_voice.py
  37. +1
    -3
      voice/baidu/baidu_voice.py
  38. +2
    -8
      voice/google/google_voice.py
  39. +1
    -5
      voice/openai/openai_voice.py
  40. +7
    -5
      voice/pytts/pytts_voice.py

+ 1
- 0
app.py View File

@@ -19,6 +19,7 @@ def sigterm_handler_wrap(_signo):
if callable(old_handler): # check old_handler if callable(old_handler): # check old_handler
return old_handler(_signo, _stack_frame) return old_handler(_signo, _stack_frame)
sys.exit(0) sys.exit(0)

signal.signal(_signo, func) signal.signal(_signo, func)






+ 2
- 10
bot/baidu/baidu_unit_bot.py View File

@@ -10,10 +10,7 @@ from bridge.reply import Reply, ReplyType
class BaiduUnitBot(Bot): class BaiduUnitBot(Bot):
def reply(self, query, context=None): def reply(self, query, context=None):
token = self.get_token() token = self.get_token()
url = (
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+ token
)
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
post_data = ( post_data = (
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
+ query + query
@@ -32,12 +29,7 @@ class BaiduUnitBot(Bot):
def get_token(self): def get_token(self):
access_key = "YOUR_ACCESS_KEY" access_key = "YOUR_ACCESS_KEY"
secret_key = "YOUR_SECRET_KEY" secret_key = "YOUR_SECRET_KEY"
host = (
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
+ access_key
+ "&client_secret="
+ secret_key
)
host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
response = requests.get(host) response = requests.get(host)
if response: if response:
print(response.json()) print(response.json())


+ 7
- 22
bot/chatgpt/chat_gpt_bot.py View File

@@ -30,23 +30,15 @@ class ChatGPTBot(Bot, OpenAIImage):
if conf().get("rate_limit_chatgpt"): if conf().get("rate_limit_chatgpt"):
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))


self.sessions = SessionManager(
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
)
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = { self.args = {
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数 # "max_tokens":4096, # 回复最大的字符数
"top_p": 1, "top_p": 1,
"frequency_penalty": conf().get(
"frequency_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
} }


@@ -87,15 +79,10 @@ class ChatGPTBot(Bot, OpenAIImage):
reply_content["completion_tokens"], reply_content["completion_tokens"],
) )
) )
if (
reply_content["completion_tokens"] == 0
and len(reply_content["content"]) > 0
):
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
reply = Reply(ReplyType.ERROR, reply_content["content"]) reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0: elif reply_content["completion_tokens"] > 0:
self.sessions.session_reply(
reply_content["content"], session_id, reply_content["total_tokens"]
)
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"]) reply = Reply(ReplyType.TEXT, reply_content["content"])
else: else:
reply = Reply(ReplyType.ERROR, reply_content["content"]) reply = Reply(ReplyType.ERROR, reply_content["content"])
@@ -126,9 +113,7 @@ class ChatGPTBot(Bot, OpenAIImage):
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used # if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create(
api_key=api_key, messages=session.messages, **self.args
)
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return { return {
"total_tokens": response["usage"]["total_tokens"], "total_tokens": response["usage"]["total_tokens"],


+ 5
- 17
bot/chatgpt/chat_gpt_session.py View File

@@ -25,9 +25,7 @@ class ChatGPTSession(Session):
precise = False precise = False
if cur_tokens is None: if cur_tokens is None:
raise e raise e
logger.debug(
"Exception when counting tokens precisely for query: {}".format(e)
)
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens: while cur_tokens > max_tokens:
if len(self.messages) > 2: if len(self.messages) > 2:
self.messages.pop(1) self.messages.pop(1)
@@ -39,16 +37,10 @@ class ChatGPTSession(Session):
cur_tokens = cur_tokens - max_tokens cur_tokens = cur_tokens - max_tokens
break break
elif len(self.messages) == 2 and self.messages[1]["role"] == "user": elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
logger.warn(
"user message exceed max_tokens. total_tokens={}".format(cur_tokens)
)
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
break break
else: else:
logger.debug(
"max_tokens={}, total_tokens={}, len(messages)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
break break
if precise: if precise:
cur_tokens = self.calc_tokens() cur_tokens = self.calc_tokens()
@@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model):
elif model == "gpt-4": elif model == "gpt-4":
return num_tokens_from_messages(messages, model="gpt-4-0314") return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301": elif model == "gpt-3.5-turbo-0301":
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314": elif model == "gpt-4-0314":
tokens_per_message = 3 tokens_per_message = 3
tokens_per_name = 1 tokens_per_name = 1
else: else:
logger.warn(
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
)
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
num_tokens = 0 num_tokens = 0
for message in messages: for message in messages:


+ 7
- 21
bot/openai/open_ai_bot.py View File

@@ -28,23 +28,15 @@ class OpenAIBot(Bot, OpenAIImage):
if proxy: if proxy:
openai.proxy = proxy openai.proxy = proxy


self.sessions = SessionManager(
OpenAISession, model=conf().get("model") or "text-davinci-003"
)
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
self.args = { self.args = {
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称 "model": conf().get("model") or "text-davinci-003", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"max_tokens": 1200, # 回复最大的字符数 "max_tokens": 1200, # 回复最大的字符数
"top_p": 1, "top_p": 1,
"frequency_penalty": conf().get(
"frequency_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
"stop": ["\n\n\n"], "stop": ["\n\n\n"],
} }
@@ -71,17 +63,13 @@ class OpenAIBot(Bot, OpenAIImage):
result["content"], result["content"],
) )
logger.debug( logger.debug(
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
str(session), session_id, reply_content, completion_tokens
)
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
) )


if total_tokens == 0: if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content) reply = Reply(ReplyType.ERROR, reply_content)
else: else:
self.sessions.session_reply(
reply_content, session_id, total_tokens
)
self.sessions.session_reply(reply_content, session_id, total_tokens)
reply = Reply(ReplyType.TEXT, reply_content) reply = Reply(ReplyType.TEXT, reply_content)
return reply return reply
elif context.type == ContextType.IMAGE_CREATE: elif context.type == ContextType.IMAGE_CREATE:
@@ -96,9 +84,7 @@ class OpenAIBot(Bot, OpenAIImage):
def reply_text(self, session: OpenAISession, retry_count=0): def reply_text(self, session: OpenAISession, retry_count=0):
try: try:
response = openai.Completion.create(prompt=str(session), **self.args) response = openai.Completion.create(prompt=str(session), **self.args)
res_content = (
response.choices[0]["text"].strip().replace("<|endoftext|>", "")
)
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
total_tokens = response["usage"]["total_tokens"] total_tokens = response["usage"]["total_tokens"]
completion_tokens = response["usage"]["completion_tokens"] completion_tokens = response["usage"]["completion_tokens"]
logger.info("[OPEN_AI] reply={}".format(res_content)) logger.info("[OPEN_AI] reply={}".format(res_content))


+ 2
- 8
bot/openai/open_ai_image.py View File

@@ -23,9 +23,7 @@ class OpenAIImage(object):
response = openai.Image.create( response = openai.Image.create(
prompt=query, # 图片描述 prompt=query, # 图片描述
n=1, # 每次生成图片的数量 n=1, # 每次生成图片的数量
size=conf().get(
"image_create_size", "256x256"
), # 图片大小,可选有 256x256, 512x512, 1024x1024
size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
) )
image_url = response["data"][0]["url"] image_url = response["data"][0]["url"]
logger.info("[OPEN_AI] image_url={}".format(image_url)) logger.info("[OPEN_AI] image_url={}".format(image_url))
@@ -34,11 +32,7 @@ class OpenAIImage(object):
logger.warn(e) logger.warn(e)
if retry_count < 1: if retry_count < 1:
time.sleep(5) time.sleep(5)
logger.warn(
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
retry_count + 1
)
)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
return self.create_img(query, retry_count + 1) return self.create_img(query, retry_count + 1)
else: else:
return False, "提问太快啦,请休息一下再问我吧" return False, "提问太快啦,请休息一下再问我吧"


+ 3
- 13
bot/openai/open_ai_session.py View File

@@ -36,9 +36,7 @@ class OpenAISession(Session):
precise = False precise = False
if cur_tokens is None: if cur_tokens is None:
raise e raise e
logger.debug(
"Exception when counting tokens precisely for query: {}".format(e)
)
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens: while cur_tokens > max_tokens:
if len(self.messages) > 1: if len(self.messages) > 1:
self.messages.pop(0) self.messages.pop(0)
@@ -50,18 +48,10 @@ class OpenAISession(Session):
cur_tokens = len(str(self)) cur_tokens = len(str(self))
break break
elif len(self.messages) == 1 and self.messages[0]["role"] == "user": elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
logger.warn(
"user question exceed max_tokens. total_tokens={}".format(
cur_tokens
)
)
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
break break
else: else:
logger.debug(
"max_tokens={}, total_tokens={}, len(conversation)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
break break
if precise: if precise:
cur_tokens = self.calc_tokens() cur_tokens = self.calc_tokens()


+ 4
- 16
bot/session_manager.py View File

@@ -55,9 +55,7 @@ class SessionManager(object):
return self.sessioncls(session_id, system_prompt, **self.session_args) return self.sessioncls(session_id, system_prompt, **self.session_args)


if session_id not in self.sessions: if session_id not in self.sessions:
self.sessions[session_id] = self.sessioncls(
session_id, system_prompt, **self.session_args
)
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
self.sessions[session_id].set_system_prompt(system_prompt) self.sessions[session_id].set_system_prompt(system_prompt)
session = self.sessions[session_id] session = self.sessions[session_id]
@@ -71,9 +69,7 @@ class SessionManager(object):
total_tokens = session.discard_exceeding(max_tokens, None) total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens)) logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e: except Exception as e:
logger.debug(
"Exception when counting tokens precisely for prompt: {}".format(str(e))
)
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
return session return session


def session_reply(self, reply, session_id, total_tokens=None): def session_reply(self, reply, session_id, total_tokens=None):
@@ -82,17 +78,9 @@ class SessionManager(object):
try: try:
max_tokens = conf().get("conversation_max_tokens", 1000) max_tokens = conf().get("conversation_max_tokens", 1000)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.debug(
"raw total_tokens={}, savesession tokens={}".format(
total_tokens, tokens_cnt
)
)
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
except Exception as e: except Exception as e:
logger.debug(
"Exception when counting tokens precisely for session: {}".format(
str(e)
)
)
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
return session return session


def clear_session(self, session_id): def clear_session(self, session_id):


+ 1
- 3
bridge/context.py View File

@@ -60,6 +60,4 @@ class Context:
del self.kwargs[key] del self.kwargs[key]


def __str__(self): def __str__(self):
return "Context(type={}, content={}, kwargs={})".format(
self.type, self.content, self.kwargs
)
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)

+ 28
- 103
channel/chat_channel.py View File

@@ -53,9 +53,7 @@ class ChatChannel(Channel):
group_id = cmsg.other_user_id group_id = cmsg.other_user_id


group_name_white_list = config.get("group_name_white_list", []) group_name_white_list = config.get("group_name_white_list", [])
group_name_keyword_white_list = config.get(
"group_name_keyword_white_list", []
)
group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
if any( if any(
[ [
group_name in group_name_white_list, group_name in group_name_white_list,
@@ -63,9 +61,7 @@ class ChatChannel(Channel):
check_contain(group_name, group_name_keyword_white_list), check_contain(group_name, group_name_keyword_white_list),
] ]
): ):
group_chat_in_one_session = conf().get(
"group_chat_in_one_session", []
)
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
session_id = cmsg.actual_user_id session_id = cmsg.actual_user_id
if any( if any(
[ [
@@ -81,17 +77,11 @@ class ChatChannel(Channel):
else: else:
context["session_id"] = cmsg.other_user_id context["session_id"] = cmsg.other_user_id
context["receiver"] = cmsg.other_user_id context["receiver"] = cmsg.other_user_id
e_context = PluginManager().emit_event(
EventContext(
Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
)
)
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
context = e_context["context"] context = e_context["context"]
if e_context.is_pass() or context is None: if e_context.is_pass() or context is None:
return context return context
if cmsg.from_user_id == self.user_id and not config.get(
"trigger_by_self", True
):
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
logger.debug("[WX]self message skipped") logger.debug("[WX]self message skipped")
return None return None


@@ -119,19 +109,13 @@ class ChatChannel(Channel):


if not flag: if not flag:
if context["origin_ctype"] == ContextType.VOICE: if context["origin_ctype"] == ContextType.VOICE:
logger.info(
"[WX]receive group voice, but checkprefix didn't match"
)
logger.info("[WX]receive group voice, but checkprefix didn't match")
return None return None
else: # 单聊 else: # 单聊
match_prefix = check_prefix(
content, conf().get("single_chat_prefix", [""])
)
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
content = content.replace(match_prefix, "", 1).strip() content = content.replace(match_prefix, "", 1).strip()
elif (
context["origin_ctype"] == ContextType.VOICE
): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
pass pass
else: else:
return None return None
@@ -143,18 +127,10 @@ class ChatChannel(Channel):
else: else:
context.type = ContextType.TEXT context.type = ContextType.TEXT
context.content = content.strip() context.content = content.strip()
if (
"desire_rtype" not in context
and conf().get("always_reply_voice")
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context["desire_rtype"] = ReplyType.VOICE context["desire_rtype"] = ReplyType.VOICE
elif context.type == ContextType.VOICE: elif context.type == ContextType.VOICE:
if (
"desire_rtype" not in context
and conf().get("voice_reply_voice")
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context["desire_rtype"] = ReplyType.VOICE context["desire_rtype"] = ReplyType.VOICE


return context return context
@@ -182,15 +158,8 @@ class ChatChannel(Channel):
) )
reply = e_context["reply"] reply = e_context["reply"]
if not e_context.is_pass(): if not e_context.is_pass():
logger.debug(
"[WX] ready to handle context: type={}, content={}".format(
context.type, context.content
)
)
if (
context.type == ContextType.TEXT
or context.type == ContextType.IMAGE_CREATE
): # 文字和图片消息
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
reply = super().build_reply_content(context.content, context) reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息 elif context.type == ContextType.VOICE: # 语音消息
cmsg = context["msg"] cmsg = context["msg"]
@@ -214,9 +183,7 @@ class ChatChannel(Channel):
# logger.warning("[WX]delete temp file error: " + str(e)) # logger.warning("[WX]delete temp file error: " + str(e))


if reply.type == ReplyType.TEXT: if reply.type == ReplyType.TEXT:
new_context = self._compose_context(
ContextType.TEXT, reply.content, **context.kwargs
)
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
if new_context: if new_context:
reply = self._generate_reply(new_context) reply = self._generate_reply(new_context)
else: else:
@@ -246,48 +213,24 @@ class ChatChannel(Channel):


if reply.type == ReplyType.TEXT: if reply.type == ReplyType.TEXT:
reply_text = reply.content reply_text = reply.content
if (
desire_rtype == ReplyType.VOICE
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
reply = super().build_text_to_voice(reply.content) reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply) return self._decorate_reply(context, reply)
if context.get("isgroup", False): if context.get("isgroup", False):
reply_text = (
"@"
+ context["msg"].actual_user_nickname
+ " "
+ reply_text.strip()
)
reply_text = (
conf().get("group_chat_reply_prefix", "") + reply_text
)
reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
else: else:
reply_text = (
conf().get("single_chat_reply_prefix", "") + reply_text
)
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
reply.content = reply_text reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = "[" + str(reply.type) + "]\n" + reply.content reply.content = "[" + str(reply.type) + "]\n" + reply.content
elif (
reply.type == ReplyType.IMAGE_URL
or reply.type == ReplyType.VOICE
or reply.type == ReplyType.IMAGE
):
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass pass
else: else:
logger.error("[WX] unknown reply type: {}".format(reply.type)) logger.error("[WX] unknown reply type: {}".format(reply.type))
return return
if (
desire_rtype
and desire_rtype != reply.type
and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
):
logger.warning(
"[WX] desire_rtype: {}, but reply type: {}".format(
context.get("desire_rtype"), reply.type
)
)
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
return reply return reply


def _send_reply(self, context: Context, reply: Reply): def _send_reply(self, context: Context, reply: Reply):
@@ -300,9 +243,7 @@ class ChatChannel(Channel):
) )
reply = e_context["reply"] reply = e_context["reply"]
if not e_context.is_pass() and reply and reply.type: if not e_context.is_pass() and reply and reply.type:
logger.debug(
"[WX] ready to send reply: {}, context: {}".format(reply, context)
)
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
self._send(reply, context) self._send(reply, context)


def _send(self, reply: Reply, context: Context, retry_cnt=0): def _send(self, reply: Reply, context: Context, retry_cnt=0):
@@ -328,9 +269,7 @@ class ChatChannel(Channel):
try: try:
worker_exception = worker.exception() worker_exception = worker.exception()
if worker_exception: if worker_exception:
self._fail_callback(
session_id, exception=worker_exception, **kwargs
)
self._fail_callback(session_id, exception=worker_exception, **kwargs)
else: else:
self._success_callback(session_id, **kwargs) self._success_callback(session_id, **kwargs)
except CancelledError as e: except CancelledError as e:
@@ -366,24 +305,14 @@ class ChatChannel(Channel):
if not context_queue.empty(): if not context_queue.empty():
context = context_queue.get() context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context)) logger.debug("[WX] consume context: {}".format(context))
future: Future = self.handler_pool.submit(
self._handle, context
)
future.add_done_callback(
self._thread_pool_callback(session_id, context=context)
)
future: Future = self.handler_pool.submit(self._handle, context)
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
if session_id not in self.futures: if session_id not in self.futures:
self.futures[session_id] = [] self.futures[session_id] = []
self.futures[session_id].append(future) self.futures[session_id].append(future)
elif (
semaphore._initial_value == semaphore._value + 1
): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
self.futures[session_id] = [
t for t in self.futures[session_id] if not t.done()
]
assert (
len(self.futures[session_id]) == 0
), "thread pool error"
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
assert len(self.futures[session_id]) == 0, "thread pool error"
del self.sessions[session_id] del self.sessions[session_id]
else: else:
semaphore.release() semaphore.release()
@@ -397,9 +326,7 @@ class ChatChannel(Channel):
future.cancel() future.cancel()
cnt = self.sessions[session_id][0].qsize() cnt = self.sessions[session_id][0].qsize()
if cnt > 0: if cnt > 0:
logger.info(
"Cancel {} messages in session {}".format(cnt, session_id)
)
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue() self.sessions[session_id][0] = Dequeue()


def cancel_all_session(self): def cancel_all_session(self):
@@ -409,9 +336,7 @@ class ChatChannel(Channel):
future.cancel() future.cancel()
cnt = self.sessions[session_id][0].qsize() cnt = self.sessions[session_id][0].qsize()
if cnt > 0: if cnt > 0:
logger.info(
"Cancel {} messages in session {}".format(cnt, session_id)
)
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue() self.sessions[session_id][0] = Dequeue()






+ 1
- 3
channel/terminal/terminal_channel.py View File

@@ -77,9 +77,7 @@ class TerminalChannel(ChatChannel):
if check_prefix(prompt, trigger_prefixs) is None: if check_prefix(prompt, trigger_prefixs) is None:
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀


context = self._compose_context(
ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)
)
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
if context: if context:
self.produce(context) self.produce(context)
else: else:


+ 8
- 31
channel/wechat/wechat_channel.py View File

@@ -56,10 +56,7 @@ def _check(func):
return return
self.receivedMsgs[msgId] = cmsg self.receivedMsgs[msgId] = cmsg
create_time = cmsg.create_time # 消息时间戳 create_time = cmsg.create_time # 消息时间戳
if (
conf().get("hot_reload") == True
and int(create_time) < int(time.time()) - 60
): # 跳过1分钟前的历史消息
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId)) logger.debug("[WX]history message {} skipped".format(msgId))
return return
return func(self, cmsg) return func(self, cmsg)
@@ -88,15 +85,9 @@ def qrCallback(uuid, status, qrcode):
url = f"https://login.weixin.qq.com/l/{uuid}" url = f"https://login.weixin.qq.com/l/{uuid}"


qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
qr_api2 = (
"https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(
url
)
)
qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url) qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(
url
)
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
print("You can also scan QRCode in any website below:") print("You can also scan QRCode in any website below:")
print(qr_api3) print(qr_api3)
print(qr_api4) print(qr_api4)
@@ -134,18 +125,12 @@ class WechatChannel(ChatChannel):
logger.error("Hot reload failed, try to login without hot reload") logger.error("Hot reload failed, try to login without hot reload")
itchat.logout() itchat.logout()
os.remove(status_path) os.remove(status_path)
itchat.auto_login(
enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback
)
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
else: else:
raise e raise e
self.user_id = itchat.instance.storageClass.userName self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName self.name = itchat.instance.storageClass.nickName
logger.info(
"Wechat login success, user_id: {}, nickname: {}".format(
self.user_id, self.name
)
)
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
# start message listener # start message listener
itchat.run() itchat.run()


@@ -173,16 +158,10 @@ class WechatChannel(ChatChannel):
elif cmsg.ctype == ContextType.PATPAT: elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content)) logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT: elif cmsg.ctype == ContextType.TEXT:
logger.debug(
"[WX]receive text msg: {}, cmsg={}".format(
json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg
)
)
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
else: else:
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg)) logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
context = self._compose_context(
cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg
)
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
if context: if context:
self.produce(context) self.produce(context)


@@ -202,9 +181,7 @@ class WechatChannel(ChatChannel):
pass pass
else: else:
logger.debug("[WX]receive group msg: {}".format(cmsg.content)) logger.debug("[WX]receive group msg: {}".format(cmsg.content))
context = self._compose_context(
cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg
)
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
if context: if context:
self.produce(context) self.produce(context)




+ 6
- 20
channel/wechat/wechat_message.py View File

@@ -27,37 +27,23 @@ class WeChatMessage(ChatMessage):
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content) self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000: elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
if is_group and (
"加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]
):
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
self.ctype = ContextType.JOIN_GROUP self.ctype = ContextType.JOIN_GROUP
self.content = itchat_msg["Content"] self.content = itchat_msg["Content"]
# 这里只能得到nickname, actual_user_id还是机器人的id # 这里只能得到nickname, actual_user_id还是机器人的id
if "加入了群聊" in itchat_msg["Content"]: if "加入了群聊" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall(
r"\"(.*?)\"", itchat_msg["Content"]
)[-1]
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
elif "加入群聊" in itchat_msg["Content"]: elif "加入群聊" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall(
r"\"(.*?)\"", itchat_msg["Content"]
)[0]
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
elif "拍了拍我" in itchat_msg["Content"]: elif "拍了拍我" in itchat_msg["Content"]:
self.ctype = ContextType.PATPAT self.ctype = ContextType.PATPAT
self.content = itchat_msg["Content"] self.content = itchat_msg["Content"]
if is_group: if is_group:
self.actual_user_nickname = re.findall(
r"\"(.*?)\"", itchat_msg["Content"]
)[0]
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
else: else:
raise NotImplementedError(
"Unsupported note message: " + itchat_msg["Content"]
)
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
else: else:
raise NotImplementedError(
"Unsupported message type: Type:{} MsgType:{}".format(
itchat_msg["Type"], itchat_msg["MsgType"]
)
)
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))


self.from_user_id = itchat_msg["FromUserName"] self.from_user_id = itchat_msg["FromUserName"]
self.to_user_id = itchat_msg["ToUserName"] self.to_user_id = itchat_msg["ToUserName"]


+ 5
- 15
channel/wechat/wechaty_channel.py View File

@@ -60,13 +60,9 @@ class WechatyChannel(ChatChannel):
receiver_id = context["receiver"] receiver_id = context["receiver"]
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if context["isgroup"]: if context["isgroup"]:
receiver = asyncio.run_coroutine_threadsafe(
self.bot.Room.find(receiver_id), loop
).result()
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
else: else:
receiver = asyncio.run_coroutine_threadsafe(
self.bot.Contact.find(receiver_id), loop
).result()
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
msg = None msg = None
if reply.type == ReplyType.TEXT: if reply.type == ReplyType.TEXT:
msg = reply.content msg = reply.content
@@ -83,9 +79,7 @@ class WechatyChannel(ChatChannel):
voiceLength = int(any_to_sil(file_path, sil_file)) voiceLength = int(any_to_sil(file_path, sil_file))
if voiceLength >= 60000: if voiceLength >= 60000:
voiceLength = 60000 voiceLength = 60000
logger.info(
"[WX] voice too long, length={}, set to 60s".format(voiceLength)
)
logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
# 发送语音 # 发送语音
t = int(time.time()) t = int(time.time())
msg = FileBox.from_file(sil_file, name=str(t) + ".sil") msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
@@ -98,9 +92,7 @@ class WechatyChannel(ChatChannel):
os.remove(sil_file) os.remove(sil_file)
except Exception as e: except Exception as e:
pass pass
logger.info(
"[WX] sendVoice={}, receiver={}".format(reply.content, receiver)
)
logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content img_url = reply.content
t = int(time.time()) t = int(time.time())
@@ -111,9 +103,7 @@ class WechatyChannel(ChatChannel):
image_storage = reply.content image_storage = reply.content
image_storage.seek(0) image_storage.seek(0)
t = int(time.time()) t = int(time.time())
msg = FileBox.from_base64(
base64.b64encode(image_storage.read()), str(t) + ".png"
)
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendImage, receiver={}".format(receiver)) logger.info("[WX] sendImage, receiver={}".format(receiver))




+ 3
- 9
channel/wechat/wechaty_message.py View File

@@ -45,16 +45,12 @@ class WechatyMessage(ChatMessage, aobject):


def func(): def func():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(
voice_file.to_file(self.content), loop
).result()
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()


self._prepare_fn = func self._prepare_fn = func


else: else:
raise NotImplementedError(
"Unsupported message type: {}".format(wechaty_msg.type())
)
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))


from_contact = wechaty_msg.talker() # 获取消息的发送者 from_contact = wechaty_msg.talker() # 获取消息的发送者
self.from_user_id = from_contact.contact_id self.from_user_id = from_contact.contact_id
@@ -73,9 +69,7 @@ class WechatyMessage(ChatMessage, aobject):
self.to_user_id = to_contact.contact_id self.to_user_id = to_contact.contact_id
self.to_user_nickname = to_contact.name self.to_user_nickname = to_contact.name


if (
self.is_group or wechaty_msg.is_self()
): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
self.other_user_id = self.to_user_id self.other_user_id = self.to_user_id
self.other_user_nickname = self.to_user_nickname self.other_user_nickname = self.to_user_nickname
else: else:


+ 8
- 17
channel/wechatmp/active_reply.py View File

@@ -1,16 +1,17 @@
import time import time


import web import web
from wechatpy import parse_message
from wechatpy.replies import create_reply


from channel.wechatmp.wechatmp_message import WeChatMPMessage
from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
from channel.wechatmp.common import * from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel from channel.wechatmp.wechatmp_channel import WechatMPChannel
from wechatpy import parse_message
from channel.wechatmp.wechatmp_message import WeChatMPMessage
from common.log import logger from common.log import logger
from config import conf from config import conf
from wechatpy.replies import create_reply


# This class is instantiated once per query # This class is instantiated once per query
class Query: class Query:
@@ -50,29 +51,19 @@ class Query:
) )
) )
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
context = channel._compose_context(
wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg
)
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
else: else:
context = channel._compose_context(
wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg
)
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
if context: if context:
# set private openai_api_key # set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel # if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user) user_data = conf().get_user_data(from_user)
context["openai_api_key"] = user_data.get(
"openai_api_key"
) # None or user openai_api_key
context["openai_api_key"] = user_data.get("openai_api_key") # None or user openai_api_key
channel.produce(context) channel.produce(context)
# The reply will be sent by channel.send() in another thread # The reply will be sent by channel.send() in another thread
return "success" return "success"
elif msg.type == "event": elif msg.type == "event":
logger.info(
"[wechatmp] Event {} from {}".format(
msg.event, msg.source
)
)
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
if msg.event in ["subscribe", "subscribe_scan"]: if msg.event in ["subscribe", "subscribe_scan"]:
reply_text = subscribe_msg() reply_text = subscribe_msg()
replyPost = create_reply(reply_text, msg) replyPost = create_reply(reply_text, msg)


+ 5
- 3
channel/wechatmp/common.py View File

@@ -1,10 +1,12 @@
import textwrap import textwrap
import web


from config import conf
from wechatpy.utils import check_signature
import web
from wechatpy.crypto import WeChatCrypto from wechatpy.crypto import WeChatCrypto
from wechatpy.exceptions import InvalidSignatureException from wechatpy.exceptions import InvalidSignatureException
from wechatpy.utils import check_signature

from config import conf

MAX_UTF8_LEN = 2048 MAX_UTF8_LEN = 2048






+ 18
- 36
channel/wechatmp/passive_reply.py View File

@@ -1,17 +1,18 @@
import time
import asyncio import asyncio
import time


import web import web
from wechatpy import parse_message
from wechatpy.replies import ImageReply, VoiceReply, create_reply


from channel.wechatmp.wechatmp_message import WeChatMPMessage
from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
from channel.wechatmp.common import * from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel from channel.wechatmp.wechatmp_channel import WechatMPChannel
from channel.wechatmp.wechatmp_message import WeChatMPMessage
from common.log import logger from common.log import logger
from config import conf from config import conf
from wechatpy import parse_message
from wechatpy.replies import create_reply, ImageReply, VoiceReply



# This class is instantiated once per query # This class is instantiated once per query
class Query: class Query:
@@ -49,21 +50,15 @@ class Query:
if ( if (
from_user not in channel.cache_dict from_user not in channel.cache_dict
and from_user not in channel.running and from_user not in channel.running
or content.startswith("#")
and message_id not in channel.request_cnt # insert the godcmd
or content.startswith("#")
and message_id not in channel.request_cnt # insert the godcmd
): ):
# The first query begin # The first query begin
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
context = channel._compose_context(
wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg
)
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
else: else:
context = channel._compose_context(
wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg
)
logger.debug(
"[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported)
)
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))


if supported and context: if supported and context:
# set private openai_api_key # set private openai_api_key
@@ -94,23 +89,17 @@ class Query:
"""\ """\
未知错误,请稍后再试""" 未知错误,请稍后再试"""
) )
replyPost = create_reply(reply_text, msg) replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render()) return encrypt_func(replyPost.render())



# Wechat official server will request 3 times (5 seconds each), with the same message_id. # Wechat official server will request 3 times (5 seconds each), with the same message_id.
# Because the interval is 5 seconds, here assumed that do not have multithreading problems. # Because the interval is 5 seconds, here assumed that do not have multithreading problems.
request_cnt = channel.request_cnt.get(message_id, 0) + 1 request_cnt = channel.request_cnt.get(message_id, 0) + 1
channel.request_cnt[message_id] = request_cnt channel.request_cnt[message_id] = request_cnt
logger.info( logger.info(
"[wechatmp] Request {} from {} {} {}:{}\n{}".format( "[wechatmp] Request {} from {} {} {}:{}\n{}".format(
request_cnt,
from_user,
message_id,
web.ctx.env.get("REMOTE_ADDR"),
web.ctx.env.get("REMOTE_PORT"),
content
request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
) )
) )


@@ -130,7 +119,7 @@ class Query:
time.sleep(2) time.sleep(2)
# and do nothing, waiting for the next request # and do nothing, waiting for the next request
return "success" return "success"
else: # request_cnt == 3:
else: # request_cnt == 3:
# return timeout message # return timeout message
reply_text = "【正在思考中,回复任意文字尝试获取回复】" reply_text = "【正在思考中,回复任意文字尝试获取回复】"
replyPost = create_reply(reply_text, msg) replyPost = create_reply(reply_text, msg)
@@ -140,10 +129,7 @@ class Query:
channel.request_cnt.pop(message_id) channel.request_cnt.pop(message_id)


# no return because of bandwords or other reasons # no return because of bandwords or other reasons
if (
from_user not in channel.cache_dict
and from_user not in channel.running
):
if from_user not in channel.cache_dict and from_user not in channel.running:
return "success" return "success"


# Only one request can access to the cached data # Only one request can access to the cached data
@@ -152,7 +138,7 @@ class Query:
except KeyError: except KeyError:
return "success" return "success"


if (reply_type == "text"):
if reply_type == "text":
if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN: if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
reply_text = reply_content reply_text = reply_content
else: else:
@@ -177,7 +163,7 @@ class Query:
replyPost = create_reply(reply_text, msg) replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render()) return encrypt_func(replyPost.render())


elif (reply_type == "voice"):
elif reply_type == "voice":
media_id = reply_content media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
logger.info( logger.info(
@@ -193,7 +179,7 @@ class Query:
replyPost.media_id = media_id replyPost.media_id = media_id
return encrypt_func(replyPost.render()) return encrypt_func(replyPost.render())


elif (reply_type == "image"):
elif reply_type == "image":
media_id = reply_content media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
logger.info( logger.info(
@@ -210,11 +196,7 @@ class Query:
return encrypt_func(replyPost.render()) return encrypt_func(replyPost.render())


elif msg.type == "event": elif msg.type == "event":
logger.info(
"[wechatmp] Event {} from {}".format(
msg.event, msg.source
)
)
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
if msg.event in ["subscribe", "subscribe_scan"]: if msg.event in ["subscribe", "subscribe_scan"]:
reply_text = subscribe_msg() reply_text = subscribe_msg()
replyPost = create_reply(reply_text, msg) replyPost = create_reply(reply_text, msg)


+ 20
- 27
channel/wechatmp/wechatmp_channel.py View File

@@ -1,24 +1,26 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import asyncio
import imghdr
import io import io
import os import os
import threading
import time import time
import imghdr
import requests import requests
import asyncio
import threading
from config import conf
import web
from wechatpy.crypto import WeChatCrypto
from wechatpy.exceptions import WeChatClientException

from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
from common.log import logger
from common.singleton import singleton
from voice.audio_convert import any_to_mp3
from channel.chat_channel import ChatChannel from channel.chat_channel import ChatChannel
from channel.wechatmp.common import * from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_client import WechatMPClient from channel.wechatmp.wechatmp_client import WechatMPClient
from wechatpy.exceptions import WeChatClientException
from wechatpy.crypto import WeChatCrypto
from common.log import logger
from common.singleton import singleton
from config import conf
from voice.audio_convert import any_to_mp3


import web
# If using SSL, uncomment the following lines, and modify the certificate path. # If using SSL, uncomment the following lines, and modify the certificate path.
# from cheroot.server import HTTPServer # from cheroot.server import HTTPServer
# from cheroot.ssl.builtin import BuiltinSSLAdapter # from cheroot.ssl.builtin import BuiltinSSLAdapter
@@ -54,7 +56,6 @@ class WechatMPChannel(ChatChannel):
t.setDaemon(True) t.setDaemon(True)
t.start() t.start()



def startup(self): def startup(self):
if self.passive_reply: if self.passive_reply:
urls = ("/wx", "channel.wechatmp.passive_reply.Query") urls = ("/wx", "channel.wechatmp.passive_reply.Query")
@@ -84,7 +85,7 @@ class WechatMPChannel(ChatChannel):
elif reply.type == ReplyType.VOICE: elif reply.type == ReplyType.VOICE:
try: try:
voice_file_path = reply.content voice_file_path = reply.content
with open(voice_file_path, 'rb') as f:
with open(voice_file_path, "rb") as f:
# support: <2M, <60s, mp3/wma/wav/amr # support: <2M, <60s, mp3/wma/wav/amr
response = self.client.material.add("voice", f) response = self.client.material.add("voice", f)
logger.debug("[wechatmp] upload voice response: {}".format(response)) logger.debug("[wechatmp] upload voice response: {}".format(response))
@@ -107,7 +108,7 @@ class WechatMPChannel(ChatChannel):
image_storage.write(block) image_storage.write(block)
image_storage.seek(0) image_storage.seek(0)
image_type = imghdr.what(image_storage) image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type content_type = "image/" + image_type
try: try:
response = self.client.material.add("image", (filename, image_storage, content_type)) response = self.client.material.add("image", (filename, image_storage, content_type))
@@ -122,7 +123,7 @@ class WechatMPChannel(ChatChannel):
image_storage = reply.content image_storage = reply.content
image_storage.seek(0) image_storage.seek(0)
image_type = imghdr.what(image_storage) image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type content_type = "image/" + image_type
try: try:
response = self.client.material.add("image", (filename, image_storage, content_type)) response = self.client.material.add("image", (filename, image_storage, content_type))
@@ -137,7 +138,7 @@ class WechatMPChannel(ChatChannel):
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR: if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
reply_text = reply.content reply_text = reply.content
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN) texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
if len(texts)>1:
if len(texts) > 1:
logger.info("[wechatmp] text too long, split into {} parts".format(len(texts))) logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
for text in texts: for text in texts:
self.client.message.send_text(receiver, text) self.client.message.send_text(receiver, text)
@@ -174,7 +175,7 @@ class WechatMPChannel(ChatChannel):
image_storage.write(block) image_storage.write(block)
image_storage.seek(0) image_storage.seek(0)
image_type = imghdr.what(image_storage) image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type content_type = "image/" + image_type
try: try:
response = self.client.media.upload("image", (filename, image_storage, content_type)) response = self.client.media.upload("image", (filename, image_storage, content_type))
@@ -188,7 +189,7 @@ class WechatMPChannel(ChatChannel):
image_storage = reply.content image_storage = reply.content
image_storage.seek(0) image_storage.seek(0)
image_type = imghdr.what(image_storage) image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type content_type = "image/" + image_type
try: try:
response = self.client.media.upload("image", (filename, image_storage, content_type)) response = self.client.media.upload("image", (filename, image_storage, content_type))
@@ -201,20 +202,12 @@ class WechatMPChannel(ChatChannel):
return return


def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
logger.debug(
"[wechatmp] Success to generate reply, msgId={}".format(
context["msg"].msg_id
)
)
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
if self.passive_reply: if self.passive_reply:
self.running.remove(session_id) self.running.remove(session_id)


def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
logger.exception(
"[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(
context["msg"].msg_id, exception
)
)
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
if self.passive_reply: if self.passive_reply:
assert session_id not in self.cache_dict assert session_id not in self.cache_dict
self.running.remove(session_id) self.running.remove(session_id)

+ 10
- 11
channel/wechatmp/wechatmp_client.py View File

@@ -1,17 +1,16 @@
import time
import threading import threading
from channel.wechatmp.common import *
import time

from wechatpy.client import WeChatClient from wechatpy.client import WeChatClient
from common.log import logger
from wechatpy.exceptions import APILimitedException from wechatpy.exceptions import APILimitedException


from channel.wechatmp.common import *
from common.log import logger



class WechatMPClient(WeChatClient): class WechatMPClient(WeChatClient):
def __init__(self, appid, secret, access_token=None,
session=None, timeout=None, auto_retry=True):
super(WechatMPClient, self).__init__(
appid, secret, access_token, session, timeout, auto_retry
)
def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
self.fetch_access_token_lock = threading.Lock() self.fetch_access_token_lock = threading.Lock()


def clear_quota(self): def clear_quota(self):
@@ -20,7 +19,7 @@ class WechatMPClient(WeChatClient):
def clear_quota_v2(self): def clear_quota_v2(self):
return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret}) return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})


def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
with self.fetch_access_token_lock: with self.fetch_access_token_lock:
access_token = self.session.get(self.access_token_key) access_token = self.session.get(self.access_token_key)
if access_token: if access_token:
@@ -31,11 +30,11 @@ class WechatMPClient(WeChatClient):
return access_token return access_token
return super().fetch_access_token() return super().fetch_access_token()


def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
try: try:
return super()._request(method, url_or_endpoint, **kwargs) return super()._request(method, url_or_endpoint, **kwargs)
except APILimitedException as e: except APILimitedException as e:
logger.error("[wechatmp] API quata has been used up. {}".format(e)) logger.error("[wechatmp] API quata has been used up. {}".format(e))
response = self.clear_quota_v2() response = self.clear_quota_v2()
logger.debug("[wechatmp] API quata has been cleard, {}".format(response)) logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
return super()._request(method, url_or_endpoint, **kwargs)
return super()._request(method, url_or_endpoint, **kwargs)

+ 5
- 14
channel/wechatmp/wechatmp_message.py View File

@@ -6,7 +6,6 @@ from common.log import logger
from common.tmp_dir import TmpDir from common.tmp_dir import TmpDir





class WeChatMPMessage(ChatMessage): class WeChatMPMessage(ChatMessage):
def __init__(self, msg, client=None): def __init__(self, msg, client=None):
super().__init__(msg) super().__init__(msg)
@@ -18,12 +17,9 @@ class WeChatMPMessage(ChatMessage):
self.ctype = ContextType.TEXT self.ctype = ContextType.TEXT
self.content = msg.content self.content = msg.content
elif msg.type == "voice": elif msg.type == "voice":
if msg.recognition == None: if msg.recognition == None:
self.ctype = ContextType.VOICE self.ctype = ContextType.VOICE
self.content = (
TmpDir().path() + msg.media_id + "." + msg.format
) # content直接存临时目录路径
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径


def download_voice(): def download_voice():
# 如果响应状态码是200,则将响应内容写入本地文件 # 如果响应状态码是200,则将响应内容写入本地文件
@@ -32,9 +28,7 @@ class WeChatMPMessage(ChatMessage):
with open(self.content, "wb") as f: with open(self.content, "wb") as f:
f.write(response.content) f.write(response.content)
else: else:
logger.info(
f"[wechatmp] Failed to download voice file, {response.content}"
)
logger.info(f"[wechatmp] Failed to download voice file, {response.content}")


self._prepare_fn = download_voice self._prepare_fn = download_voice
else: else:
@@ -43,6 +37,7 @@ class WeChatMPMessage(ChatMessage):
elif msg.type == "image": elif msg.type == "image":
self.ctype = ContextType.IMAGE self.ctype = ContextType.IMAGE
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径 self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径

def download_image(): def download_image():
# 如果响应状态码是200,则将响应内容写入本地文件 # 如果响应状态码是200,则将响应内容写入本地文件
response = client.media.download(msg.media_id) response = client.media.download(msg.media_id)
@@ -50,15 +45,11 @@ class WeChatMPMessage(ChatMessage):
with open(self.content, "wb") as f: with open(self.content, "wb") as f:
f.write(response.content) f.write(response.content)
else: else:
logger.info(
f"[wechatmp] Failed to download image file, {response.content}"
)
logger.info(f"[wechatmp] Failed to download image file, {response.content}")


self._prepare_fn = download_image self._prepare_fn = download_image
else: else:
raise NotImplementedError(
"Unsupported message type: Type:{} ".format(msg.type)
)
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))


self.from_user_id = msg.source self.from_user_id = msg.source
self.to_user_id = msg.target self.to_user_id = msg.target


+ 3
- 11
common/time_check.py View File

@@ -13,23 +13,15 @@ def time_checker(f):
if chat_time_module: if chat_time_module:
chat_start_time = _config.get("chat_start_time", "00:00") chat_start_time = _config.get("chat_start_time", "00:00")
chat_stopt_time = _config.get("chat_stop_time", "24:00") chat_stopt_time = _config.get("chat_stop_time", "24:00")
time_regex = re.compile(
r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$"
) # 时间匹配,包含24:00
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00


starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间


# 时间格式检查 # 时间格式检查
if not (
starttime_format_check and stoptime_format_check and chat_time_check
):
logger.warn(
"时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(
starttime_format_check, stoptime_format_check
)
)
if not (starttime_format_check and stoptime_format_check and chat_time_check):
logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
if chat_start_time > "23:59": if chat_start_time > "23:59":
logger.error("启动时间可能存在问题,请修改!") logger.error("启动时间可能存在问题,请修改!")




+ 1
- 3
config.py View File

@@ -158,9 +158,7 @@ def load_config():
for name, value in os.environ.items(): for name, value in os.environ.items():
name = name.lower() name = name.lower()
if name in available_setting: if name in available_setting:
logger.info(
"[INIT] override config by environ args: {}={}".format(name, value)
)
logger.info("[INIT] override config by environ args: {}={}".format(name, value))
try: try:
config[name] = eval(value) config[name] = eval(value)
except: except:


+ 3
- 9
plugins/banwords/banwords.py View File

@@ -50,9 +50,7 @@ class Banwords(Plugin):
self.reply_action = conf.get("reply_action", "ignore") self.reply_action = conf.get("reply_action", "ignore")
logger.info("[Banwords] inited") logger.info("[Banwords] inited")
except Exception as e: except Exception as e:
logger.warn(
"[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ."
)
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
raise e raise e


def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
@@ -72,9 +70,7 @@ class Banwords(Plugin):
return return
elif self.action == "replace": elif self.action == "replace":
if self.searchr.ContainsAny(content): if self.searchr.ContainsAny(content):
reply = Reply(
ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)
)
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content))
e_context["reply"] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
@@ -94,9 +90,7 @@ class Banwords(Plugin):
return return
elif self.reply_action == "replace": elif self.reply_action == "replace":
if self.searchr.ContainsAny(content): if self.searchr.ContainsAny(content):
reply = Reply(
ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)
)
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content))
e_context["reply"] = reply e_context["reply"] = reply
e_context.action = EventAction.CONTINUE e_context.action = EventAction.CONTINUE
return return


+ 7
- 32
plugins/bdunit/bdunit.py View File

@@ -76,9 +76,7 @@ class BDunit(Plugin):
Returns: Returns:
string: access_token string: access_token
""" """
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
self.api_key, self.secret_key
)
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
payload = "" payload = ""
headers = {"Content-Type": "application/json", "Accept": "application/json"} headers = {"Content-Type": "application/json", "Accept": "application/json"}


@@ -94,10 +92,7 @@ class BDunit(Plugin):
:returns: UNIT 解析结果。如果解析失败,返回 None :returns: UNIT 解析结果。如果解析失败,返回 None
""" """


url = (
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+ self.access_token
)
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
request = { request = {
"query": query, "query": query,
"user_id": str(get_mac())[:32], "user_id": str(get_mac())[:32],
@@ -124,10 +119,7 @@ class BDunit(Plugin):
:param query: 用户的指令字符串 :param query: 用户的指令字符串
:returns: UNIT 解析结果。如果解析失败,返回 None :returns: UNIT 解析结果。如果解析失败,返回 None
""" """
url = (
"https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token="
+ self.access_token
)
url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
request = {"query": query, "user_id": str(get_mac())[:32]} request = {"query": query, "user_id": str(get_mac())[:32]}
body = { body = {
"log_id": str(uuid.uuid1()), "log_id": str(uuid.uuid1()),
@@ -170,11 +162,7 @@ class BDunit(Plugin):
if parsed and "result" in parsed and "response_list" in parsed["result"]: if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"] response_list = parsed["result"]["response_list"]
for response in response_list: for response in response_list:
if (
"schema" in response
and "intent" in response["schema"]
and response["schema"]["intent"] == intent
):
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
return True return True
return False return False
else: else:
@@ -198,12 +186,7 @@ class BDunit(Plugin):
logger.warning(e) logger.warning(e)
return [] return []
for response in response_list: for response in response_list:
if (
"schema" in response
and "intent" in response["schema"]
and "slots" in response["schema"]
and response["schema"]["intent"] == intent
):
if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
return response["schema"]["slots"] return response["schema"]["slots"]
return [] return []
else: else:
@@ -239,11 +222,7 @@ class BDunit(Plugin):
if ( if (
"schema" in response "schema" in response
and "intent_confidence" in response["schema"] and "intent_confidence" in response["schema"]
and (
not answer
or response["schema"]["intent_confidence"]
> answer["schema"]["intent_confidence"]
)
and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
): ):
answer = response answer = response
return answer["action_list"][0]["say"] return answer["action_list"][0]["say"]
@@ -267,11 +246,7 @@ class BDunit(Plugin):
logger.warning(e) logger.warning(e)
return "" return ""
for response in response_list: for response in response_list:
if (
"schema" in response
and "intent" in response["schema"]
and response["schema"]["intent"] == intent
):
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
try: try:
return response["action_list"][0]["say"] return response["action_list"][0]["say"]
except Exception as e: except Exception as e:


+ 2
- 8
plugins/dungeon/dungeon.py View File

@@ -84,9 +84,7 @@ class Dungeon(Plugin):
if len(clist) > 1: if len(clist) > 1:
story = clist[1] story = clist[1]
else: else:
story = (
"你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
)
story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
self.games[sessionid] = StoryTeller(bot, sessionid, story) self.games[sessionid] = StoryTeller(bot, sessionid, story)
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
e_context["reply"] = reply e_context["reply"] = reply
@@ -102,11 +100,7 @@ class Dungeon(Plugin):
if kwargs.get("verbose") != True: if kwargs.get("verbose") != True:
return help_text return help_text
trigger_prefix = conf().get("plugin_trigger_prefix", "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = (
f"{trigger_prefix}开始冒险 "
+ "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"
+ f"{trigger_prefix}停止冒险: 结束游戏。\n"
)
help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n"
if kwargs.get("verbose") == True: if kwargs.get("verbose") == True:
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
return help_text return help_text

+ 7
- 23
plugins/godcmd/godcmd.py View File

@@ -140,9 +140,7 @@ def get_help_text(isadmin, isgroup):
if plugins[plugin].enabled and not plugins[plugin].hidden: if plugins[plugin].enabled and not plugins[plugin].hidden:
namecn = plugins[plugin].namecn namecn = plugins[plugin].namecn
help_text += "\n%s:" % namecn help_text += "\n%s:" % namecn
help_text += (
PluginManager().instances[plugin].get_help_text(verbose=False).strip()
)
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()


if ADMIN_COMMANDS and isadmin: if ADMIN_COMMANDS and isadmin:
help_text += "\n\n管理员指令:\n" help_text += "\n\n管理员指令:\n"
@@ -191,9 +189,7 @@ class Godcmd(Plugin):
COMMANDS["reset"]["alias"].append(custom_command) COMMANDS["reset"]["alias"].append(custom_command)


self.password = gconf["password"] self.password = gconf["password"]
self.admin_users = gconf[
"admin_users"
] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
self.isrunning = True # 机器人是否运行中 self.isrunning = True # 机器人是否运行中


self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
@@ -215,7 +211,7 @@ class Godcmd(Plugin):
reply.content = f"空指令,输入#help查看指令列表\n" reply.content = f"空指令,输入#help查看指令列表\n"
e_context["reply"] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return
return
# msg = e_context['context']['msg'] # msg = e_context['context']['msg']
channel = e_context["channel"] channel = e_context["channel"]
user = e_context["context"]["receiver"] user = e_context["context"]["receiver"]
@@ -248,11 +244,7 @@ class Godcmd(Plugin):
if not plugincls.enabled: if not plugincls.enabled:
continue continue
if query_name == name or query_name == plugincls.namecn: if query_name == name or query_name == plugincls.namecn:
ok, result = True, PluginManager().instances[
name
].get_help_text(
isgroup=isgroup, isadmin=isadmin, verbose=True
)
ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
break break
if not ok: if not ok:
result = "插件不存在或未启用" result = "插件不存在或未启用"
@@ -285,11 +277,7 @@ class Godcmd(Plugin):
if isgroup: if isgroup:
ok, result = False, "群聊不可执行管理员指令" ok, result = False, "群聊不可执行管理员指令"
else: else:
cmd = next(
c
for c, info in ADMIN_COMMANDS.items()
if cmd in info["alias"]
)
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"])
if cmd == "stop": if cmd == "stop":
self.isrunning = False self.isrunning = False
ok, result = True, "服务已暂停" ok, result = True, "服务已暂停"
@@ -325,18 +313,14 @@ class Godcmd(Plugin):
PluginManager().activate_plugins() PluginManager().activate_plugins()
if len(new_plugins) > 0: if len(new_plugins) > 0:
result += "\n发现新插件:\n" result += "\n发现新插件:\n"
result += "\n".join(
[f"{p.name}_v{p.version}" for p in new_plugins]
)
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
else: else:
result += ", 未发现新插件" result += ", 未发现新插件"
elif cmd == "setpri": elif cmd == "setpri":
if len(args) != 2: if len(args) != 2:
ok, result = False, "请提供插件名和优先级" ok, result = False, "请提供插件名和优先级"
else: else:
ok = PluginManager().set_plugin_priority(
args[0], int(args[1])
)
ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
if ok: if ok:
result = "插件" + args[0] + "优先级已设置为" + args[1] result = "插件" + args[0] + "优先级已设置为" + args[1]
else: else:


+ 2
- 6
plugins/hello/hello.py View File

@@ -33,9 +33,7 @@ class Hello(Plugin):
if e_context["context"].type == ContextType.JOIN_GROUP: if e_context["context"].type == ContextType.JOIN_GROUP:
e_context["context"].type = ContextType.TEXT e_context["context"].type = ContextType.TEXT
msg: ChatMessage = e_context["context"]["msg"] msg: ChatMessage = e_context["context"]["msg"]
e_context[
"context"
].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
return return


@@ -53,9 +51,7 @@ class Hello(Plugin):
reply.type = ReplyType.TEXT reply.type = ReplyType.TEXT
msg: ChatMessage = e_context["context"]["msg"] msg: ChatMessage = e_context["context"]["msg"]
if e_context["context"]["isgroup"]: if e_context["context"]["isgroup"]:
reply.content = (
f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
)
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
else: else:
reply.content = f"Hello, {msg.from_user_nickname}" reply.content = f"Hello, {msg.from_user_nickname}"
e_context["reply"] = reply e_context["reply"] = reply


+ 2
- 2
plugins/keyword/config.json.template View File

@@ -1,5 +1,5 @@
{ {
"keyword": { "keyword": {
"关键字匹配": "测试成功"
"关键字匹配": "测试成功"
} }
}
}

+ 1
- 3
plugins/keyword/keyword.py View File

@@ -41,9 +41,7 @@ class Keyword(Plugin):
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[keyword] inited.") logger.info("[keyword] inited.")
except Exception as e: except Exception as e:
logger.warn(
"[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword ."
)
logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .")
raise e raise e


def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):


+ 14
- 46
plugins/plugin_manager.py View File

@@ -31,23 +31,14 @@ class PluginManager:
plugincls.desc = kwargs.get("desc") plugincls.desc = kwargs.get("desc")
plugincls.author = kwargs.get("author") plugincls.author = kwargs.get("author")
plugincls.path = self.current_plugin_path plugincls.path = self.current_plugin_path
plugincls.version = (
kwargs.get("version") if kwargs.get("version") != None else "1.0"
)
plugincls.namecn = (
kwargs.get("namecn") if kwargs.get("namecn") != None else name
)
plugincls.hidden = (
kwargs.get("hidden") if kwargs.get("hidden") != None else False
)
plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0"
plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name
plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False
plugincls.enabled = True plugincls.enabled = True
if self.current_plugin_path == None: if self.current_plugin_path == None:
raise Exception("Plugin path not set") raise Exception("Plugin path not set")
self.plugins[name.upper()] = plugincls self.plugins[name.upper()] = plugincls
logger.info(
"Plugin %s_v%s registered, path=%s"
% (name, plugincls.version, plugincls.path)
)
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))


return wrapper return wrapper


@@ -62,9 +53,7 @@ class PluginManager:
if os.path.exists("./plugins/plugins.json"): if os.path.exists("./plugins/plugins.json"):
with open("./plugins/plugins.json", "r", encoding="utf-8") as f: with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
pconf = json.load(f) pconf = json.load(f)
pconf["plugins"] = SortedDict(
lambda k, v: v["priority"], pconf["plugins"], reverse=True
)
pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True)
else: else:
modified = True modified = True
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)} pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
@@ -90,26 +79,16 @@ class PluginManager:
if plugin_path in self.loaded: if plugin_path in self.loaded:
if self.loaded[plugin_path] == None: if self.loaded[plugin_path] == None:
logger.info("reload module %s" % plugin_name) logger.info("reload module %s" % plugin_name)
self.loaded[plugin_path] = importlib.reload(
sys.modules[import_path]
)
dependent_module_names = [
name
for name in sys.modules.keys()
if name.startswith(import_path + ".")
]
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
for name in dependent_module_names: for name in dependent_module_names:
logger.info("reload module %s" % name) logger.info("reload module %s" % name)
importlib.reload(sys.modules[name]) importlib.reload(sys.modules[name])
else: else:
self.loaded[plugin_path] = importlib.import_module(
import_path
)
self.loaded[plugin_path] = importlib.import_module(import_path)
self.current_plugin_path = None self.current_plugin_path = None
except Exception as e: except Exception as e:
logger.exception(
"Failed to import plugin %s: %s" % (plugin_name, e)
)
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
continue continue
pconf = self.pconf pconf = self.pconf
news = [self.plugins[name] for name in self.plugins] news = [self.plugins[name] for name in self.plugins]
@@ -119,9 +98,7 @@ class PluginManager:
rawname = plugincls.name rawname = plugincls.name
if rawname not in pconf["plugins"]: if rawname not in pconf["plugins"]:
modified = True modified = True
logger.info(
"Plugin %s not found in pconfig, adding to pconfig..." % name
)
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
pconf["plugins"][rawname] = { pconf["plugins"][rawname] = {
"enabled": plugincls.enabled, "enabled": plugincls.enabled,
"priority": plugincls.priority, "priority": plugincls.priority,
@@ -136,9 +113,7 @@ class PluginManager:


def refresh_order(self): def refresh_order(self):
for event in self.listening_plugins.keys(): for event in self.listening_plugins.keys():
self.listening_plugins[event].sort(
key=lambda name: self.plugins[name].priority, reverse=True
)
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)


def activate_plugins(self): # 生成新开启的插件实例 def activate_plugins(self): # 生成新开启的插件实例
failed_plugins = [] failed_plugins = []
@@ -184,13 +159,8 @@ class PluginManager:
def emit_event(self, e_context: EventContext, *args, **kwargs): def emit_event(self, e_context: EventContext, *args, **kwargs):
if e_context.event in self.listening_plugins: if e_context.event in self.listening_plugins:
for name in self.listening_plugins[e_context.event]: for name in self.listening_plugins[e_context.event]:
if (
self.plugins[name].enabled
and e_context.action == EventAction.CONTINUE
):
logger.debug(
"Plugin %s triggered by event %s" % (name, e_context.event)
)
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
logger.debug("Plugin %s triggered by event %s" % (name, e_context.event))
instance = self.instances[name] instance = self.instances[name]
instance.handlers[e_context.event](e_context, *args, **kwargs) instance.handlers[e_context.event](e_context, *args, **kwargs)
return e_context return e_context
@@ -262,9 +232,7 @@ class PluginManager:
source = json.load(f) source = json.load(f)
if repo in source["repo"]: if repo in source["repo"]:
repo = source["repo"][repo]["url"] repo = source["repo"][repo]["url"]
match = re.match(
r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo
)
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
if not match: if not match:
return False, "安装插件失败,source中的仓库地址不合法" return False, "安装插件失败,source中的仓库地址不合法"
else: else:


+ 7
- 24
plugins/role/role.py View File

@@ -69,13 +69,9 @@ class Role(Plugin):
logger.info("[Role] inited") logger.info("[Role] inited")
except Exception as e: except Exception as e:
if isinstance(e, FileNotFoundError): if isinstance(e, FileNotFoundError):
logger.warn(
f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
)
logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
else: else:
logger.warn(
"[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
)
logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
raise e raise e


def get_role(self, name, find_closest=True, min_sim=0.35): def get_role(self, name, find_closest=True, min_sim=0.35):
@@ -143,9 +139,7 @@ class Role(Plugin):
else: else:
help_text = f"未知角色类型。\n" help_text = f"未知角色类型。\n"
help_text += "目前的角色类型有: \n" help_text += "目前的角色类型有: \n"
help_text += (
",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
)
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
else: else:
help_text = f"请输入角色类型。\n" help_text = f"请输入角色类型。\n"
help_text += "目前的角色类型有: \n" help_text += "目前的角色类型有: \n"
@@ -158,9 +152,7 @@ class Role(Plugin):
return return
logger.debug("[Role] on_handle_context. content: %s" % content) logger.debug("[Role] on_handle_context. content: %s" % content)
if desckey is not None: if desckey is not None:
if len(clist) == 1 or (
len(clist) > 1 and clist[1].lower() in ["help", "帮助"]
):
if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
e_context["reply"] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
@@ -178,9 +170,7 @@ class Role(Plugin):
self.roles[role][desckey], self.roles[role][desckey],
self.roles[role].get("wrapper", "%s"), self.roles[role].get("wrapper", "%s"),
) )
reply = Reply(
ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]
)
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey])
e_context["reply"] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
elif customize == True: elif customize == True:
@@ -199,17 +189,10 @@ class Role(Plugin):
if not verbose: if not verbose:
return help_text return help_text
trigger_prefix = conf().get("plugin_trigger_prefix", "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = (
f"使用方法:\n{trigger_prefix}角色"
+ " 预设角色名: 设定角色为{预设角色名}。\n"
+ f"{trigger_prefix}role"
+ " 预设角色名: 同上,但使用英文设定。\n"
)
help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n"
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n" help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
help_text += (
f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
)
help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
help_text += "\n目前的角色类型有: \n" help_text += "\n目前的角色类型有: \n"
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n" help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n"
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"


+ 1
- 1
plugins/tool/README.md View File

@@ -60,7 +60,7 @@


> 该tool每天返回内容相同 > 该tool每天返回内容相同


#### 6.3. finance-news
#### 6.3. finance-news
###### 获取实时的金融财政新闻 ###### 获取实时的金融财政新闻


> 该工具需要解决browser tool 的google-chrome依赖安装 > 该工具需要解决browser tool 的google-chrome依赖安装


+ 4
- 10
plugins/tool/tool.py View File

@@ -82,9 +82,7 @@ class Tool(Plugin):
return return
elif content_list[1].startswith("reset"): elif content_list[1].startswith("reset"):
logger.debug("[tool]: remind") logger.debug("[tool]: remind")
e_context[
"context"
].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
e_context["context"].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"


e_context.action = EventAction.BREAK e_context.action = EventAction.BREAK
return return
@@ -93,18 +91,14 @@ class Tool(Plugin):


# Don't modify bot name # Don't modify bot name
all_sessions = Bridge().get_bot("chat").sessions all_sessions = Bridge().get_bot("chat").sessions
user_session = all_sessions.session_query(
query, e_context["context"]["session_id"]
).messages
user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages


# chatgpt-tool-hub will reply you with many tools # chatgpt-tool-hub will reply you with many tools
logger.debug("[tool]: just-go") logger.debug("[tool]: just-go")
try: try:
_reply = self.app.ask(query, user_session) _reply = self.app.ask(query, user_session)
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
all_sessions.session_reply(
_reply, e_context["context"]["session_id"]
)
all_sessions.session_reply(_reply, e_context["context"]["session_id"])
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
logger.error(str(e)) logger.error(str(e))
@@ -178,4 +172,4 @@ class Tool(Plugin):
# filter not support tool # filter not support tool
tool_list = self._filter_tool_list(tool_config.get("tools", [])) tool_list = self._filter_tool_list(tool_config.get("tools", []))


return app.create_app(tools_list=tool_list, **app_kwargs)
return app.create_app(tools_list=tool_list, **app_kwargs)

+ 5
- 15
voice/audio_convert.py View File

@@ -33,6 +33,7 @@ def get_pcm_from_wav(wav_path):
wav = wave.open(wav_path, "rb") wav = wave.open(wav_path, "rb")
return wav.readframes(wav.getnframes()) return wav.readframes(wav.getnframes())



def any_to_mp3(any_path, mp3_path): def any_to_mp3(any_path, mp3_path):
""" """
把任意格式转成mp3文件 把任意格式转成mp3文件
@@ -40,16 +41,13 @@ def any_to_mp3(any_path, mp3_path):
if any_path.endswith(".mp3"): if any_path.endswith(".mp3"):
shutil.copy2(any_path, mp3_path) shutil.copy2(any_path, mp3_path)
return return
if (
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
sil_to_wav(any_path, any_path) sil_to_wav(any_path, any_path)
any_path = mp3_path any_path = mp3_path
audio = AudioSegment.from_file(any_path) audio = AudioSegment.from_file(any_path)
audio.export(mp3_path, format="mp3") audio.export(mp3_path, format="mp3")



def any_to_wav(any_path, wav_path): def any_to_wav(any_path, wav_path):
""" """
把任意格式转成wav文件 把任意格式转成wav文件
@@ -57,11 +55,7 @@ def any_to_wav(any_path, wav_path):
if any_path.endswith(".wav"): if any_path.endswith(".wav"):
shutil.copy2(any_path, wav_path) shutil.copy2(any_path, wav_path)
return return
if (
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
return sil_to_wav(any_path, wav_path) return sil_to_wav(any_path, wav_path)
audio = AudioSegment.from_file(any_path) audio = AudioSegment.from_file(any_path)
audio.export(wav_path, format="wav") audio.export(wav_path, format="wav")
@@ -71,11 +65,7 @@ def any_to_sil(any_path, sil_path):
""" """
把任意格式转成sil文件 把任意格式转成sil文件
""" """
if (
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
shutil.copy2(any_path, sil_path) shutil.copy2(any_path, sil_path)
return 10000 return 10000
audio = AudioSegment.from_file(any_path) audio = AudioSegment.from_file(any_path)


+ 9
- 33
voice/azure/azure_voice.py View File

@@ -40,57 +40,33 @@ class AzureVoice(Voice):
config = json.load(fr) config = json.load(fr)
self.api_key = conf().get("azure_voice_api_key") self.api_key = conf().get("azure_voice_api_key")
self.api_region = conf().get("azure_voice_region") self.api_region = conf().get("azure_voice_region")
self.speech_config = speechsdk.SpeechConfig(
subscription=self.api_key, region=self.api_region
)
self.speech_config.speech_synthesis_voice_name = config[
"speech_synthesis_voice_name"
]
self.speech_config.speech_recognition_language = config[
"speech_recognition_language"
]
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
self.speech_config.speech_recognition_language = config["speech_recognition_language"]
except Exception as e: except Exception as e:
logger.warn("AzureVoice init failed: %s, ignore " % e) logger.warn("AzureVoice init failed: %s, ignore " % e)


def voiceToText(self, voice_file): def voiceToText(self, voice_file):
audio_config = speechsdk.AudioConfig(filename=voice_file) audio_config = speechsdk.AudioConfig(filename=voice_file)
speech_recognizer = speechsdk.SpeechRecognizer(
speech_config=self.speech_config, audio_config=audio_config
)
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_recognizer.recognize_once() result = speech_recognizer.recognize_once()
if result.reason == speechsdk.ResultReason.RecognizedSpeech: if result.reason == speechsdk.ResultReason.RecognizedSpeech:
logger.info(
"[Azure] voiceToText voice file name={} text={}".format(
voice_file, result.text
)
)
logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text))
reply = Reply(ReplyType.TEXT, result.text) reply = Reply(ReplyType.TEXT, result.text)
else: else:
logger.error(
"[Azure] voiceToText error, result={}, canceldetails={}".format(
result, result.cancellation_details
)
)
logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details))
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
return reply return reply


def textToVoice(self, text): def textToVoice(self, text):
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
audio_config = speechsdk.AudioConfig(filename=fileName) audio_config = speechsdk.AudioConfig(filename=fileName)
speech_synthesizer = speechsdk.SpeechSynthesizer(
speech_config=self.speech_config, audio_config=audio_config
)
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_synthesizer.speak_text(text) result = speech_synthesizer.speak_text(text)
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
logger.info(
"[Azure] textToVoice text={} voice file name={}".format(text, fileName)
)
logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName))
reply = Reply(ReplyType.VOICE, fileName) reply = Reply(ReplyType.VOICE, fileName)
else: else:
logger.error(
"[Azure] textToVoice error, result={}, canceldetails={}".format(
result, result.cancellation_details
)
)
logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details))
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
return reply return reply

+ 1
- 3
voice/baidu/baidu_voice.py View File

@@ -85,9 +85,7 @@ class BaiduVoice(Voice):
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
with open(fileName, "wb") as f: with open(fileName, "wb") as f:
f.write(result) f.write(result)
logger.info(
"[Baidu] textToVoice text={} voice file name={}".format(text, fileName)
)
logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName))
reply = Reply(ReplyType.VOICE, fileName) reply = Reply(ReplyType.VOICE, fileName)
else: else:
logger.error("[Baidu] textToVoice error={}".format(result)) logger.error("[Baidu] textToVoice error={}".format(result))


+ 2
- 8
voice/google/google_voice.py View File

@@ -24,11 +24,7 @@ class GoogleVoice(Voice):
audio = self.recognizer.record(source) audio = self.recognizer.record(source)
try: try:
text = self.recognizer.recognize_google(audio, language="zh-CN") text = self.recognizer.recognize_google(audio, language="zh-CN")
logger.info(
"[Google] voiceToText text={} voice file name={}".format(
text, voice_file
)
)
logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file))
reply = Reply(ReplyType.TEXT, text) reply = Reply(ReplyType.TEXT, text)
except speech_recognition.UnknownValueError: except speech_recognition.UnknownValueError:
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
@@ -42,9 +38,7 @@ class GoogleVoice(Voice):
mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
tts = gTTS(text=text, lang="zh") tts = gTTS(text=text, lang="zh")
tts.save(mp3File) tts.save(mp3File)
logger.info(
"[Google] textToVoice text={} voice file name={}".format(text, mp3File)
)
logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File))
reply = Reply(ReplyType.VOICE, mp3File) reply = Reply(ReplyType.VOICE, mp3File)
except Exception as e: except Exception as e:
reply = Reply(ReplyType.ERROR, str(e)) reply = Reply(ReplyType.ERROR, str(e))


+ 1
- 5
voice/openai/openai_voice.py View File

@@ -22,11 +22,7 @@ class OpenaiVoice(Voice):
result = openai.Audio.transcribe("whisper-1", file) result = openai.Audio.transcribe("whisper-1", file)
text = result["text"] text = result["text"]
reply = Reply(ReplyType.TEXT, text) reply = Reply(ReplyType.TEXT, text)
logger.info(
"[Openai] voiceToText text={} voice file name={}".format(
text, voice_file
)
)
logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
except Exception as e: except Exception as e:
reply = Reply(ReplyType.ERROR, str(e)) reply = Reply(ReplyType.ERROR, str(e))
finally: finally:


+ 7
- 5
voice/pytts/pytts_voice.py View File

@@ -5,6 +5,7 @@ pytts voice service (offline)
import os import os
import sys import sys
import time import time

import pyttsx3 import pyttsx3


from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
@@ -12,6 +13,7 @@ from common.log import logger
from common.tmp_dir import TmpDir from common.tmp_dir import TmpDir
from voice.voice import Voice from voice.voice import Voice



class PyttsVoice(Voice): class PyttsVoice(Voice):
engine = pyttsx3.init() engine = pyttsx3.init()


@@ -20,7 +22,7 @@ class PyttsVoice(Voice):
self.engine.setProperty("rate", 125) self.engine.setProperty("rate", 125)
# 音量 # 音量
self.engine.setProperty("volume", 1.0) self.engine.setProperty("volume", 1.0)
if sys.platform == 'win32':
if sys.platform == "win32":
for voice in self.engine.getProperty("voices"): for voice in self.engine.getProperty("voices"):
if "Chinese" in voice.name: if "Chinese" in voice.name:
self.engine.setProperty("voice", voice.id) self.engine.setProperty("voice", voice.id)
@@ -33,23 +35,23 @@ class PyttsVoice(Voice):
def textToVoice(self, text): def textToVoice(self, text):
try: try:
# avoid the same filename # avoid the same filename
wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7fffffff) + ".wav"
wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
wavFile = TmpDir().path() + wavFileName wavFile = TmpDir().path() + wavFileName
logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)) logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile))


self.engine.save_to_file(text, wavFile) self.engine.save_to_file(text, wavFile)


if sys.platform == 'win32':
if sys.platform == "win32":
self.engine.runAndWait() self.engine.runAndWait()
else: else:
# In ubuntu, runAndWait do not really wait until the file created.
# In ubuntu, runAndWait do not really wait until the file created.
# It will return once the task queue is empty, but the task is still running in coroutine. # It will return once the task queue is empty, but the task is still running in coroutine.
# And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this. # And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this.
# If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file. # If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file.
# self.engine.runAndWait() # self.engine.runAndWait()


# Before espeak fix this problem, we iterate the generator and control the waiting by ourself. # Before espeak fix this problem, we iterate the generator and control the waiting by ourself.
# But this is not the canonical way to use it, for example if the file already exists it also cannot wait.
# But this is not the canonical way to use it, for example if the file already exists it also cannot wait.
self.engine.iterate() self.engine.iterate()
while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()): while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()):
time.sleep(0.1) time.sleep(0.1)


Loading…
Cancel
Save