@@ -19,6 +19,7 @@ def sigterm_handler_wrap(_signo): | |||||
if callable(old_handler): # check old_handler | if callable(old_handler): # check old_handler | ||||
return old_handler(_signo, _stack_frame) | return old_handler(_signo, _stack_frame) | ||||
sys.exit(0) | sys.exit(0) | ||||
signal.signal(_signo, func) | signal.signal(_signo, func) | ||||
@@ -10,10 +10,7 @@ from bridge.reply import Reply, ReplyType | |||||
class BaiduUnitBot(Bot): | class BaiduUnitBot(Bot): | ||||
def reply(self, query, context=None): | def reply(self, query, context=None): | ||||
token = self.get_token() | token = self.get_token() | ||||
url = ( | |||||
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" | |||||
+ token | |||||
) | |||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token | |||||
post_data = ( | post_data = ( | ||||
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' | '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' | ||||
+ query | + query | ||||
@@ -32,12 +29,7 @@ class BaiduUnitBot(Bot): | |||||
def get_token(self): | def get_token(self): | ||||
access_key = "YOUR_ACCESS_KEY" | access_key = "YOUR_ACCESS_KEY" | ||||
secret_key = "YOUR_SECRET_KEY" | secret_key = "YOUR_SECRET_KEY" | ||||
host = ( | |||||
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" | |||||
+ access_key | |||||
+ "&client_secret=" | |||||
+ secret_key | |||||
) | |||||
host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key | |||||
response = requests.get(host) | response = requests.get(host) | ||||
if response: | if response: | ||||
print(response.json()) | print(response.json()) | ||||
@@ -30,23 +30,15 @@ class ChatGPTBot(Bot, OpenAIImage): | |||||
if conf().get("rate_limit_chatgpt"): | if conf().get("rate_limit_chatgpt"): | ||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) | self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) | ||||
self.sessions = SessionManager( | |||||
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo" | |||||
) | |||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") | |||||
self.args = { | self.args = { | ||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 | "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 | ||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | ||||
# "max_tokens":4096, # 回复最大的字符数 | # "max_tokens":4096, # 回复最大的字符数 | ||||
"top_p": 1, | "top_p": 1, | ||||
"frequency_penalty": conf().get( | |||||
"frequency_penalty", 0.0 | |||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"presence_penalty": conf().get( | |||||
"presence_penalty", 0.0 | |||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"request_timeout": conf().get( | |||||
"request_timeout", None | |||||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 | "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 | ||||
} | } | ||||
@@ -87,15 +79,10 @@ class ChatGPTBot(Bot, OpenAIImage): | |||||
reply_content["completion_tokens"], | reply_content["completion_tokens"], | ||||
) | ) | ||||
) | ) | ||||
if ( | |||||
reply_content["completion_tokens"] == 0 | |||||
and len(reply_content["content"]) > 0 | |||||
): | |||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: | |||||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | reply = Reply(ReplyType.ERROR, reply_content["content"]) | ||||
elif reply_content["completion_tokens"] > 0: | elif reply_content["completion_tokens"] > 0: | ||||
self.sessions.session_reply( | |||||
reply_content["content"], session_id, reply_content["total_tokens"] | |||||
) | |||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) | |||||
reply = Reply(ReplyType.TEXT, reply_content["content"]) | reply = Reply(ReplyType.TEXT, reply_content["content"]) | ||||
else: | else: | ||||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | reply = Reply(ReplyType.ERROR, reply_content["content"]) | ||||
@@ -126,9 +113,7 @@ class ChatGPTBot(Bot, OpenAIImage): | |||||
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): | if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): | ||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") | raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") | ||||
# if api_key == None, the default openai.api_key will be used | # if api_key == None, the default openai.api_key will be used | ||||
response = openai.ChatCompletion.create( | |||||
api_key=api_key, messages=session.messages, **self.args | |||||
) | |||||
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args) | |||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | ||||
return { | return { | ||||
"total_tokens": response["usage"]["total_tokens"], | "total_tokens": response["usage"]["total_tokens"], | ||||
@@ -25,9 +25,7 @@ class ChatGPTSession(Session): | |||||
precise = False | precise = False | ||||
if cur_tokens is None: | if cur_tokens is None: | ||||
raise e | raise e | ||||
logger.debug( | |||||
"Exception when counting tokens precisely for query: {}".format(e) | |||||
) | |||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||||
while cur_tokens > max_tokens: | while cur_tokens > max_tokens: | ||||
if len(self.messages) > 2: | if len(self.messages) > 2: | ||||
self.messages.pop(1) | self.messages.pop(1) | ||||
@@ -39,16 +37,10 @@ class ChatGPTSession(Session): | |||||
cur_tokens = cur_tokens - max_tokens | cur_tokens = cur_tokens - max_tokens | ||||
break | break | ||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user": | elif len(self.messages) == 2 and self.messages[1]["role"] == "user": | ||||
logger.warn( | |||||
"user message exceed max_tokens. total_tokens={}".format(cur_tokens) | |||||
) | |||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||||
break | break | ||||
else: | else: | ||||
logger.debug( | |||||
"max_tokens={}, total_tokens={}, len(messages)={}".format( | |||||
max_tokens, cur_tokens, len(self.messages) | |||||
) | |||||
) | |||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||||
break | break | ||||
if precise: | if precise: | ||||
cur_tokens = self.calc_tokens() | cur_tokens = self.calc_tokens() | ||||
@@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model): | |||||
elif model == "gpt-4": | elif model == "gpt-4": | ||||
return num_tokens_from_messages(messages, model="gpt-4-0314") | return num_tokens_from_messages(messages, model="gpt-4-0314") | ||||
elif model == "gpt-3.5-turbo-0301": | elif model == "gpt-3.5-turbo-0301": | ||||
tokens_per_message = ( | |||||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |||||
) | |||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |||||
tokens_per_name = -1 # if there's a name, the role is omitted | tokens_per_name = -1 # if there's a name, the role is omitted | ||||
elif model == "gpt-4-0314": | elif model == "gpt-4-0314": | ||||
tokens_per_message = 3 | tokens_per_message = 3 | ||||
tokens_per_name = 1 | tokens_per_name = 1 | ||||
else: | else: | ||||
logger.warn( | |||||
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301." | |||||
) | |||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.") | |||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") | return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") | ||||
num_tokens = 0 | num_tokens = 0 | ||||
for message in messages: | for message in messages: | ||||
@@ -28,23 +28,15 @@ class OpenAIBot(Bot, OpenAIImage): | |||||
if proxy: | if proxy: | ||||
openai.proxy = proxy | openai.proxy = proxy | ||||
self.sessions = SessionManager( | |||||
OpenAISession, model=conf().get("model") or "text-davinci-003" | |||||
) | |||||
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003") | |||||
self.args = { | self.args = { | ||||
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称 | "model": conf().get("model") or "text-davinci-003", # 对话模型的名称 | ||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | ||||
"max_tokens": 1200, # 回复最大的字符数 | "max_tokens": 1200, # 回复最大的字符数 | ||||
"top_p": 1, | "top_p": 1, | ||||
"frequency_penalty": conf().get( | |||||
"frequency_penalty", 0.0 | |||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"presence_penalty": conf().get( | |||||
"presence_penalty", 0.0 | |||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"request_timeout": conf().get( | |||||
"request_timeout", None | |||||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 | "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 | ||||
"stop": ["\n\n\n"], | "stop": ["\n\n\n"], | ||||
} | } | ||||
@@ -71,17 +63,13 @@ class OpenAIBot(Bot, OpenAIImage): | |||||
result["content"], | result["content"], | ||||
) | ) | ||||
logger.debug( | logger.debug( | ||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( | |||||
str(session), session_id, reply_content, completion_tokens | |||||
) | |||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens) | |||||
) | ) | ||||
if total_tokens == 0: | if total_tokens == 0: | ||||
reply = Reply(ReplyType.ERROR, reply_content) | reply = Reply(ReplyType.ERROR, reply_content) | ||||
else: | else: | ||||
self.sessions.session_reply( | |||||
reply_content, session_id, total_tokens | |||||
) | |||||
self.sessions.session_reply(reply_content, session_id, total_tokens) | |||||
reply = Reply(ReplyType.TEXT, reply_content) | reply = Reply(ReplyType.TEXT, reply_content) | ||||
return reply | return reply | ||||
elif context.type == ContextType.IMAGE_CREATE: | elif context.type == ContextType.IMAGE_CREATE: | ||||
@@ -96,9 +84,7 @@ class OpenAIBot(Bot, OpenAIImage): | |||||
def reply_text(self, session: OpenAISession, retry_count=0): | def reply_text(self, session: OpenAISession, retry_count=0): | ||||
try: | try: | ||||
response = openai.Completion.create(prompt=str(session), **self.args) | response = openai.Completion.create(prompt=str(session), **self.args) | ||||
res_content = ( | |||||
response.choices[0]["text"].strip().replace("<|endoftext|>", "") | |||||
) | |||||
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "") | |||||
total_tokens = response["usage"]["total_tokens"] | total_tokens = response["usage"]["total_tokens"] | ||||
completion_tokens = response["usage"]["completion_tokens"] | completion_tokens = response["usage"]["completion_tokens"] | ||||
logger.info("[OPEN_AI] reply={}".format(res_content)) | logger.info("[OPEN_AI] reply={}".format(res_content)) | ||||
@@ -23,9 +23,7 @@ class OpenAIImage(object): | |||||
response = openai.Image.create( | response = openai.Image.create( | ||||
prompt=query, # 图片描述 | prompt=query, # 图片描述 | ||||
n=1, # 每次生成图片的数量 | n=1, # 每次生成图片的数量 | ||||
size=conf().get( | |||||
"image_create_size", "256x256" | |||||
), # 图片大小,可选有 256x256, 512x512, 1024x1024 | |||||
size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024 | |||||
) | ) | ||||
image_url = response["data"][0]["url"] | image_url = response["data"][0]["url"] | ||||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | logger.info("[OPEN_AI] image_url={}".format(image_url)) | ||||
@@ -34,11 +32,7 @@ class OpenAIImage(object): | |||||
logger.warn(e) | logger.warn(e) | ||||
if retry_count < 1: | if retry_count < 1: | ||||
time.sleep(5) | time.sleep(5) | ||||
logger.warn( | |||||
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format( | |||||
retry_count + 1 | |||||
) | |||||
) | |||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1)) | |||||
return self.create_img(query, retry_count + 1) | return self.create_img(query, retry_count + 1) | ||||
else: | else: | ||||
return False, "提问太快啦,请休息一下再问我吧" | return False, "提问太快啦,请休息一下再问我吧" | ||||
@@ -36,9 +36,7 @@ class OpenAISession(Session): | |||||
precise = False | precise = False | ||||
if cur_tokens is None: | if cur_tokens is None: | ||||
raise e | raise e | ||||
logger.debug( | |||||
"Exception when counting tokens precisely for query: {}".format(e) | |||||
) | |||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||||
while cur_tokens > max_tokens: | while cur_tokens > max_tokens: | ||||
if len(self.messages) > 1: | if len(self.messages) > 1: | ||||
self.messages.pop(0) | self.messages.pop(0) | ||||
@@ -50,18 +48,10 @@ class OpenAISession(Session): | |||||
cur_tokens = len(str(self)) | cur_tokens = len(str(self)) | ||||
break | break | ||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user": | elif len(self.messages) == 1 and self.messages[0]["role"] == "user": | ||||
logger.warn( | |||||
"user question exceed max_tokens. total_tokens={}".format( | |||||
cur_tokens | |||||
) | |||||
) | |||||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||||
break | break | ||||
else: | else: | ||||
logger.debug( | |||||
"max_tokens={}, total_tokens={}, len(conversation)={}".format( | |||||
max_tokens, cur_tokens, len(self.messages) | |||||
) | |||||
) | |||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||||
break | break | ||||
if precise: | if precise: | ||||
cur_tokens = self.calc_tokens() | cur_tokens = self.calc_tokens() | ||||
@@ -55,9 +55,7 @@ class SessionManager(object): | |||||
return self.sessioncls(session_id, system_prompt, **self.session_args) | return self.sessioncls(session_id, system_prompt, **self.session_args) | ||||
if session_id not in self.sessions: | if session_id not in self.sessions: | ||||
self.sessions[session_id] = self.sessioncls( | |||||
session_id, system_prompt, **self.session_args | |||||
) | |||||
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args) | |||||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session | elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session | ||||
self.sessions[session_id].set_system_prompt(system_prompt) | self.sessions[session_id].set_system_prompt(system_prompt) | ||||
session = self.sessions[session_id] | session = self.sessions[session_id] | ||||
@@ -71,9 +69,7 @@ class SessionManager(object): | |||||
total_tokens = session.discard_exceeding(max_tokens, None) | total_tokens = session.discard_exceeding(max_tokens, None) | ||||
logger.debug("prompt tokens used={}".format(total_tokens)) | logger.debug("prompt tokens used={}".format(total_tokens)) | ||||
except Exception as e: | except Exception as e: | ||||
logger.debug( | |||||
"Exception when counting tokens precisely for prompt: {}".format(str(e)) | |||||
) | |||||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e))) | |||||
return session | return session | ||||
def session_reply(self, reply, session_id, total_tokens=None): | def session_reply(self, reply, session_id, total_tokens=None): | ||||
@@ -82,17 +78,9 @@ class SessionManager(object): | |||||
try: | try: | ||||
max_tokens = conf().get("conversation_max_tokens", 1000) | max_tokens = conf().get("conversation_max_tokens", 1000) | ||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) | tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) | ||||
logger.debug( | |||||
"raw total_tokens={}, savesession tokens={}".format( | |||||
total_tokens, tokens_cnt | |||||
) | |||||
) | |||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) | |||||
except Exception as e: | except Exception as e: | ||||
logger.debug( | |||||
"Exception when counting tokens precisely for session: {}".format( | |||||
str(e) | |||||
) | |||||
) | |||||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e))) | |||||
return session | return session | ||||
def clear_session(self, session_id): | def clear_session(self, session_id): | ||||
@@ -60,6 +60,4 @@ class Context: | |||||
del self.kwargs[key] | del self.kwargs[key] | ||||
def __str__(self): | def __str__(self): | ||||
return "Context(type={}, content={}, kwargs={})".format( | |||||
self.type, self.content, self.kwargs | |||||
) | |||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) |
@@ -53,9 +53,7 @@ class ChatChannel(Channel): | |||||
group_id = cmsg.other_user_id | group_id = cmsg.other_user_id | ||||
group_name_white_list = config.get("group_name_white_list", []) | group_name_white_list = config.get("group_name_white_list", []) | ||||
group_name_keyword_white_list = config.get( | |||||
"group_name_keyword_white_list", [] | |||||
) | |||||
group_name_keyword_white_list = config.get("group_name_keyword_white_list", []) | |||||
if any( | if any( | ||||
[ | [ | ||||
group_name in group_name_white_list, | group_name in group_name_white_list, | ||||
@@ -63,9 +61,7 @@ class ChatChannel(Channel): | |||||
check_contain(group_name, group_name_keyword_white_list), | check_contain(group_name, group_name_keyword_white_list), | ||||
] | ] | ||||
): | ): | ||||
group_chat_in_one_session = conf().get( | |||||
"group_chat_in_one_session", [] | |||||
) | |||||
group_chat_in_one_session = conf().get("group_chat_in_one_session", []) | |||||
session_id = cmsg.actual_user_id | session_id = cmsg.actual_user_id | ||||
if any( | if any( | ||||
[ | [ | ||||
@@ -81,17 +77,11 @@ class ChatChannel(Channel): | |||||
else: | else: | ||||
context["session_id"] = cmsg.other_user_id | context["session_id"] = cmsg.other_user_id | ||||
context["receiver"] = cmsg.other_user_id | context["receiver"] = cmsg.other_user_id | ||||
e_context = PluginManager().emit_event( | |||||
EventContext( | |||||
Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context} | |||||
) | |||||
) | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context})) | |||||
context = e_context["context"] | context = e_context["context"] | ||||
if e_context.is_pass() or context is None: | if e_context.is_pass() or context is None: | ||||
return context | return context | ||||
if cmsg.from_user_id == self.user_id and not config.get( | |||||
"trigger_by_self", True | |||||
): | |||||
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True): | |||||
logger.debug("[WX]self message skipped") | logger.debug("[WX]self message skipped") | ||||
return None | return None | ||||
@@ -119,19 +109,13 @@ class ChatChannel(Channel): | |||||
if not flag: | if not flag: | ||||
if context["origin_ctype"] == ContextType.VOICE: | if context["origin_ctype"] == ContextType.VOICE: | ||||
logger.info( | |||||
"[WX]receive group voice, but checkprefix didn't match" | |||||
) | |||||
logger.info("[WX]receive group voice, but checkprefix didn't match") | |||||
return None | return None | ||||
else: # 单聊 | else: # 单聊 | ||||
match_prefix = check_prefix( | |||||
content, conf().get("single_chat_prefix", [""]) | |||||
) | |||||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""])) | |||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 | if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 | ||||
content = content.replace(match_prefix, "", 1).strip() | content = content.replace(match_prefix, "", 1).strip() | ||||
elif ( | |||||
context["origin_ctype"] == ContextType.VOICE | |||||
): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 | |||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 | |||||
pass | pass | ||||
else: | else: | ||||
return None | return None | ||||
@@ -143,18 +127,10 @@ class ChatChannel(Channel): | |||||
else: | else: | ||||
context.type = ContextType.TEXT | context.type = ContextType.TEXT | ||||
context.content = content.strip() | context.content = content.strip() | ||||
if ( | |||||
"desire_rtype" not in context | |||||
and conf().get("always_reply_voice") | |||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE | |||||
): | |||||
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||||
context["desire_rtype"] = ReplyType.VOICE | context["desire_rtype"] = ReplyType.VOICE | ||||
elif context.type == ContextType.VOICE: | elif context.type == ContextType.VOICE: | ||||
if ( | |||||
"desire_rtype" not in context | |||||
and conf().get("voice_reply_voice") | |||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE | |||||
): | |||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||||
context["desire_rtype"] = ReplyType.VOICE | context["desire_rtype"] = ReplyType.VOICE | ||||
return context | return context | ||||
@@ -182,15 +158,8 @@ class ChatChannel(Channel): | |||||
) | ) | ||||
reply = e_context["reply"] | reply = e_context["reply"] | ||||
if not e_context.is_pass(): | if not e_context.is_pass(): | ||||
logger.debug( | |||||
"[WX] ready to handle context: type={}, content={}".format( | |||||
context.type, context.content | |||||
) | |||||
) | |||||
if ( | |||||
context.type == ContextType.TEXT | |||||
or context.type == ContextType.IMAGE_CREATE | |||||
): # 文字和图片消息 | |||||
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content)) | |||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 | |||||
reply = super().build_reply_content(context.content, context) | reply = super().build_reply_content(context.content, context) | ||||
elif context.type == ContextType.VOICE: # 语音消息 | elif context.type == ContextType.VOICE: # 语音消息 | ||||
cmsg = context["msg"] | cmsg = context["msg"] | ||||
@@ -214,9 +183,7 @@ class ChatChannel(Channel): | |||||
# logger.warning("[WX]delete temp file error: " + str(e)) | # logger.warning("[WX]delete temp file error: " + str(e)) | ||||
if reply.type == ReplyType.TEXT: | if reply.type == ReplyType.TEXT: | ||||
new_context = self._compose_context( | |||||
ContextType.TEXT, reply.content, **context.kwargs | |||||
) | |||||
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs) | |||||
if new_context: | if new_context: | ||||
reply = self._generate_reply(new_context) | reply = self._generate_reply(new_context) | ||||
else: | else: | ||||
@@ -246,48 +213,24 @@ class ChatChannel(Channel): | |||||
if reply.type == ReplyType.TEXT: | if reply.type == ReplyType.TEXT: | ||||
reply_text = reply.content | reply_text = reply.content | ||||
if ( | |||||
desire_rtype == ReplyType.VOICE | |||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE | |||||
): | |||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||||
reply = super().build_text_to_voice(reply.content) | reply = super().build_text_to_voice(reply.content) | ||||
return self._decorate_reply(context, reply) | return self._decorate_reply(context, reply) | ||||
if context.get("isgroup", False): | if context.get("isgroup", False): | ||||
reply_text = ( | |||||
"@" | |||||
+ context["msg"].actual_user_nickname | |||||
+ " " | |||||
+ reply_text.strip() | |||||
) | |||||
reply_text = ( | |||||
conf().get("group_chat_reply_prefix", "") + reply_text | |||||
) | |||||
reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip() | |||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text | |||||
else: | else: | ||||
reply_text = ( | |||||
conf().get("single_chat_reply_prefix", "") + reply_text | |||||
) | |||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text | |||||
reply.content = reply_text | reply.content = reply_text | ||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | ||||
reply.content = "[" + str(reply.type) + "]\n" + reply.content | reply.content = "[" + str(reply.type) + "]\n" + reply.content | ||||
elif ( | |||||
reply.type == ReplyType.IMAGE_URL | |||||
or reply.type == ReplyType.VOICE | |||||
or reply.type == ReplyType.IMAGE | |||||
): | |||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: | |||||
pass | pass | ||||
else: | else: | ||||
logger.error("[WX] unknown reply type: {}".format(reply.type)) | logger.error("[WX] unknown reply type: {}".format(reply.type)) | ||||
return | return | ||||
if ( | |||||
desire_rtype | |||||
and desire_rtype != reply.type | |||||
and reply.type not in [ReplyType.ERROR, ReplyType.INFO] | |||||
): | |||||
logger.warning( | |||||
"[WX] desire_rtype: {}, but reply type: {}".format( | |||||
context.get("desire_rtype"), reply.type | |||||
) | |||||
) | |||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: | |||||
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type)) | |||||
return reply | return reply | ||||
def _send_reply(self, context: Context, reply: Reply): | def _send_reply(self, context: Context, reply: Reply): | ||||
@@ -300,9 +243,7 @@ class ChatChannel(Channel): | |||||
) | ) | ||||
reply = e_context["reply"] | reply = e_context["reply"] | ||||
if not e_context.is_pass() and reply and reply.type: | if not e_context.is_pass() and reply and reply.type: | ||||
logger.debug( | |||||
"[WX] ready to send reply: {}, context: {}".format(reply, context) | |||||
) | |||||
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context)) | |||||
self._send(reply, context) | self._send(reply, context) | ||||
def _send(self, reply: Reply, context: Context, retry_cnt=0): | def _send(self, reply: Reply, context: Context, retry_cnt=0): | ||||
@@ -328,9 +269,7 @@ class ChatChannel(Channel): | |||||
try: | try: | ||||
worker_exception = worker.exception() | worker_exception = worker.exception() | ||||
if worker_exception: | if worker_exception: | ||||
self._fail_callback( | |||||
session_id, exception=worker_exception, **kwargs | |||||
) | |||||
self._fail_callback(session_id, exception=worker_exception, **kwargs) | |||||
else: | else: | ||||
self._success_callback(session_id, **kwargs) | self._success_callback(session_id, **kwargs) | ||||
except CancelledError as e: | except CancelledError as e: | ||||
@@ -366,24 +305,14 @@ class ChatChannel(Channel): | |||||
if not context_queue.empty(): | if not context_queue.empty(): | ||||
context = context_queue.get() | context = context_queue.get() | ||||
logger.debug("[WX] consume context: {}".format(context)) | logger.debug("[WX] consume context: {}".format(context)) | ||||
future: Future = self.handler_pool.submit( | |||||
self._handle, context | |||||
) | |||||
future.add_done_callback( | |||||
self._thread_pool_callback(session_id, context=context) | |||||
) | |||||
future: Future = self.handler_pool.submit(self._handle, context) | |||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context)) | |||||
if session_id not in self.futures: | if session_id not in self.futures: | ||||
self.futures[session_id] = [] | self.futures[session_id] = [] | ||||
self.futures[session_id].append(future) | self.futures[session_id].append(future) | ||||
elif ( | |||||
semaphore._initial_value == semaphore._value + 1 | |||||
): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 | |||||
self.futures[session_id] = [ | |||||
t for t in self.futures[session_id] if not t.done() | |||||
] | |||||
assert ( | |||||
len(self.futures[session_id]) == 0 | |||||
), "thread pool error" | |||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 | |||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()] | |||||
assert len(self.futures[session_id]) == 0, "thread pool error" | |||||
del self.sessions[session_id] | del self.sessions[session_id] | ||||
else: | else: | ||||
semaphore.release() | semaphore.release() | ||||
@@ -397,9 +326,7 @@ class ChatChannel(Channel): | |||||
future.cancel() | future.cancel() | ||||
cnt = self.sessions[session_id][0].qsize() | cnt = self.sessions[session_id][0].qsize() | ||||
if cnt > 0: | if cnt > 0: | ||||
logger.info( | |||||
"Cancel {} messages in session {}".format(cnt, session_id) | |||||
) | |||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) | |||||
self.sessions[session_id][0] = Dequeue() | self.sessions[session_id][0] = Dequeue() | ||||
def cancel_all_session(self): | def cancel_all_session(self): | ||||
@@ -409,9 +336,7 @@ class ChatChannel(Channel): | |||||
future.cancel() | future.cancel() | ||||
cnt = self.sessions[session_id][0].qsize() | cnt = self.sessions[session_id][0].qsize() | ||||
if cnt > 0: | if cnt > 0: | ||||
logger.info( | |||||
"Cancel {} messages in session {}".format(cnt, session_id) | |||||
) | |||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) | |||||
self.sessions[session_id][0] = Dequeue() | self.sessions[session_id][0] = Dequeue() | ||||
@@ -77,9 +77,7 @@ class TerminalChannel(ChatChannel): | |||||
if check_prefix(prompt, trigger_prefixs) is None: | if check_prefix(prompt, trigger_prefixs) is None: | ||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 | prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 | ||||
context = self._compose_context( | |||||
ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt) | |||||
) | |||||
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)) | |||||
if context: | if context: | ||||
self.produce(context) | self.produce(context) | ||||
else: | else: | ||||
@@ -56,10 +56,7 @@ def _check(func): | |||||
return | return | ||||
self.receivedMsgs[msgId] = cmsg | self.receivedMsgs[msgId] = cmsg | ||||
create_time = cmsg.create_time # 消息时间戳 | create_time = cmsg.create_time # 消息时间戳 | ||||
if ( | |||||
conf().get("hot_reload") == True | |||||
and int(create_time) < int(time.time()) - 60 | |||||
): # 跳过1分钟前的历史消息 | |||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | |||||
logger.debug("[WX]history message {} skipped".format(msgId)) | logger.debug("[WX]history message {} skipped".format(msgId)) | ||||
return | return | ||||
return func(self, cmsg) | return func(self, cmsg) | ||||
@@ -88,15 +85,9 @@ def qrCallback(uuid, status, qrcode): | |||||
url = f"https://login.weixin.qq.com/l/{uuid}" | url = f"https://login.weixin.qq.com/l/{uuid}" | ||||
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) | qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) | ||||
qr_api2 = ( | |||||
"https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format( | |||||
url | |||||
) | |||||
) | |||||
qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url) | |||||
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url) | qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url) | ||||
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format( | |||||
url | |||||
) | |||||
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url) | |||||
print("You can also scan QRCode in any website below:") | print("You can also scan QRCode in any website below:") | ||||
print(qr_api3) | print(qr_api3) | ||||
print(qr_api4) | print(qr_api4) | ||||
@@ -134,18 +125,12 @@ class WechatChannel(ChatChannel): | |||||
logger.error("Hot reload failed, try to login without hot reload") | logger.error("Hot reload failed, try to login without hot reload") | ||||
itchat.logout() | itchat.logout() | ||||
os.remove(status_path) | os.remove(status_path) | ||||
itchat.auto_login( | |||||
enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback | |||||
) | |||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) | |||||
else: | else: | ||||
raise e | raise e | ||||
self.user_id = itchat.instance.storageClass.userName | self.user_id = itchat.instance.storageClass.userName | ||||
self.name = itchat.instance.storageClass.nickName | self.name = itchat.instance.storageClass.nickName | ||||
logger.info( | |||||
"Wechat login success, user_id: {}, nickname: {}".format( | |||||
self.user_id, self.name | |||||
) | |||||
) | |||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) | |||||
# start message listener | # start message listener | ||||
itchat.run() | itchat.run() | ||||
@@ -173,16 +158,10 @@ class WechatChannel(ChatChannel): | |||||
elif cmsg.ctype == ContextType.PATPAT: | elif cmsg.ctype == ContextType.PATPAT: | ||||
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content)) | logger.debug("[WX]receive patpat msg: {}".format(cmsg.content)) | ||||
elif cmsg.ctype == ContextType.TEXT: | elif cmsg.ctype == ContextType.TEXT: | ||||
logger.debug( | |||||
"[WX]receive text msg: {}, cmsg={}".format( | |||||
json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg | |||||
) | |||||
) | |||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | |||||
else: | else: | ||||
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg)) | logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg)) | ||||
context = self._compose_context( | |||||
cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg | |||||
) | |||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg) | |||||
if context: | if context: | ||||
self.produce(context) | self.produce(context) | ||||
@@ -202,9 +181,7 @@ class WechatChannel(ChatChannel): | |||||
pass | pass | ||||
else: | else: | ||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content)) | logger.debug("[WX]receive group msg: {}".format(cmsg.content)) | ||||
context = self._compose_context( | |||||
cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg | |||||
) | |||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) | |||||
if context: | if context: | ||||
self.produce(context) | self.produce(context) | ||||
@@ -27,37 +27,23 @@ class WeChatMessage(ChatMessage): | |||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 | self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 | ||||
self._prepare_fn = lambda: itchat_msg.download(self.content) | self._prepare_fn = lambda: itchat_msg.download(self.content) | ||||
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000: | elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000: | ||||
if is_group and ( | |||||
"加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"] | |||||
): | |||||
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]): | |||||
self.ctype = ContextType.JOIN_GROUP | self.ctype = ContextType.JOIN_GROUP | ||||
self.content = itchat_msg["Content"] | self.content = itchat_msg["Content"] | ||||
# 这里只能得到nickname, actual_user_id还是机器人的id | # 这里只能得到nickname, actual_user_id还是机器人的id | ||||
if "加入了群聊" in itchat_msg["Content"]: | if "加入了群聊" in itchat_msg["Content"]: | ||||
self.actual_user_nickname = re.findall( | |||||
r"\"(.*?)\"", itchat_msg["Content"] | |||||
)[-1] | |||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1] | |||||
elif "加入群聊" in itchat_msg["Content"]: | elif "加入群聊" in itchat_msg["Content"]: | ||||
self.actual_user_nickname = re.findall( | |||||
r"\"(.*?)\"", itchat_msg["Content"] | |||||
)[0] | |||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] | |||||
elif "拍了拍我" in itchat_msg["Content"]: | elif "拍了拍我" in itchat_msg["Content"]: | ||||
self.ctype = ContextType.PATPAT | self.ctype = ContextType.PATPAT | ||||
self.content = itchat_msg["Content"] | self.content = itchat_msg["Content"] | ||||
if is_group: | if is_group: | ||||
self.actual_user_nickname = re.findall( | |||||
r"\"(.*?)\"", itchat_msg["Content"] | |||||
)[0] | |||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] | |||||
else: | else: | ||||
raise NotImplementedError( | |||||
"Unsupported note message: " + itchat_msg["Content"] | |||||
) | |||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"]) | |||||
else: | else: | ||||
raise NotImplementedError( | |||||
"Unsupported message type: Type:{} MsgType:{}".format( | |||||
itchat_msg["Type"], itchat_msg["MsgType"] | |||||
) | |||||
) | |||||
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"])) | |||||
self.from_user_id = itchat_msg["FromUserName"] | self.from_user_id = itchat_msg["FromUserName"] | ||||
self.to_user_id = itchat_msg["ToUserName"] | self.to_user_id = itchat_msg["ToUserName"] | ||||
@@ -60,13 +60,9 @@ class WechatyChannel(ChatChannel): | |||||
receiver_id = context["receiver"] | receiver_id = context["receiver"] | ||||
loop = asyncio.get_event_loop() | loop = asyncio.get_event_loop() | ||||
if context["isgroup"]: | if context["isgroup"]: | ||||
receiver = asyncio.run_coroutine_threadsafe( | |||||
self.bot.Room.find(receiver_id), loop | |||||
).result() | |||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result() | |||||
else: | else: | ||||
receiver = asyncio.run_coroutine_threadsafe( | |||||
self.bot.Contact.find(receiver_id), loop | |||||
).result() | |||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result() | |||||
msg = None | msg = None | ||||
if reply.type == ReplyType.TEXT: | if reply.type == ReplyType.TEXT: | ||||
msg = reply.content | msg = reply.content | ||||
@@ -83,9 +79,7 @@ class WechatyChannel(ChatChannel): | |||||
voiceLength = int(any_to_sil(file_path, sil_file)) | voiceLength = int(any_to_sil(file_path, sil_file)) | ||||
if voiceLength >= 60000: | if voiceLength >= 60000: | ||||
voiceLength = 60000 | voiceLength = 60000 | ||||
logger.info( | |||||
"[WX] voice too long, length={}, set to 60s".format(voiceLength) | |||||
) | |||||
logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength)) | |||||
# 发送语音 | # 发送语音 | ||||
t = int(time.time()) | t = int(time.time()) | ||||
msg = FileBox.from_file(sil_file, name=str(t) + ".sil") | msg = FileBox.from_file(sil_file, name=str(t) + ".sil") | ||||
@@ -98,9 +92,7 @@ class WechatyChannel(ChatChannel): | |||||
os.remove(sil_file) | os.remove(sil_file) | ||||
except Exception as e: | except Exception as e: | ||||
pass | pass | ||||
logger.info( | |||||
"[WX] sendVoice={}, receiver={}".format(reply.content, receiver) | |||||
) | |||||
logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver)) | |||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | ||||
img_url = reply.content | img_url = reply.content | ||||
t = int(time.time()) | t = int(time.time()) | ||||
@@ -111,9 +103,7 @@ class WechatyChannel(ChatChannel): | |||||
image_storage = reply.content | image_storage = reply.content | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
t = int(time.time()) | t = int(time.time()) | ||||
msg = FileBox.from_base64( | |||||
base64.b64encode(image_storage.read()), str(t) + ".png" | |||||
) | |||||
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png") | |||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | ||||
logger.info("[WX] sendImage, receiver={}".format(receiver)) | logger.info("[WX] sendImage, receiver={}".format(receiver)) | ||||
@@ -45,16 +45,12 @@ class WechatyMessage(ChatMessage, aobject): | |||||
def func(): | def func(): | ||||
loop = asyncio.get_event_loop() | loop = asyncio.get_event_loop() | ||||
asyncio.run_coroutine_threadsafe( | |||||
voice_file.to_file(self.content), loop | |||||
).result() | |||||
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result() | |||||
self._prepare_fn = func | self._prepare_fn = func | ||||
else: | else: | ||||
raise NotImplementedError( | |||||
"Unsupported message type: {}".format(wechaty_msg.type()) | |||||
) | |||||
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type())) | |||||
from_contact = wechaty_msg.talker() # 获取消息的发送者 | from_contact = wechaty_msg.talker() # 获取消息的发送者 | ||||
self.from_user_id = from_contact.contact_id | self.from_user_id = from_contact.contact_id | ||||
@@ -73,9 +69,7 @@ class WechatyMessage(ChatMessage, aobject): | |||||
self.to_user_id = to_contact.contact_id | self.to_user_id = to_contact.contact_id | ||||
self.to_user_nickname = to_contact.name | self.to_user_nickname = to_contact.name | ||||
if ( | |||||
self.is_group or wechaty_msg.is_self() | |||||
): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 | |||||
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 | |||||
self.other_user_id = self.to_user_id | self.other_user_id = self.to_user_id | ||||
self.other_user_nickname = self.to_user_nickname | self.other_user_nickname = self.to_user_nickname | ||||
else: | else: | ||||
@@ -1,16 +1,17 @@ | |||||
import time | import time | ||||
import web | import web | ||||
from wechatpy import parse_message | |||||
from wechatpy.replies import create_reply | |||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage | |||||
from bridge.context import * | from bridge.context import * | ||||
from bridge.reply import * | from bridge.reply import * | ||||
from channel.wechatmp.common import * | from channel.wechatmp.common import * | ||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | from channel.wechatmp.wechatmp_channel import WechatMPChannel | ||||
from wechatpy import parse_message | |||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage | |||||
from common.log import logger | from common.log import logger | ||||
from config import conf | from config import conf | ||||
from wechatpy.replies import create_reply | |||||
# This class is instantiated once per query | # This class is instantiated once per query | ||||
class Query: | class Query: | ||||
@@ -50,29 +51,19 @@ class Query: | |||||
) | ) | ||||
) | ) | ||||
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): | if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): | ||||
context = channel._compose_context( | |||||
wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg | |||||
) | |||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg) | |||||
else: | else: | ||||
context = channel._compose_context( | |||||
wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg | |||||
) | |||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg) | |||||
if context: | if context: | ||||
# set private openai_api_key | # set private openai_api_key | ||||
# if from_user is not changed in itchat, this can be placed at chat_channel | # if from_user is not changed in itchat, this can be placed at chat_channel | ||||
user_data = conf().get_user_data(from_user) | user_data = conf().get_user_data(from_user) | ||||
context["openai_api_key"] = user_data.get( | |||||
"openai_api_key" | |||||
) # None or user openai_api_key | |||||
context["openai_api_key"] = user_data.get("openai_api_key") # None or user openai_api_key | |||||
channel.produce(context) | channel.produce(context) | ||||
# The reply will be sent by channel.send() in another thread | # The reply will be sent by channel.send() in another thread | ||||
return "success" | return "success" | ||||
elif msg.type == "event": | elif msg.type == "event": | ||||
logger.info( | |||||
"[wechatmp] Event {} from {}".format( | |||||
msg.event, msg.source | |||||
) | |||||
) | |||||
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source)) | |||||
if msg.event in ["subscribe", "subscribe_scan"]: | if msg.event in ["subscribe", "subscribe_scan"]: | ||||
reply_text = subscribe_msg() | reply_text = subscribe_msg() | ||||
replyPost = create_reply(reply_text, msg) | replyPost = create_reply(reply_text, msg) | ||||
@@ -1,10 +1,12 @@ | |||||
import textwrap | import textwrap | ||||
import web | |||||
from config import conf | |||||
from wechatpy.utils import check_signature | |||||
import web | |||||
from wechatpy.crypto import WeChatCrypto | from wechatpy.crypto import WeChatCrypto | ||||
from wechatpy.exceptions import InvalidSignatureException | from wechatpy.exceptions import InvalidSignatureException | ||||
from wechatpy.utils import check_signature | |||||
from config import conf | |||||
MAX_UTF8_LEN = 2048 | MAX_UTF8_LEN = 2048 | ||||
@@ -1,17 +1,18 @@ | |||||
import time | |||||
import asyncio | import asyncio | ||||
import time | |||||
import web | import web | ||||
from wechatpy import parse_message | |||||
from wechatpy.replies import ImageReply, VoiceReply, create_reply | |||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage | |||||
from bridge.context import * | from bridge.context import * | ||||
from bridge.reply import * | from bridge.reply import * | ||||
from channel.wechatmp.common import * | from channel.wechatmp.common import * | ||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | from channel.wechatmp.wechatmp_channel import WechatMPChannel | ||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage | |||||
from common.log import logger | from common.log import logger | ||||
from config import conf | from config import conf | ||||
from wechatpy import parse_message | |||||
from wechatpy.replies import create_reply, ImageReply, VoiceReply | |||||
# This class is instantiated once per query | # This class is instantiated once per query | ||||
class Query: | class Query: | ||||
@@ -49,21 +50,15 @@ class Query: | |||||
if ( | if ( | ||||
from_user not in channel.cache_dict | from_user not in channel.cache_dict | ||||
and from_user not in channel.running | and from_user not in channel.running | ||||
or content.startswith("#") | |||||
and message_id not in channel.request_cnt # insert the godcmd | |||||
or content.startswith("#") | |||||
and message_id not in channel.request_cnt # insert the godcmd | |||||
): | ): | ||||
# The first query begin | # The first query begin | ||||
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): | if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): | ||||
context = channel._compose_context( | |||||
wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg | |||||
) | |||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg) | |||||
else: | else: | ||||
context = channel._compose_context( | |||||
wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg | |||||
) | |||||
logger.debug( | |||||
"[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported) | |||||
) | |||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg) | |||||
logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported)) | |||||
if supported and context: | if supported and context: | ||||
# set private openai_api_key | # set private openai_api_key | ||||
@@ -94,23 +89,17 @@ class Query: | |||||
"""\ | """\ | ||||
未知错误,请稍后再试""" | 未知错误,请稍后再试""" | ||||
) | ) | ||||
replyPost = create_reply(reply_text, msg) | replyPost = create_reply(reply_text, msg) | ||||
return encrypt_func(replyPost.render()) | return encrypt_func(replyPost.render()) | ||||
# Wechat official server will request 3 times (5 seconds each), with the same message_id. | # Wechat official server will request 3 times (5 seconds each), with the same message_id. | ||||
# Because the interval is 5 seconds, here assumed that do not have multithreading problems. | # Because the interval is 5 seconds, here assumed that do not have multithreading problems. | ||||
request_cnt = channel.request_cnt.get(message_id, 0) + 1 | request_cnt = channel.request_cnt.get(message_id, 0) + 1 | ||||
channel.request_cnt[message_id] = request_cnt | channel.request_cnt[message_id] = request_cnt | ||||
logger.info( | logger.info( | ||||
"[wechatmp] Request {} from {} {} {}:{}\n{}".format( | "[wechatmp] Request {} from {} {} {}:{}\n{}".format( | ||||
request_cnt, | |||||
from_user, | |||||
message_id, | |||||
web.ctx.env.get("REMOTE_ADDR"), | |||||
web.ctx.env.get("REMOTE_PORT"), | |||||
content | |||||
request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content | |||||
) | ) | ||||
) | ) | ||||
@@ -130,7 +119,7 @@ class Query: | |||||
time.sleep(2) | time.sleep(2) | ||||
# and do nothing, waiting for the next request | # and do nothing, waiting for the next request | ||||
return "success" | return "success" | ||||
else: # request_cnt == 3: | |||||
else: # request_cnt == 3: | |||||
# return timeout message | # return timeout message | ||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】" | reply_text = "【正在思考中,回复任意文字尝试获取回复】" | ||||
replyPost = create_reply(reply_text, msg) | replyPost = create_reply(reply_text, msg) | ||||
@@ -140,10 +129,7 @@ class Query: | |||||
channel.request_cnt.pop(message_id) | channel.request_cnt.pop(message_id) | ||||
# no return because of bandwords or other reasons | # no return because of bandwords or other reasons | ||||
if ( | |||||
from_user not in channel.cache_dict | |||||
and from_user not in channel.running | |||||
): | |||||
if from_user not in channel.cache_dict and from_user not in channel.running: | |||||
return "success" | return "success" | ||||
# Only one request can access to the cached data | # Only one request can access to the cached data | ||||
@@ -152,7 +138,7 @@ class Query: | |||||
except KeyError: | except KeyError: | ||||
return "success" | return "success" | ||||
if (reply_type == "text"): | |||||
if reply_type == "text": | |||||
if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN: | if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN: | ||||
reply_text = reply_content | reply_text = reply_content | ||||
else: | else: | ||||
@@ -177,7 +163,7 @@ class Query: | |||||
replyPost = create_reply(reply_text, msg) | replyPost = create_reply(reply_text, msg) | ||||
return encrypt_func(replyPost.render()) | return encrypt_func(replyPost.render()) | ||||
elif (reply_type == "voice"): | |||||
elif reply_type == "voice": | |||||
media_id = reply_content | media_id = reply_content | ||||
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) | asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) | ||||
logger.info( | logger.info( | ||||
@@ -193,7 +179,7 @@ class Query: | |||||
replyPost.media_id = media_id | replyPost.media_id = media_id | ||||
return encrypt_func(replyPost.render()) | return encrypt_func(replyPost.render()) | ||||
elif (reply_type == "image"): | |||||
elif reply_type == "image": | |||||
media_id = reply_content | media_id = reply_content | ||||
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) | asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) | ||||
logger.info( | logger.info( | ||||
@@ -210,11 +196,7 @@ class Query: | |||||
return encrypt_func(replyPost.render()) | return encrypt_func(replyPost.render()) | ||||
elif msg.type == "event": | elif msg.type == "event": | ||||
logger.info( | |||||
"[wechatmp] Event {} from {}".format( | |||||
msg.event, msg.source | |||||
) | |||||
) | |||||
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source)) | |||||
if msg.event in ["subscribe", "subscribe_scan"]: | if msg.event in ["subscribe", "subscribe_scan"]: | ||||
reply_text = subscribe_msg() | reply_text = subscribe_msg() | ||||
replyPost = create_reply(reply_text, msg) | replyPost = create_reply(reply_text, msg) | ||||
@@ -1,24 +1,26 @@ | |||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import asyncio | |||||
import imghdr | |||||
import io | import io | ||||
import os | import os | ||||
import threading | |||||
import time | import time | ||||
import imghdr | |||||
import requests | import requests | ||||
import asyncio | |||||
import threading | |||||
from config import conf | |||||
import web | |||||
from wechatpy.crypto import WeChatCrypto | |||||
from wechatpy.exceptions import WeChatClientException | |||||
from bridge.context import * | from bridge.context import * | ||||
from bridge.reply import * | from bridge.reply import * | ||||
from common.log import logger | |||||
from common.singleton import singleton | |||||
from voice.audio_convert import any_to_mp3 | |||||
from channel.chat_channel import ChatChannel | from channel.chat_channel import ChatChannel | ||||
from channel.wechatmp.common import * | from channel.wechatmp.common import * | ||||
from channel.wechatmp.wechatmp_client import WechatMPClient | from channel.wechatmp.wechatmp_client import WechatMPClient | ||||
from wechatpy.exceptions import WeChatClientException | |||||
from wechatpy.crypto import WeChatCrypto | |||||
from common.log import logger | |||||
from common.singleton import singleton | |||||
from config import conf | |||||
from voice.audio_convert import any_to_mp3 | |||||
import web | |||||
# If using SSL, uncomment the following lines, and modify the certificate path. | # If using SSL, uncomment the following lines, and modify the certificate path. | ||||
# from cheroot.server import HTTPServer | # from cheroot.server import HTTPServer | ||||
# from cheroot.ssl.builtin import BuiltinSSLAdapter | # from cheroot.ssl.builtin import BuiltinSSLAdapter | ||||
@@ -54,7 +56,6 @@ class WechatMPChannel(ChatChannel): | |||||
t.setDaemon(True) | t.setDaemon(True) | ||||
t.start() | t.start() | ||||
def startup(self): | def startup(self): | ||||
if self.passive_reply: | if self.passive_reply: | ||||
urls = ("/wx", "channel.wechatmp.passive_reply.Query") | urls = ("/wx", "channel.wechatmp.passive_reply.Query") | ||||
@@ -84,7 +85,7 @@ class WechatMPChannel(ChatChannel): | |||||
elif reply.type == ReplyType.VOICE: | elif reply.type == ReplyType.VOICE: | ||||
try: | try: | ||||
voice_file_path = reply.content | voice_file_path = reply.content | ||||
with open(voice_file_path, 'rb') as f: | |||||
with open(voice_file_path, "rb") as f: | |||||
# support: <2M, <60s, mp3/wma/wav/amr | # support: <2M, <60s, mp3/wma/wav/amr | ||||
response = self.client.material.add("voice", f) | response = self.client.material.add("voice", f) | ||||
logger.debug("[wechatmp] upload voice response: {}".format(response)) | logger.debug("[wechatmp] upload voice response: {}".format(response)) | ||||
@@ -107,7 +108,7 @@ class WechatMPChannel(ChatChannel): | |||||
image_storage.write(block) | image_storage.write(block) | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
image_type = imghdr.what(image_storage) | image_type = imghdr.what(image_storage) | ||||
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type | |||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type | |||||
content_type = "image/" + image_type | content_type = "image/" + image_type | ||||
try: | try: | ||||
response = self.client.material.add("image", (filename, image_storage, content_type)) | response = self.client.material.add("image", (filename, image_storage, content_type)) | ||||
@@ -122,7 +123,7 @@ class WechatMPChannel(ChatChannel): | |||||
image_storage = reply.content | image_storage = reply.content | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
image_type = imghdr.what(image_storage) | image_type = imghdr.what(image_storage) | ||||
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type | |||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type | |||||
content_type = "image/" + image_type | content_type = "image/" + image_type | ||||
try: | try: | ||||
response = self.client.material.add("image", (filename, image_storage, content_type)) | response = self.client.material.add("image", (filename, image_storage, content_type)) | ||||
@@ -137,7 +138,7 @@ class WechatMPChannel(ChatChannel): | |||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR: | if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR: | ||||
reply_text = reply.content | reply_text = reply.content | ||||
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN) | texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN) | ||||
if len(texts)>1: | |||||
if len(texts) > 1: | |||||
logger.info("[wechatmp] text too long, split into {} parts".format(len(texts))) | logger.info("[wechatmp] text too long, split into {} parts".format(len(texts))) | ||||
for text in texts: | for text in texts: | ||||
self.client.message.send_text(receiver, text) | self.client.message.send_text(receiver, text) | ||||
@@ -174,7 +175,7 @@ class WechatMPChannel(ChatChannel): | |||||
image_storage.write(block) | image_storage.write(block) | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
image_type = imghdr.what(image_storage) | image_type = imghdr.what(image_storage) | ||||
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type | |||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type | |||||
content_type = "image/" + image_type | content_type = "image/" + image_type | ||||
try: | try: | ||||
response = self.client.media.upload("image", (filename, image_storage, content_type)) | response = self.client.media.upload("image", (filename, image_storage, content_type)) | ||||
@@ -188,7 +189,7 @@ class WechatMPChannel(ChatChannel): | |||||
image_storage = reply.content | image_storage = reply.content | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
image_type = imghdr.what(image_storage) | image_type = imghdr.what(image_storage) | ||||
filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type | |||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type | |||||
content_type = "image/" + image_type | content_type = "image/" + image_type | ||||
try: | try: | ||||
response = self.client.media.upload("image", (filename, image_storage, content_type)) | response = self.client.media.upload("image", (filename, image_storage, content_type)) | ||||
@@ -201,20 +202,12 @@ class WechatMPChannel(ChatChannel): | |||||
return | return | ||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 | def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 | ||||
logger.debug( | |||||
"[wechatmp] Success to generate reply, msgId={}".format( | |||||
context["msg"].msg_id | |||||
) | |||||
) | |||||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id)) | |||||
if self.passive_reply: | if self.passive_reply: | ||||
self.running.remove(session_id) | self.running.remove(session_id) | ||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | ||||
logger.exception( | |||||
"[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format( | |||||
context["msg"].msg_id, exception | |||||
) | |||||
) | |||||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception)) | |||||
if self.passive_reply: | if self.passive_reply: | ||||
assert session_id not in self.cache_dict | assert session_id not in self.cache_dict | ||||
self.running.remove(session_id) | self.running.remove(session_id) |
@@ -1,17 +1,16 @@ | |||||
import time | |||||
import threading | import threading | ||||
from channel.wechatmp.common import * | |||||
import time | |||||
from wechatpy.client import WeChatClient | from wechatpy.client import WeChatClient | ||||
from common.log import logger | |||||
from wechatpy.exceptions import APILimitedException | from wechatpy.exceptions import APILimitedException | ||||
from channel.wechatmp.common import * | |||||
from common.log import logger | |||||
class WechatMPClient(WeChatClient): | class WechatMPClient(WeChatClient): | ||||
def __init__(self, appid, secret, access_token=None, | |||||
session=None, timeout=None, auto_retry=True): | |||||
super(WechatMPClient, self).__init__( | |||||
appid, secret, access_token, session, timeout, auto_retry | |||||
) | |||||
def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True): | |||||
super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry) | |||||
self.fetch_access_token_lock = threading.Lock() | self.fetch_access_token_lock = threading.Lock() | ||||
def clear_quota(self): | def clear_quota(self): | ||||
@@ -20,7 +19,7 @@ class WechatMPClient(WeChatClient): | |||||
def clear_quota_v2(self): | def clear_quota_v2(self): | ||||
return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret}) | return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret}) | ||||
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token | |||||
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token | |||||
with self.fetch_access_token_lock: | with self.fetch_access_token_lock: | ||||
access_token = self.session.get(self.access_token_key) | access_token = self.session.get(self.access_token_key) | ||||
if access_token: | if access_token: | ||||
@@ -31,11 +30,11 @@ class WechatMPClient(WeChatClient): | |||||
return access_token | return access_token | ||||
return super().fetch_access_token() | return super().fetch_access_token() | ||||
def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试 | |||||
def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试 | |||||
try: | try: | ||||
return super()._request(method, url_or_endpoint, **kwargs) | return super()._request(method, url_or_endpoint, **kwargs) | ||||
except APILimitedException as e: | except APILimitedException as e: | ||||
logger.error("[wechatmp] API quata has been used up. {}".format(e)) | logger.error("[wechatmp] API quata has been used up. {}".format(e)) | ||||
response = self.clear_quota_v2() | response = self.clear_quota_v2() | ||||
logger.debug("[wechatmp] API quata has been cleard, {}".format(response)) | logger.debug("[wechatmp] API quata has been cleard, {}".format(response)) | ||||
return super()._request(method, url_or_endpoint, **kwargs) | |||||
return super()._request(method, url_or_endpoint, **kwargs) |
@@ -6,7 +6,6 @@ from common.log import logger | |||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
class WeChatMPMessage(ChatMessage): | class WeChatMPMessage(ChatMessage): | ||||
def __init__(self, msg, client=None): | def __init__(self, msg, client=None): | ||||
super().__init__(msg) | super().__init__(msg) | ||||
@@ -18,12 +17,9 @@ class WeChatMPMessage(ChatMessage): | |||||
self.ctype = ContextType.TEXT | self.ctype = ContextType.TEXT | ||||
self.content = msg.content | self.content = msg.content | ||||
elif msg.type == "voice": | elif msg.type == "voice": | ||||
if msg.recognition == None: | if msg.recognition == None: | ||||
self.ctype = ContextType.VOICE | self.ctype = ContextType.VOICE | ||||
self.content = ( | |||||
TmpDir().path() + msg.media_id + "." + msg.format | |||||
) # content直接存临时目录路径 | |||||
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径 | |||||
def download_voice(): | def download_voice(): | ||||
# 如果响应状态码是200,则将响应内容写入本地文件 | # 如果响应状态码是200,则将响应内容写入本地文件 | ||||
@@ -32,9 +28,7 @@ class WeChatMPMessage(ChatMessage): | |||||
with open(self.content, "wb") as f: | with open(self.content, "wb") as f: | ||||
f.write(response.content) | f.write(response.content) | ||||
else: | else: | ||||
logger.info( | |||||
f"[wechatmp] Failed to download voice file, {response.content}" | |||||
) | |||||
logger.info(f"[wechatmp] Failed to download voice file, {response.content}") | |||||
self._prepare_fn = download_voice | self._prepare_fn = download_voice | ||||
else: | else: | ||||
@@ -43,6 +37,7 @@ class WeChatMPMessage(ChatMessage): | |||||
elif msg.type == "image": | elif msg.type == "image": | ||||
self.ctype = ContextType.IMAGE | self.ctype = ContextType.IMAGE | ||||
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径 | self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径 | ||||
def download_image(): | def download_image(): | ||||
# 如果响应状态码是200,则将响应内容写入本地文件 | # 如果响应状态码是200,则将响应内容写入本地文件 | ||||
response = client.media.download(msg.media_id) | response = client.media.download(msg.media_id) | ||||
@@ -50,15 +45,11 @@ class WeChatMPMessage(ChatMessage): | |||||
with open(self.content, "wb") as f: | with open(self.content, "wb") as f: | ||||
f.write(response.content) | f.write(response.content) | ||||
else: | else: | ||||
logger.info( | |||||
f"[wechatmp] Failed to download image file, {response.content}" | |||||
) | |||||
logger.info(f"[wechatmp] Failed to download image file, {response.content}") | |||||
self._prepare_fn = download_image | self._prepare_fn = download_image | ||||
else: | else: | ||||
raise NotImplementedError( | |||||
"Unsupported message type: Type:{} ".format(msg.type) | |||||
) | |||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type)) | |||||
self.from_user_id = msg.source | self.from_user_id = msg.source | ||||
self.to_user_id = msg.target | self.to_user_id = msg.target | ||||
@@ -13,23 +13,15 @@ def time_checker(f): | |||||
if chat_time_module: | if chat_time_module: | ||||
chat_start_time = _config.get("chat_start_time", "00:00") | chat_start_time = _config.get("chat_start_time", "00:00") | ||||
chat_stopt_time = _config.get("chat_stop_time", "24:00") | chat_stopt_time = _config.get("chat_stop_time", "24:00") | ||||
time_regex = re.compile( | |||||
r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$" | |||||
) # 时间匹配,包含24:00 | |||||
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00 | |||||
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 | starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 | ||||
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 | stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 | ||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 | chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 | ||||
# 时间格式检查 | # 时间格式检查 | ||||
if not ( | |||||
starttime_format_check and stoptime_format_check and chat_time_check | |||||
): | |||||
logger.warn( | |||||
"时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format( | |||||
starttime_format_check, stoptime_format_check | |||||
) | |||||
) | |||||
if not (starttime_format_check and stoptime_format_check and chat_time_check): | |||||
logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check)) | |||||
if chat_start_time > "23:59": | if chat_start_time > "23:59": | ||||
logger.error("启动时间可能存在问题,请修改!") | logger.error("启动时间可能存在问题,请修改!") | ||||
@@ -158,9 +158,7 @@ def load_config(): | |||||
for name, value in os.environ.items(): | for name, value in os.environ.items(): | ||||
name = name.lower() | name = name.lower() | ||||
if name in available_setting: | if name in available_setting: | ||||
logger.info( | |||||
"[INIT] override config by environ args: {}={}".format(name, value) | |||||
) | |||||
logger.info("[INIT] override config by environ args: {}={}".format(name, value)) | |||||
try: | try: | ||||
config[name] = eval(value) | config[name] = eval(value) | ||||
except: | except: | ||||
@@ -50,9 +50,7 @@ class Banwords(Plugin): | |||||
self.reply_action = conf.get("reply_action", "ignore") | self.reply_action = conf.get("reply_action", "ignore") | ||||
logger.info("[Banwords] inited") | logger.info("[Banwords] inited") | ||||
except Exception as e: | except Exception as e: | ||||
logger.warn( | |||||
"[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ." | |||||
) | |||||
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .") | |||||
raise e | raise e | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
@@ -72,9 +70,7 @@ class Banwords(Plugin): | |||||
return | return | ||||
elif self.action == "replace": | elif self.action == "replace": | ||||
if self.searchr.ContainsAny(content): | if self.searchr.ContainsAny(content): | ||||
reply = Reply( | |||||
ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content) | |||||
) | |||||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)) | |||||
e_context["reply"] = reply | e_context["reply"] = reply | ||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
@@ -94,9 +90,7 @@ class Banwords(Plugin): | |||||
return | return | ||||
elif self.reply_action == "replace": | elif self.reply_action == "replace": | ||||
if self.searchr.ContainsAny(content): | if self.searchr.ContainsAny(content): | ||||
reply = Reply( | |||||
ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content) | |||||
) | |||||
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)) | |||||
e_context["reply"] = reply | e_context["reply"] = reply | ||||
e_context.action = EventAction.CONTINUE | e_context.action = EventAction.CONTINUE | ||||
return | return | ||||
@@ -76,9 +76,7 @@ class BDunit(Plugin): | |||||
Returns: | Returns: | ||||
string: access_token | string: access_token | ||||
""" | """ | ||||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format( | |||||
self.api_key, self.secret_key | |||||
) | |||||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key) | |||||
payload = "" | payload = "" | ||||
headers = {"Content-Type": "application/json", "Accept": "application/json"} | headers = {"Content-Type": "application/json", "Accept": "application/json"} | ||||
@@ -94,10 +92,7 @@ class BDunit(Plugin): | |||||
:returns: UNIT 解析结果。如果解析失败,返回 None | :returns: UNIT 解析结果。如果解析失败,返回 None | ||||
""" | """ | ||||
url = ( | |||||
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" | |||||
+ self.access_token | |||||
) | |||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token | |||||
request = { | request = { | ||||
"query": query, | "query": query, | ||||
"user_id": str(get_mac())[:32], | "user_id": str(get_mac())[:32], | ||||
@@ -124,10 +119,7 @@ class BDunit(Plugin): | |||||
:param query: 用户的指令字符串 | :param query: 用户的指令字符串 | ||||
:returns: UNIT 解析结果。如果解析失败,返回 None | :returns: UNIT 解析结果。如果解析失败,返回 None | ||||
""" | """ | ||||
url = ( | |||||
"https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" | |||||
+ self.access_token | |||||
) | |||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token | |||||
request = {"query": query, "user_id": str(get_mac())[:32]} | request = {"query": query, "user_id": str(get_mac())[:32]} | ||||
body = { | body = { | ||||
"log_id": str(uuid.uuid1()), | "log_id": str(uuid.uuid1()), | ||||
@@ -170,11 +162,7 @@ class BDunit(Plugin): | |||||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | if parsed and "result" in parsed and "response_list" in parsed["result"]: | ||||
response_list = parsed["result"]["response_list"] | response_list = parsed["result"]["response_list"] | ||||
for response in response_list: | for response in response_list: | ||||
if ( | |||||
"schema" in response | |||||
and "intent" in response["schema"] | |||||
and response["schema"]["intent"] == intent | |||||
): | |||||
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent: | |||||
return True | return True | ||||
return False | return False | ||||
else: | else: | ||||
@@ -198,12 +186,7 @@ class BDunit(Plugin): | |||||
logger.warning(e) | logger.warning(e) | ||||
return [] | return [] | ||||
for response in response_list: | for response in response_list: | ||||
if ( | |||||
"schema" in response | |||||
and "intent" in response["schema"] | |||||
and "slots" in response["schema"] | |||||
and response["schema"]["intent"] == intent | |||||
): | |||||
if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent: | |||||
return response["schema"]["slots"] | return response["schema"]["slots"] | ||||
return [] | return [] | ||||
else: | else: | ||||
@@ -239,11 +222,7 @@ class BDunit(Plugin): | |||||
if ( | if ( | ||||
"schema" in response | "schema" in response | ||||
and "intent_confidence" in response["schema"] | and "intent_confidence" in response["schema"] | ||||
and ( | |||||
not answer | |||||
or response["schema"]["intent_confidence"] | |||||
> answer["schema"]["intent_confidence"] | |||||
) | |||||
and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"]) | |||||
): | ): | ||||
answer = response | answer = response | ||||
return answer["action_list"][0]["say"] | return answer["action_list"][0]["say"] | ||||
@@ -267,11 +246,7 @@ class BDunit(Plugin): | |||||
logger.warning(e) | logger.warning(e) | ||||
return "" | return "" | ||||
for response in response_list: | for response in response_list: | ||||
if ( | |||||
"schema" in response | |||||
and "intent" in response["schema"] | |||||
and response["schema"]["intent"] == intent | |||||
): | |||||
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent: | |||||
try: | try: | ||||
return response["action_list"][0]["say"] | return response["action_list"][0]["say"] | ||||
except Exception as e: | except Exception as e: | ||||
@@ -84,9 +84,7 @@ class Dungeon(Plugin): | |||||
if len(clist) > 1: | if len(clist) > 1: | ||||
story = clist[1] | story = clist[1] | ||||
else: | else: | ||||
story = ( | |||||
"你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" | |||||
) | |||||
story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" | |||||
self.games[sessionid] = StoryTeller(bot, sessionid, story) | self.games[sessionid] = StoryTeller(bot, sessionid, story) | ||||
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) | reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) | ||||
e_context["reply"] = reply | e_context["reply"] = reply | ||||
@@ -102,11 +100,7 @@ class Dungeon(Plugin): | |||||
if kwargs.get("verbose") != True: | if kwargs.get("verbose") != True: | ||||
return help_text | return help_text | ||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | trigger_prefix = conf().get("plugin_trigger_prefix", "$") | ||||
help_text = ( | |||||
f"{trigger_prefix}开始冒险 " | |||||
+ "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" | |||||
+ f"{trigger_prefix}停止冒险: 结束游戏。\n" | |||||
) | |||||
help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n" | |||||
if kwargs.get("verbose") == True: | if kwargs.get("verbose") == True: | ||||
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" | help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" | ||||
return help_text | return help_text |
@@ -140,9 +140,7 @@ def get_help_text(isadmin, isgroup): | |||||
if plugins[plugin].enabled and not plugins[plugin].hidden: | if plugins[plugin].enabled and not plugins[plugin].hidden: | ||||
namecn = plugins[plugin].namecn | namecn = plugins[plugin].namecn | ||||
help_text += "\n%s:" % namecn | help_text += "\n%s:" % namecn | ||||
help_text += ( | |||||
PluginManager().instances[plugin].get_help_text(verbose=False).strip() | |||||
) | |||||
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip() | |||||
if ADMIN_COMMANDS and isadmin: | if ADMIN_COMMANDS and isadmin: | ||||
help_text += "\n\n管理员指令:\n" | help_text += "\n\n管理员指令:\n" | ||||
@@ -191,9 +189,7 @@ class Godcmd(Plugin): | |||||
COMMANDS["reset"]["alias"].append(custom_command) | COMMANDS["reset"]["alias"].append(custom_command) | ||||
self.password = gconf["password"] | self.password = gconf["password"] | ||||
self.admin_users = gconf[ | |||||
"admin_users" | |||||
] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 | |||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 | |||||
self.isrunning = True # 机器人是否运行中 | self.isrunning = True # 机器人是否运行中 | ||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | ||||
@@ -215,7 +211,7 @@ class Godcmd(Plugin): | |||||
reply.content = f"空指令,输入#help查看指令列表\n" | reply.content = f"空指令,输入#help查看指令列表\n" | ||||
e_context["reply"] = reply | e_context["reply"] = reply | ||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | |||||
return | |||||
# msg = e_context['context']['msg'] | # msg = e_context['context']['msg'] | ||||
channel = e_context["channel"] | channel = e_context["channel"] | ||||
user = e_context["context"]["receiver"] | user = e_context["context"]["receiver"] | ||||
@@ -248,11 +244,7 @@ class Godcmd(Plugin): | |||||
if not plugincls.enabled: | if not plugincls.enabled: | ||||
continue | continue | ||||
if query_name == name or query_name == plugincls.namecn: | if query_name == name or query_name == plugincls.namecn: | ||||
ok, result = True, PluginManager().instances[ | |||||
name | |||||
].get_help_text( | |||||
isgroup=isgroup, isadmin=isadmin, verbose=True | |||||
) | |||||
ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True) | |||||
break | break | ||||
if not ok: | if not ok: | ||||
result = "插件不存在或未启用" | result = "插件不存在或未启用" | ||||
@@ -285,11 +277,7 @@ class Godcmd(Plugin): | |||||
if isgroup: | if isgroup: | ||||
ok, result = False, "群聊不可执行管理员指令" | ok, result = False, "群聊不可执行管理员指令" | ||||
else: | else: | ||||
cmd = next( | |||||
c | |||||
for c, info in ADMIN_COMMANDS.items() | |||||
if cmd in info["alias"] | |||||
) | |||||
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"]) | |||||
if cmd == "stop": | if cmd == "stop": | ||||
self.isrunning = False | self.isrunning = False | ||||
ok, result = True, "服务已暂停" | ok, result = True, "服务已暂停" | ||||
@@ -325,18 +313,14 @@ class Godcmd(Plugin): | |||||
PluginManager().activate_plugins() | PluginManager().activate_plugins() | ||||
if len(new_plugins) > 0: | if len(new_plugins) > 0: | ||||
result += "\n发现新插件:\n" | result += "\n发现新插件:\n" | ||||
result += "\n".join( | |||||
[f"{p.name}_v{p.version}" for p in new_plugins] | |||||
) | |||||
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) | |||||
else: | else: | ||||
result += ", 未发现新插件" | result += ", 未发现新插件" | ||||
elif cmd == "setpri": | elif cmd == "setpri": | ||||
if len(args) != 2: | if len(args) != 2: | ||||
ok, result = False, "请提供插件名和优先级" | ok, result = False, "请提供插件名和优先级" | ||||
else: | else: | ||||
ok = PluginManager().set_plugin_priority( | |||||
args[0], int(args[1]) | |||||
) | |||||
ok = PluginManager().set_plugin_priority(args[0], int(args[1])) | |||||
if ok: | if ok: | ||||
result = "插件" + args[0] + "优先级已设置为" + args[1] | result = "插件" + args[0] + "优先级已设置为" + args[1] | ||||
else: | else: | ||||
@@ -33,9 +33,7 @@ class Hello(Plugin): | |||||
if e_context["context"].type == ContextType.JOIN_GROUP: | if e_context["context"].type == ContextType.JOIN_GROUP: | ||||
e_context["context"].type = ContextType.TEXT | e_context["context"].type = ContextType.TEXT | ||||
msg: ChatMessage = e_context["context"]["msg"] | msg: ChatMessage = e_context["context"]["msg"] | ||||
e_context[ | |||||
"context" | |||||
].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。' | |||||
e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。' | |||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | ||||
return | return | ||||
@@ -53,9 +51,7 @@ class Hello(Plugin): | |||||
reply.type = ReplyType.TEXT | reply.type = ReplyType.TEXT | ||||
msg: ChatMessage = e_context["context"]["msg"] | msg: ChatMessage = e_context["context"]["msg"] | ||||
if e_context["context"]["isgroup"]: | if e_context["context"]["isgroup"]: | ||||
reply.content = ( | |||||
f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" | |||||
) | |||||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" | |||||
else: | else: | ||||
reply.content = f"Hello, {msg.from_user_nickname}" | reply.content = f"Hello, {msg.from_user_nickname}" | ||||
e_context["reply"] = reply | e_context["reply"] = reply | ||||
@@ -1,5 +1,5 @@ | |||||
{ | { | ||||
"keyword": { | "keyword": { | ||||
"关键字匹配": "测试成功" | |||||
"关键字匹配": "测试成功" | |||||
} | } | ||||
} | |||||
} |
@@ -41,9 +41,7 @@ class Keyword(Plugin): | |||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | ||||
logger.info("[keyword] inited.") | logger.info("[keyword] inited.") | ||||
except Exception as e: | except Exception as e: | ||||
logger.warn( | |||||
"[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword ." | |||||
) | |||||
logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .") | |||||
raise e | raise e | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
@@ -31,23 +31,14 @@ class PluginManager: | |||||
plugincls.desc = kwargs.get("desc") | plugincls.desc = kwargs.get("desc") | ||||
plugincls.author = kwargs.get("author") | plugincls.author = kwargs.get("author") | ||||
plugincls.path = self.current_plugin_path | plugincls.path = self.current_plugin_path | ||||
plugincls.version = ( | |||||
kwargs.get("version") if kwargs.get("version") != None else "1.0" | |||||
) | |||||
plugincls.namecn = ( | |||||
kwargs.get("namecn") if kwargs.get("namecn") != None else name | |||||
) | |||||
plugincls.hidden = ( | |||||
kwargs.get("hidden") if kwargs.get("hidden") != None else False | |||||
) | |||||
plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0" | |||||
plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name | |||||
plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False | |||||
plugincls.enabled = True | plugincls.enabled = True | ||||
if self.current_plugin_path == None: | if self.current_plugin_path == None: | ||||
raise Exception("Plugin path not set") | raise Exception("Plugin path not set") | ||||
self.plugins[name.upper()] = plugincls | self.plugins[name.upper()] = plugincls | ||||
logger.info( | |||||
"Plugin %s_v%s registered, path=%s" | |||||
% (name, plugincls.version, plugincls.path) | |||||
) | |||||
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path)) | |||||
return wrapper | return wrapper | ||||
@@ -62,9 +53,7 @@ class PluginManager: | |||||
if os.path.exists("./plugins/plugins.json"): | if os.path.exists("./plugins/plugins.json"): | ||||
with open("./plugins/plugins.json", "r", encoding="utf-8") as f: | with open("./plugins/plugins.json", "r", encoding="utf-8") as f: | ||||
pconf = json.load(f) | pconf = json.load(f) | ||||
pconf["plugins"] = SortedDict( | |||||
lambda k, v: v["priority"], pconf["plugins"], reverse=True | |||||
) | |||||
pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True) | |||||
else: | else: | ||||
modified = True | modified = True | ||||
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)} | pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)} | ||||
@@ -90,26 +79,16 @@ class PluginManager: | |||||
if plugin_path in self.loaded: | if plugin_path in self.loaded: | ||||
if self.loaded[plugin_path] == None: | if self.loaded[plugin_path] == None: | ||||
logger.info("reload module %s" % plugin_name) | logger.info("reload module %s" % plugin_name) | ||||
self.loaded[plugin_path] = importlib.reload( | |||||
sys.modules[import_path] | |||||
) | |||||
dependent_module_names = [ | |||||
name | |||||
for name in sys.modules.keys() | |||||
if name.startswith(import_path + ".") | |||||
] | |||||
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path]) | |||||
dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")] | |||||
for name in dependent_module_names: | for name in dependent_module_names: | ||||
logger.info("reload module %s" % name) | logger.info("reload module %s" % name) | ||||
importlib.reload(sys.modules[name]) | importlib.reload(sys.modules[name]) | ||||
else: | else: | ||||
self.loaded[plugin_path] = importlib.import_module( | |||||
import_path | |||||
) | |||||
self.loaded[plugin_path] = importlib.import_module(import_path) | |||||
self.current_plugin_path = None | self.current_plugin_path = None | ||||
except Exception as e: | except Exception as e: | ||||
logger.exception( | |||||
"Failed to import plugin %s: %s" % (plugin_name, e) | |||||
) | |||||
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e)) | |||||
continue | continue | ||||
pconf = self.pconf | pconf = self.pconf | ||||
news = [self.plugins[name] for name in self.plugins] | news = [self.plugins[name] for name in self.plugins] | ||||
@@ -119,9 +98,7 @@ class PluginManager: | |||||
rawname = plugincls.name | rawname = plugincls.name | ||||
if rawname not in pconf["plugins"]: | if rawname not in pconf["plugins"]: | ||||
modified = True | modified = True | ||||
logger.info( | |||||
"Plugin %s not found in pconfig, adding to pconfig..." % name | |||||
) | |||||
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) | |||||
pconf["plugins"][rawname] = { | pconf["plugins"][rawname] = { | ||||
"enabled": plugincls.enabled, | "enabled": plugincls.enabled, | ||||
"priority": plugincls.priority, | "priority": plugincls.priority, | ||||
@@ -136,9 +113,7 @@ class PluginManager: | |||||
def refresh_order(self): | def refresh_order(self): | ||||
for event in self.listening_plugins.keys(): | for event in self.listening_plugins.keys(): | ||||
self.listening_plugins[event].sort( | |||||
key=lambda name: self.plugins[name].priority, reverse=True | |||||
) | |||||
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) | |||||
def activate_plugins(self): # 生成新开启的插件实例 | def activate_plugins(self): # 生成新开启的插件实例 | ||||
failed_plugins = [] | failed_plugins = [] | ||||
@@ -184,13 +159,8 @@ class PluginManager: | |||||
def emit_event(self, e_context: EventContext, *args, **kwargs): | def emit_event(self, e_context: EventContext, *args, **kwargs): | ||||
if e_context.event in self.listening_plugins: | if e_context.event in self.listening_plugins: | ||||
for name in self.listening_plugins[e_context.event]: | for name in self.listening_plugins[e_context.event]: | ||||
if ( | |||||
self.plugins[name].enabled | |||||
and e_context.action == EventAction.CONTINUE | |||||
): | |||||
logger.debug( | |||||
"Plugin %s triggered by event %s" % (name, e_context.event) | |||||
) | |||||
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: | |||||
logger.debug("Plugin %s triggered by event %s" % (name, e_context.event)) | |||||
instance = self.instances[name] | instance = self.instances[name] | ||||
instance.handlers[e_context.event](e_context, *args, **kwargs) | instance.handlers[e_context.event](e_context, *args, **kwargs) | ||||
return e_context | return e_context | ||||
@@ -262,9 +232,7 @@ class PluginManager: | |||||
source = json.load(f) | source = json.load(f) | ||||
if repo in source["repo"]: | if repo in source["repo"]: | ||||
repo = source["repo"][repo]["url"] | repo = source["repo"][repo]["url"] | ||||
match = re.match( | |||||
r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo | |||||
) | |||||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) | |||||
if not match: | if not match: | ||||
return False, "安装插件失败,source中的仓库地址不合法" | return False, "安装插件失败,source中的仓库地址不合法" | ||||
else: | else: | ||||
@@ -69,13 +69,9 @@ class Role(Plugin): | |||||
logger.info("[Role] inited") | logger.info("[Role] inited") | ||||
except Exception as e: | except Exception as e: | ||||
if isinstance(e, FileNotFoundError): | if isinstance(e, FileNotFoundError): | ||||
logger.warn( | |||||
f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." | |||||
) | |||||
logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") | |||||
else: | else: | ||||
logger.warn( | |||||
"[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." | |||||
) | |||||
logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") | |||||
raise e | raise e | ||||
def get_role(self, name, find_closest=True, min_sim=0.35): | def get_role(self, name, find_closest=True, min_sim=0.35): | ||||
@@ -143,9 +139,7 @@ class Role(Plugin): | |||||
else: | else: | ||||
help_text = f"未知角色类型。\n" | help_text = f"未知角色类型。\n" | ||||
help_text += "目前的角色类型有: \n" | help_text += "目前的角色类型有: \n" | ||||
help_text += ( | |||||
",".join([self.tags[tag][0] for tag in self.tags]) + "\n" | |||||
) | |||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n" | |||||
else: | else: | ||||
help_text = f"请输入角色类型。\n" | help_text = f"请输入角色类型。\n" | ||||
help_text += "目前的角色类型有: \n" | help_text += "目前的角色类型有: \n" | ||||
@@ -158,9 +152,7 @@ class Role(Plugin): | |||||
return | return | ||||
logger.debug("[Role] on_handle_context. content: %s" % content) | logger.debug("[Role] on_handle_context. content: %s" % content) | ||||
if desckey is not None: | if desckey is not None: | ||||
if len(clist) == 1 or ( | |||||
len(clist) > 1 and clist[1].lower() in ["help", "帮助"] | |||||
): | |||||
if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]): | |||||
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) | reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) | ||||
e_context["reply"] = reply | e_context["reply"] = reply | ||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
@@ -178,9 +170,7 @@ class Role(Plugin): | |||||
self.roles[role][desckey], | self.roles[role][desckey], | ||||
self.roles[role].get("wrapper", "%s"), | self.roles[role].get("wrapper", "%s"), | ||||
) | ) | ||||
reply = Reply( | |||||
ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey] | |||||
) | |||||
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]) | |||||
e_context["reply"] = reply | e_context["reply"] = reply | ||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
elif customize == True: | elif customize == True: | ||||
@@ -199,17 +189,10 @@ class Role(Plugin): | |||||
if not verbose: | if not verbose: | ||||
return help_text | return help_text | ||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | trigger_prefix = conf().get("plugin_trigger_prefix", "$") | ||||
help_text = ( | |||||
f"使用方法:\n{trigger_prefix}角色" | |||||
+ " 预设角色名: 设定角色为{预设角色名}。\n" | |||||
+ f"{trigger_prefix}role" | |||||
+ " 预设角色名: 同上,但使用英文设定。\n" | |||||
) | |||||
help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n" | |||||
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n" | help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n" | ||||
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" | help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" | ||||
help_text += ( | |||||
f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" | |||||
) | |||||
help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" | |||||
help_text += "\n目前的角色类型有: \n" | help_text += "\n目前的角色类型有: \n" | ||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n" | help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n" | ||||
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" | help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" | ||||
@@ -60,7 +60,7 @@ | |||||
> 该tool每天返回内容相同 | > 该tool每天返回内容相同 | ||||
#### 6.3. finance-news | |||||
#### 6.3. finance-news | |||||
###### 获取实时的金融财政新闻 | ###### 获取实时的金融财政新闻 | ||||
> 该工具需要解决browser tool 的google-chrome依赖安装 | > 该工具需要解决browser tool 的google-chrome依赖安装 | ||||
@@ -82,9 +82,7 @@ class Tool(Plugin): | |||||
return | return | ||||
elif content_list[1].startswith("reset"): | elif content_list[1].startswith("reset"): | ||||
logger.debug("[tool]: remind") | logger.debug("[tool]: remind") | ||||
e_context[ | |||||
"context" | |||||
].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" | |||||
e_context["context"].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" | |||||
e_context.action = EventAction.BREAK | e_context.action = EventAction.BREAK | ||||
return | return | ||||
@@ -93,18 +91,14 @@ class Tool(Plugin): | |||||
# Don't modify bot name | # Don't modify bot name | ||||
all_sessions = Bridge().get_bot("chat").sessions | all_sessions = Bridge().get_bot("chat").sessions | ||||
user_session = all_sessions.session_query( | |||||
query, e_context["context"]["session_id"] | |||||
).messages | |||||
user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages | |||||
# chatgpt-tool-hub will reply you with many tools | # chatgpt-tool-hub will reply you with many tools | ||||
logger.debug("[tool]: just-go") | logger.debug("[tool]: just-go") | ||||
try: | try: | ||||
_reply = self.app.ask(query, user_session) | _reply = self.app.ask(query, user_session) | ||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
all_sessions.session_reply( | |||||
_reply, e_context["context"]["session_id"] | |||||
) | |||||
all_sessions.session_reply(_reply, e_context["context"]["session_id"]) | |||||
except Exception as e: | except Exception as e: | ||||
logger.exception(e) | logger.exception(e) | ||||
logger.error(str(e)) | logger.error(str(e)) | ||||
@@ -178,4 +172,4 @@ class Tool(Plugin): | |||||
# filter not support tool | # filter not support tool | ||||
tool_list = self._filter_tool_list(tool_config.get("tools", [])) | tool_list = self._filter_tool_list(tool_config.get("tools", [])) | ||||
return app.create_app(tools_list=tool_list, **app_kwargs) | |||||
return app.create_app(tools_list=tool_list, **app_kwargs) |
@@ -33,6 +33,7 @@ def get_pcm_from_wav(wav_path): | |||||
wav = wave.open(wav_path, "rb") | wav = wave.open(wav_path, "rb") | ||||
return wav.readframes(wav.getnframes()) | return wav.readframes(wav.getnframes()) | ||||
def any_to_mp3(any_path, mp3_path): | def any_to_mp3(any_path, mp3_path): | ||||
""" | """ | ||||
把任意格式转成mp3文件 | 把任意格式转成mp3文件 | ||||
@@ -40,16 +41,13 @@ def any_to_mp3(any_path, mp3_path): | |||||
if any_path.endswith(".mp3"): | if any_path.endswith(".mp3"): | ||||
shutil.copy2(any_path, mp3_path) | shutil.copy2(any_path, mp3_path) | ||||
return | return | ||||
if ( | |||||
any_path.endswith(".sil") | |||||
or any_path.endswith(".silk") | |||||
or any_path.endswith(".slk") | |||||
): | |||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): | |||||
sil_to_wav(any_path, any_path) | sil_to_wav(any_path, any_path) | ||||
any_path = mp3_path | any_path = mp3_path | ||||
audio = AudioSegment.from_file(any_path) | audio = AudioSegment.from_file(any_path) | ||||
audio.export(mp3_path, format="mp3") | audio.export(mp3_path, format="mp3") | ||||
def any_to_wav(any_path, wav_path): | def any_to_wav(any_path, wav_path): | ||||
""" | """ | ||||
把任意格式转成wav文件 | 把任意格式转成wav文件 | ||||
@@ -57,11 +55,7 @@ def any_to_wav(any_path, wav_path): | |||||
if any_path.endswith(".wav"): | if any_path.endswith(".wav"): | ||||
shutil.copy2(any_path, wav_path) | shutil.copy2(any_path, wav_path) | ||||
return | return | ||||
if ( | |||||
any_path.endswith(".sil") | |||||
or any_path.endswith(".silk") | |||||
or any_path.endswith(".slk") | |||||
): | |||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): | |||||
return sil_to_wav(any_path, wav_path) | return sil_to_wav(any_path, wav_path) | ||||
audio = AudioSegment.from_file(any_path) | audio = AudioSegment.from_file(any_path) | ||||
audio.export(wav_path, format="wav") | audio.export(wav_path, format="wav") | ||||
@@ -71,11 +65,7 @@ def any_to_sil(any_path, sil_path): | |||||
""" | """ | ||||
把任意格式转成sil文件 | 把任意格式转成sil文件 | ||||
""" | """ | ||||
if ( | |||||
any_path.endswith(".sil") | |||||
or any_path.endswith(".silk") | |||||
or any_path.endswith(".slk") | |||||
): | |||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): | |||||
shutil.copy2(any_path, sil_path) | shutil.copy2(any_path, sil_path) | ||||
return 10000 | return 10000 | ||||
audio = AudioSegment.from_file(any_path) | audio = AudioSegment.from_file(any_path) | ||||
@@ -40,57 +40,33 @@ class AzureVoice(Voice): | |||||
config = json.load(fr) | config = json.load(fr) | ||||
self.api_key = conf().get("azure_voice_api_key") | self.api_key = conf().get("azure_voice_api_key") | ||||
self.api_region = conf().get("azure_voice_region") | self.api_region = conf().get("azure_voice_region") | ||||
self.speech_config = speechsdk.SpeechConfig( | |||||
subscription=self.api_key, region=self.api_region | |||||
) | |||||
self.speech_config.speech_synthesis_voice_name = config[ | |||||
"speech_synthesis_voice_name" | |||||
] | |||||
self.speech_config.speech_recognition_language = config[ | |||||
"speech_recognition_language" | |||||
] | |||||
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region) | |||||
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"] | |||||
self.speech_config.speech_recognition_language = config["speech_recognition_language"] | |||||
except Exception as e: | except Exception as e: | ||||
logger.warn("AzureVoice init failed: %s, ignore " % e) | logger.warn("AzureVoice init failed: %s, ignore " % e) | ||||
def voiceToText(self, voice_file): | def voiceToText(self, voice_file): | ||||
audio_config = speechsdk.AudioConfig(filename=voice_file) | audio_config = speechsdk.AudioConfig(filename=voice_file) | ||||
speech_recognizer = speechsdk.SpeechRecognizer( | |||||
speech_config=self.speech_config, audio_config=audio_config | |||||
) | |||||
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config) | |||||
result = speech_recognizer.recognize_once() | result = speech_recognizer.recognize_once() | ||||
if result.reason == speechsdk.ResultReason.RecognizedSpeech: | if result.reason == speechsdk.ResultReason.RecognizedSpeech: | ||||
logger.info( | |||||
"[Azure] voiceToText voice file name={} text={}".format( | |||||
voice_file, result.text | |||||
) | |||||
) | |||||
logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text)) | |||||
reply = Reply(ReplyType.TEXT, result.text) | reply = Reply(ReplyType.TEXT, result.text) | ||||
else: | else: | ||||
logger.error( | |||||
"[Azure] voiceToText error, result={}, canceldetails={}".format( | |||||
result, result.cancellation_details | |||||
) | |||||
) | |||||
logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details)) | |||||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") | reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") | ||||
return reply | return reply | ||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" | fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" | ||||
audio_config = speechsdk.AudioConfig(filename=fileName) | audio_config = speechsdk.AudioConfig(filename=fileName) | ||||
speech_synthesizer = speechsdk.SpeechSynthesizer( | |||||
speech_config=self.speech_config, audio_config=audio_config | |||||
) | |||||
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config) | |||||
result = speech_synthesizer.speak_text(text) | result = speech_synthesizer.speak_text(text) | ||||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: | if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: | ||||
logger.info( | |||||
"[Azure] textToVoice text={} voice file name={}".format(text, fileName) | |||||
) | |||||
logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName)) | |||||
reply = Reply(ReplyType.VOICE, fileName) | reply = Reply(ReplyType.VOICE, fileName) | ||||
else: | else: | ||||
logger.error( | |||||
"[Azure] textToVoice error, result={}, canceldetails={}".format( | |||||
result, result.cancellation_details | |||||
) | |||||
) | |||||
logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details)) | |||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | ||||
return reply | return reply |
@@ -85,9 +85,7 @@ class BaiduVoice(Voice): | |||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" | fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" | ||||
with open(fileName, "wb") as f: | with open(fileName, "wb") as f: | ||||
f.write(result) | f.write(result) | ||||
logger.info( | |||||
"[Baidu] textToVoice text={} voice file name={}".format(text, fileName) | |||||
) | |||||
logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName)) | |||||
reply = Reply(ReplyType.VOICE, fileName) | reply = Reply(ReplyType.VOICE, fileName) | ||||
else: | else: | ||||
logger.error("[Baidu] textToVoice error={}".format(result)) | logger.error("[Baidu] textToVoice error={}".format(result)) | ||||
@@ -24,11 +24,7 @@ class GoogleVoice(Voice): | |||||
audio = self.recognizer.record(source) | audio = self.recognizer.record(source) | ||||
try: | try: | ||||
text = self.recognizer.recognize_google(audio, language="zh-CN") | text = self.recognizer.recognize_google(audio, language="zh-CN") | ||||
logger.info( | |||||
"[Google] voiceToText text={} voice file name={}".format( | |||||
text, voice_file | |||||
) | |||||
) | |||||
logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file)) | |||||
reply = Reply(ReplyType.TEXT, text) | reply = Reply(ReplyType.TEXT, text) | ||||
except speech_recognition.UnknownValueError: | except speech_recognition.UnknownValueError: | ||||
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") | reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") | ||||
@@ -42,9 +38,7 @@ class GoogleVoice(Voice): | |||||
mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" | mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" | ||||
tts = gTTS(text=text, lang="zh") | tts = gTTS(text=text, lang="zh") | ||||
tts.save(mp3File) | tts.save(mp3File) | ||||
logger.info( | |||||
"[Google] textToVoice text={} voice file name={}".format(text, mp3File) | |||||
) | |||||
logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File)) | |||||
reply = Reply(ReplyType.VOICE, mp3File) | reply = Reply(ReplyType.VOICE, mp3File) | ||||
except Exception as e: | except Exception as e: | ||||
reply = Reply(ReplyType.ERROR, str(e)) | reply = Reply(ReplyType.ERROR, str(e)) | ||||
@@ -22,11 +22,7 @@ class OpenaiVoice(Voice): | |||||
result = openai.Audio.transcribe("whisper-1", file) | result = openai.Audio.transcribe("whisper-1", file) | ||||
text = result["text"] | text = result["text"] | ||||
reply = Reply(ReplyType.TEXT, text) | reply = Reply(ReplyType.TEXT, text) | ||||
logger.info( | |||||
"[Openai] voiceToText text={} voice file name={}".format( | |||||
text, voice_file | |||||
) | |||||
) | |||||
logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file)) | |||||
except Exception as e: | except Exception as e: | ||||
reply = Reply(ReplyType.ERROR, str(e)) | reply = Reply(ReplyType.ERROR, str(e)) | ||||
finally: | finally: | ||||
@@ -5,6 +5,7 @@ pytts voice service (offline) | |||||
import os | import os | ||||
import sys | import sys | ||||
import time | import time | ||||
import pyttsx3 | import pyttsx3 | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
@@ -12,6 +13,7 @@ from common.log import logger | |||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
from voice.voice import Voice | from voice.voice import Voice | ||||
class PyttsVoice(Voice): | class PyttsVoice(Voice): | ||||
engine = pyttsx3.init() | engine = pyttsx3.init() | ||||
@@ -20,7 +22,7 @@ class PyttsVoice(Voice): | |||||
self.engine.setProperty("rate", 125) | self.engine.setProperty("rate", 125) | ||||
# 音量 | # 音量 | ||||
self.engine.setProperty("volume", 1.0) | self.engine.setProperty("volume", 1.0) | ||||
if sys.platform == 'win32': | |||||
if sys.platform == "win32": | |||||
for voice in self.engine.getProperty("voices"): | for voice in self.engine.getProperty("voices"): | ||||
if "Chinese" in voice.name: | if "Chinese" in voice.name: | ||||
self.engine.setProperty("voice", voice.id) | self.engine.setProperty("voice", voice.id) | ||||
@@ -33,23 +35,23 @@ class PyttsVoice(Voice): | |||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
try: | try: | ||||
# avoid the same filename | # avoid the same filename | ||||
wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7fffffff) + ".wav" | |||||
wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav" | |||||
wavFile = TmpDir().path() + wavFileName | wavFile = TmpDir().path() + wavFileName | ||||
logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)) | logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)) | ||||
self.engine.save_to_file(text, wavFile) | self.engine.save_to_file(text, wavFile) | ||||
if sys.platform == 'win32': | |||||
if sys.platform == "win32": | |||||
self.engine.runAndWait() | self.engine.runAndWait() | ||||
else: | else: | ||||
# In ubuntu, runAndWait do not really wait until the file created. | |||||
# In ubuntu, runAndWait do not really wait until the file created. | |||||
# It will return once the task queue is empty, but the task is still running in coroutine. | # It will return once the task queue is empty, but the task is still running in coroutine. | ||||
# And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this. | # And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this. | ||||
# If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file. | # If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file. | ||||
# self.engine.runAndWait() | # self.engine.runAndWait() | ||||
# Before espeak fix this problem, we iterate the generator and control the waiting by ourself. | # Before espeak fix this problem, we iterate the generator and control the waiting by ourself. | ||||
# But this is not the canonical way to use it, for example if the file already exists it also cannot wait. | |||||
# But this is not the canonical way to use it, for example if the file already exists it also cannot wait. | |||||
self.engine.iterate() | self.engine.iterate() | ||||
while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()): | while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()): | ||||
time.sleep(0.1) | time.sleep(0.1) | ||||