From 38c8ceba12b9a0e1e6ce2e52de6e4cf6faa3a1fe Mon Sep 17 00:00:00 2001 From: lanvent Date: Sat, 11 Mar 2023 02:20:39 +0800 Subject: [PATCH 01/21] avoid repeatedly instantiating bot --- bot/chatgpt/chat_gpt_bot.py | 49 ++++++++++++++++--------------------- bridge/bridge.py | 19 ++++++++++---- common/singleton.py | 9 +++++++ 3 files changed, 44 insertions(+), 33 deletions(-) create mode 100644 common/singleton.py diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 78e344f..f94b1f7 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -7,16 +7,13 @@ from common.expired_dict import ExpiredDict import openai import time -if conf().get('expires_in_seconds'): - all_sessions = ExpiredDict(conf().get('expires_in_seconds')) -else: - all_sessions = dict() # OpenAI对话模型API (可用) class ChatGPTBot(Bot): def __init__(self): openai.api_key = conf().get('open_ai_api_key') proxy = conf().get('proxy') + self.sessions=SessionManager() if proxy: openai.proxy = proxy @@ -26,16 +23,16 @@ class ChatGPTBot(Bot): logger.info("[OPEN_AI] query={}".format(query)) session_id = context.get('session_id') or context.get('from_user_id') if query == '#清除记忆': - Session.clear_session(session_id) + self.sessions.clear_session(session_id) return '记忆已清除' elif query == '#清除所有': - Session.clear_all_session() + self.sessions.clear_all_session() return '所有人记忆已清除' elif query == '#更新配置': load_config() return '配置已更新' - session = Session.build_session_query(query, session_id) + session = self.sessions.build_session_query(query, session_id) logger.debug("[OPEN_AI] session query={}".format(session)) # if context.get('stream'): @@ -45,7 +42,7 @@ class ChatGPTBot(Bot): reply_content = self.reply_text(session, session_id, 0) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) if reply_content["completion_tokens"] > 0: - Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) + self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) return reply_content["content"] elif context.get('type', None) == 'IMAGE_CREATE': @@ -94,7 +91,7 @@ class ChatGPTBot(Bot): except Exception as e: # unknown exception logger.exception(e) - Session.clear_session(session_id) + self.sessions.clear_session(session_id) return {"completion_tokens": 0, "content": "请再问我一次吧"} def create_img(self, query, retry_count=0): @@ -119,10 +116,11 @@ class ChatGPTBot(Bot): except Exception as e: logger.exception(e) return None - -class Session(object): - @staticmethod - def build_session_query(query, session_id): + +class SessionManager(object): + def __init__(self): + self.sessions = {} + def build_session_query(self,query, session_id): ''' build query with conversation history e.g. [ @@ -135,36 +133,33 @@ class Session(object): :param session_id: session id :return: query content with conversaction ''' - session = all_sessions.get(session_id, []) + session = self.sessions.get(session_id, []) if len(session) == 0: system_prompt = conf().get("character_desc", "") system_item = {'role': 'system', 'content': system_prompt} session.append(system_item) - all_sessions[session_id] = session + self.sessions[session_id] = session user_item = {'role': 'user', 'content': query} session.append(user_item) return session - @staticmethod - def save_session(answer, session_id, total_tokens): + def save_session(self, answer, session_id, total_tokens): max_tokens = conf().get("conversation_max_tokens") if not max_tokens: # default 3000 max_tokens = 1000 max_tokens=int(max_tokens) - session = all_sessions.get(session_id) + session = self.sessions.get(session_id) if session: # append conversation gpt_item = {'role': 'assistant', 'content': answer} session.append(gpt_item) # discard exceed limit conversation - Session.discard_exceed_conversation(session, max_tokens, total_tokens) + self.discard_exceed_conversation(session, max_tokens, total_tokens) - - @staticmethod - def discard_exceed_conversation(session, max_tokens, total_tokens): + def discard_exceed_conversation(self, session, max_tokens, total_tokens): dec_tokens = int(total_tokens) # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) while dec_tokens > max_tokens: @@ -176,10 +171,8 @@ class Session(object): break dec_tokens = dec_tokens - max_tokens - @staticmethod - def clear_session(session_id): - all_sessions[session_id] = [] + def clear_session(self,session_id): + self.sessions[session_id] = [] - @staticmethod - def clear_all_session(): - all_sessions.clear() + def clear_all_session(self): + self.sessions.clear() diff --git a/bridge/bridge.py b/bridge/bridge.py index e739a7f..068a58e 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,16 +1,25 @@ from bot import bot_factory +from common.singleton import singleton from voice import voice_factory - +@singleton class Bridge(object): def __init__(self): - pass + self.bots = { + "chat": bot_factory.create_bot("chatGPT"), + "voice_to_text": voice_factory.create_voice("openai"), + # "text_to_voice": voice_factory.create_voice("baidu") + } + try: + self.bots["text_to_voice"] = voice_factory.create_voice("baidu") + except ModuleNotFoundError as e: + print(e) def fetch_reply_content(self, query, context): - return bot_factory.create_bot("chatGPT").reply(query, context) + return self.bots["chat"].reply(query, context) def fetch_voice_to_text(self, voiceFile): - return voice_factory.create_voice("openai").voiceToText(voiceFile) + return self.bots["voice_to_text"].voiceToText(voiceFile) def fetch_text_to_voice(self, text): - return voice_factory.create_voice("baidu").textToVoice(text) \ No newline at end of file + return self.bots["text_to_voice"].textToVoice(text) \ No newline at end of file diff --git a/common/singleton.py b/common/singleton.py new file mode 100644 index 0000000..b46095c --- /dev/null +++ b/common/singleton.py @@ -0,0 +1,9 @@ +def singleton(cls): + instances = {} + + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + + return get_instance From d6037422ac9c32523083dc301da63faa10eb352a Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 00:58:49 +0800 Subject: [PATCH 02/21] decouple message processing process --- bot/chatgpt/chat_gpt_bot.py | 45 ++++-- bridge/bridge.py | 5 + channel/wechat/wechat_channel.py | 253 ++++++++++++++++--------------- 3 files changed, 167 insertions(+), 136 deletions(-) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index f94b1f7..327cc95 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -19,19 +19,24 @@ class ChatGPTBot(Bot): def reply(self, query, context=None): # acquire reply content - if not context or not context.get('type') or context.get('type') == 'TEXT': + if context['type']=='TEXT': logger.info("[OPEN_AI] query={}".format(query)) - session_id = context.get('session_id') or context.get('from_user_id') + session_id = context['session_id'] + reply=None if query == '#清除记忆': self.sessions.clear_session(session_id) - return '记忆已清除' + reply={'type':'INFO', 'content':'记忆已清除'} elif query == '#清除所有': self.sessions.clear_all_session() - return '所有人记忆已清除' + reply={'type':'INFO', 'content':'所有人记忆已清除'} elif query == '#更新配置': load_config() - return '配置已更新' - + reply={'type':'INFO', 'content':'配置已更新'} + elif query == '#DEBUG': + logger.setLevel('DEBUG') + reply={'type':'INFO', 'content':'DEBUG模式已开启'} + if reply: + return reply session = self.sessions.build_session_query(query, session_id) logger.debug("[OPEN_AI] session query={}".format(session)) @@ -41,12 +46,26 @@ class ChatGPTBot(Bot): reply_content = self.reply_text(session, session_id, 0) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) - if reply_content["completion_tokens"] > 0: + if reply_content['completion_tokens']==0 and len(reply_content['content'])>0: + reply={'type':'ERROR', 'content':reply_content['content']} + elif reply_content["completion_tokens"] > 0: self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) - return reply_content["content"] + reply={'type':'TEXT', 'content':reply_content["content"]} + else: + reply={'type':'ERROR', 'content':reply_content['content']} + logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) + return reply - elif context.get('type', None) == 'IMAGE_CREATE': - return self.create_img(query, 0) + elif context['type'] == 'IMAGE_CREATE': + ok, retstring=self.create_img(query, 0) + reply=None + if ok: + reply = {'type':'IMAGE', 'content':retstring} + else: + reply = {'type':'ERROR', 'content':retstring} + return reply + else: + reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} def reply_text(self, session, session_id, retry_count=0) ->dict: ''' @@ -104,7 +123,7 @@ class ChatGPTBot(Bot): ) image_url = response['data'][0]['url'] logger.info("[OPEN_AI] image_url={}".format(image_url)) - return image_url + return True,image_url except openai.error.RateLimitError as e: logger.warn(e) if retry_count < 1: @@ -112,10 +131,10 @@ class ChatGPTBot(Bot): logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) return self.create_img(query, retry_count+1) else: - return "提问太快啦,请休息一下再问我吧" + return False,"提问太快啦,请休息一下再问我吧" except Exception as e: logger.exception(e) - return None + return False,str(e) class SessionManager(object): def __init__(self): diff --git a/bridge/bridge.py b/bridge/bridge.py index 068a58e..392d9e8 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -15,6 +15,11 @@ class Bridge(object): except ModuleNotFoundError as e: print(e) + + # 以下所有函数需要得到一个reply字典,格式如下: + # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... + # reply["content"] = reply的内容 + def fetch_reply_content(self, query, context): return self.bots["chat"].reply(query, context) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index d43a0a3..ddf38a4 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -46,62 +46,55 @@ class WechatChannel(Channel): # start message listener itchat.run() + + # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context + # context是一个字典,包含了消息的所有信息,包括以下key + # type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE + # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是CMD_IMAGE_CREATE类型,content就是图片生成命令 + # session_id: 会话id + # isgroup: 是否是群聊 + # msg: 原始消息对象 + # receiver: 需要回复的对象 def handle_voice(self, msg): if conf().get('speech_recognition') != True : return logger.debug("[WX]receive voice msg: " + msg['FileName']) - thread_pool.submit(self._do_handle_voice, msg) - - def _do_handle_voice(self, msg): from_user_id = msg['FromUserName'] other_user_id = msg['User']['UserName'] if from_user_id == other_user_id: - file_name = TmpDir().path() + msg['FileName'] - msg.download(file_name) - query = super().build_voice_to_text(file_name) - if conf().get('voice_reply_voice'): - self._do_send_voice(query, from_user_id) - else: - self._do_send_text(query, from_user_id) + context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id} + context['type']='VOICE' + context['session_id']=other_user_id + thread_pool.submit(self.handle, context) + def handle_text(self, msg): logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) content = msg['Text'] - self._handle_single_msg(msg, content) - - def _handle_single_msg(self, msg, content): from_user_id = msg['FromUserName'] to_user_id = msg['ToUserName'] # 接收人id other_user_id = msg['User']['UserName'] # 对手方id - match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) + match_prefix = check_prefix(content, conf().get('single_chat_prefix')) if "」\n- - - - - - - - - - - - - - -" in content: logger.debug("[WX]reference query skipped") return - if from_user_id == other_user_id and match_prefix is not None: - # 好友向自己发送消息 - if match_prefix != '': - str_list = content.split(match_prefix, 1) - if len(str_list) == 2: - content = str_list[1].strip() - - img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) - if img_match_prefix: - content = content.split(img_match_prefix, 1)[1].strip() - thread_pool.submit(self._do_send_img, content, from_user_id) - else : - thread_pool.submit(self._do_send_text, content, from_user_id) - elif to_user_id == other_user_id and match_prefix: - # 自己给好友发送消息 - str_list = content.split(match_prefix, 1) - if len(str_list) == 2: - content = str_list[1].strip() - img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) - if img_match_prefix: - content = content.split(img_match_prefix, 1)[1].strip() - thread_pool.submit(self._do_send_img, content, to_user_id) - else: - thread_pool.submit(self._do_send_text, content, to_user_id) + if match_prefix: + content=content.replace(match_prefix,'',1).strip() + else: + return + context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id} + context['session_id']=other_user_id + + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) + if img_match_prefix: + content=content.replace(img_match_prefix,'',1).strip() + context['type']='CMD_IMAGE_CREATE' + else: + context['type']='TEXT' + + context['content']=content + thread_pool.submit(self.handle, context) def handle_group(self, msg): @@ -122,100 +115,114 @@ class WechatChannel(Channel): logger.debug("[WX]reference query skipped") return "" config = conf() - match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \ - or self.check_contain(origin_content, config.get('group_chat_keyword')) - if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: - img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) + match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ + or check_contain(origin_content, config.get('group_chat_keyword')) + if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: + context = { 'isgroup': True, 'msg': msg, 'receiver': group_id} + + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: - content = content.split(img_match_prefix, 1)[1].strip() - thread_pool.submit(self._do_send_img, content, group_id) + content=content.replace(img_match_prefix,'',1).strip() + context['type']='CMD_IMAGE_CREATE' else: - thread_pool.submit(self._do_send_group, content, msg) - - def send(self, msg, receiver): - itchat.send(msg, toUserName=receiver) - logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver)) - - def _do_send_voice(self, query, reply_user_id): - try: - if not query: - return - context = dict() - context['from_user_id'] = reply_user_id - reply_text = super().build_reply_content(query, context) - if reply_text: - replyFile = super().build_text_to_voice(reply_text) - itchat.send_file(replyFile, toUserName=reply_user_id) - logger.info('[WX] sendFile={}, receiver={}'.format(replyFile, reply_user_id)) - except Exception as e: - logger.exception(e) - - def _do_send_text(self, query, reply_user_id): - try: - if not query: - return - context = dict() - context['session_id'] = reply_user_id - reply_text = super().build_reply_content(query, context) - if reply_text: - self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) - except Exception as e: - logger.exception(e) - - def _do_send_img(self, query, reply_user_id): - try: - if not query: - return - context = dict() - context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) - if not img_url: - return - - # 图片下载 + context['type']='TEXT' + context['content']=content + + group_chat_in_one_session = conf().get('group_chat_in_one_session', []) + if ('ALL_GROUP' in group_chat_in_one_session or \ + group_name in group_chat_in_one_session or \ + check_contain(group_name, group_chat_in_one_session)): + context['session_id'] = group_id + else: + context['session_id'] = msg['ActualUserName'] + + thread_pool.submit(self.handle, context) + + # 统一的发送函数,根据reply的type字段发送不同类型的消息 + + def send(self, reply, receiver): + if reply['type']=='TEXT': + itchat.send(reply['content'], toUserName=receiver) + logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + elif reply['type']=='ERROR' or reply['type']=='INFO': + itchat.send(reply['content'], toUserName=receiver) + logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + elif reply['type']=='VOICE': + itchat.send_file(reply['content'], toUserName=receiver) + logger.info('[WX] sendFile={}, receiver={}'.format(reply['content'], receiver)) + elif reply['type']=='IMAGE_URL': # 从网络下载图片 + img_url = reply['content'] pic_res = requests.get(img_url, stream=True) image_storage = io.BytesIO() for block in pic_res.iter_content(1024): image_storage.write(block) image_storage.seek(0) - - # 图片发送 - itchat.send_image(image_storage, reply_user_id) - logger.info('[WX] sendImage, receiver={}'.format(reply_user_id)) - except Exception as e: - logger.exception(e) - - def _do_send_group(self, query, msg): - if not query: - return - context = dict() - group_name = msg['User']['NickName'] - group_id = msg['User']['UserName'] - group_chat_in_one_session = conf().get('group_chat_in_one_session', []) - if ('ALL_GROUP' in group_chat_in_one_session or \ - group_name in group_chat_in_one_session or \ - self.check_contain(group_name, group_chat_in_one_session)): - context['session_id'] = group_id + itchat.send_image(image_storage, toUserName=receiver) + logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver)) + elif reply['type']=='IMAGE': # 从文件读取图片 + image_storage = reply['content'] + image_storage.seek(0) + itchat.send_image(image_storage, toUserName=receiver) + logger.info('[WX] sendImage, receiver={}'.format(receiver)) + + # 处理消息 + def handle(self, context): + content=context['content'] + reply=None + + logger.debug('[WX] ready to handle context: {}'.format(context)) + # reply的构建步骤 + if context['type']=='TEXT' or context['type']=='CMD_IMAGE_CREATE': + reply = super().build_reply_content(content,context) + elif context['type']=='VOICE': + msg=context['msg'] + file_name = TmpDir().path() + msg['FileName'] + msg.download(file_name) + reply = super().build_voice_to_text(file_name) + if reply['type'] != 'ERROR' and reply['type'] != 'INFO': + reply = super().build_reply_content(reply['content'],context) + if reply['type']=='TEXT': + if conf().get('voice_reply_voice'): + reply = super().build_text_to_voice(reply['content']) else: - context['session_id'] = msg['ActualUserName'] - reply_text = super().build_reply_content(query, context) - if reply_text: - reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip() - self.send(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) - - - def check_prefix(self, content, prefix_list): - for prefix in prefix_list: - if content.startswith(prefix): - return prefix - return None + logger.error('[WX] unknown context type: {}'.format(context['type'])) + return + + logger.debug('[WX] ready to decorate reply: {}'.format(reply)) + # reply的包装步骤 + if reply: + if reply['type']=='TEXT': + reply_text=reply['content'] + if context['isgroup']: + reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() + reply_text=conf().get("group_chat_reply_prefix","")+reply_text + else: + reply_text=conf().get("single_chat_reply_prefix","")+reply_text + reply['content']=reply_text + elif reply['type']=='ERROR' or reply['type']=='INFO': + reply['content']=reply['type']+": "+ reply['content'] + elif reply['type']=='IMAGE_URL' or reply['type']=='VOICE': + pass + else: + logger.error('[WX] unknown reply type: {}'.format(reply['type'])) + return + if reply: + logger.debug('[WX] ready to send reply: {} to {}'.format(reply,context['receiver'])) + self.send(reply, context['receiver']) + + +def check_prefix(content, prefix_list): + for prefix in prefix_list: + if content.startswith(prefix): + return prefix + return None - def check_contain(self, content, keyword_list): - if not keyword_list: - return None - for ky in keyword_list: - if content.find(ky) != -1: - return True +def check_contain(content, keyword_list): + if not keyword_list: return None + for ky in keyword_list: + if content.find(ky) != -1: + return True + return None From 9ae7b7773eaa3e0dcfa319ae953c5aa8bc14b1ab Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 01:10:18 +0800 Subject: [PATCH 03/21] simple compatibility for wechaty --- channel/wechat/wechaty_channel.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 5e01464..7ae4683 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -129,7 +129,9 @@ class WechatyChannel(Channel): return context = dict() context['session_id'] = reply_user_id - reply_text = super().build_reply_content(query, context) + context['type'] = 'TEXT' + context['content'] = query + reply_text = super().build_reply_content(query, context)['content'] if reply_text: await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) except Exception as e: @@ -141,7 +143,8 @@ class WechatyChannel(Channel): return context = dict() context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) + context['content'] = query + img_url = super().build_reply_content(query, context)['content'] if not img_url: return # 图片下载 @@ -170,7 +173,9 @@ class WechatyChannel(Channel): context['session_id'] = str(group_id) else: context['session_id'] = str(group_id) + '-' + str(group_user_id) - reply_text = super().build_reply_content(query, context) + context['type'] = 'TEXT' + context['content'] = query + reply_text = super().build_reply_content(query, context)['content'] if reply_text: reply_text = '@' + group_user_name + ' ' + reply_text.strip() await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) @@ -181,7 +186,8 @@ class WechatyChannel(Channel): return context = dict() context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) + context['content'] = query + img_url = super().build_reply_content(query, context)['content'] if not img_url: return # 图片发送 From 9e07703eb1264367d147d50c679b40c580278cee Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 01:22:49 +0800 Subject: [PATCH 04/21] formatting code --- bot/chatgpt/chat_gpt_bot.py | 58 +++++++++--------- bridge/bridge.py | 4 +- channel/wechat/wechat_channel.py | 101 ++++++++++++++++--------------- 3 files changed, 83 insertions(+), 80 deletions(-) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 327cc95..9bfea5b 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -13,28 +13,28 @@ class ChatGPTBot(Bot): def __init__(self): openai.api_key = conf().get('open_ai_api_key') proxy = conf().get('proxy') - self.sessions=SessionManager() + self.sessions = SessionManager() if proxy: openai.proxy = proxy def reply(self, query, context=None): # acquire reply content - if context['type']=='TEXT': + if context['type'] == 'TEXT': logger.info("[OPEN_AI] query={}".format(query)) session_id = context['session_id'] - reply=None + reply = None if query == '#清除记忆': self.sessions.clear_session(session_id) - reply={'type':'INFO', 'content':'记忆已清除'} + reply = {'type': 'INFO', 'content': '记忆已清除'} elif query == '#清除所有': self.sessions.clear_all_session() - reply={'type':'INFO', 'content':'所有人记忆已清除'} + reply = {'type': 'INFO', 'content': '所有人记忆已清除'} elif query == '#更新配置': load_config() - reply={'type':'INFO', 'content':'配置已更新'} + reply = {'type': 'INFO', 'content': '配置已更新'} elif query == '#DEBUG': logger.setLevel('DEBUG') - reply={'type':'INFO', 'content':'DEBUG模式已开启'} + reply = {'type': 'INFO', 'content': 'DEBUG模式已开启'} if reply: return reply session = self.sessions.build_session_query(query, session_id) @@ -46,28 +46,28 @@ class ChatGPTBot(Bot): reply_content = self.reply_text(session, session_id, 0) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) - if reply_content['completion_tokens']==0 and len(reply_content['content'])>0: - reply={'type':'ERROR', 'content':reply_content['content']} + if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: + reply = {'type': 'ERROR', 'content': reply_content['content']} elif reply_content["completion_tokens"] > 0: self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) reply={'type':'TEXT', 'content':reply_content["content"]} else: - reply={'type':'ERROR', 'content':reply_content['content']} + reply = {'type': 'ERROR', 'content': reply_content['content']} logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) return reply elif context['type'] == 'IMAGE_CREATE': - ok, retstring=self.create_img(query, 0) - reply=None + ok, retstring = self.create_img(query, 0) + reply = None if ok: - reply = {'type':'IMAGE', 'content':retstring} + reply = {'type': 'IMAGE', 'content': retstring} else: - reply = {'type':'ERROR', 'content':retstring} + reply = {'type': 'ERROR', 'content': retstring} return reply else: reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} - def reply_text(self, session, session_id, retry_count=0) ->dict: + def reply_text(self, session, session_id, retry_count=0) -> dict: ''' call openai's ChatCompletion to get the answer :param session: a conversation session @@ -86,8 +86,8 @@ class ChatGPTBot(Bot): presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 ) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) - return {"total_tokens": response["usage"]["total_tokens"], - "completion_tokens": response["usage"]["completion_tokens"], + return {"total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response["usage"]["completion_tokens"], "content": response.choices[0]['message']['content']} except openai.error.RateLimitError as e: # rate limit exception @@ -102,11 +102,11 @@ class ChatGPTBot(Bot): # api connection exception logger.warn(e) logger.warn("[OPEN_AI] APIConnection failed") - return {"completion_tokens": 0, "content":"我连接不到你的网络"} + return {"completion_tokens": 0, "content": "我连接不到你的网络"} except openai.error.Timeout as e: logger.warn(e) logger.warn("[OPEN_AI] Timeout") - return {"completion_tokens": 0, "content":"我没有收到你的消息"} + return {"completion_tokens": 0, "content": "我没有收到你的消息"} except Exception as e: # unknown exception logger.exception(e) @@ -123,7 +123,7 @@ class ChatGPTBot(Bot): ) image_url = response['data'][0]['url'] logger.info("[OPEN_AI] image_url={}".format(image_url)) - return True,image_url + return True, image_url except openai.error.RateLimitError as e: logger.warn(e) if retry_count < 1: @@ -131,15 +131,17 @@ class ChatGPTBot(Bot): logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) return self.create_img(query, retry_count+1) else: - return False,"提问太快啦,请休息一下再问我吧" + return False, "提问太快啦,请休息一下再问我吧" except Exception as e: logger.exception(e) - return False,str(e) - + return False, str(e) + + class SessionManager(object): def __init__(self): self.sessions = {} - def build_session_query(self,query, session_id): + + def build_session_query(self, query, session_id): ''' build query with conversation history e.g. [ @@ -167,7 +169,7 @@ class SessionManager(object): if not max_tokens: # default 3000 max_tokens = 1000 - max_tokens=int(max_tokens) + max_tokens = int(max_tokens) session = self.sessions.get(session_id) if session: @@ -177,7 +179,7 @@ class SessionManager(object): # discard exceed limit conversation self.discard_exceed_conversation(session, max_tokens, total_tokens) - + def discard_exceed_conversation(self, session, max_tokens, total_tokens): dec_tokens = int(total_tokens) # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) @@ -187,10 +189,10 @@ class SessionManager(object): session.pop(1) session.pop(1) else: - break + break dec_tokens = dec_tokens - max_tokens - def clear_session(self,session_id): + def clear_session(self, session_id): self.sessions[session_id] = [] def clear_all_session(self): diff --git a/bridge/bridge.py b/bridge/bridge.py index 392d9e8..81e5d73 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -2,6 +2,7 @@ from bot import bot_factory from common.singleton import singleton from voice import voice_factory + @singleton class Bridge(object): def __init__(self): @@ -15,7 +16,6 @@ class Bridge(object): except ModuleNotFoundError as e: print(e) - # 以下所有函数需要得到一个reply字典,格式如下: # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... # reply["content"] = reply的内容 @@ -27,4 +27,4 @@ class Bridge(object): return self.bots["voice_to_text"].voiceToText(voiceFile) def fetch_text_to_voice(self, text): - return self.bots["text_to_voice"].textToVoice(text) \ No newline at end of file + return self.bots["text_to_voice"].textToVoice(text) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index ddf38a4..e8be17e 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -46,7 +46,7 @@ class WechatChannel(Channel): # start message listener itchat.run() - + # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context # context是一个字典,包含了消息的所有信息,包括以下key # type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE @@ -57,18 +57,17 @@ class WechatChannel(Channel): # receiver: 需要回复的对象 def handle_voice(self, msg): - if conf().get('speech_recognition') != True : + if conf().get('speech_recognition') != True: return logger.debug("[WX]receive voice msg: " + msg['FileName']) from_user_id = msg['FromUserName'] other_user_id = msg['User']['UserName'] if from_user_id == other_user_id: - context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id} - context['type']='VOICE' - context['session_id']=other_user_id + context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} + context['type'] = 'VOICE' + context['session_id'] = other_user_id thread_pool.submit(self.handle, context) - def handle_text(self, msg): logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) content = msg['Text'] @@ -80,22 +79,21 @@ class WechatChannel(Channel): logger.debug("[WX]reference query skipped") return if match_prefix: - content=content.replace(match_prefix,'',1).strip() + content = content.replace(match_prefix, '', 1).strip() else: return - context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id} - context['session_id']=other_user_id - + context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} + context['session_id'] = other_user_id + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: - content=content.replace(img_match_prefix,'',1).strip() - context['type']='CMD_IMAGE_CREATE' + content = content.replace(img_match_prefix, '', 1).strip() + context['type'] = 'CMD_IMAGE_CREATE' else: - context['type']='TEXT' - - context['content']=content - thread_pool.submit(self.handle, context) + context['type'] = 'TEXT' + context['content'] = content + thread_pool.submit(self.handle, context) def handle_group(self, msg): logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) @@ -122,32 +120,32 @@ class WechatChannel(Channel): img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: - content=content.replace(img_match_prefix,'',1).strip() - context['type']='CMD_IMAGE_CREATE' + content = content.replace(img_match_prefix, '', 1).strip() + context['type'] = 'CMD_IMAGE_CREATE' else: - context['type']='TEXT' - context['content']=content + context['type'] = 'TEXT' + context['content'] = content group_chat_in_one_session = conf().get('group_chat_in_one_session', []) - if ('ALL_GROUP' in group_chat_in_one_session or \ - group_name in group_chat_in_one_session or \ + if ('ALL_GROUP' in group_chat_in_one_session or + group_name in group_chat_in_one_session or check_contain(group_name, group_chat_in_one_session)): context['session_id'] = group_id else: context['session_id'] = msg['ActualUserName'] - + thread_pool.submit(self.handle, context) - + # 统一的发送函数,根据reply的type字段发送不同类型的消息 def send(self, reply, receiver): - if reply['type']=='TEXT': + if reply['type'] == 'TEXT': itchat.send(reply['content'], toUserName=receiver) logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) - elif reply['type']=='ERROR' or reply['type']=='INFO': + elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': itchat.send(reply['content'], toUserName=receiver) logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) - elif reply['type']=='VOICE': + elif reply['type'] == 'VOICE': itchat.send_file(reply['content'], toUserName=receiver) logger.info('[WX] sendFile={}, receiver={}'.format(reply['content'], receiver)) elif reply['type']=='IMAGE_URL': # 从网络下载图片 @@ -164,52 +162,56 @@ class WechatChannel(Channel): image_storage.seek(0) itchat.send_image(image_storage, toUserName=receiver) logger.info('[WX] sendImage, receiver={}'.format(receiver)) - + # 处理消息 def handle(self, context): - content=context['content'] - reply=None + content = context['content'] + reply = None logger.debug('[WX] ready to handle context: {}'.format(context)) # reply的构建步骤 - if context['type']=='TEXT' or context['type']=='CMD_IMAGE_CREATE': - reply = super().build_reply_content(content,context) - elif context['type']=='VOICE': - msg=context['msg'] + if context['type'] == 'TEXT' or context['type'] == 'CMD_IMAGE_CREATE': + reply = super().build_reply_content(content, context) + elif context['type'] == 'VOICE': + msg = context['msg'] file_name = TmpDir().path() + msg['FileName'] msg.download(file_name) reply = super().build_voice_to_text(file_name) if reply['type'] != 'ERROR' and reply['type'] != 'INFO': - reply = super().build_reply_content(reply['content'],context) - if reply['type']=='TEXT': + reply = super().build_reply_content(reply['content'], context) + if reply['type'] == 'TEXT': if conf().get('voice_reply_voice'): reply = super().build_text_to_voice(reply['content']) else: logger.error('[WX] unknown context type: {}'.format(context['type'])) return - + logger.debug('[WX] ready to decorate reply: {}'.format(reply)) # reply的包装步骤 if reply: - if reply['type']=='TEXT': - reply_text=reply['content'] + if reply['type'] == 'TEXT': + reply_text = reply['content'] if context['isgroup']: - reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() - reply_text=conf().get("group_chat_reply_prefix","")+reply_text + reply_text = '@' + \ + context['msg']['ActualNickName'] + \ + ' ' + reply_text.strip() + reply_text = conf().get("group_chat_reply_prefix", "")+reply_text else: - reply_text=conf().get("single_chat_reply_prefix","")+reply_text - reply['content']=reply_text - elif reply['type']=='ERROR' or reply['type']=='INFO': - reply['content']=reply['type']+": "+ reply['content'] - elif reply['type']=='IMAGE_URL' or reply['type']=='VOICE': + reply_text = conf().get("single_chat_reply_prefix", "")+reply_text + reply['content'] = reply_text + elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': + reply['content'] = reply['type']+": " + reply['content'] + elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE': pass else: - logger.error('[WX] unknown reply type: {}'.format(reply['type'])) + logger.error( + '[WX] unknown reply type: {}'.format(reply['type'])) return if reply: - logger.debug('[WX] ready to send reply: {} to {}'.format(reply,context['receiver'])) + logger.debug('[WX] ready to send reply: {} to {}'.format( + reply, context['receiver'])) self.send(reply, context['receiver']) - + def check_prefix(content, prefix_list): for prefix in prefix_list: @@ -225,4 +227,3 @@ def check_contain(content, keyword_list): if content.find(ky) != -1: return True return None - From 0fcf0824dcf7830bd60aaa025e6c115364e5c4d8 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 11:53:06 +0800 Subject: [PATCH 05/21] feat: support plugins --- .gitignore | 1 + app.py | 7 ++- bot/chatgpt/chat_gpt_bot.py | 9 ++- channel/wechat/wechat_channel.py | 103 +++++++++++++++++-------------- plugins/__init__.py | 9 +++ plugins/event.py | 49 +++++++++++++++ plugins/plugin.py | 3 + plugins/plugin_manager.py | 89 ++++++++++++++++++++++++++ 8 files changed, 220 insertions(+), 50 deletions(-) create mode 100644 plugins/__init__.py create mode 100644 plugins/event.py create mode 100644 plugins/plugin.py create mode 100644 plugins/plugin_manager.py diff --git a/.gitignore b/.gitignore index 8bc62f3..c349037 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ config.json QR.png nohup.out tmp +plugins.json \ No newline at end of file diff --git a/app.py b/app.py index 1ca359f..f07b275 100644 --- a/app.py +++ b/app.py @@ -4,14 +4,17 @@ import config from channel import channel_factory from common.log import logger - +from plugins import * if __name__ == '__main__': try: # load config config.load_config() # create channel - channel = channel_factory.create_channel("wx") + channel_name='wx' + channel = channel_factory.create_channel(channel_name) + if channel_name=='wx': + PluginManager().load_plugins() # startup channel channel.startup() diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 9bfea5b..2c8567d 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -60,12 +60,13 @@ class ChatGPTBot(Bot): ok, retstring = self.create_img(query, 0) reply = None if ok: - reply = {'type': 'IMAGE', 'content': retstring} + reply = {'type': 'IMAGE_URL', 'content': retstring} else: reply = {'type': 'ERROR', 'content': retstring} return reply else: reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} + return reply def reply_text(self, session, session_id, retry_count=0) -> dict: ''' @@ -139,7 +140,11 @@ class ChatGPTBot(Bot): class SessionManager(object): def __init__(self): - self.sessions = {} + if conf().get('expires_in_seconds'): + sessions = ExpiredDict(conf().get('expires_in_seconds')) + else: + sessions = dict() + self.sessions = sessions def build_session_query(self, query, session_id): ''' diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index e8be17e..f436e48 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -12,9 +12,12 @@ from concurrent.futures import ThreadPoolExecutor from common.log import logger from common.tmp_dir import TmpDir from config import conf +from plugins import * + import requests import io + thread_pool = ThreadPoolExecutor(max_workers=8) @@ -49,8 +52,8 @@ class WechatChannel(Channel): # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context # context是一个字典,包含了消息的所有信息,包括以下key - # type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE - # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是CMD_IMAGE_CREATE类型,content就是图片生成命令 + # type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE + # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 # session_id: 会话id # isgroup: 是否是群聊 # msg: 原始消息对象 @@ -88,7 +91,7 @@ class WechatChannel(Channel): img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'CMD_IMAGE_CREATE' + context['type'] = 'IMAGE_CREATE' else: context['type'] = 'TEXT' @@ -121,7 +124,7 @@ class WechatChannel(Channel): img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'CMD_IMAGE_CREATE' + context['type'] = 'IMAGE_CREATE' else: context['type'] = 'TEXT' context['content'] = content @@ -136,8 +139,7 @@ class WechatChannel(Channel): thread_pool.submit(self.handle, context) - # 统一的发送函数,根据reply的type字段发送不同类型的消息 - + # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 def send(self, reply, receiver): if reply['type'] == 'TEXT': itchat.send(reply['content'], toUserName=receiver) @@ -163,54 +165,63 @@ class WechatChannel(Channel): itchat.send_image(image_storage, toUserName=receiver) logger.info('[WX] sendImage, receiver={}'.format(receiver)) - # 处理消息 + # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 def handle(self, context): - content = context['content'] - reply = None + reply = {} logger.debug('[WX] ready to handle context: {}'.format(context)) + # reply的构建步骤 - if context['type'] == 'TEXT' or context['type'] == 'CMD_IMAGE_CREATE': - reply = super().build_reply_content(content, context) - elif context['type'] == 'VOICE': - msg = context['msg'] - file_name = TmpDir().path() + msg['FileName'] - msg.download(file_name) - reply = super().build_voice_to_text(file_name) - if reply['type'] != 'ERROR' and reply['type'] != 'INFO': - reply = super().build_reply_content(reply['content'], context) - if reply['type'] == 'TEXT': - if conf().get('voice_reply_voice'): - reply = super().build_text_to_voice(reply['content']) - else: - logger.error('[WX] unknown context type: {}'.format(context['type'])) - return + e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass(): + logger.debug('[WX] ready to handle context: type={}, content={}'.format(context['type'], context['content'])) + if context['type'] == 'TEXT' or context['type'] == 'IMAGE_CREATE': + reply = super().build_reply_content(context['content'], context) + elif context['type'] == 'VOICE': + msg = context['msg'] + file_name = TmpDir().path() + msg['FileName'] + msg.download(file_name) + reply = super().build_voice_to_text(file_name) + if reply['type'] != 'ERROR' and reply['type'] != 'INFO': + reply = super().build_reply_content(reply['content'], context) + if reply['type'] == 'TEXT': + if conf().get('voice_reply_voice'): + reply = super().build_text_to_voice(reply['content']) + else: + logger.error('[WX] unknown context type: {}'.format(context['type'])) + return logger.debug('[WX] ready to decorate reply: {}'.format(reply)) + # reply的包装步骤 - if reply: - if reply['type'] == 'TEXT': - reply_text = reply['content'] - if context['isgroup']: - reply_text = '@' + \ - context['msg']['ActualNickName'] + \ - ' ' + reply_text.strip() - reply_text = conf().get("group_chat_reply_prefix", "")+reply_text + if reply and reply['type']: + e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass() and reply and reply['type']: + if reply['type'] == 'TEXT': + reply_text = reply['content'] + if context['isgroup']: + reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() + reply_text = conf().get("group_chat_reply_prefix", "")+reply_text + else: + reply_text = conf().get("single_chat_reply_prefix", "")+reply_text + reply['content'] = reply_text + elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': + reply['content'] = reply['type']+": " + reply['content'] + elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE': + pass else: - reply_text = conf().get("single_chat_reply_prefix", "")+reply_text - reply['content'] = reply_text - elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': - reply['content'] = reply['type']+": " + reply['content'] - elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE': - pass - else: - logger.error( - '[WX] unknown reply type: {}'.format(reply['type'])) - return - if reply: - logger.debug('[WX] ready to send reply: {} to {}'.format( - reply, context['receiver'])) - self.send(reply, context['receiver']) + logger.error('[WX] unknown reply type: {}'.format(reply['type'])) + return + + # reply的发送步骤 + if reply and reply['type']: + e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass() and reply and reply['type']: + logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) + self.send(reply, context['receiver']) def check_prefix(content, prefix_list): diff --git a/plugins/__init__.py b/plugins/__init__.py new file mode 100644 index 0000000..6137d4a --- /dev/null +++ b/plugins/__init__.py @@ -0,0 +1,9 @@ +from .plugin_manager import PluginManager +from .event import * +from .plugin import * + +instance = PluginManager() + +register = instance.register +# load_plugins = instance.load_plugins +# emit_event = instance.emit_event diff --git a/plugins/event.py b/plugins/event.py new file mode 100644 index 0000000..a65e548 --- /dev/null +++ b/plugins/event.py @@ -0,0 +1,49 @@ +# encoding:utf-8 + +from enum import Enum + + +class Event(Enum): + # ON_RECEIVE_MESSAGE = 1 # 收到消息 + + ON_HANDLE_CONTEXT = 2 # 处理消息前 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } + """ + + ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } + """ + + ON_SEND_REPLY = 4 # 发送回复前 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } + """ + + # AFTER_SEND_REPLY = 5 # 发送回复后 + + +class EventAction(Enum): + CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 + BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 + BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 + + +class EventContext: + def __init__(self, event, econtext=dict()): + self.event = event + self.econtext = econtext + self.action = EventAction.CONTINUE + + def __getitem__(self, key): + return self.econtext[key] + + def __setitem__(self, key, value): + self.econtext[key] = value + + def __delitem__(self, key): + del self.econtext[key] + + def is_pass(self): + return self.action == EventAction.BREAK_PASS diff --git a/plugins/plugin.py b/plugins/plugin.py new file mode 100644 index 0000000..865eecb --- /dev/null +++ b/plugins/plugin.py @@ -0,0 +1,3 @@ +class Plugin: + def __init__(self): + self.handlers = {} diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py new file mode 100644 index 0000000..d4cda12 --- /dev/null +++ b/plugins/plugin_manager.py @@ -0,0 +1,89 @@ +# encoding:utf-8 + +import importlib +import json +import os +from common.singleton import singleton +from .event import * +from .plugin import * +from common.log import logger + + +@singleton +class PluginManager: + def __init__(self): + self.plugins = {} + self.listening_plugins = {} + self.instances = {} + + def register(self, name: str, desc: str, version: str, author: str): + def wrapper(plugincls): + self.plugins[name] = plugincls + plugincls.name = name + plugincls.desc = desc + plugincls.version = version + plugincls.author = author + plugincls.enabled = True + logger.info("Plugin %s registered" % name) + return plugincls + return wrapper + + def save_config(self, pconf): + with open("plugins/plugins.json", "w", encoding="utf-8") as f: + json.dump(pconf, f, indent=4, ensure_ascii=False) + + def load_config(self): + logger.info("Loading plugins config...") + plugins_dir = "plugins" + for plugin_name in os.listdir(plugins_dir): + plugin_path = os.path.join(plugins_dir, plugin_name) + if os.path.isdir(plugin_path): + # 判断插件是否包含main.py文件 + main_module_path = os.path.join(plugin_path, "main.py") + if os.path.isfile(main_module_path): + # 导入插件的main + import_path = "{}.{}.main".format(plugins_dir, plugin_name) + main_module = importlib.import_module(import_path) + + modified = False + if os.path.exists("plugins/plugins.json"): + with open("plugins/plugins.json", "r", encoding="utf-8") as f: + pconf = json.load(f) + else: + modified = True + pconf = {"plugins": []} + for name, plugincls in self.plugins.items(): + if name not in [plugin["name"] for plugin in pconf["plugins"]]: + modified = True + logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) + pconf["plugins"].append({"name": name, "enabled": True}) + if modified: + self.save_config(pconf) + return pconf + + def load_plugins(self): + pconf = self.load_config() + + for plugin in pconf["plugins"]: + name = plugin["name"] + enabled = plugin["enabled"] + self.plugins[name].enabled = enabled + + for name, plugincls in self.plugins.items(): + if plugincls.enabled: + if name not in self.instances: + instance = plugincls() + self.instances[name] = instance + for event in instance.handlers: + if event not in self.listening_plugins: + self.listening_plugins[event] = [] + self.listening_plugins[event].append(name) + + def emit_event(self, e_context: EventContext, *args, **kwargs): + if e_context.event in self.listening_plugins: + for name in self.listening_plugins[e_context.event]: + if e_context.action == EventAction.CONTINUE: + logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) + instance = self.instances[name] + instance.handlers[e_context.event](e_context, *args, **kwargs) + return e_context From d9b902f6ee92f07004ca32cf4025ae40c83c3950 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 11:53:47 +0800 Subject: [PATCH 06/21] add a plugin example --- plugins/hello/__init__.py | 0 plugins/hello/main.py | 40 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 plugins/hello/__init__.py create mode 100644 plugins/hello/main.py diff --git a/plugins/hello/__init__.py b/plugins/hello/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/hello/main.py b/plugins/hello/main.py new file mode 100644 index 0000000..c7d8ee1 --- /dev/null +++ b/plugins/hello/main.py @@ -0,0 +1,40 @@ +# encoding:utf-8 + +import plugins +from plugins import * +from common.log import logger + + +@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent") +class Hello(Plugin): + def __init__(self): + super().__init__() + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + # self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[hello] inited") + + def on_handle_context(self, e_context: EventContext): + + logger.debug("on_handle_context. content: %s" % e_context['context']['content']) + + if e_context['context']['content'] == "Hello": + e_context['reply']['type'] = "TEXT" + msg = e_context['context']['msg'] + if e_context['context']['isgroup']: + e_context['reply']['content'] = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") + else: + e_context['reply']['content'] = "Hello, " + msg['User'].get('NickName', "My friend") + + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + + if e_context['context']['content'] == "Hi": + e_context['reply']['type'] = "TEXT" + e_context['reply']['content'] = "Hi" + e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply + + if e_context['context']['content'] == "End": + # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" + if e_context['context']['type'] == "TEXT": + e_context['context']['type'] = "IMAGE_CREATE" + e_context['context']['content'] = "The World" + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 From 73de429af1935aa95331c63218f8ecb7d7233d0f Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 12:57:27 +0800 Subject: [PATCH 07/21] import file with the same name as plugin --- plugins/hello/{main.py => hello.py} | 4 ++-- plugins/plugin_manager.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) rename plugins/hello/{main.py => hello.py} (92%) diff --git a/plugins/hello/main.py b/plugins/hello/hello.py similarity index 92% rename from plugins/hello/main.py rename to plugins/hello/hello.py index c7d8ee1..c380b96 100644 --- a/plugins/hello/main.py +++ b/plugins/hello/hello.py @@ -11,11 +11,11 @@ class Hello(Plugin): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context # self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context - logger.info("[hello] inited") + logger.info("[Hello] inited") def on_handle_context(self, e_context: EventContext): - logger.debug("on_handle_context. content: %s" % e_context['context']['content']) + logger.debug("[Hello] on_handle_context. content: %s" % e_context['context']['content']) if e_context['context']['content'] == "Hello": e_context['reply']['type'] = "TEXT" diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index d4cda12..b7d88e7 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -38,11 +38,11 @@ class PluginManager: for plugin_name in os.listdir(plugins_dir): plugin_path = os.path.join(plugins_dir, plugin_name) if os.path.isdir(plugin_path): - # 判断插件是否包含main.py文件 - main_module_path = os.path.join(plugin_path, "main.py") + # 判断插件是否包含同名.py文件 + main_module_path = os.path.join(plugin_path, plugin_name+".py") if os.path.isfile(main_module_path): - # 导入插件的main - import_path = "{}.{}.main".format(plugins_dir, plugin_name) + # 导入插件 + import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name) main_module = importlib.import_module(import_path) modified = False @@ -63,7 +63,7 @@ class PluginManager: def load_plugins(self): pconf = self.load_config() - + logger.debug("plugins.json config={}" % pconf) for plugin in pconf["plugins"]: name = plugin["name"] enabled = plugin["enabled"] From 8847b5b6742509591f4085eea68595b871f63283 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 13:25:23 +0800 Subject: [PATCH 08/21] create bot when first need --- bridge/bridge.py | 26 ++++++++++++++++++-------- plugins/plugin_manager.py | 2 +- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/bridge/bridge.py b/bridge/bridge.py index 81e5d73..e16d721 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,3 +1,4 @@ +from common.log import logger from bot import bot_factory from common.singleton import singleton from voice import voice_factory @@ -6,16 +7,24 @@ from voice import voice_factory @singleton class Bridge(object): def __init__(self): - self.bots = { - "chat": bot_factory.create_bot("chatGPT"), - "voice_to_text": voice_factory.create_voice("openai"), - # "text_to_voice": voice_factory.create_voice("baidu") + self.btype={ + "chat": "chatGPT", + "voice_to_text": "openai", + "text_to_voice": "baidu" } - try: - self.bots["text_to_voice"] = voice_factory.create_voice("baidu") - except ModuleNotFoundError as e: - print(e) + self.bots={} + def getbot(self,typename): + if self.bots.get(typename) is None: + logger.info("create bot {} for {}".format(self.btype[typename],typename)) + if typename == "text_to_voice": + self.bots[typename] = voice_factory.create_voice(self.btype[typename]) + elif typename == "voice_to_text": + self.bots[typename] = voice_factory.create_voice(self.btype[typename]) + elif typename == "chat": + self.bots[typename] = bot_factory.create_bot(self.btype[typename]) + return self.bots[typename] + # 以下所有函数需要得到一个reply字典,格式如下: # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... # reply["content"] = reply的内容 @@ -28,3 +37,4 @@ class Bridge(object): def fetch_text_to_voice(self, text): return self.bots["text_to_voice"].textToVoice(text) + diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index b7d88e7..bf202f8 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -63,7 +63,7 @@ class PluginManager: def load_plugins(self): pconf = self.load_config() - logger.debug("plugins.json config={}" % pconf) + logger.debug("plugins.json config={}".format(pconf)) for plugin in pconf["plugins"]: name = plugin["name"] enabled = plugin["enabled"] From 475ada22e7d64a04cbcad0576feca0c859b8d23c Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 22:49:07 +0800 Subject: [PATCH 09/21] catch thread exception --- bridge/bridge.py | 11 +++++++---- channel/wechat/wechat_channel.py | 13 ++++++++----- plugins/hello/hello.py | 1 - plugins/plugin_manager.py | 2 +- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/bridge/bridge.py b/bridge/bridge.py index e16d721..304046e 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -14,7 +14,7 @@ class Bridge(object): } self.bots={} - def getbot(self,typename): + def get_bot(self,typename): if self.bots.get(typename) is None: logger.info("create bot {} for {}".format(self.btype[typename],typename)) if typename == "text_to_voice": @@ -25,16 +25,19 @@ class Bridge(object): self.bots[typename] = bot_factory.create_bot(self.btype[typename]) return self.bots[typename] + def get_bot_type(self,typename): + return self.btype[typename] + # 以下所有函数需要得到一个reply字典,格式如下: # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... # reply["content"] = reply的内容 def fetch_reply_content(self, query, context): - return self.bots["chat"].reply(query, context) + return self.get_bot("chat").reply(query, context) def fetch_voice_to_text(self, voiceFile): - return self.bots["voice_to_text"].voiceToText(voiceFile) + return self.get_bot("voice_to_text").voiceToText(voiceFile) def fetch_text_to_voice(self, text): - return self.bots["text_to_voice"].textToVoice(text) + return self.get_bot("text_to_voice").textToVoice(text) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index f436e48..73bb59a 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -19,7 +19,10 @@ import io thread_pool = ThreadPoolExecutor(max_workers=8) - +def thread_pool_callback(worker): + worker_exception = worker.exception() + if worker_exception: + logger.exception("Worker return exception: {}".format(worker_exception)) @itchat.msg_register(TEXT) def handler_single_msg(msg): @@ -69,7 +72,7 @@ class WechatChannel(Channel): context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} context['type'] = 'VOICE' context['session_id'] = other_user_id - thread_pool.submit(self.handle, context) + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_text(self, msg): logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) @@ -96,7 +99,7 @@ class WechatChannel(Channel): context['type'] = 'TEXT' context['content'] = content - thread_pool.submit(self.handle, context) + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_group(self, msg): logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) @@ -137,7 +140,7 @@ class WechatChannel(Channel): else: context['session_id'] = msg['ActualUserName'] - thread_pool.submit(self.handle, context) + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 def send(self, reply, receiver): @@ -208,7 +211,7 @@ class WechatChannel(Channel): reply_text = conf().get("single_chat_reply_prefix", "")+reply_text reply['content'] = reply_text elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': - reply['content'] = reply['type']+": " + reply['content'] + reply['content'] = reply['type']+":\n" + reply['content'] elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE': pass else: diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index c380b96..144906b 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -10,7 +10,6 @@ class Hello(Plugin): def __init__(self): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context - # self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context logger.info("[Hello] inited") def on_handle_context(self, e_context: EventContext): diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index bf202f8..dc8e892 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -24,7 +24,7 @@ class PluginManager: plugincls.version = version plugincls.author = author plugincls.enabled = True - logger.info("Plugin %s registered" % name) + logger.info("Plugin %s_v%s registered" % (name, version)) return plugincls return wrapper From cee57e4ffc05c7ea6e848a90cb1a6fbbfafd1730 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 23:05:28 +0800 Subject: [PATCH 10/21] plugin: add godcmd plugin --- plugins/godcmd/__init__.py | 0 plugins/godcmd/godcmd.py | 197 +++++++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 plugins/godcmd/__init__.py create mode 100644 plugins/godcmd/godcmd.py diff --git a/plugins/godcmd/__init__.py b/plugins/godcmd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py new file mode 100644 index 0000000..300353d --- /dev/null +++ b/plugins/godcmd/godcmd.py @@ -0,0 +1,197 @@ +# encoding:utf-8 + +import json +import os +import traceback +from typing import Tuple +from bridge.bridge import Bridge +from config import load_config +import plugins +from plugins import * +from common.log import logger + +# 定义指令集 +COMMANDS = { + "help": { + "alias": ["help", "帮助"], + "desc": "打印指令集合", + }, + "auth": { + "alias": ["auth", "认证"], + "args": ["口令"], + "desc": "管理员认证", + }, + # "id": { + # "alias": ["id", "用户"], + # "desc": "获取用户id", #目前无实际意义 + # }, + "reset": { + "alias": ["reset", "重置会话"], + "desc": "重置会话", + }, +} + +ADMIN_COMMANDS = { + "resume": { + "alias": ["resume", "恢复服务"], + "desc": "恢复服务", + }, + "stop": { + "alias": ["stop", "暂停服务"], + "desc": "暂停服务", + }, + "reconf": { + "alias": ["reconf", "重载配置"], + "desc": "重载配置(不包含插件配置)", + }, + "resetall": { + "alias": ["resetall", "重置所有会话"], + "desc": "重置所有会话", + }, + "debug": { + "alias": ["debug", "调试模式", "DEBUG"], + "desc": "开启机器调试日志", + }, +} +# 定义帮助函数 +def get_help_text(isadmin, isgroup): + help_text = "可用指令:\n" + for cmd, info in COMMANDS.items(): + if cmd=="auth" and (isadmin or isgroup): # 群聊不可认证 + continue + + alias=["#"+a for a in info['alias']] + help_text += f"{','.join(alias)} " + if 'args' in info: + args=["{"+a+"}" for a in info['args']] + help_text += f"{' '.join(args)} " + help_text += f": {info['desc']}\n" + if ADMIN_COMMANDS and isadmin: + help_text += "\n管理员指令:\n" + for cmd, info in ADMIN_COMMANDS.items(): + alias=["#"+a for a in info['alias']] + help_text += f"{','.join(alias)} " + help_text += f": {info['desc']}\n" + return help_text + +@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent") +class Godcmd(Plugin): + + def __init__(self): + super().__init__() + + curdir=os.path.dirname(__file__) + config_path=os.path.join(curdir,"config.json") + gconf=None + if not os.path.exists(config_path): + gconf={"password":"","admin_users":[]} + with open(config_path,"w") as f: + json.dump(gconf,f,indent=4) + else: + with open(config_path,"r") as f: + gconf=json.load(f) + + self.password = gconf["password"] + self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用 + self.isrunning = True # 机器人是否运行中 + + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[Godcmd] inited") + + + def on_handle_context(self, e_context: EventContext): + content = e_context['context']['content'] + context_type = e_context['context']['type'] + logger.debug("[Godcmd] on_handle_context. content: %s" % content) + + if content.startswith("#") and context_type == "TEXT": + # msg = e_context['context']['msg'] + user = e_context['context']['receiver'] + session_id = e_context['context']['session_id'] + isgroup = e_context['context']['isgroup'] + bottype = Bridge().get_bot_type("chat") + bot = Bridge().get_bot("chat") + # 将命令和参数分割 + command_parts = content[1:].split(" ") + cmd = command_parts[0] + args = command_parts[1:] + isadmin=False + if user in self.admin_users: + isadmin=True + ok=False + result="string" + if any(cmd in info['alias'] for info in COMMANDS.values()): + cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias']) + if cmd == "auth": + ok, result = self.authenticate(user, args, isadmin, isgroup) + elif cmd == "help": + ok, result = True, get_help_text(isadmin, isgroup) + elif cmd == "id": + ok, result = True, f"用户id=\n{user}" + elif cmd == "reset": + if bottype == "chatGPT": + bot.sessions.clear_session(session_id) + ok, result = True, "会话已重置" + else: + ok, result = False, "当前机器人不支持重置会话" + logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) + elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): + if isadmin: + if isgroup: + ok, result = False, "群聊不可执行管理员指令" + else: + cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias']) + if cmd == "stop": + self.isrunning = False + ok, result = True, "服务已暂停" + elif cmd == "resume": + self.isrunning = True + ok, result = True, "服务已恢复" + elif cmd == "reconf": + load_config() + ok, result = True, "配置已重载" + elif cmd == "resetall": + if bottype == "chatGPT": + bot.sessions.clear_all_session() + ok, result = True, "重置所有会话成功" + else: + ok, result = False, "当前机器人不支持重置会话" + elif cmd == "debug": + logger.setLevel('DEBUG') + ok, result = True, "DEBUG模式已开启" + logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user)) + else: + ok, result = False, "需要管理员权限才能执行该指令" + else: + ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" + + reply = {} + if ok: + reply["type"] = "INFO" + else: + reply["type"] = "ERROR" + reply["content"] = result + e_context['reply'] = reply + + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + elif not self.isrunning: + e_context.action = EventAction.BREAK_PASS + else: + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 + + def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : + if isgroup: + return False,"请勿在群聊中认证" + + if isadmin: + return False,"管理员账号无需认证" + + if len(args) != 1: + return False,"请提供口令" + password = args[0] + if password == self.password: + self.admin_users.append(userid) + return True,"认证成功" + else: + return False,"认证失败" + From 8d2e81815c104f2082dad4b695c6ccdbac2d6240 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 00:12:34 +0800 Subject: [PATCH 11/21] compatible for voice --- channel/wechat/wechat_channel.py | 7 +++++-- plugins/godcmd/godcmd.py | 12 ++++++----- plugins/hello/hello.py | 35 ++++++++++++++++++-------------- voice/baidu/baidu_voice.py | 5 +++-- voice/google/google_voice.py | 27 +++++++++++++++--------- voice/openai/openai_voice.py | 18 ++++++++++------ 6 files changed, 64 insertions(+), 40 deletions(-) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index 73bb59a..0ba923f 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -71,6 +71,7 @@ class WechatChannel(Channel): if from_user_id == other_user_id: context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} context['type'] = 'VOICE' + context['content'] = msg['FileName'] context['session_id'] = other_user_id thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) @@ -183,11 +184,13 @@ class WechatChannel(Channel): reply = super().build_reply_content(context['content'], context) elif context['type'] == 'VOICE': msg = context['msg'] - file_name = TmpDir().path() + msg['FileName'] + file_name = TmpDir().path() + context['content'] msg.download(file_name) reply = super().build_voice_to_text(file_name) if reply['type'] != 'ERROR' and reply['type'] != 'INFO': - reply = super().build_reply_content(reply['content'], context) + context['content'] = reply['content'] # 语音转文字后,将文字内容作为新的context + context['type'] = reply['type'] + reply = super().build_reply_content(context['content'], context) if reply['type'] == 'TEXT': if conf().get('voice_reply_voice'): reply = super().build_text_to_voice(reply['content']) diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 300353d..3dd8760 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -100,11 +100,15 @@ class Godcmd(Plugin): def on_handle_context(self, e_context: EventContext): - content = e_context['context']['content'] context_type = e_context['context']['type'] - logger.debug("[Godcmd] on_handle_context. content: %s" % content) + if context_type != "TEXT": + if not self.isrunning: + e_context.action = EventAction.BREAK_PASS + return - if content.startswith("#") and context_type == "TEXT": + content = e_context['context']['content'] + logger.debug("[Godcmd] on_handle_context. content: %s" % content) + if content.startswith("#"): # msg = e_context['context']['msg'] user = e_context['context']['receiver'] session_id = e_context['context']['session_id'] @@ -176,8 +180,6 @@ class Godcmd(Plugin): e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 elif not self.isrunning: e_context.action = EventAction.BREAK_PASS - else: - e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : if isgroup: diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index 144906b..ca1d257 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -14,26 +14,31 @@ class Hello(Plugin): def on_handle_context(self, e_context: EventContext): - logger.debug("[Hello] on_handle_context. content: %s" % e_context['context']['content']) - - if e_context['context']['content'] == "Hello": - e_context['reply']['type'] = "TEXT" + if e_context['context']['type'] != "TEXT": + return + + content = e_context['context']['content'] + logger.debug("[Hello] on_handle_context. content: %s" % content) + if content == "Hello": + reply = {} + reply['type'] = "TEXT" msg = e_context['context']['msg'] if e_context['context']['isgroup']: - e_context['reply']['content'] = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") + reply['content'] = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") else: - e_context['reply']['content'] = "Hello, " + msg['User'].get('NickName', "My friend") - + reply['content'] = "Hello, " + msg['User'].get('NickName', "My friend") + e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 - if e_context['context']['content'] == "Hi": - e_context['reply']['type'] = "TEXT" - e_context['reply']['content'] = "Hi" + if content == "Hi": + reply={} + reply['type'] = "TEXT" + reply['content'] = "Hi" + e_context['reply'] = reply e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply - if e_context['context']['content'] == "End": + if content == "End": # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" - if e_context['context']['type'] == "TEXT": - e_context['context']['type'] = "IMAGE_CREATE" - e_context['context']['content'] = "The World" - e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 + e_context['context']['type'] = "IMAGE_CREATE" + content = "The World" + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py index d99db37..adab169 100644 --- a/voice/baidu/baidu_voice.py +++ b/voice/baidu/baidu_voice.py @@ -30,7 +30,8 @@ class BaiduVoice(Voice): with open(fileName, 'wb') as f: f.write(result) logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) - return fileName + reply = {"type": "VOICE", "content": fileName} else: logger.error('[Baidu] textToVoice error={}'.format(result)) - return None + reply = {"type": "ERROR", "content": "抱歉,语音合成失败"} + return reply diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py index 8e339f2..6c00892 100644 --- a/voice/google/google_voice.py +++ b/voice/google/google_voice.py @@ -32,20 +32,27 @@ class GoogleVoice(Voice): ' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True) with speech_recognition.AudioFile(new_file) as source: audio = self.recognizer.record(source) + reply = {} try: text = self.recognizer.recognize_google(audio, language='zh-CN') logger.info( '[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) - return text + reply = {"type": "TEXT", "content": text} except speech_recognition.UnknownValueError: - return "抱歉,我听不懂。" + reply = {"type": "ERROR", "content": "抱歉,我听不懂"} except speech_recognition.RequestError as e: - return "抱歉,无法连接到 Google 语音识别服务;{0}".format(e) - + reply = {"type": "ERROR", "content": "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)} + finally: + return reply def textToVoice(self, text): - textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' - self.engine.save_to_file(text, textFile) - self.engine.runAndWait() - logger.info( - '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) - return textFile + try: + textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' + self.engine.save_to_file(text, textFile) + self.engine.runAndWait() + logger.info( + '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) + reply = {"type": "VOICE", "content": textFile} + except Exception as e: + reply = {"type": "ERROR", "content": str(e)} + finally: + return reply diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 475aac6..3b77c52 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -16,12 +16,18 @@ class OpenaiVoice(Voice): def voiceToText(self, voice_file): logger.debug( '[Openai] voice file name={}'.format(voice_file)) - file = open(voice_file, "rb") - reply = openai.Audio.transcribe("whisper-1", file) - text = reply["text"] - logger.info( - '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) - return text + reply={} + try: + file = open(voice_file, "rb") + result = openai.Audio.transcribe("whisper-1", file) + text = result["text"] + reply = {"type": "TEXT", "content": text} + logger.info( + '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) + except Exception as e: + reply = {"type": "ERROR", "content": str(e)} + finally: + return reply def textToVoice(self, text): pass From cb7bf446e3b9e55b3428ab8859de2b3d6c6238b2 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 01:50:37 +0800 Subject: [PATCH 12/21] plugin: godcmd support manage plugins --- plugins/godcmd/godcmd.py | 56 ++++++++++++++++++++++++++ plugins/plugin_manager.py | 82 +++++++++++++++++++++++++++++---------- 2 files changed, 118 insertions(+), 20 deletions(-) diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 3dd8760..2f57d32 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -48,6 +48,24 @@ ADMIN_COMMANDS = { "alias": ["resetall", "重置所有会话"], "desc": "重置所有会话", }, + "scanp": { + "alias": ["scanp", "扫描插件"], + "desc": "扫描插件目录是否有新插件", + }, + "plist": { + "alias": ["plist", "插件"], + "desc": "打印当前插件列表", + }, + "enablep": { + "alias": ["enablep", "启用插件"], + "args": ["插件名"], + "desc": "启用指定插件", + }, + "disablep": { + "alias": ["disablep", "禁用插件"], + "args": ["插件名"], + "desc": "禁用指定插件", + }, "debug": { "alias": ["debug", "调试模式", "DEBUG"], "desc": "开启机器调试日志", @@ -163,6 +181,44 @@ class Godcmd(Plugin): elif cmd == "debug": logger.setLevel('DEBUG') ok, result = True, "DEBUG模式已开启" + elif cmd == "plist": + plugins = PluginManager().list_plugins() + ok = True + result = "插件列表:\n" + for name,plugincls in plugins.items(): + result += f"{name}_v{plugincls.version} - " + if plugincls.enabled: + result += "已启用\n" + else: + result += "未启用\n" + elif cmd == "scanp": + new_plugins = PluginManager().scan_plugins() + ok, result = True, "插件扫描完成" + if len(new_plugins) >0 : + PluginManager().activate_plugins() + result += "\n发现新插件:\n" + result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) + else : + result +=", 未发现新插件" + elif cmd == "enablep": + if len(args) != 1: + ok, result = False, "请提供插件名" + else: + ok = PluginManager().enable_plugin(args[0]) + if ok: + result = "插件已启用" + else: + result = "插件不存在" + elif cmd == "disablep": + if len(args) != 1: + ok, result = False, "请提供插件名" + else: + ok = PluginManager().disable_plugin(args[0]) + if ok: + result = "插件已禁用" + else: + result = "插件不存在" + logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user)) else: ok, result = False, "需要管理员权限才能执行该指令" diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index dc8e892..630041a 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -15,6 +15,7 @@ class PluginManager: self.plugins = {} self.listening_plugins = {} self.instances = {} + self.pconf = {} def register(self, name: str, desc: str, version: str, author: str): def wrapper(plugincls): @@ -28,12 +29,27 @@ class PluginManager: return plugincls return wrapper - def save_config(self, pconf): + def save_config(self): with open("plugins/plugins.json", "w", encoding="utf-8") as f: - json.dump(pconf, f, indent=4, ensure_ascii=False) + json.dump(self.pconf, f, indent=4, ensure_ascii=False) def load_config(self): logger.info("Loading plugins config...") + + modified = False + if os.path.exists("plugins/plugins.json"): + with open("plugins/plugins.json", "r", encoding="utf-8") as f: + pconf = json.load(f) + else: + modified = True + pconf = {"plugins": []} + self.pconf = pconf + if modified: + self.save_config() + return pconf + + def scan_plugins(self): + logger.info("Scaning plugins ...") plugins_dir = "plugins" for plugin_name in os.listdir(plugins_dir): plugin_path = os.path.join(plugins_dir, plugin_name) @@ -44,31 +60,20 @@ class PluginManager: # 导入插件 import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name) main_module = importlib.import_module(import_path) - + pconf = self.pconf + new_plugins = [] modified = False - if os.path.exists("plugins/plugins.json"): - with open("plugins/plugins.json", "r", encoding="utf-8") as f: - pconf = json.load(f) - else: - modified = True - pconf = {"plugins": []} for name, plugincls in self.plugins.items(): if name not in [plugin["name"] for plugin in pconf["plugins"]]: + new_plugins.append(plugincls) modified = True logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) pconf["plugins"].append({"name": name, "enabled": True}) if modified: - self.save_config(pconf) - return pconf - - def load_plugins(self): - pconf = self.load_config() - logger.debug("plugins.json config={}".format(pconf)) - for plugin in pconf["plugins"]: - name = plugin["name"] - enabled = plugin["enabled"] - self.plugins[name].enabled = enabled + self.save_config() + return new_plugins + def activate_plugins(self): for name, plugincls in self.plugins.items(): if plugincls.enabled: if name not in self.instances: @@ -79,11 +84,48 @@ class PluginManager: self.listening_plugins[event] = [] self.listening_plugins[event].append(name) + def load_plugins(self): + self.load_config() + self.scan_plugins() + pconf = self.pconf + logger.debug("plugins.json config={}".format(pconf)) + for plugin in pconf["plugins"]: + name = plugin["name"] + enabled = plugin["enabled"] + self.plugins[name].enabled = enabled + self.activate_plugins() + def emit_event(self, e_context: EventContext, *args, **kwargs): if e_context.event in self.listening_plugins: for name in self.listening_plugins[e_context.event]: - if e_context.action == EventAction.CONTINUE: + if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) instance = self.instances[name] instance.handlers[e_context.event](e_context, *args, **kwargs) return e_context + + def enable_plugin(self,name): + if name not in self.plugins: + return False + if not self.plugins[name].enabled : + self.plugins[name].enabled = True + idx = next(i for i in range(len(self.pconf['plugins'])) if self.pconf["plugins"][i]['name'] == name) + self.pconf["plugins"][idx]["enabled"] = True + self.save_config() + self.activate_plugins() + return True + return True + + def disable_plugin(self,name): + if name not in self.plugins: + return False + if self.plugins[name].enabled : + self.plugins[name].enabled = False + idx = next(i for i in range(len(self.pconf['plugins'])) if self.pconf["plugins"][i]['name'] == name) + self.pconf["plugins"][idx]["enabled"] = False + self.save_config() + return True + return True + + def list_plugins(self): + return self.plugins \ No newline at end of file From 1dc3f85a66d0063a68494f257c365566d5e9c81c Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 15:32:28 +0800 Subject: [PATCH 13/21] plugin: support priority to decide trigger order --- common/sorted_dict.py | 65 +++++++++++++++++++++++++++++++++++++++ plugins/godcmd/godcmd.py | 20 ++++++++++-- plugins/hello/hello.py | 2 +- plugins/plugin_manager.py | 52 ++++++++++++++++++++++--------- 4 files changed, 120 insertions(+), 19 deletions(-) create mode 100644 common/sorted_dict.py diff --git a/common/sorted_dict.py b/common/sorted_dict.py new file mode 100644 index 0000000..a918a0c --- /dev/null +++ b/common/sorted_dict.py @@ -0,0 +1,65 @@ +import heapq + + +class SortedDict(dict): + def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False): + if init_dict is None: + init_dict = [] + if isinstance(init_dict, dict): + init_dict = init_dict.items() + self.sort_func = sort_func + self.sorted_keys = None + self.reverse = reverse + self.heap = [] + for k, v in init_dict: + self[k] = v + + def __setitem__(self, key, value): + if key in self: + super().__setitem__(key, value) + for i, (priority, k) in enumerate(self.heap): + if k == key: + self.heap[i] = (self.sort_func(key, value), key) + heapq.heapify(self.heap) + break + self.sorted_keys = None + else: + super().__setitem__(key, value) + heapq.heappush(self.heap, (self.sort_func(key, value), key)) + self.sorted_keys = None + + def __delitem__(self, key): + super().__delitem__(key) + for i, (priority, k) in enumerate(self.heap): + if k == key: + del self.heap[i] + heapq.heapify(self.heap) + break + self.sorted_keys = None + + def keys(self): + if self.sorted_keys is None: + self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] + return self.sorted_keys + + def items(self): + if self.sorted_keys is None: + self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] + sorted_items = [(k, self[k]) for k in self.sorted_keys] + return sorted_items + + def _update_heap(self, key): + for i, (priority, k) in enumerate(self.heap): + if k == key: + new_priority = self.sort_func(key, self[key]) + if new_priority != priority: + self.heap[i] = (new_priority, key) + heapq.heapify(self.heap) + self.sorted_keys = None + break + + def __iter__(self): + return iter(self.keys()) + + def __repr__(self): + return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})' diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 2f57d32..3bd24dd 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -56,6 +56,11 @@ ADMIN_COMMANDS = { "alias": ["plist", "插件"], "desc": "打印当前插件列表", }, + "setpri": { + "alias": ["setpri", "设置插件优先级"], + "args": ["插件名", "优先级"], + "desc": "设置指定插件的优先级,越大越优先", + }, "enablep": { "alias": ["enablep", "启用插件"], "args": ["插件名"], @@ -92,7 +97,7 @@ def get_help_text(isadmin, isgroup): help_text += f": {info['desc']}\n" return help_text -@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent") +@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent", desire_priority= 999) class Godcmd(Plugin): def __init__(self): @@ -186,7 +191,7 @@ class Godcmd(Plugin): ok = True result = "插件列表:\n" for name,plugincls in plugins.items(): - result += f"{name}_v{plugincls.version} - " + result += f"{name}_v{plugincls.version} {plugincls.priority} - " if plugincls.enabled: result += "已启用\n" else: @@ -194,12 +199,21 @@ class Godcmd(Plugin): elif cmd == "scanp": new_plugins = PluginManager().scan_plugins() ok, result = True, "插件扫描完成" + PluginManager().activate_plugins() if len(new_plugins) >0 : - PluginManager().activate_plugins() result += "\n发现新插件:\n" result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) else : result +=", 未发现新插件" + elif cmd == "setpri": + if len(args) != 2: + ok, result = False, "请提供插件名和优先级" + else: + ok = PluginManager().set_plugin_priority(args[0], int(args[1])) + if ok: + result = "插件" + args[0] + "优先级已设置为" + args[1] + else: + result = "插件不存在" elif cmd == "enablep": if len(args) != 1: ok, result = False, "请提供插件名" diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index ca1d257..1eb409e 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -5,7 +5,7 @@ from plugins import * from common.log import logger -@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent") +@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1) class Hello(Plugin): def __init__(self): super().__init__() diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index 630041a..107c63b 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -4,6 +4,7 @@ import importlib import json import os from common.singleton import singleton +from common.sorted_dict import SortedDict from .event import * from .plugin import * from common.log import logger @@ -12,19 +13,20 @@ from common.log import logger @singleton class PluginManager: def __init__(self): - self.plugins = {} + self.plugins = SortedDict(lambda k,v: v.priority,reverse=True) self.listening_plugins = {} self.instances = {} self.pconf = {} - def register(self, name: str, desc: str, version: str, author: str): + def register(self, name: str, desc: str, version: str, author: str, desire_priority: int = 0): def wrapper(plugincls): - self.plugins[name] = plugincls plugincls.name = name plugincls.desc = desc plugincls.version = version plugincls.author = author + plugincls.priority = desire_priority plugincls.enabled = True + self.plugins[name] = plugincls logger.info("Plugin %s_v%s registered" % (name, version)) return plugincls return wrapper @@ -40,9 +42,10 @@ class PluginManager: if os.path.exists("plugins/plugins.json"): with open("plugins/plugins.json", "r", encoding="utf-8") as f: pconf = json.load(f) + pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True) else: modified = True - pconf = {"plugins": []} + pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)} self.pconf = pconf if modified: self.save_config() @@ -64,16 +67,24 @@ class PluginManager: new_plugins = [] modified = False for name, plugincls in self.plugins.items(): - if name not in [plugin["name"] for plugin in pconf["plugins"]]: + if name not in pconf["plugins"]: new_plugins.append(plugincls) modified = True logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) - pconf["plugins"].append({"name": name, "enabled": True}) + pconf["plugins"][name] = {"enabled": plugincls.enabled, "priority": plugincls.priority} + else: + self.plugins[name].enabled = pconf["plugins"][name]["enabled"] + self.plugins[name].priority = pconf["plugins"][name]["priority"] + self.plugins._update_heap(name) # 更新下plugins中的顺序 if modified: self.save_config() return new_plugins - def activate_plugins(self): + def refresh_order(self): + for event in self.listening_plugins.keys(): + self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) + + def activate_plugins(self): # 生成新开启的插件实例 for name, plugincls in self.plugins.items(): if plugincls.enabled: if name not in self.instances: @@ -83,16 +94,16 @@ class PluginManager: if event not in self.listening_plugins: self.listening_plugins[event] = [] self.listening_plugins[event].append(name) + self.refresh_order() def load_plugins(self): self.load_config() self.scan_plugins() pconf = self.pconf logger.debug("plugins.json config={}".format(pconf)) - for plugin in pconf["plugins"]: - name = plugin["name"] - enabled = plugin["enabled"] - self.plugins[name].enabled = enabled + for name,plugin in pconf["plugins"].items(): + if name not in self.plugins: + logger.error("Plugin %s not found, but found in plugins.json" % name) self.activate_plugins() def emit_event(self, e_context: EventContext, *args, **kwargs): @@ -104,13 +115,25 @@ class PluginManager: instance.handlers[e_context.event](e_context, *args, **kwargs) return e_context + def set_plugin_priority(self,name,priority): + if name not in self.plugins: + return False + if self.plugins[name].priority == priority: + return True + self.plugins[name].priority = priority + self.plugins._update_heap(name) + self.pconf["plugins"][name]["priority"] = priority + self.pconf["plugins"]._update_heap(name) + self.save_config() + self.refresh_order() + return True + def enable_plugin(self,name): if name not in self.plugins: return False if not self.plugins[name].enabled : self.plugins[name].enabled = True - idx = next(i for i in range(len(self.pconf['plugins'])) if self.pconf["plugins"][i]['name'] == name) - self.pconf["plugins"][idx]["enabled"] = True + self.pconf["plugins"][name]["enabled"] = True self.save_config() self.activate_plugins() return True @@ -121,8 +144,7 @@ class PluginManager: return False if self.plugins[name].enabled : self.plugins[name].enabled = False - idx = next(i for i in range(len(self.pconf['plugins'])) if self.pconf["plugins"][i]['name'] == name) - self.pconf["plugins"][idx]["enabled"] = False + self.pconf["plugins"][name]["enabled"] = False self.save_config() return True return True From ad6ae0b32a15b7cd75e1060971f95461408c9944 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 19:44:24 +0800 Subject: [PATCH 14/21] refactor: use enum to specify type --- bot/bot.py | 6 +- bot/chatgpt/chat_gpt_bot.py | 27 ++++---- bridge/bridge.py | 11 ++-- bridge/context.py | 42 +++++++++++++ bridge/reply.py | 22 +++++++ channel/channel.py | 8 ++- channel/wechat/wechat_channel.py | 101 +++++++++++++++--------------- channel/wechat/wechaty_channel.py | 25 +++----- plugins/godcmd/godcmd.py | 16 ++--- plugins/hello/hello.py | 22 ++++--- voice/baidu/baidu_voice.py | 5 +- voice/google/google_voice.py | 12 ++-- voice/openai/openai_voice.py | 6 +- 13 files changed, 185 insertions(+), 118 deletions(-) create mode 100644 bridge/context.py create mode 100644 bridge/reply.py diff --git a/bot/bot.py b/bot/bot.py index 850ba3b..fd56e50 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -3,8 +3,12 @@ Auto-replay chat robot abstract class """ +from bridge.context import Context +from bridge.reply import Reply + + class Bot(object): - def reply(self, query, context=None): + def reply(self, query, context : Context =None) -> Reply: """ bot auto-reply content :param req: received message diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 2c8567d..b2e062d 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -1,6 +1,8 @@ # encoding:utf-8 from bot.bot import Bot +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType from config import conf, load_config from common.log import logger from common.expired_dict import ExpiredDict @@ -19,22 +21,19 @@ class ChatGPTBot(Bot): def reply(self, query, context=None): # acquire reply content - if context['type'] == 'TEXT': + if context.type == ContextType.TEXT: logger.info("[OPEN_AI] query={}".format(query)) session_id = context['session_id'] reply = None if query == '#清除记忆': self.sessions.clear_session(session_id) - reply = {'type': 'INFO', 'content': '记忆已清除'} + reply = Reply(ReplyType.INFO, '记忆已清除') elif query == '#清除所有': self.sessions.clear_all_session() - reply = {'type': 'INFO', 'content': '所有人记忆已清除'} + reply = Reply(ReplyType.INFO, '所有人记忆已清除') elif query == '#更新配置': load_config() - reply = {'type': 'INFO', 'content': '配置已更新'} - elif query == '#DEBUG': - logger.setLevel('DEBUG') - reply = {'type': 'INFO', 'content': 'DEBUG模式已开启'} + reply = Reply(ReplyType.INFO, '配置已更新') if reply: return reply session = self.sessions.build_session_query(query, session_id) @@ -47,25 +46,25 @@ class ChatGPTBot(Bot): reply_content = self.reply_text(session, session_id, 0) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: - reply = {'type': 'ERROR', 'content': reply_content['content']} + reply = Reply(ReplyType.ERROR, reply_content['content']) elif reply_content["completion_tokens"] > 0: self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) - reply={'type':'TEXT', 'content':reply_content["content"]} + reply = Reply(ReplyType.TEXT, reply_content["content"]) else: - reply = {'type': 'ERROR', 'content': reply_content['content']} + reply = Reply(ReplyType.ERROR, reply_content['content']) logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) return reply - elif context['type'] == 'IMAGE_CREATE': + elif context.type == ContextType.IMAGE_CREATE: ok, retstring = self.create_img(query, 0) reply = None if ok: - reply = {'type': 'IMAGE_URL', 'content': retstring} + reply = Reply(ReplyType.IMAGE_URL, retstring) else: - reply = {'type': 'ERROR', 'content': retstring} + reply = Reply(ReplyType.ERROR, retstring) return reply else: - reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} + reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) return reply def reply_text(self, session, session_id, retry_count=0) -> dict: diff --git a/bridge/bridge.py b/bridge/bridge.py index 304046e..2b67a8a 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,3 +1,5 @@ +from bridge.context import Context +from bridge.reply import Reply from common.log import logger from bot import bot_factory from common.singleton import singleton @@ -28,16 +30,13 @@ class Bridge(object): def get_bot_type(self,typename): return self.btype[typename] - # 以下所有函数需要得到一个reply字典,格式如下: - # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... - # reply["content"] = reply的内容 - def fetch_reply_content(self, query, context): + def fetch_reply_content(self, query, context : Context) -> Reply: return self.get_bot("chat").reply(query, context) - def fetch_voice_to_text(self, voiceFile): + def fetch_voice_to_text(self, voiceFile) -> Reply: return self.get_bot("voice_to_text").voiceToText(voiceFile) - def fetch_text_to_voice(self, text): + def fetch_text_to_voice(self, text) -> Reply: return self.get_bot("text_to_voice").textToVoice(text) diff --git a/bridge/context.py b/bridge/context.py new file mode 100644 index 0000000..1fbe4d4 --- /dev/null +++ b/bridge/context.py @@ -0,0 +1,42 @@ +# encoding:utf-8 + +from enum import Enum + +class ContextType (Enum): + TEXT = 1 # 文本消息 + VOICE = 2 # 音频消息 + IMAGE_CREATE = 3 # 创建图片命令 + + def __str__(self): + return self.name +class Context: + def __init__(self, type : ContextType = None , content = None, kwargs = dict()): + self.type = type + self.content = content + self.kwargs = kwargs + def __getitem__(self, key): + if key == 'type': + return self.type + elif key == 'content': + return self.content + else: + return self.kwargs[key] + + def __setitem__(self, key, value): + if key == 'type': + self.type = value + elif key == 'content': + self.content = value + else: + self.kwargs[key] = value + + def __delitem__(self, key): + if key == 'type': + self.type = None + elif key == 'content': + self.content = None + else: + del self.kwargs[key] + + def __str__(self): + return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) \ No newline at end of file diff --git a/bridge/reply.py b/bridge/reply.py new file mode 100644 index 0000000..c6bcd54 --- /dev/null +++ b/bridge/reply.py @@ -0,0 +1,22 @@ + +# encoding:utf-8 + +from enum import Enum + +class ReplyType(Enum): + TEXT = 1 # 文本 + VOICE = 2 # 音频文件 + IMAGE = 3 # 图片文件 + IMAGE_URL = 4 # 图片URL + + INFO = 9 + ERROR = 10 + def __str__(self): + return self.name + +class Reply: + def __init__(self, type : ReplyType = None , content = None): + self.type = type + self.content = content + def __str__(self): + return "Reply(type={}, content={})".format(self.type, self.content) \ No newline at end of file diff --git a/channel/channel.py b/channel/channel.py index a1395c4..62e1ad3 100644 --- a/channel/channel.py +++ b/channel/channel.py @@ -3,6 +3,8 @@ Message sending channel abstract class """ from bridge.bridge import Bridge +from bridge.context import Context +from bridge.reply import Reply class Channel(object): def startup(self): @@ -27,11 +29,11 @@ class Channel(object): """ raise NotImplementedError - def build_reply_content(self, query, context=None): + def build_reply_content(self, query, context : Context=None) -> Reply: return Bridge().fetch_reply_content(query, context) - def build_voice_to_text(self, voice_file): + def build_voice_to_text(self, voice_file) -> Reply: return Bridge().fetch_voice_to_text(voice_file) - def build_text_to_voice(self, text): + def build_text_to_voice(self, text) -> Reply: return Bridge().fetch_text_to_voice(text) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index 0ba923f..eff788d 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -7,6 +7,8 @@ wechat channel import itchat import json from itchat.content import * +from bridge.reply import * +from bridge.context import * from channel.channel import Channel from concurrent.futures import ThreadPoolExecutor from common.log import logger @@ -69,10 +71,8 @@ class WechatChannel(Channel): from_user_id = msg['FromUserName'] other_user_id = msg['User']['UserName'] if from_user_id == other_user_id: - context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} - context['type'] = 'VOICE' - context['content'] = msg['FileName'] - context['session_id'] = other_user_id + context = Context(ContextType.VOICE,msg['FileName']) + context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_text(self, msg): @@ -89,17 +89,17 @@ class WechatChannel(Channel): content = content.replace(match_prefix, '', 1).strip() else: return - context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} - context['session_id'] = other_user_id + context = Context() + context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'IMAGE_CREATE' + context.type = ContextType.IMAGE_CREATE else: - context['type'] = 'TEXT' + context.type = ContextType.TEXT - context['content'] = content + context.content = content thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_group(self, msg): @@ -123,15 +123,16 @@ class WechatChannel(Channel): match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ or check_contain(origin_content, config.get('group_chat_keyword')) if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: - context = { 'isgroup': True, 'msg': msg, 'receiver': group_id} + context = Context() + context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'IMAGE_CREATE' + context.type = ContextType.IMAGE_CREATE else: - context['type'] = 'TEXT' - context['content'] = content + context.type = ContextType.TEXT + context.content = content group_chat_in_one_session = conf().get('group_chat_in_one_session', []) if ('ALL_GROUP' in group_chat_in_one_session or @@ -144,18 +145,18 @@ class WechatChannel(Channel): thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 - def send(self, reply, receiver): - if reply['type'] == 'TEXT': - itchat.send(reply['content'], toUserName=receiver) + def send(self, reply : Reply, receiver): + if reply.type == ReplyType.TEXT: + itchat.send(reply.content, toUserName=receiver) logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) - elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': - itchat.send(reply['content'], toUserName=receiver) + elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: + itchat.send(reply.content, toUserName=receiver) logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) - elif reply['type'] == 'VOICE': - itchat.send_file(reply['content'], toUserName=receiver) - logger.info('[WX] sendFile={}, receiver={}'.format(reply['content'], receiver)) - elif reply['type']=='IMAGE_URL': # 从网络下载图片 - img_url = reply['content'] + elif reply.type == ReplyType.VOICE: + itchat.send_file(reply.content, toUserName=receiver) + logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) + elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 + img_url = reply.content pic_res = requests.get(img_url, stream=True) image_storage = io.BytesIO() for block in pic_res.iter_content(1024): @@ -163,69 +164,69 @@ class WechatChannel(Channel): image_storage.seek(0) itchat.send_image(image_storage, toUserName=receiver) logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver)) - elif reply['type']=='IMAGE': # 从文件读取图片 - image_storage = reply['content'] + elif reply.type == ReplyType.IMAGE: # 从文件读取图片 + image_storage = reply.content image_storage.seek(0) itchat.send_image(image_storage, toUserName=receiver) logger.info('[WX] sendImage, receiver={}'.format(receiver)) # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 def handle(self, context): - reply = {} + reply = Reply() logger.debug('[WX] ready to handle context: {}'.format(context)) # reply的构建步骤 e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) - reply=e_context['reply'] + reply = e_context['reply'] if not e_context.is_pass(): - logger.debug('[WX] ready to handle context: type={}, content={}'.format(context['type'], context['content'])) - if context['type'] == 'TEXT' or context['type'] == 'IMAGE_CREATE': - reply = super().build_reply_content(context['content'], context) - elif context['type'] == 'VOICE': + logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) + if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: + reply = super().build_reply_content(context.content, context) + elif context.type == ContextType.VOICE: msg = context['msg'] - file_name = TmpDir().path() + context['content'] + file_name = TmpDir().path() + context.content msg.download(file_name) reply = super().build_voice_to_text(file_name) - if reply['type'] != 'ERROR' and reply['type'] != 'INFO': - context['content'] = reply['content'] # 语音转文字后,将文字内容作为新的context - context['type'] = reply['type'] - reply = super().build_reply_content(context['content'], context) - if reply['type'] == 'TEXT': + if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: + context.content = reply.content # 语音转文字后,将文字内容作为新的context + context.type = ContextType.TEXT + reply = super().build_reply_content(context.content, context) + if reply.type == ReplyType.TEXT: if conf().get('voice_reply_voice'): - reply = super().build_text_to_voice(reply['content']) + reply = super().build_text_to_voice(reply.content) else: - logger.error('[WX] unknown context type: {}'.format(context['type'])) + logger.error('[WX] unknown context type: {}'.format(context.type)) return logger.debug('[WX] ready to decorate reply: {}'.format(reply)) # reply的包装步骤 - if reply and reply['type']: + if reply and reply.type: e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) reply=e_context['reply'] - if not e_context.is_pass() and reply and reply['type']: - if reply['type'] == 'TEXT': - reply_text = reply['content'] + if not e_context.is_pass() and reply and reply.type: + if reply.type == ReplyType.TEXT: + reply_text = reply.content if context['isgroup']: reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() reply_text = conf().get("group_chat_reply_prefix", "")+reply_text else: reply_text = conf().get("single_chat_reply_prefix", "")+reply_text - reply['content'] = reply_text - elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': - reply['content'] = reply['type']+":\n" + reply['content'] - elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE': + reply.content = reply_text + elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: + reply.content = str(reply.type)+":\n" + reply.content + elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: pass else: - logger.error('[WX] unknown reply type: {}'.format(reply['type'])) + logger.error('[WX] unknown reply type: {}'.format(reply.type)) return # reply的发送步骤 - if reply and reply['type']: + if reply and reply.type: e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) reply=e_context['reply'] - if not e_context.is_pass() and reply and reply['type']: + if not e_context.is_pass() and reply and reply.type: logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) self.send(reply, context['receiver']) diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 7ae4683..3a7fe2c 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -11,6 +11,7 @@ import time import asyncio import requests from typing import Optional, Union +from bridge.context import Context, ContextType from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore from wechaty import Wechaty, Contact from wechaty.user import Message, Room, MiniProgram, UrlLink @@ -127,11 +128,9 @@ class WechatyChannel(Channel): try: if not query: return - context = dict() + context = Context(ContextType.TEXT, query) context['session_id'] = reply_user_id - context['type'] = 'TEXT' - context['content'] = query - reply_text = super().build_reply_content(query, context)['content'] + reply_text = super().build_reply_content(query, context).content if reply_text: await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) except Exception as e: @@ -141,10 +140,8 @@ class WechatyChannel(Channel): try: if not query: return - context = dict() - context['type'] = 'IMAGE_CREATE' - context['content'] = query - img_url = super().build_reply_content(query, context)['content'] + context = Context(ContextType.IMAGE_CREATE, query) + img_url = super().build_reply_content(query, context).content if not img_url: return # 图片下载 @@ -165,7 +162,7 @@ class WechatyChannel(Channel): async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name): if not query: return - context = dict() + context = Context(ContextType.TEXT, query) group_chat_in_one_session = conf().get('group_chat_in_one_session', []) if ('ALL_GROUP' in group_chat_in_one_session or \ group_name in group_chat_in_one_session or \ @@ -173,9 +170,7 @@ class WechatyChannel(Channel): context['session_id'] = str(group_id) else: context['session_id'] = str(group_id) + '-' + str(group_user_id) - context['type'] = 'TEXT' - context['content'] = query - reply_text = super().build_reply_content(query, context)['content'] + reply_text = super().build_reply_content(query, context).content if reply_text: reply_text = '@' + group_user_name + ' ' + reply_text.strip() await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) @@ -184,10 +179,8 @@ class WechatyChannel(Channel): try: if not query: return - context = dict() - context['type'] = 'IMAGE_CREATE' - context['content'] = query - img_url = super().build_reply_content(query, context)['content'] + context = Context(ContextType.IMAGE_CREATE, query) + img_url = super().build_reply_content(query, context).content if not img_url: return # 图片发送 diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 3bd24dd..e988452 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -5,6 +5,8 @@ import os import traceback from typing import Tuple from bridge.bridge import Bridge +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType from config import load_config import plugins from plugins import * @@ -123,13 +125,13 @@ class Godcmd(Plugin): def on_handle_context(self, e_context: EventContext): - context_type = e_context['context']['type'] - if context_type != "TEXT": + context_type = e_context['context'].type + if context_type != ContextType.TEXT: if not self.isrunning: e_context.action = EventAction.BREAK_PASS return - content = e_context['context']['content'] + content = e_context['context'].content logger.debug("[Godcmd] on_handle_context. content: %s" % content) if content.startswith("#"): # msg = e_context['context']['msg'] @@ -239,12 +241,12 @@ class Godcmd(Plugin): else: ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" - reply = {} + reply = Reply() if ok: - reply["type"] = "INFO" + reply.type = ReplyType.INFO else: - reply["type"] = "ERROR" - reply["content"] = result + reply.type = ReplyType.ERROR + reply.content = result e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index 1eb409e..53d87e6 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -1,5 +1,7 @@ # encoding:utf-8 +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType import plugins from plugins import * from common.log import logger @@ -14,31 +16,31 @@ class Hello(Plugin): def on_handle_context(self, e_context: EventContext): - if e_context['context']['type'] != "TEXT": + if e_context['context'].type != ContextType.TEXT: return - content = e_context['context']['content'] + content = e_context['context'].content logger.debug("[Hello] on_handle_context. content: %s" % content) if content == "Hello": - reply = {} - reply['type'] = "TEXT" + reply = Reply() + reply.type = ReplyType.TEXT msg = e_context['context']['msg'] if e_context['context']['isgroup']: - reply['content'] = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") + reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") else: - reply['content'] = "Hello, " + msg['User'].get('NickName', "My friend") + reply.content = "Hello, " + msg['User'].get('NickName', "My friend") e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 if content == "Hi": - reply={} - reply['type'] = "TEXT" - reply['content'] = "Hi" + reply = Reply() + reply.type = ReplyType.TEXT + reply.content = "Hi" e_context['reply'] = reply e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply if content == "End": # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" - e_context['context']['type'] = "IMAGE_CREATE" + e_context['context'].type = "IMAGE_CREATE" content = "The World" e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py index adab169..531d8ce 100644 --- a/voice/baidu/baidu_voice.py +++ b/voice/baidu/baidu_voice.py @@ -4,6 +4,7 @@ baidu voice service """ import time from aip import AipSpeech +from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir from voice.voice import Voice @@ -30,8 +31,8 @@ class BaiduVoice(Voice): with open(fileName, 'wb') as f: f.write(result) logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) - reply = {"type": "VOICE", "content": fileName} + reply = Reply(ReplyType.VOICE, fileName) else: logger.error('[Baidu] textToVoice error={}'.format(result)) - reply = {"type": "ERROR", "content": "抱歉,语音合成失败"} + reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") return reply diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py index 6c00892..74431db 100644 --- a/voice/google/google_voice.py +++ b/voice/google/google_voice.py @@ -6,6 +6,7 @@ google voice service import pathlib import subprocess import time +from bridge.reply import Reply, ReplyType import speech_recognition import pyttsx3 from common.log import logger @@ -32,16 +33,15 @@ class GoogleVoice(Voice): ' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True) with speech_recognition.AudioFile(new_file) as source: audio = self.recognizer.record(source) - reply = {} try: text = self.recognizer.recognize_google(audio, language='zh-CN') logger.info( '[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) - reply = {"type": "TEXT", "content": text} + reply = Reply(ReplyType.TEXT, text) except speech_recognition.UnknownValueError: - reply = {"type": "ERROR", "content": "抱歉,我听不懂"} + reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") except speech_recognition.RequestError as e: - reply = {"type": "ERROR", "content": "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)} + reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) finally: return reply def textToVoice(self, text): @@ -51,8 +51,8 @@ class GoogleVoice(Voice): self.engine.runAndWait() logger.info( '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) - reply = {"type": "VOICE", "content": textFile} + reply = Reply(ReplyType.VOICE, textFile) except Exception as e: - reply = {"type": "ERROR", "content": str(e)} + reply = Reply(ReplyType.ERROR, str(e)) finally: return reply diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 3b77c52..2e85e10 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -4,6 +4,7 @@ google voice service """ import json import openai +from bridge.reply import Reply, ReplyType from config import conf from common.log import logger from voice.voice import Voice @@ -16,16 +17,15 @@ class OpenaiVoice(Voice): def voiceToText(self, voice_file): logger.debug( '[Openai] voice file name={}'.format(voice_file)) - reply={} try: file = open(voice_file, "rb") result = openai.Audio.transcribe("whisper-1", file) text = result["text"] - reply = {"type": "TEXT", "content": text} + reply = Reply(ReplyType.TEXT, text) logger.info( '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) except Exception as e: - reply = {"type": "ERROR", "content": str(e)} + reply = Reply(ReplyType.ERROR, str(e)) finally: return reply From dce9c4dccbfb7ff7fbc63093a470ba0165fcb694 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 19:58:35 +0800 Subject: [PATCH 15/21] compatible with openai bot --- bot/baidu/baidu_unit_bot.py | 4 +++- bot/openai/open_ai_bot.py | 47 ++++++++++++++++++++----------------- plugins/godcmd/godcmd.py | 4 ++-- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/bot/baidu/baidu_unit_bot.py b/bot/baidu/baidu_unit_bot.py index a84ac57..2b7dd8d 100644 --- a/bot/baidu/baidu_unit_bot.py +++ b/bot/baidu/baidu_unit_bot.py @@ -2,6 +2,7 @@ import requests from bot.bot import Bot +from bridge.reply import Reply, ReplyType # Baidu Unit对话接口 (可用, 但能力较弱) @@ -14,7 +15,8 @@ class BaiduUnitBot(Bot): headers = {'content-type': 'application/x-www-form-urlencoded'} response = requests.post(url, data=post_data.encode(), headers=headers) if response: - return response.json()['result']['context']['SYS_PRESUMED_HIST'][1] + reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1]) + return reply def get_token(self): access_key = 'YOUR_ACCESS_KEY' diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index c948a7c..e96d60f 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -1,6 +1,8 @@ # encoding:utf-8 from bot.bot import Bot +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType from config import conf from common.log import logger import openai @@ -13,30 +15,31 @@ class OpenAIBot(Bot): def __init__(self): openai.api_key = conf().get('open_ai_api_key') - def reply(self, query, context=None): # acquire reply content - if not context or not context.get('type') or context.get('type') == 'TEXT': - logger.info("[OPEN_AI] query={}".format(query)) - from_user_id = context.get('from_user_id') or context.get('session_id') - if query == '#清除记忆': - Session.clear_session(from_user_id) - return '记忆已清除' - elif query == '#清除所有': - Session.clear_all_session() - return '所有人记忆已清除' - - new_query = Session.build_session_query(query, from_user_id) - logger.debug("[OPEN_AI] session query={}".format(new_query)) - - reply_content = self.reply_text(new_query, from_user_id, 0) - logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) - if reply_content and query: - Session.save_session(query, reply_content, from_user_id) - return reply_content - - elif context.get('type', None) == 'IMAGE_CREATE': - return self.create_img(query, 0) + if context and context.type: + if context.type == ContextType.TEXT: + logger.info("[OPEN_AI] query={}".format(query)) + from_user_id = context['session_id'] + reply = None + if query == '#清除记忆': + Session.clear_session(from_user_id) + reply = Reply(ReplyType.INFO, '记忆已清除') + elif query == '#清除所有': + Session.clear_all_session() + reply = Reply(ReplyType.INFO, '所有人记忆已清除') + else: + new_query = Session.build_session_query(query, from_user_id) + logger.debug("[OPEN_AI] session query={}".format(new_query)) + + reply_content = self.reply_text(new_query, from_user_id, 0) + logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) + if reply_content and query: + Session.save_session(query, reply_content, from_user_id) + reply = Reply(ReplyType.TEXT, reply_content) + return reply + elif context.type == ContextType.IMAGE_CREATE: + return self.create_img(query, 0) def reply_text(self, query, user_id, retry_count=0): try: diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index e988452..e6c971b 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -162,7 +162,7 @@ class Godcmd(Plugin): bot.sessions.clear_session(session_id) ok, result = True, "会话已重置" else: - ok, result = False, "当前机器人不支持重置会话" + ok, result = False, "当前对话机器人不支持重置会话" logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): if isadmin: @@ -184,7 +184,7 @@ class Godcmd(Plugin): bot.sessions.clear_all_session() ok, result = True, "重置所有会话成功" else: - ok, result = False, "当前机器人不支持重置会话" + ok, result = False, "当前对话机器人不支持重置会话" elif cmd == "debug": logger.setLevel('DEBUG') ok, result = True, "DEBUG模式已开启" From e6d148e72945813d668aba092c8ece4e90a6782e Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 00:49:28 +0800 Subject: [PATCH 16/21] plugins: add sdwebui(stable diffusion) plugin --- plugins/sdwebui/__init__.py | 0 plugins/sdwebui/config.json.template | 67 +++++++++++++++++++++ plugins/sdwebui/readme.md | 63 ++++++++++++++++++++ plugins/sdwebui/sdwebui.py | 88 ++++++++++++++++++++++++++++ 4 files changed, 218 insertions(+) create mode 100644 plugins/sdwebui/__init__.py create mode 100644 plugins/sdwebui/config.json.template create mode 100644 plugins/sdwebui/readme.md create mode 100644 plugins/sdwebui/sdwebui.py diff --git a/plugins/sdwebui/__init__.py b/plugins/sdwebui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/sdwebui/config.json.template b/plugins/sdwebui/config.json.template new file mode 100644 index 0000000..4fca569 --- /dev/null +++ b/plugins/sdwebui/config.json.template @@ -0,0 +1,67 @@ +{ + "start":{ + "host" : "127.0.0.1", + "port" : 7860 + }, + "defaults": { + "params": { + "sampler_name": "DPM++ 2M Karras", + "steps": 20, + "width": 512, + "height": 512, + "cfg_scale": 7, + "prompt":"masterpiece, best quality", + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "enable_hr": false, + "hr_scale": 2, + "hr_upscaler": "Latent", + "hr_second_pass_steps": 15, + "denoising_strength": 0.7 + }, + "options": { + "sd_model_checkpoint": "perfectWorld_v2Baked" + } + }, + "rules": [ + { + "keywords": [ + "横版", + "壁纸" + ], + "params": { + "width": 640, + "height": 384 + } + }, + { + "keywords": [ + "竖版" + ], + "params": { + "width": 384, + "height": 640 + } + }, + { + "keywords": [ + "高清" + ], + "params": { + "enable_hr": true, + "hr_scale": 1.6 + } + }, + { + "keywords": [ + "二次元" + ], + "params": { + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "prompt": "masterpiece, best quality" + }, + "options": { + "sd_model_checkpoint": "meinamix_meinaV8" + } + } + ] +} \ No newline at end of file diff --git a/plugins/sdwebui/readme.md b/plugins/sdwebui/readme.md new file mode 100644 index 0000000..ef325cd --- /dev/null +++ b/plugins/sdwebui/readme.md @@ -0,0 +1,63 @@ +本插件用于将画图请求转发给stable diffusion webui +使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api" +具体参考(https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API) + +请**安装**本插件的依赖包```webuiapi``` +``` + ```pip install webuiapi``` +``` +请将```config.json.template```复制为```config.json```,并修改其中的参数和规则 + +用户的画图请求格式为: +``` + <画图触发词><关键词1> <关键词2> ... <关键词n>: +``` +本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。 +规则会按顺序匹配,每个关键词最多匹配到1次,如果有重复的参数,则以最后一个为准: +第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后 + +例如: 画横版 高清 二次元:cat +会触发三个关键词 "横版", "高清", "二次元",prompt为"cat" +若默认参数是: +``` + "width": 512, + "height": 512, + "enable_hr": false, + "prompt": "8k" + "negative_prompt": "nsfw", + "sd_model_checkpoint": "perfectWorld_v2Baked" +``` + +"横版"触发的规则参数为: +``` + "width": 640, + "height": 384, +``` +"高清"触发的规则参数为: +``` + "enable_hr": true, + "hr_scale": 1.6, +``` +"二次元"触发的规则参数为: +``` + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "steps": 20, + "prompt": "masterpiece, best quality", + + "sd_model_checkpoint": "meinamix_meinaV8" +``` +最后将第一个":"后的内容cat连接在prompt后,得到最终参数为: +``` + "width": 640, + "height": 384, + "enable_hr": true, + "hr_scale": 1.6, + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "steps": 20, + "prompt": "masterpiece, best quality, cat", + + "sd_model_checkpoint": "meinamix_meinaV8" +``` +PS: 参数分为两部分, +一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致 +另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。 \ No newline at end of file diff --git a/plugins/sdwebui/sdwebui.py b/plugins/sdwebui/sdwebui.py new file mode 100644 index 0000000..f7da523 --- /dev/null +++ b/plugins/sdwebui/sdwebui.py @@ -0,0 +1,88 @@ +# encoding:utf-8 + +import json +import os +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +import plugins +from plugins import * +from common.log import logger +import webuiapi +import io + + +@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent") +class SDWebUI(Plugin): + def __init__(self): + super().__init__() + curdir = os.path.dirname(__file__) + config_path = os.path.join(curdir, "config.json") + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + self.rules = config["rules"] + defaults = config["defaults"] + self.default_params = defaults["params"] + self.default_options = defaults["options"] + self.start_args = config["start"] + self.api = webuiapi.WebUIApi(**self.start_args) + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[SD] inited") + except FileNotFoundError: + logger.error(f"[SD] init failed, {config_path} not found") + except Exception as e: + logger.error("[SD] init failed, exception: %s" % e) + + def on_handle_context(self, e_context: EventContext): + + if e_context['context'].type != ContextType.IMAGE_CREATE: + return + + logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content) + + logger.info("[SD] image_query={}".format(e_context['context'].content)) + reply = Reply() + try: + content = e_context['context'].content[:] + # 解析用户输入 如"横版 高清 二次元:cat" + keywords, prompt = content.split(":", 1) + keywords = keywords.split() + + rule_params = {} + rule_options = {} + for keyword in keywords: + matched = False + for rule in self.rules: + if keyword in rule["keywords"]: + for key in rule["params"]: + rule_params[key] = rule["params"][key] + if "options" in rule: + for key in rule["options"]: + rule_options[key] = rule["options"][key] + matched = True + break # 一个关键词只匹配一个规则 + if not matched: + logger.warning("[SD] keyword not matched: %s" % keyword) + + params = {**self.default_params, **rule_params} + options = {**self.default_options, **rule_options} + params["prompt"] = params.get("prompt", "")+f", {prompt}" + if len(options) > 0: + logger.info("[SD] cover rule_options={}".format(rule_options)) + self.api.set_options(options) + logger.info("[SD] params={}".format(params)) + result = self.api.txt2img( + **params + ) + reply.type = ReplyType.IMAGE + b_img = io.BytesIO() + result.image.save(b_img, format="PNG") + reply.content = b_img + e_context.action = EventAction.BREAK_PASS # 事件结束后,不跳过处理context的默认逻辑 + except Exception as e: + reply.type = ReplyType.ERROR + reply.content = "[SD] "+str(e) + logger.error("[SD] exception: %s" % e) + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 + finally: + e_context['reply'] = reply From e6b65437e4197ca148ff017bac3a6a80ed4f01dd Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 12:07:03 +0800 Subject: [PATCH 17/21] sdwebui : add help reply --- plugins/sdwebui/config.json.template | 9 ++- plugins/sdwebui/sdwebui.py | 90 ++++++++++++++++++---------- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/plugins/sdwebui/config.json.template b/plugins/sdwebui/config.json.template index 4fca569..213acdc 100644 --- a/plugins/sdwebui/config.json.template +++ b/plugins/sdwebui/config.json.template @@ -31,7 +31,8 @@ "params": { "width": 640, "height": 384 - } + }, + "desc": "分辨率会变成640x384" }, { "keywords": [ @@ -49,7 +50,8 @@ "params": { "enable_hr": true, "hr_scale": 1.6 - } + }, + "desc": "出图分辨率长宽都会提高1.6倍" }, { "keywords": [ @@ -61,7 +63,8 @@ }, "options": { "sd_model_checkpoint": "meinamix_meinaV8" - } + }, + "desc": "使用二次元风格模型出图" } ] } \ No newline at end of file diff --git a/plugins/sdwebui/sdwebui.py b/plugins/sdwebui/sdwebui.py index f7da523..cc07c7a 100644 --- a/plugins/sdwebui/sdwebui.py +++ b/plugins/sdwebui/sdwebui.py @@ -4,6 +4,7 @@ import json import os from bridge.context import ContextType from bridge.reply import Reply, ReplyType +from config import conf import plugins from plugins import * from common.log import logger @@ -45,40 +46,49 @@ class SDWebUI(Plugin): try: content = e_context['context'].content[:] # 解析用户输入 如"横版 高清 二次元:cat" - keywords, prompt = content.split(":", 1) + if ":" in content: + keywords, prompt = content.split(":", 1) + else: + keywords = content + prompt = "" + keywords = keywords.split() - rule_params = {} - rule_options = {} - for keyword in keywords: - matched = False - for rule in self.rules: - if keyword in rule["keywords"]: - for key in rule["params"]: - rule_params[key] = rule["params"][key] - if "options" in rule: - for key in rule["options"]: - rule_options[key] = rule["options"][key] - matched = True - break # 一个关键词只匹配一个规则 - if not matched: - logger.warning("[SD] keyword not matched: %s" % keyword) - - params = {**self.default_params, **rule_params} - options = {**self.default_options, **rule_options} - params["prompt"] = params.get("prompt", "")+f", {prompt}" - if len(options) > 0: - logger.info("[SD] cover rule_options={}".format(rule_options)) - self.api.set_options(options) - logger.info("[SD] params={}".format(params)) - result = self.api.txt2img( - **params - ) - reply.type = ReplyType.IMAGE - b_img = io.BytesIO() - result.image.save(b_img, format="PNG") - reply.content = b_img - e_context.action = EventAction.BREAK_PASS # 事件结束后,不跳过处理context的默认逻辑 + if "help" in keywords or "帮助" in keywords: + reply.type = ReplyType.INFO + reply.content = self.get_help_text() + else: + rule_params = {} + rule_options = {} + for keyword in keywords: + matched = False + for rule in self.rules: + if keyword in rule["keywords"]: + for key in rule["params"]: + rule_params[key] = rule["params"][key] + if "options" in rule: + for key in rule["options"]: + rule_options[key] = rule["options"][key] + matched = True + break # 一个关键词只匹配一个规则 + if not matched: + logger.warning("[SD] keyword not matched: %s" % keyword) + + params = {**self.default_params, **rule_params} + options = {**self.default_options, **rule_options} + params["prompt"] = params.get("prompt", "")+f", {prompt}" + if len(options) > 0: + logger.info("[SD] cover rule_options={}".format(rule_options)) + self.api.set_options(options) + logger.info("[SD] params={}".format(params)) + result = self.api.txt2img( + **params + ) + reply.type = ReplyType.IMAGE + b_img = io.BytesIO() + result.image.save(b_img, format="PNG") + reply.content = b_img + e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑 except Exception as e: reply.type = ReplyType.ERROR reply.content = "[SD] "+str(e) @@ -86,3 +96,19 @@ class SDWebUI(Plugin): e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 finally: e_context['reply'] = reply + + def get_help_text(self): + if not conf().get('image_create_prefix'): + return "画图功能未启用" + else: + trigger = conf()['image_create_prefix'][0] + help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n" + help_text += "目前可用关键词:\n" + for rule in self.rules: + keywords = [f"[{keyword}]" for keyword in rule['keywords']] + help_text += f"{','.join(keywords)}" + if "desc" in rule: + help_text += f"-{rule['desc']}\n" + else: + help_text += "\n" + return help_text \ No newline at end of file From c782b38ba18613a22566f8dd96f1024f74bfee02 Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 15:31:26 +0800 Subject: [PATCH 18/21] sdwebui: modify README.md --- plugins/sdwebui/readme.md | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/plugins/sdwebui/readme.md b/plugins/sdwebui/readme.md index ef325cd..bb8c62c 100644 --- a/plugins/sdwebui/readme.md +++ b/plugins/sdwebui/readme.md @@ -1,19 +1,25 @@ -本插件用于将画图请求转发给stable diffusion webui -使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api" -具体参考(https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API) +### 插件描述 +本插件用于将画图请求转发给stable diffusion webui。 + +### 环境要求 +使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。 +具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。 请**安装**本插件的依赖包```webuiapi``` ``` ```pip install webuiapi``` ``` -请将```config.json.template```复制为```config.json```,并修改其中的参数和规则 +### 使用说明 +请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。 +#### 画图请求格式 用户的画图请求格式为: ``` <画图触发词><关键词1> <关键词2> ... <关键词n>: ``` -本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。 -规则会按顺序匹配,每个关键词最多匹配到1次,如果有重复的参数,则以最后一个为准: +- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。 +- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准: +- 关键词中包含`help`或`帮助`,会打印出帮助文档。 第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后 例如: 画横版 高清 二次元:cat @@ -58,6 +64,6 @@ "sd_model_checkpoint": "meinamix_meinaV8" ``` -PS: 参数分为两部分, -一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致 -另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。 \ No newline at end of file +PS: 参数分为两部分: +- 一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致 +- 另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。 \ No newline at end of file From 300b7b96874338505ff3147519925dce815fd96f Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 15:59:52 +0800 Subject: [PATCH 19/21] plugins: support reload plugin --- plugins/godcmd/godcmd.py | 14 ++++++++++++++ plugins/plugin_manager.py | 10 ++++++++++ plugins/sdwebui/sdwebui.py | 2 +- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index e6c971b..296b9a9 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -63,6 +63,11 @@ ADMIN_COMMANDS = { "args": ["插件名", "优先级"], "desc": "设置指定插件的优先级,越大越优先", }, + "reloadp": { + "alias": ["reloadp", "重载插件"], + "args": ["插件名"], + "desc": "重载指定插件配置", + }, "enablep": { "alias": ["enablep", "启用插件"], "args": ["插件名"], @@ -216,6 +221,15 @@ class Godcmd(Plugin): result = "插件" + args[0] + "优先级已设置为" + args[1] else: result = "插件不存在" + elif cmd == "reloadp": + if len(args) != 1: + ok, result = False, "请提供插件名" + else: + ok = PluginManager().reload_plugin(args[0]) + if ok: + result = "插件配置已重载" + else: + result = "插件不存在" elif cmd == "enablep": if len(args) != 1: ok, result = False, "请提供插件名" diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index 107c63b..9fada89 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -96,6 +96,16 @@ class PluginManager: self.listening_plugins[event].append(name) self.refresh_order() + def reload_plugin(self, name): + if name in self.instances: + for event in self.listening_plugins: + if name in self.listening_plugins[event]: + self.listening_plugins[event].remove(name) + del self.instances[name] + self.activate_plugins() + return True + return False + def load_plugins(self): self.load_config() self.scan_plugins() diff --git a/plugins/sdwebui/sdwebui.py b/plugins/sdwebui/sdwebui.py index cc07c7a..56842e8 100644 --- a/plugins/sdwebui/sdwebui.py +++ b/plugins/sdwebui/sdwebui.py @@ -78,7 +78,7 @@ class SDWebUI(Plugin): options = {**self.default_options, **rule_options} params["prompt"] = params.get("prompt", "")+f", {prompt}" if len(options) > 0: - logger.info("[SD] cover rule_options={}".format(rule_options)) + logger.info("[SD] cover options={}".format(options)) self.api.set_options(options) logger.info("[SD] params={}".format(params)) result = self.api.txt2img( From 8915149d361cf0260f11b21e406091d968c7b0a7 Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 17:30:30 +0800 Subject: [PATCH 20/21] plugin: add banwords plugin --- plugins/banwords/.gitignore | 1 + plugins/banwords/README.md | 9 + plugins/banwords/WordsSearch.py | 250 +++++++++++++++++++++++++ plugins/banwords/__init__.py | 0 plugins/banwords/banwords.py | 63 +++++++ plugins/banwords/banwords.txt.template | 3 + plugins/banwords/config.json.template | 3 + plugins/godcmd/config.json.template | 4 + plugins/godcmd/godcmd.py | 6 +- 9 files changed, 338 insertions(+), 1 deletion(-) create mode 100644 plugins/banwords/.gitignore create mode 100644 plugins/banwords/README.md create mode 100644 plugins/banwords/WordsSearch.py create mode 100644 plugins/banwords/__init__.py create mode 100644 plugins/banwords/banwords.py create mode 100644 plugins/banwords/banwords.txt.template create mode 100644 plugins/banwords/config.json.template create mode 100644 plugins/godcmd/config.json.template diff --git a/plugins/banwords/.gitignore b/plugins/banwords/.gitignore new file mode 100644 index 0000000..a6593bf --- /dev/null +++ b/plugins/banwords/.gitignore @@ -0,0 +1 @@ +banwords.txt \ No newline at end of file diff --git a/plugins/banwords/README.md b/plugins/banwords/README.md new file mode 100644 index 0000000..9c7e498 --- /dev/null +++ b/plugins/banwords/README.md @@ -0,0 +1,9 @@ +### 说明 +简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。 + +`config.json`中能够填写默认的处理行为,目前行为有: +- `ignore` : 无视这条消息。 +- `replace` : 将消息中的敏感词替换成"*",并回复违规。 + +### 致谢 +搜索功能实现来自https://github.com/toolgood/ToolGood.Words \ No newline at end of file diff --git a/plugins/banwords/WordsSearch.py b/plugins/banwords/WordsSearch.py new file mode 100644 index 0000000..d41d6e7 --- /dev/null +++ b/plugins/banwords/WordsSearch.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# ToolGood.Words.WordsSearch.py +# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words +# Licensed under the Apache License 2.0 +# 更新日志 +# 2020.04.06 第一次提交 +# 2020.05.16 修改,支持大于0xffff的字符 + +__all__ = ['WordsSearch'] +__author__ = 'Lin Zhijun' +__date__ = '2020.05.16' + +class TrieNode(): + def __init__(self): + self.Index = 0 + self.Index = 0 + self.Layer = 0 + self.End = False + self.Char = '' + self.Results = [] + self.m_values = {} + self.Failure = None + self.Parent = None + + def Add(self,c): + if c in self.m_values : + return self.m_values[c] + node = TrieNode() + node.Parent = self + node.Char = c + self.m_values[c] = node + return node + + def SetResults(self,index): + if (self.End == False): + self.End = True + self.Results.append(index) + +class TrieNode2(): + def __init__(self): + self.End = False + self.Results = [] + self.m_values = {} + self.minflag = 0xffff + self.maxflag = 0 + + def Add(self,c,node3): + if (self.minflag > c): + self.minflag = c + if (self.maxflag < c): + self.maxflag = c + self.m_values[c] = node3 + + def SetResults(self,index): + if (self.End == False) : + self.End = True + if (index in self.Results )==False : + self.Results.append(index) + + def HasKey(self,c): + return c in self.m_values + + + def TryGetValue(self,c): + if (self.minflag <= c and self.maxflag >= c): + if c in self.m_values: + return self.m_values[c] + return None + + +class WordsSearch(): + def __init__(self): + self._first = {} + self._keywords = [] + self._indexs=[] + + def SetKeywords(self,keywords): + self._keywords = keywords + self._indexs=[] + for i in range(len(keywords)): + self._indexs.append(i) + + root = TrieNode() + allNodeLayer={} + + for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++) + p = self._keywords[i] + nd = root + for j in range(len(p)): # for (j = 0; j < p.length; j++) + nd = nd.Add(ord(p[j])) + if (nd.Layer == 0): + nd.Layer = j + 1 + if nd.Layer in allNodeLayer: + allNodeLayer[nd.Layer].append(nd) + else: + allNodeLayer[nd.Layer]=[] + allNodeLayer[nd.Layer].append(nd) + nd.SetResults(i) + + + allNode = [] + allNode.append(root) + for key in allNodeLayer.keys(): + for nd in allNodeLayer[key]: + allNode.append(nd) + allNodeLayer=None + + for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) + if i==0 : + continue + nd=allNode[i] + nd.Index = i + r = nd.Parent.Failure + c = nd.Char + while (r != None and (c in r.m_values)==False): + r = r.Failure + if (r == None): + nd.Failure = root + else: + nd.Failure = r.m_values[c] + for key2 in nd.Failure.Results : + nd.SetResults(key2) + root.Failure = root + + allNode2 = [] + for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) + allNode2.append( TrieNode2()) + + for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++) + oldNode = allNode[i] + newNode = allNode2[i] + + for key in oldNode.m_values : + index = oldNode.m_values[key].Index + newNode.Add(key, allNode2[index]) + + for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++) + item = oldNode.Results[index] + newNode.SetResults(item) + + oldNode=oldNode.Failure + while oldNode != root: + for key in oldNode.m_values : + if (newNode.HasKey(key) == False): + index = oldNode.m_values[key].Index + newNode.Add(key, allNode2[index]) + for index in range(len(oldNode.Results)): + item = oldNode.Results[index] + newNode.SetResults(item) + oldNode=oldNode.Failure + allNode = None + root = None + + # first = [] + # for index in range(65535):# for (index = 0; index < 0xffff; index++) + # first.append(None) + + # for key in allNode2[0].m_values : + # first[key] = allNode2[0].m_values[key] + + self._first = allNode2[0] + + + def FindFirst(self,text): + ptr = None + for index in range(len(text)): # for (index = 0; index < text.length; index++) + t =ord(text[index]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + + if (tn != None): + if (tn.End): + item = tn.Results[0] + keyword = self._keywords[item] + return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] } + ptr = tn + return None + + def FindAll(self,text): + ptr = None + list = [] + + for index in range(len(text)): # for (index = 0; index < text.length; index++) + t =ord(text[index]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + + if (tn != None): + if (tn.End): + for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++) + item = tn.Results[j] + keyword = self._keywords[item] + list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }) + ptr = tn + return list + + + def ContainsAny(self,text): + ptr = None + for index in range(len(text)): # for (index = 0; index < text.length; index++) + t =ord(text[index]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + if (tn != None): + if (tn.End): + return True + ptr = tn + return False + + def Replace(self,text, replaceChar = '*'): + result = list(text) + + ptr = None + for i in range(len(text)): # for (i = 0; i < text.length; i++) + t =ord(text[i]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + if (tn != None): + if (tn.End): + maxLength = len( self._keywords[tn.Results[0]]) + start = i + 1 - maxLength + for j in range(start,i+1): # for (j = start; j <= i; j++) + result[j] = replaceChar + ptr = tn + return ''.join(result) \ No newline at end of file diff --git a/plugins/banwords/__init__.py b/plugins/banwords/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/banwords/banwords.py b/plugins/banwords/banwords.py new file mode 100644 index 0000000..2b4a711 --- /dev/null +++ b/plugins/banwords/banwords.py @@ -0,0 +1,63 @@ +# encoding:utf-8 + +import json +import os +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +import plugins +from plugins import * +from common.log import logger +from .WordsSearch import WordsSearch + + +@plugins.register(name="Banwords", desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent", desire_priority= 100) +class Banwords(Plugin): + def __init__(self): + super().__init__() + try: + curdir=os.path.dirname(__file__) + config_path=os.path.join(curdir,"config.json") + conf=None + if not os.path.exists(config_path): + conf={"action":"ignore"} + with open(config_path,"w") as f: + json.dump(conf,f,indent=4) + else: + with open(config_path,"r") as f: + conf=json.load(f) + self.searchr = WordsSearch() + self.action = conf["action"] + banwords_path = os.path.join(curdir,"banwords.txt") + with open(banwords_path, 'r', encoding='utf-8') as f: + words=[] + for line in f: + word = line.strip() + if word: + words.append(word) + self.searchr.SetKeywords(words) + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[Banwords] inited") + except Exception as e: + logger.error("Banwords init failed: %s" % e) + + + + def on_handle_context(self, e_context: EventContext): + + if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]: + return + + content = e_context['context'].content + logger.debug("[Banwords] on_handle_context. content: %s" % content) + if self.action == "ignore": + f = self.searchr.FindFirst(content) + if f: + logger.info("Banwords: %s" % f["Keyword"]) + e_context.action = EventAction.BREAK_PASS + return + elif self.action == "replace": + if self.searchr.ContainsAny(content): + reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content)) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + return \ No newline at end of file diff --git a/plugins/banwords/banwords.txt.template b/plugins/banwords/banwords.txt.template new file mode 100644 index 0000000..9b2e8ed --- /dev/null +++ b/plugins/banwords/banwords.txt.template @@ -0,0 +1,3 @@ +nipples +pennis +法轮功 \ No newline at end of file diff --git a/plugins/banwords/config.json.template b/plugins/banwords/config.json.template new file mode 100644 index 0000000..000fdda --- /dev/null +++ b/plugins/banwords/config.json.template @@ -0,0 +1,3 @@ +{ + "action": "ignore" +} \ No newline at end of file diff --git a/plugins/godcmd/config.json.template b/plugins/godcmd/config.json.template new file mode 100644 index 0000000..5240738 --- /dev/null +++ b/plugins/godcmd/config.json.template @@ -0,0 +1,4 @@ +{ + "password": "", + "admin_users": [] +} \ No newline at end of file diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 296b9a9..2a942df 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -273,9 +273,13 @@ class Godcmd(Plugin): if isadmin: return False,"管理员账号无需认证" - + + if len(self.password) == 0: + return False,"未设置口令,无法认证" + if len(args) != 1: return False,"请提供口令" + password = args[0] if password == self.password: self.admin_users.append(userid) From 2568322879158cf0ecaa5327affffde1a16fa478 Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 18:02:07 +0800 Subject: [PATCH 21/21] plugin: ignore cases when manage plugins --- plugins/godcmd/godcmd.py | 2 +- plugins/plugin_manager.py | 36 ++++++++++++++++++++++-------------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 2a942df..c33c1e2 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -198,7 +198,7 @@ class Godcmd(Plugin): ok = True result = "插件列表:\n" for name,plugincls in plugins.items(): - result += f"{name}_v{plugincls.version} {plugincls.priority} - " + result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " if plugincls.enabled: result += "已启用\n" else: diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index 9fada89..d946786 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -26,7 +26,7 @@ class PluginManager: plugincls.author = author plugincls.priority = desire_priority plugincls.enabled = True - self.plugins[name] = plugincls + self.plugins[name.upper()] = plugincls logger.info("Plugin %s_v%s registered" % (name, version)) return plugincls return wrapper @@ -67,14 +67,15 @@ class PluginManager: new_plugins = [] modified = False for name, plugincls in self.plugins.items(): - if name not in pconf["plugins"]: + rawname = plugincls.name + if rawname not in pconf["plugins"]: new_plugins.append(plugincls) modified = True logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) - pconf["plugins"][name] = {"enabled": plugincls.enabled, "priority": plugincls.priority} + pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} else: - self.plugins[name].enabled = pconf["plugins"][name]["enabled"] - self.plugins[name].priority = pconf["plugins"][name]["priority"] + self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] + self.plugins[name].priority = pconf["plugins"][rawname]["priority"] self.plugins._update_heap(name) # 更新下plugins中的顺序 if modified: self.save_config() @@ -96,7 +97,8 @@ class PluginManager: self.listening_plugins[event].append(name) self.refresh_order() - def reload_plugin(self, name): + def reload_plugin(self, name:str): + name = name.upper() if name in self.instances: for event in self.listening_plugins: if name in self.listening_plugins[event]: @@ -112,7 +114,7 @@ class PluginManager: pconf = self.pconf logger.debug("plugins.json config={}".format(pconf)) for name,plugin in pconf["plugins"].items(): - if name not in self.plugins: + if name.upper() not in self.plugins: logger.error("Plugin %s not found, but found in plugins.json" % name) self.activate_plugins() @@ -125,36 +127,42 @@ class PluginManager: instance.handlers[e_context.event](e_context, *args, **kwargs) return e_context - def set_plugin_priority(self,name,priority): + def set_plugin_priority(self, name:str, priority:int): + name = name.upper() if name not in self.plugins: return False if self.plugins[name].priority == priority: return True self.plugins[name].priority = priority self.plugins._update_heap(name) - self.pconf["plugins"][name]["priority"] = priority - self.pconf["plugins"]._update_heap(name) + rawname = self.plugins[name].name + self.pconf["plugins"][rawname]["priority"] = priority + self.pconf["plugins"]._update_heap(rawname) self.save_config() self.refresh_order() return True - def enable_plugin(self,name): + def enable_plugin(self, name:str): + name = name.upper() if name not in self.plugins: return False if not self.plugins[name].enabled : self.plugins[name].enabled = True - self.pconf["plugins"][name]["enabled"] = True + rawname = self.plugins[name].name + self.pconf["plugins"][rawname]["enabled"] = True self.save_config() self.activate_plugins() return True return True - def disable_plugin(self,name): + def disable_plugin(self, name:str): + name = name.upper() if name not in self.plugins: return False if self.plugins[name].enabled : self.plugins[name].enabled = False - self.pconf["plugins"][name]["enabled"] = False + rawname = self.plugins[name].name + self.pconf["plugins"][rawname]["enabled"] = False self.save_config() return True return True