Browse Source

formatting: run precommit on all files

master
lanvent 1 year ago
parent
commit
618c94edb8
40 changed files with 228 additions and 646 deletions
  1. +1
    -0
      app.py
  2. +2
    -10
      bot/baidu/baidu_unit_bot.py
  3. +7
    -22
      bot/chatgpt/chat_gpt_bot.py
  4. +5
    -17
      bot/chatgpt/chat_gpt_session.py
  5. +7
    -21
      bot/openai/open_ai_bot.py
  6. +2
    -8
      bot/openai/open_ai_image.py
  7. +3
    -13
      bot/openai/open_ai_session.py
  8. +4
    -16
      bot/session_manager.py
  9. +1
    -3
      bridge/context.py
  10. +28
    -103
      channel/chat_channel.py
  11. +1
    -3
      channel/terminal/terminal_channel.py
  12. +8
    -31
      channel/wechat/wechat_channel.py
  13. +6
    -20
      channel/wechat/wechat_message.py
  14. +5
    -15
      channel/wechat/wechaty_channel.py
  15. +3
    -9
      channel/wechat/wechaty_message.py
  16. +8
    -17
      channel/wechatmp/active_reply.py
  17. +5
    -3
      channel/wechatmp/common.py
  18. +18
    -36
      channel/wechatmp/passive_reply.py
  19. +20
    -27
      channel/wechatmp/wechatmp_channel.py
  20. +10
    -11
      channel/wechatmp/wechatmp_client.py
  21. +5
    -14
      channel/wechatmp/wechatmp_message.py
  22. +3
    -11
      common/time_check.py
  23. +1
    -3
      config.py
  24. +3
    -9
      plugins/banwords/banwords.py
  25. +7
    -32
      plugins/bdunit/bdunit.py
  26. +2
    -8
      plugins/dungeon/dungeon.py
  27. +7
    -23
      plugins/godcmd/godcmd.py
  28. +2
    -6
      plugins/hello/hello.py
  29. +2
    -2
      plugins/keyword/config.json.template
  30. +1
    -3
      plugins/keyword/keyword.py
  31. +14
    -46
      plugins/plugin_manager.py
  32. +7
    -24
      plugins/role/role.py
  33. +1
    -1
      plugins/tool/README.md
  34. +4
    -10
      plugins/tool/tool.py
  35. +5
    -15
      voice/audio_convert.py
  36. +9
    -33
      voice/azure/azure_voice.py
  37. +1
    -3
      voice/baidu/baidu_voice.py
  38. +2
    -8
      voice/google/google_voice.py
  39. +1
    -5
      voice/openai/openai_voice.py
  40. +7
    -5
      voice/pytts/pytts_voice.py

+ 1
- 0
app.py View File

@@ -19,6 +19,7 @@ def sigterm_handler_wrap(_signo):
if callable(old_handler): # check old_handler
return old_handler(_signo, _stack_frame)
sys.exit(0)

signal.signal(_signo, func)




+ 2
- 10
bot/baidu/baidu_unit_bot.py View File

@@ -10,10 +10,7 @@ from bridge.reply import Reply, ReplyType
class BaiduUnitBot(Bot):
def reply(self, query, context=None):
token = self.get_token()
url = (
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+ token
)
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
post_data = (
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
+ query
@@ -32,12 +29,7 @@ class BaiduUnitBot(Bot):
def get_token(self):
access_key = "YOUR_ACCESS_KEY"
secret_key = "YOUR_SECRET_KEY"
host = (
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
+ access_key
+ "&client_secret="
+ secret_key
)
host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
response = requests.get(host)
if response:
print(response.json())


+ 7
- 22
bot/chatgpt/chat_gpt_bot.py View File

@@ -30,23 +30,15 @@ class ChatGPTBot(Bot, OpenAIImage):
if conf().get("rate_limit_chatgpt"):
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))

self.sessions = SessionManager(
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
)
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = {
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数
"top_p": 1,
"frequency_penalty": conf().get(
"frequency_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
}

@@ -87,15 +79,10 @@ class ChatGPTBot(Bot, OpenAIImage):
reply_content["completion_tokens"],
)
)
if (
reply_content["completion_tokens"] == 0
and len(reply_content["content"]) > 0
):
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0:
self.sessions.session_reply(
reply_content["content"], session_id, reply_content["total_tokens"]
)
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"])
else:
reply = Reply(ReplyType.ERROR, reply_content["content"])
@@ -126,9 +113,7 @@ class ChatGPTBot(Bot, OpenAIImage):
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create(
api_key=api_key, messages=session.messages, **self.args
)
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {
"total_tokens": response["usage"]["total_tokens"],


+ 5
- 17
bot/chatgpt/chat_gpt_session.py View File

@@ -25,9 +25,7 @@ class ChatGPTSession(Session):
precise = False
if cur_tokens is None:
raise e
logger.debug(
"Exception when counting tokens precisely for query: {}".format(e)
)
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 2:
self.messages.pop(1)
@@ -39,16 +37,10 @@ class ChatGPTSession(Session):
cur_tokens = cur_tokens - max_tokens
break
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
logger.warn(
"user message exceed max_tokens. total_tokens={}".format(cur_tokens)
)
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
logger.debug(
"max_tokens={}, total_tokens={}, len(messages)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
@@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model):
elif model == "gpt-4":
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
logger.warn(
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
)
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
num_tokens = 0
for message in messages:


+ 7
- 21
bot/openai/open_ai_bot.py View File

@@ -28,23 +28,15 @@ class OpenAIBot(Bot, OpenAIImage):
if proxy:
openai.proxy = proxy

self.sessions = SessionManager(
OpenAISession, model=conf().get("model") or "text-davinci-003"
)
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
self.args = {
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"max_tokens": 1200, # 回复最大的字符数
"top_p": 1,
"frequency_penalty": conf().get(
"frequency_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
"stop": ["\n\n\n"],
}
@@ -71,17 +63,13 @@ class OpenAIBot(Bot, OpenAIImage):
result["content"],
)
logger.debug(
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
str(session), session_id, reply_content, completion_tokens
)
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
)

if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content)
else:
self.sessions.session_reply(
reply_content, session_id, total_tokens
)
self.sessions.session_reply(reply_content, session_id, total_tokens)
reply = Reply(ReplyType.TEXT, reply_content)
return reply
elif context.type == ContextType.IMAGE_CREATE:
@@ -96,9 +84,7 @@ class OpenAIBot(Bot, OpenAIImage):
def reply_text(self, session: OpenAISession, retry_count=0):
try:
response = openai.Completion.create(prompt=str(session), **self.args)
res_content = (
response.choices[0]["text"].strip().replace("<|endoftext|>", "")
)
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
total_tokens = response["usage"]["total_tokens"]
completion_tokens = response["usage"]["completion_tokens"]
logger.info("[OPEN_AI] reply={}".format(res_content))


+ 2
- 8
bot/openai/open_ai_image.py View File

@@ -23,9 +23,7 @@ class OpenAIImage(object):
response = openai.Image.create(
prompt=query, # 图片描述
n=1, # 每次生成图片的数量
size=conf().get(
"image_create_size", "256x256"
), # 图片大小,可选有 256x256, 512x512, 1024x1024
size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
)
image_url = response["data"][0]["url"]
logger.info("[OPEN_AI] image_url={}".format(image_url))
@@ -34,11 +32,7 @@ class OpenAIImage(object):
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn(
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
retry_count + 1
)
)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
return self.create_img(query, retry_count + 1)
else:
return False, "提问太快啦,请休息一下再问我吧"


+ 3
- 13
bot/openai/open_ai_session.py View File

@@ -36,9 +36,7 @@ class OpenAISession(Session):
precise = False
if cur_tokens is None:
raise e
logger.debug(
"Exception when counting tokens precisely for query: {}".format(e)
)
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 1:
self.messages.pop(0)
@@ -50,18 +48,10 @@ class OpenAISession(Session):
cur_tokens = len(str(self))
break
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
logger.warn(
"user question exceed max_tokens. total_tokens={}".format(
cur_tokens
)
)
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
logger.debug(
"max_tokens={}, total_tokens={}, len(conversation)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()


+ 4
- 16
bot/session_manager.py View File

@@ -55,9 +55,7 @@ class SessionManager(object):
return self.sessioncls(session_id, system_prompt, **self.session_args)

if session_id not in self.sessions:
self.sessions[session_id] = self.sessioncls(
session_id, system_prompt, **self.session_args
)
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
self.sessions[session_id].set_system_prompt(system_prompt)
session = self.sessions[session_id]
@@ -71,9 +69,7 @@ class SessionManager(object):
total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e:
logger.debug(
"Exception when counting tokens precisely for prompt: {}".format(str(e))
)
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
return session

def session_reply(self, reply, session_id, total_tokens=None):
@@ -82,17 +78,9 @@ class SessionManager(object):
try:
max_tokens = conf().get("conversation_max_tokens", 1000)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.debug(
"raw total_tokens={}, savesession tokens={}".format(
total_tokens, tokens_cnt
)
)
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
except Exception as e:
logger.debug(
"Exception when counting tokens precisely for session: {}".format(
str(e)
)
)
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
return session

def clear_session(self, session_id):


+ 1
- 3
bridge/context.py View File

@@ -60,6 +60,4 @@ class Context:
del self.kwargs[key]

def __str__(self):
return "Context(type={}, content={}, kwargs={})".format(
self.type, self.content, self.kwargs
)
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)

+ 28
- 103
channel/chat_channel.py View File

@@ -53,9 +53,7 @@ class ChatChannel(Channel):
group_id = cmsg.other_user_id

group_name_white_list = config.get("group_name_white_list", [])
group_name_keyword_white_list = config.get(
"group_name_keyword_white_list", []
)
group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
if any(
[
group_name in group_name_white_list,
@@ -63,9 +61,7 @@ class ChatChannel(Channel):
check_contain(group_name, group_name_keyword_white_list),
]
):
group_chat_in_one_session = conf().get(
"group_chat_in_one_session", []
)
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
session_id = cmsg.actual_user_id
if any(
[
@@ -81,17 +77,11 @@ class ChatChannel(Channel):
else:
context["session_id"] = cmsg.other_user_id
context["receiver"] = cmsg.other_user_id
e_context = PluginManager().emit_event(
EventContext(
Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
)
)
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
context = e_context["context"]
if e_context.is_pass() or context is None:
return context
if cmsg.from_user_id == self.user_id and not config.get(
"trigger_by_self", True
):
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
logger.debug("[WX]self message skipped")
return None

@@ -119,19 +109,13 @@ class ChatChannel(Channel):

if not flag:
if context["origin_ctype"] == ContextType.VOICE:
logger.info(
"[WX]receive group voice, but checkprefix didn't match"
)
logger.info("[WX]receive group voice, but checkprefix didn't match")
return None
else: # 单聊
match_prefix = check_prefix(
content, conf().get("single_chat_prefix", [""])
)
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
content = content.replace(match_prefix, "", 1).strip()
elif (
context["origin_ctype"] == ContextType.VOICE
): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
pass
else:
return None
@@ -143,18 +127,10 @@ class ChatChannel(Channel):
else:
context.type = ContextType.TEXT
context.content = content.strip()
if (
"desire_rtype" not in context
and conf().get("always_reply_voice")
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context["desire_rtype"] = ReplyType.VOICE
elif context.type == ContextType.VOICE:
if (
"desire_rtype" not in context
and conf().get("voice_reply_voice")
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context["desire_rtype"] = ReplyType.VOICE

return context
@@ -182,15 +158,8 @@ class ChatChannel(Channel):
)
reply = e_context["reply"]
if not e_context.is_pass():
logger.debug(
"[WX] ready to handle context: type={}, content={}".format(
context.type, context.content
)
)
if (
context.type == ContextType.TEXT
or context.type == ContextType.IMAGE_CREATE
): # 文字和图片消息
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息
cmsg = context["msg"]
@@ -214,9 +183,7 @@ class ChatChannel(Channel):
# logger.warning("[WX]delete temp file error: " + str(e))

if reply.type == ReplyType.TEXT:
new_context = self._compose_context(
ContextType.TEXT, reply.content, **context.kwargs
)
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
if new_context:
reply = self._generate_reply(new_context)
else:
@@ -246,48 +213,24 @@ class ChatChannel(Channel):

if reply.type == ReplyType.TEXT:
reply_text = reply.content
if (
desire_rtype == ReplyType.VOICE
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply)
if context.get("isgroup", False):
reply_text = (
"@"
+ context["msg"].actual_user_nickname
+ " "
+ reply_text.strip()
)
reply_text = (
conf().get("group_chat_reply_prefix", "") + reply_text
)
reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
else:
reply_text = (
conf().get("single_chat_reply_prefix", "") + reply_text
)
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = "[" + str(reply.type) + "]\n" + reply.content
elif (
reply.type == ReplyType.IMAGE_URL
or reply.type == ReplyType.VOICE
or reply.type == ReplyType.IMAGE
):
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass
else:
logger.error("[WX] unknown reply type: {}".format(reply.type))
return
if (
desire_rtype
and desire_rtype != reply.type
and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
):
logger.warning(
"[WX] desire_rtype: {}, but reply type: {}".format(
context.get("desire_rtype"), reply.type
)
)
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
return reply

def _send_reply(self, context: Context, reply: Reply):
@@ -300,9 +243,7 @@ class ChatChannel(Channel):
)
reply = e_context["reply"]
if not e_context.is_pass() and reply and reply.type:
logger.debug(
"[WX] ready to send reply: {}, context: {}".format(reply, context)
)
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
self._send(reply, context)

def _send(self, reply: Reply, context: Context, retry_cnt=0):
@@ -328,9 +269,7 @@ class ChatChannel(Channel):
try:
worker_exception = worker.exception()
if worker_exception:
self._fail_callback(
session_id, exception=worker_exception, **kwargs
)
self._fail_callback(session_id, exception=worker_exception, **kwargs)
else:
self._success_callback(session_id, **kwargs)
except CancelledError as e:
@@ -366,24 +305,14 @@ class ChatChannel(Channel):
if not context_queue.empty():
context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context))
future: Future = self.handler_pool.submit(
self._handle, context
)
future.add_done_callback(
self._thread_pool_callback(session_id, context=context)
)
future: Future = self.handler_pool.submit(self._handle, context)
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
if session_id not in self.futures:
self.futures[session_id] = []
self.futures[session_id].append(future)
elif (
semaphore._initial_value == semaphore._value + 1
): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
self.futures[session_id] = [
t for t in self.futures[session_id] if not t.done()
]
assert (
len(self.futures[session_id]) == 0
), "thread pool error"
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
assert len(self.futures[session_id]) == 0, "thread pool error"
del self.sessions[session_id]
else:
semaphore.release()
@@ -397,9 +326,7 @@ class ChatChannel(Channel):
future.cancel()
cnt = self.sessions[session_id][0].qsize()
if cnt > 0:
logger.info(
"Cancel {} messages in session {}".format(cnt, session_id)
)
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue()

def cancel_all_session(self):
@@ -409,9 +336,7 @@ class ChatChannel(Channel):
future.cancel()
cnt = self.sessions[session_id][0].qsize()
if cnt > 0:
logger.info(
"Cancel {} messages in session {}".format(cnt, session_id)
)
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue()




+ 1
- 3
channel/terminal/terminal_channel.py View File

@@ -77,9 +77,7 @@ class TerminalChannel(ChatChannel):
if check_prefix(prompt, trigger_prefixs) is None:
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀

context = self._compose_context(
ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)
)
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
if context:
self.produce(context)
else:


+ 8
- 31
channel/wechat/wechat_channel.py View File

@@ -56,10 +56,7 @@ def _check(func):
return
self.receivedMsgs[msgId] = cmsg
create_time = cmsg.create_time # 消息时间戳
if (
conf().get("hot_reload") == True
and int(create_time) < int(time.time()) - 60
): # 跳过1分钟前的历史消息
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId))
return
return func(self, cmsg)
@@ -88,15 +85,9 @@ def qrCallback(uuid, status, qrcode):
url = f"https://login.weixin.qq.com/l/{uuid}"

qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
qr_api2 = (
"https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(
url
)
)
qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(
url
)
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
print("You can also scan QRCode in any website below:")
print(qr_api3)
print(qr_api4)
@@ -134,18 +125,12 @@ class WechatChannel(ChatChannel):
logger.error("Hot reload failed, try to login without hot reload")
itchat.logout()
os.remove(status_path)
itchat.auto_login(
enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback
)
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
else:
raise e
self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName
logger.info(
"Wechat login success, user_id: {}, nickname: {}".format(
self.user_id, self.name
)
)
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
# start message listener
itchat.run()

@@ -173,16 +158,10 @@ class WechatChannel(ChatChannel):
elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
logger.debug(
"[WX]receive text msg: {}, cmsg={}".format(
json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg
)
)
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
else:
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
context = self._compose_context(
cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg
)
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
if context:
self.produce(context)

@@ -202,9 +181,7 @@ class WechatChannel(ChatChannel):
pass
else:
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
context = self._compose_context(
cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg
)
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
if context:
self.produce(context)



+ 6
- 20
channel/wechat/wechat_message.py View File

@@ -27,37 +27,23 @@ class WeChatMessage(ChatMessage):
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
if is_group and (
"加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]
):
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
self.ctype = ContextType.JOIN_GROUP
self.content = itchat_msg["Content"]
# 这里只能得到nickname, actual_user_id还是机器人的id
if "加入了群聊" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall(
r"\"(.*?)\"", itchat_msg["Content"]
)[-1]
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
elif "加入群聊" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall(
r"\"(.*?)\"", itchat_msg["Content"]
)[0]
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
elif "拍了拍我" in itchat_msg["Content"]:
self.ctype = ContextType.PATPAT
self.content = itchat_msg["Content"]
if is_group:
self.actual_user_nickname = re.findall(
r"\"(.*?)\"", itchat_msg["Content"]
)[0]
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
else:
raise NotImplementedError(
"Unsupported note message: " + itchat_msg["Content"]
)
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
else:
raise NotImplementedError(
"Unsupported message type: Type:{} MsgType:{}".format(
itchat_msg["Type"], itchat_msg["MsgType"]
)
)
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))

self.from_user_id = itchat_msg["FromUserName"]
self.to_user_id = itchat_msg["ToUserName"]


+ 5
- 15
channel/wechat/wechaty_channel.py View File

@@ -60,13 +60,9 @@ class WechatyChannel(ChatChannel):
receiver_id = context["receiver"]
loop = asyncio.get_event_loop()
if context["isgroup"]:
receiver = asyncio.run_coroutine_threadsafe(
self.bot.Room.find(receiver_id), loop
).result()
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
else:
receiver = asyncio.run_coroutine_threadsafe(
self.bot.Contact.find(receiver_id), loop
).result()
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
msg = None
if reply.type == ReplyType.TEXT:
msg = reply.content
@@ -83,9 +79,7 @@ class WechatyChannel(ChatChannel):
voiceLength = int(any_to_sil(file_path, sil_file))
if voiceLength >= 60000:
voiceLength = 60000
logger.info(
"[WX] voice too long, length={}, set to 60s".format(voiceLength)
)
logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
# 发送语音
t = int(time.time())
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
@@ -98,9 +92,7 @@ class WechatyChannel(ChatChannel):
os.remove(sil_file)
except Exception as e:
pass
logger.info(
"[WX] sendVoice={}, receiver={}".format(reply.content, receiver)
)
logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
t = int(time.time())
@@ -111,9 +103,7 @@ class WechatyChannel(ChatChannel):
image_storage = reply.content
image_storage.seek(0)
t = int(time.time())
msg = FileBox.from_base64(
base64.b64encode(image_storage.read()), str(t) + ".png"
)
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendImage, receiver={}".format(receiver))



+ 3
- 9
channel/wechat/wechaty_message.py View File

@@ -45,16 +45,12 @@ class WechatyMessage(ChatMessage, aobject):

def func():
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(
voice_file.to_file(self.content), loop
).result()
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()

self._prepare_fn = func

else:
raise NotImplementedError(
"Unsupported message type: {}".format(wechaty_msg.type())
)
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))

from_contact = wechaty_msg.talker() # 获取消息的发送者
self.from_user_id = from_contact.contact_id
@@ -73,9 +69,7 @@ class WechatyMessage(ChatMessage, aobject):
self.to_user_id = to_contact.contact_id
self.to_user_nickname = to_contact.name

if (
self.is_group or wechaty_msg.is_self()
): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
self.other_user_id = self.to_user_id
self.other_user_nickname = self.to_user_nickname
else:


+ 8
- 17
channel/wechatmp/active_reply.py View File

@@ -1,16 +1,17 @@
import time

import web
from wechatpy import parse_message
from wechatpy.replies import create_reply

from channel.wechatmp.wechatmp_message import WeChatMPMessage
from bridge.context import *
from bridge.reply import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
from wechatpy import parse_message
from channel.wechatmp.wechatmp_message import WeChatMPMessage
from common.log import logger
from config import conf
from wechatpy.replies import create_reply

# This class is instantiated once per query
class Query:
@@ -50,29 +51,19 @@ class Query:
)
)
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
context = channel._compose_context(
wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg
)
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
else:
context = channel._compose_context(
wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg
)
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
if context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context["openai_api_key"] = user_data.get(
"openai_api_key"
) # None or user openai_api_key
context["openai_api_key"] = user_data.get("openai_api_key") # None or user openai_api_key
channel.produce(context)
# The reply will be sent by channel.send() in another thread
return "success"
elif msg.type == "event":
logger.info(
"[wechatmp] Event {} from {}".format(
msg.event, msg.source
)
)
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
if msg.event in ["subscribe", "subscribe_scan"]:
reply_text = subscribe_msg()
replyPost = create_reply(reply_text, msg)


+ 5
- 3
channel/wechatmp/common.py View File

@@ -1,10 +1,12 @@
import textwrap
import web

from config import conf
from wechatpy.utils import check_signature
import web
from wechatpy.crypto import WeChatCrypto
from wechatpy.exceptions import InvalidSignatureException
from wechatpy.utils import check_signature

from config import conf

MAX_UTF8_LEN = 2048




+ 18
- 36
channel/wechatmp/passive_reply.py View File

@@ -1,17 +1,18 @@
import time
import asyncio
import time

import web
from wechatpy import parse_message
from wechatpy.replies import ImageReply, VoiceReply, create_reply

from channel.wechatmp.wechatmp_message import WeChatMPMessage
from bridge.context import *
from bridge.reply import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
from channel.wechatmp.wechatmp_message import WeChatMPMessage
from common.log import logger
from config import conf
from wechatpy import parse_message
from wechatpy.replies import create_reply, ImageReply, VoiceReply


# This class is instantiated once per query
class Query:
@@ -49,21 +50,15 @@ class Query:
if (
from_user not in channel.cache_dict
and from_user not in channel.running
or content.startswith("#")
and message_id not in channel.request_cnt # insert the godcmd
or content.startswith("#")
and message_id not in channel.request_cnt # insert the godcmd
):
# The first query begin
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
context = channel._compose_context(
wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg
)
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
else:
context = channel._compose_context(
wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg
)
logger.debug(
"[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported)
)
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))

if supported and context:
# set private openai_api_key
@@ -94,23 +89,17 @@ class Query:
"""\
未知错误,请稍后再试"""
)
replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render())


# Wechat official server will request 3 times (5 seconds each), with the same message_id.
# Because the interval is 5 seconds, here assumed that do not have multithreading problems.
request_cnt = channel.request_cnt.get(message_id, 0) + 1
channel.request_cnt[message_id] = request_cnt
logger.info(
"[wechatmp] Request {} from {} {} {}:{}\n{}".format(
request_cnt,
from_user,
message_id,
web.ctx.env.get("REMOTE_ADDR"),
web.ctx.env.get("REMOTE_PORT"),
content
request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
)
)

@@ -130,7 +119,7 @@ class Query:
time.sleep(2)
# and do nothing, waiting for the next request
return "success"
else: # request_cnt == 3:
else: # request_cnt == 3:
# return timeout message
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
replyPost = create_reply(reply_text, msg)
@@ -140,10 +129,7 @@ class Query:
channel.request_cnt.pop(message_id)

# no return because of bandwords or other reasons
if (
from_user not in channel.cache_dict
and from_user not in channel.running
):
if from_user not in channel.cache_dict and from_user not in channel.running:
return "success"

# Only one request can access to the cached data
@@ -152,7 +138,7 @@ class Query:
except KeyError:
return "success"

if (reply_type == "text"):
if reply_type == "text":
if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
reply_text = reply_content
else:
@@ -177,7 +163,7 @@ class Query:
replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render())

elif (reply_type == "voice"):
elif reply_type == "voice":
media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
logger.info(
@@ -193,7 +179,7 @@ class Query:
replyPost.media_id = media_id
return encrypt_func(replyPost.render())

elif (reply_type == "image"):
elif reply_type == "image":
media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
logger.info(
@@ -210,11 +196,7 @@ class Query:
return encrypt_func(replyPost.render())

elif msg.type == "event":
logger.info(
"[wechatmp] Event {} from {}".format(
msg.event, msg.source
)
)
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
if msg.event in ["subscribe", "subscribe_scan"]:
reply_text = subscribe_msg()
replyPost = create_reply(reply_text, msg)


+ 20
- 27
channel/wechatmp/wechatmp_channel.py View File

@@ -1,24 +1,26 @@
# -*- coding: utf-8 -*-
import asyncio
import imghdr
import io
import os
import threading
import time
import imghdr
import requests
import asyncio
import threading
from config import conf
import web
from wechatpy.crypto import WeChatCrypto
from wechatpy.exceptions import WeChatClientException

from bridge.context import *
from bridge.reply import *
from common.log import logger
from common.singleton import singleton
from voice.audio_convert import any_to_mp3
from channel.chat_channel import ChatChannel
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_client import WechatMPClient
from wechatpy.exceptions import WeChatClientException
from wechatpy.crypto import WeChatCrypto
from common.log import logger
from common.singleton import singleton
from config import conf
from voice.audio_convert import any_to_mp3

import web
# If using SSL, uncomment the following lines, and modify the certificate path.
# from cheroot.server import HTTPServer
# from cheroot.ssl.builtin import BuiltinSSLAdapter
@@ -54,7 +56,6 @@ class WechatMPChannel(ChatChannel):
t.setDaemon(True)
t.start()


def startup(self):
if self.passive_reply:
urls = ("/wx", "channel.wechatmp.passive_reply.Query")
@@ -84,7 +85,7 @@ class WechatMPChannel(ChatChannel):
elif reply.type == ReplyType.VOICE:
try:
voice_file_path = reply.content
with open(voice_file_path, 'rb') as f:
with open(voice_file_path, "rb") as f:
# support: <2M, <60s, mp3/wma/wav/amr
response = self.client.material.add("voice", f)
logger.debug("[wechatmp] upload voice response: {}".format(response))
@@ -107,7 +108,7 @@ class WechatMPChannel(ChatChannel):
image_storage.write(block)
image_storage.seek(0)
image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
try:
response = self.client.material.add("image", (filename, image_storage, content_type))
@@ -122,7 +123,7 @@ class WechatMPChannel(ChatChannel):
image_storage = reply.content
image_storage.seek(0)
image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
try:
response = self.client.material.add("image", (filename, image_storage, content_type))
@@ -137,7 +138,7 @@ class WechatMPChannel(ChatChannel):
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
reply_text = reply.content
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
if len(texts)>1:
if len(texts) > 1:
logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
for text in texts:
self.client.message.send_text(receiver, text)
@@ -174,7 +175,7 @@ class WechatMPChannel(ChatChannel):
image_storage.write(block)
image_storage.seek(0)
image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
try:
response = self.client.media.upload("image", (filename, image_storage, content_type))
@@ -188,7 +189,7 @@ class WechatMPChannel(ChatChannel):
image_storage = reply.content
image_storage.seek(0)
image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
try:
response = self.client.media.upload("image", (filename, image_storage, content_type))
@@ -201,20 +202,12 @@ class WechatMPChannel(ChatChannel):
return

def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
logger.debug(
"[wechatmp] Success to generate reply, msgId={}".format(
context["msg"].msg_id
)
)
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
if self.passive_reply:
self.running.remove(session_id)

def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
logger.exception(
"[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(
context["msg"].msg_id, exception
)
)
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
if self.passive_reply:
assert session_id not in self.cache_dict
self.running.remove(session_id)

+ 10
- 11
channel/wechatmp/wechatmp_client.py View File

@@ -1,17 +1,16 @@
import time
import threading
from channel.wechatmp.common import *
import time

from wechatpy.client import WeChatClient
from common.log import logger
from wechatpy.exceptions import APILimitedException

from channel.wechatmp.common import *
from common.log import logger


class WechatMPClient(WeChatClient):
def __init__(self, appid, secret, access_token=None,
session=None, timeout=None, auto_retry=True):
super(WechatMPClient, self).__init__(
appid, secret, access_token, session, timeout, auto_retry
)
def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
self.fetch_access_token_lock = threading.Lock()

def clear_quota(self):
@@ -20,7 +19,7 @@ class WechatMPClient(WeChatClient):
def clear_quota_v2(self):
return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})

def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
with self.fetch_access_token_lock:
access_token = self.session.get(self.access_token_key)
if access_token:
@@ -31,11 +30,11 @@ class WechatMPClient(WeChatClient):
return access_token
return super().fetch_access_token()

def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
try:
return super()._request(method, url_or_endpoint, **kwargs)
except APILimitedException as e:
logger.error("[wechatmp] API quata has been used up. {}".format(e))
response = self.clear_quota_v2()
logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
return super()._request(method, url_or_endpoint, **kwargs)
return super()._request(method, url_or_endpoint, **kwargs)

+ 5
- 14
channel/wechatmp/wechatmp_message.py View File

@@ -6,7 +6,6 @@ from common.log import logger
from common.tmp_dir import TmpDir



class WeChatMPMessage(ChatMessage):
def __init__(self, msg, client=None):
super().__init__(msg)
@@ -18,12 +17,9 @@ class WeChatMPMessage(ChatMessage):
self.ctype = ContextType.TEXT
self.content = msg.content
elif msg.type == "voice":
if msg.recognition == None:
self.ctype = ContextType.VOICE
self.content = (
TmpDir().path() + msg.media_id + "." + msg.format
) # content直接存临时目录路径
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径

def download_voice():
# 如果响应状态码是200,则将响应内容写入本地文件
@@ -32,9 +28,7 @@ class WeChatMPMessage(ChatMessage):
with open(self.content, "wb") as f:
f.write(response.content)
else:
logger.info(
f"[wechatmp] Failed to download voice file, {response.content}"
)
logger.info(f"[wechatmp] Failed to download voice file, {response.content}")

self._prepare_fn = download_voice
else:
@@ -43,6 +37,7 @@ class WeChatMPMessage(ChatMessage):
elif msg.type == "image":
self.ctype = ContextType.IMAGE
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径

def download_image():
# 如果响应状态码是200,则将响应内容写入本地文件
response = client.media.download(msg.media_id)
@@ -50,15 +45,11 @@ class WeChatMPMessage(ChatMessage):
with open(self.content, "wb") as f:
f.write(response.content)
else:
logger.info(
f"[wechatmp] Failed to download image file, {response.content}"
)
logger.info(f"[wechatmp] Failed to download image file, {response.content}")

self._prepare_fn = download_image
else:
raise NotImplementedError(
"Unsupported message type: Type:{} ".format(msg.type)
)
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))

self.from_user_id = msg.source
self.to_user_id = msg.target


+ 3
- 11
common/time_check.py View File

@@ -13,23 +13,15 @@ def time_checker(f):
if chat_time_module:
chat_start_time = _config.get("chat_start_time", "00:00")
chat_stopt_time = _config.get("chat_stop_time", "24:00")
time_regex = re.compile(
r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$"
) # 时间匹配,包含24:00
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00

starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间

# 时间格式检查
if not (
starttime_format_check and stoptime_format_check and chat_time_check
):
logger.warn(
"时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(
starttime_format_check, stoptime_format_check
)
)
if not (starttime_format_check and stoptime_format_check and chat_time_check):
logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
if chat_start_time > "23:59":
logger.error("启动时间可能存在问题,请修改!")



+ 1
- 3
config.py View File

@@ -158,9 +158,7 @@ def load_config():
for name, value in os.environ.items():
name = name.lower()
if name in available_setting:
logger.info(
"[INIT] override config by environ args: {}={}".format(name, value)
)
logger.info("[INIT] override config by environ args: {}={}".format(name, value))
try:
config[name] = eval(value)
except:


+ 3
- 9
plugins/banwords/banwords.py View File

@@ -50,9 +50,7 @@ class Banwords(Plugin):
self.reply_action = conf.get("reply_action", "ignore")
logger.info("[Banwords] inited")
except Exception as e:
logger.warn(
"[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ."
)
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
raise e

def on_handle_context(self, e_context: EventContext):
@@ -72,9 +70,7 @@ class Banwords(Plugin):
return
elif self.action == "replace":
if self.searchr.ContainsAny(content):
reply = Reply(
ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)
)
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content))
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
@@ -94,9 +90,7 @@ class Banwords(Plugin):
return
elif self.reply_action == "replace":
if self.searchr.ContainsAny(content):
reply = Reply(
ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)
)
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content))
e_context["reply"] = reply
e_context.action = EventAction.CONTINUE
return


+ 7
- 32
plugins/bdunit/bdunit.py View File

@@ -76,9 +76,7 @@ class BDunit(Plugin):
Returns:
string: access_token
"""
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
self.api_key, self.secret_key
)
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
payload = ""
headers = {"Content-Type": "application/json", "Accept": "application/json"}

@@ -94,10 +92,7 @@ class BDunit(Plugin):
:returns: UNIT 解析结果。如果解析失败,返回 None
"""

url = (
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+ self.access_token
)
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
request = {
"query": query,
"user_id": str(get_mac())[:32],
@@ -124,10 +119,7 @@ class BDunit(Plugin):
:param query: 用户的指令字符串
:returns: UNIT 解析结果。如果解析失败,返回 None
"""
url = (
"https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token="
+ self.access_token
)
url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
request = {"query": query, "user_id": str(get_mac())[:32]}
body = {
"log_id": str(uuid.uuid1()),
@@ -170,11 +162,7 @@ class BDunit(Plugin):
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
for response in response_list:
if (
"schema" in response
and "intent" in response["schema"]
and response["schema"]["intent"] == intent
):
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
return True
return False
else:
@@ -198,12 +186,7 @@ class BDunit(Plugin):
logger.warning(e)
return []
for response in response_list:
if (
"schema" in response
and "intent" in response["schema"]
and "slots" in response["schema"]
and response["schema"]["intent"] == intent
):
if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
return response["schema"]["slots"]
return []
else:
@@ -239,11 +222,7 @@ class BDunit(Plugin):
if (
"schema" in response
and "intent_confidence" in response["schema"]
and (
not answer
or response["schema"]["intent_confidence"]
> answer["schema"]["intent_confidence"]
)
and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
):
answer = response
return answer["action_list"][0]["say"]
@@ -267,11 +246,7 @@ class BDunit(Plugin):
logger.warning(e)
return ""
for response in response_list:
if (
"schema" in response
and "intent" in response["schema"]
and response["schema"]["intent"] == intent
):
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
try:
return response["action_list"][0]["say"]
except Exception as e:


+ 2
- 8
plugins/dungeon/dungeon.py View File

@@ -84,9 +84,7 @@ class Dungeon(Plugin):
if len(clist) > 1:
story = clist[1]
else:
story = (
"你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
)
story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
self.games[sessionid] = StoryTeller(bot, sessionid, story)
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
e_context["reply"] = reply
@@ -102,11 +100,7 @@ class Dungeon(Plugin):
if kwargs.get("verbose") != True:
return help_text
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = (
f"{trigger_prefix}开始冒险 "
+ "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"
+ f"{trigger_prefix}停止冒险: 结束游戏。\n"
)
help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n"
if kwargs.get("verbose") == True:
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
return help_text

+ 7
- 23
plugins/godcmd/godcmd.py View File

@@ -140,9 +140,7 @@ def get_help_text(isadmin, isgroup):
if plugins[plugin].enabled and not plugins[plugin].hidden:
namecn = plugins[plugin].namecn
help_text += "\n%s:" % namecn
help_text += (
PluginManager().instances[plugin].get_help_text(verbose=False).strip()
)
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()

if ADMIN_COMMANDS and isadmin:
help_text += "\n\n管理员指令:\n"
@@ -191,9 +189,7 @@ class Godcmd(Plugin):
COMMANDS["reset"]["alias"].append(custom_command)

self.password = gconf["password"]
self.admin_users = gconf[
"admin_users"
] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
self.isrunning = True # 机器人是否运行中

self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
@@ -215,7 +211,7 @@ class Godcmd(Plugin):
reply.content = f"空指令,输入#help查看指令列表\n"
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
return
# msg = e_context['context']['msg']
channel = e_context["channel"]
user = e_context["context"]["receiver"]
@@ -248,11 +244,7 @@ class Godcmd(Plugin):
if not plugincls.enabled:
continue
if query_name == name or query_name == plugincls.namecn:
ok, result = True, PluginManager().instances[
name
].get_help_text(
isgroup=isgroup, isadmin=isadmin, verbose=True
)
ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
break
if not ok:
result = "插件不存在或未启用"
@@ -285,11 +277,7 @@ class Godcmd(Plugin):
if isgroup:
ok, result = False, "群聊不可执行管理员指令"
else:
cmd = next(
c
for c, info in ADMIN_COMMANDS.items()
if cmd in info["alias"]
)
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"])
if cmd == "stop":
self.isrunning = False
ok, result = True, "服务已暂停"
@@ -325,18 +313,14 @@ class Godcmd(Plugin):
PluginManager().activate_plugins()
if len(new_plugins) > 0:
result += "\n发现新插件:\n"
result += "\n".join(
[f"{p.name}_v{p.version}" for p in new_plugins]
)
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
else:
result += ", 未发现新插件"
elif cmd == "setpri":
if len(args) != 2:
ok, result = False, "请提供插件名和优先级"
else:
ok = PluginManager().set_plugin_priority(
args[0], int(args[1])
)
ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
if ok:
result = "插件" + args[0] + "优先级已设置为" + args[1]
else:


+ 2
- 6
plugins/hello/hello.py View File

@@ -33,9 +33,7 @@ class Hello(Plugin):
if e_context["context"].type == ContextType.JOIN_GROUP:
e_context["context"].type = ContextType.TEXT
msg: ChatMessage = e_context["context"]["msg"]
e_context[
"context"
].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
return

@@ -53,9 +51,7 @@ class Hello(Plugin):
reply.type = ReplyType.TEXT
msg: ChatMessage = e_context["context"]["msg"]
if e_context["context"]["isgroup"]:
reply.content = (
f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
)
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
else:
reply.content = f"Hello, {msg.from_user_nickname}"
e_context["reply"] = reply


+ 2
- 2
plugins/keyword/config.json.template View File

@@ -1,5 +1,5 @@
{
"keyword": {
"关键字匹配": "测试成功"
"关键字匹配": "测试成功"
}
}
}

+ 1
- 3
plugins/keyword/keyword.py View File

@@ -41,9 +41,7 @@ class Keyword(Plugin):
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[keyword] inited.")
except Exception as e:
logger.warn(
"[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword ."
)
logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .")
raise e

def on_handle_context(self, e_context: EventContext):


+ 14
- 46
plugins/plugin_manager.py View File

@@ -31,23 +31,14 @@ class PluginManager:
plugincls.desc = kwargs.get("desc")
plugincls.author = kwargs.get("author")
plugincls.path = self.current_plugin_path
plugincls.version = (
kwargs.get("version") if kwargs.get("version") != None else "1.0"
)
plugincls.namecn = (
kwargs.get("namecn") if kwargs.get("namecn") != None else name
)
plugincls.hidden = (
kwargs.get("hidden") if kwargs.get("hidden") != None else False
)
plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0"
plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name
plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False
plugincls.enabled = True
if self.current_plugin_path == None:
raise Exception("Plugin path not set")
self.plugins[name.upper()] = plugincls
logger.info(
"Plugin %s_v%s registered, path=%s"
% (name, plugincls.version, plugincls.path)
)
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))

return wrapper

@@ -62,9 +53,7 @@ class PluginManager:
if os.path.exists("./plugins/plugins.json"):
with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
pconf = json.load(f)
pconf["plugins"] = SortedDict(
lambda k, v: v["priority"], pconf["plugins"], reverse=True
)
pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True)
else:
modified = True
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
@@ -90,26 +79,16 @@ class PluginManager:
if plugin_path in self.loaded:
if self.loaded[plugin_path] == None:
logger.info("reload module %s" % plugin_name)
self.loaded[plugin_path] = importlib.reload(
sys.modules[import_path]
)
dependent_module_names = [
name
for name in sys.modules.keys()
if name.startswith(import_path + ".")
]
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
for name in dependent_module_names:
logger.info("reload module %s" % name)
importlib.reload(sys.modules[name])
else:
self.loaded[plugin_path] = importlib.import_module(
import_path
)
self.loaded[plugin_path] = importlib.import_module(import_path)
self.current_plugin_path = None
except Exception as e:
logger.exception(
"Failed to import plugin %s: %s" % (plugin_name, e)
)
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
continue
pconf = self.pconf
news = [self.plugins[name] for name in self.plugins]
@@ -119,9 +98,7 @@ class PluginManager:
rawname = plugincls.name
if rawname not in pconf["plugins"]:
modified = True
logger.info(
"Plugin %s not found in pconfig, adding to pconfig..." % name
)
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
pconf["plugins"][rawname] = {
"enabled": plugincls.enabled,
"priority": plugincls.priority,
@@ -136,9 +113,7 @@ class PluginManager:

def refresh_order(self):
for event in self.listening_plugins.keys():
self.listening_plugins[event].sort(
key=lambda name: self.plugins[name].priority, reverse=True
)
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)

def activate_plugins(self): # 生成新开启的插件实例
failed_plugins = []
@@ -184,13 +159,8 @@ class PluginManager:
def emit_event(self, e_context: EventContext, *args, **kwargs):
if e_context.event in self.listening_plugins:
for name in self.listening_plugins[e_context.event]:
if (
self.plugins[name].enabled
and e_context.action == EventAction.CONTINUE
):
logger.debug(
"Plugin %s triggered by event %s" % (name, e_context.event)
)
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
logger.debug("Plugin %s triggered by event %s" % (name, e_context.event))
instance = self.instances[name]
instance.handlers[e_context.event](e_context, *args, **kwargs)
return e_context
@@ -262,9 +232,7 @@ class PluginManager:
source = json.load(f)
if repo in source["repo"]:
repo = source["repo"][repo]["url"]
match = re.match(
r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo
)
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
if not match:
return False, "安装插件失败,source中的仓库地址不合法"
else:


+ 7
- 24
plugins/role/role.py View File

@@ -69,13 +69,9 @@ class Role(Plugin):
logger.info("[Role] inited")
except Exception as e:
if isinstance(e, FileNotFoundError):
logger.warn(
f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
)
logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
else:
logger.warn(
"[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
)
logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
raise e

def get_role(self, name, find_closest=True, min_sim=0.35):
@@ -143,9 +139,7 @@ class Role(Plugin):
else:
help_text = f"未知角色类型。\n"
help_text += "目前的角色类型有: \n"
help_text += (
",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
)
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
else:
help_text = f"请输入角色类型。\n"
help_text += "目前的角色类型有: \n"
@@ -158,9 +152,7 @@ class Role(Plugin):
return
logger.debug("[Role] on_handle_context. content: %s" % content)
if desckey is not None:
if len(clist) == 1 or (
len(clist) > 1 and clist[1].lower() in ["help", "帮助"]
):
if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
@@ -178,9 +170,7 @@ class Role(Plugin):
self.roles[role][desckey],
self.roles[role].get("wrapper", "%s"),
)
reply = Reply(
ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]
)
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey])
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
elif customize == True:
@@ -199,17 +189,10 @@ class Role(Plugin):
if not verbose:
return help_text
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = (
f"使用方法:\n{trigger_prefix}角色"
+ " 预设角色名: 设定角色为{预设角色名}。\n"
+ f"{trigger_prefix}role"
+ " 预设角色名: 同上,但使用英文设定。\n"
)
help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n"
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
help_text += (
f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
)
help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
help_text += "\n目前的角色类型有: \n"
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n"
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"


+ 1
- 1
plugins/tool/README.md View File

@@ -60,7 +60,7 @@

> 该tool每天返回内容相同

#### 6.3. finance-news
#### 6.3. finance-news
###### 获取实时的金融财政新闻

> 该工具需要解决browser tool 的google-chrome依赖安装


+ 4
- 10
plugins/tool/tool.py View File

@@ -82,9 +82,7 @@ class Tool(Plugin):
return
elif content_list[1].startswith("reset"):
logger.debug("[tool]: remind")
e_context[
"context"
].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
e_context["context"].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"

e_context.action = EventAction.BREAK
return
@@ -93,18 +91,14 @@ class Tool(Plugin):

# Don't modify bot name
all_sessions = Bridge().get_bot("chat").sessions
user_session = all_sessions.session_query(
query, e_context["context"]["session_id"]
).messages
user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages

# chatgpt-tool-hub will reply you with many tools
logger.debug("[tool]: just-go")
try:
_reply = self.app.ask(query, user_session)
e_context.action = EventAction.BREAK_PASS
all_sessions.session_reply(
_reply, e_context["context"]["session_id"]
)
all_sessions.session_reply(_reply, e_context["context"]["session_id"])
except Exception as e:
logger.exception(e)
logger.error(str(e))
@@ -178,4 +172,4 @@ class Tool(Plugin):
# filter not support tool
tool_list = self._filter_tool_list(tool_config.get("tools", []))

return app.create_app(tools_list=tool_list, **app_kwargs)
return app.create_app(tools_list=tool_list, **app_kwargs)

+ 5
- 15
voice/audio_convert.py View File

@@ -33,6 +33,7 @@ def get_pcm_from_wav(wav_path):
wav = wave.open(wav_path, "rb")
return wav.readframes(wav.getnframes())


def any_to_mp3(any_path, mp3_path):
"""
把任意格式转成mp3文件
@@ -40,16 +41,13 @@ def any_to_mp3(any_path, mp3_path):
if any_path.endswith(".mp3"):
shutil.copy2(any_path, mp3_path)
return
if (
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
sil_to_wav(any_path, any_path)
any_path = mp3_path
audio = AudioSegment.from_file(any_path)
audio.export(mp3_path, format="mp3")


def any_to_wav(any_path, wav_path):
"""
把任意格式转成wav文件
@@ -57,11 +55,7 @@ def any_to_wav(any_path, wav_path):
if any_path.endswith(".wav"):
shutil.copy2(any_path, wav_path)
return
if (
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
return sil_to_wav(any_path, wav_path)
audio = AudioSegment.from_file(any_path)
audio.export(wav_path, format="wav")
@@ -71,11 +65,7 @@ def any_to_sil(any_path, sil_path):
"""
把任意格式转成sil文件
"""
if (
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
shutil.copy2(any_path, sil_path)
return 10000
audio = AudioSegment.from_file(any_path)


+ 9
- 33
voice/azure/azure_voice.py View File

@@ -40,57 +40,33 @@ class AzureVoice(Voice):
config = json.load(fr)
self.api_key = conf().get("azure_voice_api_key")
self.api_region = conf().get("azure_voice_region")
self.speech_config = speechsdk.SpeechConfig(
subscription=self.api_key, region=self.api_region
)
self.speech_config.speech_synthesis_voice_name = config[
"speech_synthesis_voice_name"
]
self.speech_config.speech_recognition_language = config[
"speech_recognition_language"
]
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
self.speech_config.speech_recognition_language = config["speech_recognition_language"]
except Exception as e:
logger.warn("AzureVoice init failed: %s, ignore " % e)

def voiceToText(self, voice_file):
audio_config = speechsdk.AudioConfig(filename=voice_file)
speech_recognizer = speechsdk.SpeechRecognizer(
speech_config=self.speech_config, audio_config=audio_config
)
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_recognizer.recognize_once()
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
logger.info(
"[Azure] voiceToText voice file name={} text={}".format(
voice_file, result.text
)
)
logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text))
reply = Reply(ReplyType.TEXT, result.text)
else:
logger.error(
"[Azure] voiceToText error, result={}, canceldetails={}".format(
result, result.cancellation_details
)
)
logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details))
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
return reply

def textToVoice(self, text):
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
audio_config = speechsdk.AudioConfig(filename=fileName)
speech_synthesizer = speechsdk.SpeechSynthesizer(
speech_config=self.speech_config, audio_config=audio_config
)
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_synthesizer.speak_text(text)
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
logger.info(
"[Azure] textToVoice text={} voice file name={}".format(text, fileName)
)
logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName))
reply = Reply(ReplyType.VOICE, fileName)
else:
logger.error(
"[Azure] textToVoice error, result={}, canceldetails={}".format(
result, result.cancellation_details
)
)
logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details))
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
return reply

+ 1
- 3
voice/baidu/baidu_voice.py View File

@@ -85,9 +85,7 @@ class BaiduVoice(Voice):
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
with open(fileName, "wb") as f:
f.write(result)
logger.info(
"[Baidu] textToVoice text={} voice file name={}".format(text, fileName)
)
logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName))
reply = Reply(ReplyType.VOICE, fileName)
else:
logger.error("[Baidu] textToVoice error={}".format(result))


+ 2
- 8
voice/google/google_voice.py View File

@@ -24,11 +24,7 @@ class GoogleVoice(Voice):
audio = self.recognizer.record(source)
try:
text = self.recognizer.recognize_google(audio, language="zh-CN")
logger.info(
"[Google] voiceToText text={} voice file name={}".format(
text, voice_file
)
)
logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file))
reply = Reply(ReplyType.TEXT, text)
except speech_recognition.UnknownValueError:
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
@@ -42,9 +38,7 @@ class GoogleVoice(Voice):
mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
tts = gTTS(text=text, lang="zh")
tts.save(mp3File)
logger.info(
"[Google] textToVoice text={} voice file name={}".format(text, mp3File)
)
logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File))
reply = Reply(ReplyType.VOICE, mp3File)
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))


+ 1
- 5
voice/openai/openai_voice.py View File

@@ -22,11 +22,7 @@ class OpenaiVoice(Voice):
result = openai.Audio.transcribe("whisper-1", file)
text = result["text"]
reply = Reply(ReplyType.TEXT, text)
logger.info(
"[Openai] voiceToText text={} voice file name={}".format(
text, voice_file
)
)
logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))
finally:


+ 7
- 5
voice/pytts/pytts_voice.py View File

@@ -5,6 +5,7 @@ pytts voice service (offline)
import os
import sys
import time

import pyttsx3

from bridge.reply import Reply, ReplyType
@@ -12,6 +13,7 @@ from common.log import logger
from common.tmp_dir import TmpDir
from voice.voice import Voice


class PyttsVoice(Voice):
engine = pyttsx3.init()

@@ -20,7 +22,7 @@ class PyttsVoice(Voice):
self.engine.setProperty("rate", 125)
# 音量
self.engine.setProperty("volume", 1.0)
if sys.platform == 'win32':
if sys.platform == "win32":
for voice in self.engine.getProperty("voices"):
if "Chinese" in voice.name:
self.engine.setProperty("voice", voice.id)
@@ -33,23 +35,23 @@ class PyttsVoice(Voice):
def textToVoice(self, text):
try:
# avoid the same filename
wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7fffffff) + ".wav"
wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
wavFile = TmpDir().path() + wavFileName
logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile))

self.engine.save_to_file(text, wavFile)

if sys.platform == 'win32':
if sys.platform == "win32":
self.engine.runAndWait()
else:
# In ubuntu, runAndWait do not really wait until the file created.
# In ubuntu, runAndWait do not really wait until the file created.
# It will return once the task queue is empty, but the task is still running in coroutine.
# And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this.
# If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file.
# self.engine.runAndWait()

# Before espeak fix this problem, we iterate the generator and control the waiting by ourself.
# But this is not the canonical way to use it, for example if the file already exists it also cannot wait.
# But this is not the canonical way to use it, for example if the file already exists it also cannot wait.
self.engine.iterate()
while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()):
time.sleep(0.1)


Loading…
Cancel
Save