From 38c8ceba12b9a0e1e6ce2e52de6e4cf6faa3a1fe Mon Sep 17 00:00:00 2001 From: lanvent Date: Sat, 11 Mar 2023 02:20:39 +0800 Subject: [PATCH 01/29] avoid repeatedly instantiating bot --- bot/chatgpt/chat_gpt_bot.py | 49 ++++++++++++++++--------------------- bridge/bridge.py | 19 ++++++++++---- common/singleton.py | 9 +++++++ 3 files changed, 44 insertions(+), 33 deletions(-) create mode 100644 common/singleton.py diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 78e344f..f94b1f7 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -7,16 +7,13 @@ from common.expired_dict import ExpiredDict import openai import time -if conf().get('expires_in_seconds'): - all_sessions = ExpiredDict(conf().get('expires_in_seconds')) -else: - all_sessions = dict() # OpenAI对话模型API (可用) class ChatGPTBot(Bot): def __init__(self): openai.api_key = conf().get('open_ai_api_key') proxy = conf().get('proxy') + self.sessions=SessionManager() if proxy: openai.proxy = proxy @@ -26,16 +23,16 @@ class ChatGPTBot(Bot): logger.info("[OPEN_AI] query={}".format(query)) session_id = context.get('session_id') or context.get('from_user_id') if query == '#清除记忆': - Session.clear_session(session_id) + self.sessions.clear_session(session_id) return '记忆已清除' elif query == '#清除所有': - Session.clear_all_session() + self.sessions.clear_all_session() return '所有人记忆已清除' elif query == '#更新配置': load_config() return '配置已更新' - session = Session.build_session_query(query, session_id) + session = self.sessions.build_session_query(query, session_id) logger.debug("[OPEN_AI] session query={}".format(session)) # if context.get('stream'): @@ -45,7 +42,7 @@ class ChatGPTBot(Bot): reply_content = self.reply_text(session, session_id, 0) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) if reply_content["completion_tokens"] > 0: - Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) + self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) return reply_content["content"] elif context.get('type', None) == 'IMAGE_CREATE': @@ -94,7 +91,7 @@ class ChatGPTBot(Bot): except Exception as e: # unknown exception logger.exception(e) - Session.clear_session(session_id) + self.sessions.clear_session(session_id) return {"completion_tokens": 0, "content": "请再问我一次吧"} def create_img(self, query, retry_count=0): @@ -119,10 +116,11 @@ class ChatGPTBot(Bot): except Exception as e: logger.exception(e) return None - -class Session(object): - @staticmethod - def build_session_query(query, session_id): + +class SessionManager(object): + def __init__(self): + self.sessions = {} + def build_session_query(self,query, session_id): ''' build query with conversation history e.g. [ @@ -135,36 +133,33 @@ class Session(object): :param session_id: session id :return: query content with conversaction ''' - session = all_sessions.get(session_id, []) + session = self.sessions.get(session_id, []) if len(session) == 0: system_prompt = conf().get("character_desc", "") system_item = {'role': 'system', 'content': system_prompt} session.append(system_item) - all_sessions[session_id] = session + self.sessions[session_id] = session user_item = {'role': 'user', 'content': query} session.append(user_item) return session - @staticmethod - def save_session(answer, session_id, total_tokens): + def save_session(self, answer, session_id, total_tokens): max_tokens = conf().get("conversation_max_tokens") if not max_tokens: # default 3000 max_tokens = 1000 max_tokens=int(max_tokens) - session = all_sessions.get(session_id) + session = self.sessions.get(session_id) if session: # append conversation gpt_item = {'role': 'assistant', 'content': answer} session.append(gpt_item) # discard exceed limit conversation - Session.discard_exceed_conversation(session, max_tokens, total_tokens) + self.discard_exceed_conversation(session, max_tokens, total_tokens) - - @staticmethod - def discard_exceed_conversation(session, max_tokens, total_tokens): + def discard_exceed_conversation(self, session, max_tokens, total_tokens): dec_tokens = int(total_tokens) # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) while dec_tokens > max_tokens: @@ -176,10 +171,8 @@ class Session(object): break dec_tokens = dec_tokens - max_tokens - @staticmethod - def clear_session(session_id): - all_sessions[session_id] = [] + def clear_session(self,session_id): + self.sessions[session_id] = [] - @staticmethod - def clear_all_session(): - all_sessions.clear() + def clear_all_session(self): + self.sessions.clear() diff --git a/bridge/bridge.py b/bridge/bridge.py index e739a7f..068a58e 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,16 +1,25 @@ from bot import bot_factory +from common.singleton import singleton from voice import voice_factory - +@singleton class Bridge(object): def __init__(self): - pass + self.bots = { + "chat": bot_factory.create_bot("chatGPT"), + "voice_to_text": voice_factory.create_voice("openai"), + # "text_to_voice": voice_factory.create_voice("baidu") + } + try: + self.bots["text_to_voice"] = voice_factory.create_voice("baidu") + except ModuleNotFoundError as e: + print(e) def fetch_reply_content(self, query, context): - return bot_factory.create_bot("chatGPT").reply(query, context) + return self.bots["chat"].reply(query, context) def fetch_voice_to_text(self, voiceFile): - return voice_factory.create_voice("openai").voiceToText(voiceFile) + return self.bots["voice_to_text"].voiceToText(voiceFile) def fetch_text_to_voice(self, text): - return voice_factory.create_voice("baidu").textToVoice(text) \ No newline at end of file + return self.bots["text_to_voice"].textToVoice(text) \ No newline at end of file diff --git a/common/singleton.py b/common/singleton.py new file mode 100644 index 0000000..b46095c --- /dev/null +++ b/common/singleton.py @@ -0,0 +1,9 @@ +def singleton(cls): + instances = {} + + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + + return get_instance From d6037422ac9c32523083dc301da63faa10eb352a Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 00:58:49 +0800 Subject: [PATCH 02/29] decouple message processing process --- bot/chatgpt/chat_gpt_bot.py | 45 ++++-- bridge/bridge.py | 5 + channel/wechat/wechat_channel.py | 253 ++++++++++++++++--------------- 3 files changed, 167 insertions(+), 136 deletions(-) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index f94b1f7..327cc95 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -19,19 +19,24 @@ class ChatGPTBot(Bot): def reply(self, query, context=None): # acquire reply content - if not context or not context.get('type') or context.get('type') == 'TEXT': + if context['type']=='TEXT': logger.info("[OPEN_AI] query={}".format(query)) - session_id = context.get('session_id') or context.get('from_user_id') + session_id = context['session_id'] + reply=None if query == '#清除记忆': self.sessions.clear_session(session_id) - return '记忆已清除' + reply={'type':'INFO', 'content':'记忆已清除'} elif query == '#清除所有': self.sessions.clear_all_session() - return '所有人记忆已清除' + reply={'type':'INFO', 'content':'所有人记忆已清除'} elif query == '#更新配置': load_config() - return '配置已更新' - + reply={'type':'INFO', 'content':'配置已更新'} + elif query == '#DEBUG': + logger.setLevel('DEBUG') + reply={'type':'INFO', 'content':'DEBUG模式已开启'} + if reply: + return reply session = self.sessions.build_session_query(query, session_id) logger.debug("[OPEN_AI] session query={}".format(session)) @@ -41,12 +46,26 @@ class ChatGPTBot(Bot): reply_content = self.reply_text(session, session_id, 0) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) - if reply_content["completion_tokens"] > 0: + if reply_content['completion_tokens']==0 and len(reply_content['content'])>0: + reply={'type':'ERROR', 'content':reply_content['content']} + elif reply_content["completion_tokens"] > 0: self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) - return reply_content["content"] + reply={'type':'TEXT', 'content':reply_content["content"]} + else: + reply={'type':'ERROR', 'content':reply_content['content']} + logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) + return reply - elif context.get('type', None) == 'IMAGE_CREATE': - return self.create_img(query, 0) + elif context['type'] == 'IMAGE_CREATE': + ok, retstring=self.create_img(query, 0) + reply=None + if ok: + reply = {'type':'IMAGE', 'content':retstring} + else: + reply = {'type':'ERROR', 'content':retstring} + return reply + else: + reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} def reply_text(self, session, session_id, retry_count=0) ->dict: ''' @@ -104,7 +123,7 @@ class ChatGPTBot(Bot): ) image_url = response['data'][0]['url'] logger.info("[OPEN_AI] image_url={}".format(image_url)) - return image_url + return True,image_url except openai.error.RateLimitError as e: logger.warn(e) if retry_count < 1: @@ -112,10 +131,10 @@ class ChatGPTBot(Bot): logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) return self.create_img(query, retry_count+1) else: - return "提问太快啦,请休息一下再问我吧" + return False,"提问太快啦,请休息一下再问我吧" except Exception as e: logger.exception(e) - return None + return False,str(e) class SessionManager(object): def __init__(self): diff --git a/bridge/bridge.py b/bridge/bridge.py index 068a58e..392d9e8 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -15,6 +15,11 @@ class Bridge(object): except ModuleNotFoundError as e: print(e) + + # 以下所有函数需要得到一个reply字典,格式如下: + # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... + # reply["content"] = reply的内容 + def fetch_reply_content(self, query, context): return self.bots["chat"].reply(query, context) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index d43a0a3..ddf38a4 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -46,62 +46,55 @@ class WechatChannel(Channel): # start message listener itchat.run() + + # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context + # context是一个字典,包含了消息的所有信息,包括以下key + # type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE + # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是CMD_IMAGE_CREATE类型,content就是图片生成命令 + # session_id: 会话id + # isgroup: 是否是群聊 + # msg: 原始消息对象 + # receiver: 需要回复的对象 def handle_voice(self, msg): if conf().get('speech_recognition') != True : return logger.debug("[WX]receive voice msg: " + msg['FileName']) - thread_pool.submit(self._do_handle_voice, msg) - - def _do_handle_voice(self, msg): from_user_id = msg['FromUserName'] other_user_id = msg['User']['UserName'] if from_user_id == other_user_id: - file_name = TmpDir().path() + msg['FileName'] - msg.download(file_name) - query = super().build_voice_to_text(file_name) - if conf().get('voice_reply_voice'): - self._do_send_voice(query, from_user_id) - else: - self._do_send_text(query, from_user_id) + context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id} + context['type']='VOICE' + context['session_id']=other_user_id + thread_pool.submit(self.handle, context) + def handle_text(self, msg): logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) content = msg['Text'] - self._handle_single_msg(msg, content) - - def _handle_single_msg(self, msg, content): from_user_id = msg['FromUserName'] to_user_id = msg['ToUserName'] # 接收人id other_user_id = msg['User']['UserName'] # 对手方id - match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) + match_prefix = check_prefix(content, conf().get('single_chat_prefix')) if "」\n- - - - - - - - - - - - - - -" in content: logger.debug("[WX]reference query skipped") return - if from_user_id == other_user_id and match_prefix is not None: - # 好友向自己发送消息 - if match_prefix != '': - str_list = content.split(match_prefix, 1) - if len(str_list) == 2: - content = str_list[1].strip() - - img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) - if img_match_prefix: - content = content.split(img_match_prefix, 1)[1].strip() - thread_pool.submit(self._do_send_img, content, from_user_id) - else : - thread_pool.submit(self._do_send_text, content, from_user_id) - elif to_user_id == other_user_id and match_prefix: - # 自己给好友发送消息 - str_list = content.split(match_prefix, 1) - if len(str_list) == 2: - content = str_list[1].strip() - img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) - if img_match_prefix: - content = content.split(img_match_prefix, 1)[1].strip() - thread_pool.submit(self._do_send_img, content, to_user_id) - else: - thread_pool.submit(self._do_send_text, content, to_user_id) + if match_prefix: + content=content.replace(match_prefix,'',1).strip() + else: + return + context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id} + context['session_id']=other_user_id + + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) + if img_match_prefix: + content=content.replace(img_match_prefix,'',1).strip() + context['type']='CMD_IMAGE_CREATE' + else: + context['type']='TEXT' + + context['content']=content + thread_pool.submit(self.handle, context) def handle_group(self, msg): @@ -122,100 +115,114 @@ class WechatChannel(Channel): logger.debug("[WX]reference query skipped") return "" config = conf() - match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \ - or self.check_contain(origin_content, config.get('group_chat_keyword')) - if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: - img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) + match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ + or check_contain(origin_content, config.get('group_chat_keyword')) + if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: + context = { 'isgroup': True, 'msg': msg, 'receiver': group_id} + + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: - content = content.split(img_match_prefix, 1)[1].strip() - thread_pool.submit(self._do_send_img, content, group_id) + content=content.replace(img_match_prefix,'',1).strip() + context['type']='CMD_IMAGE_CREATE' else: - thread_pool.submit(self._do_send_group, content, msg) - - def send(self, msg, receiver): - itchat.send(msg, toUserName=receiver) - logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver)) - - def _do_send_voice(self, query, reply_user_id): - try: - if not query: - return - context = dict() - context['from_user_id'] = reply_user_id - reply_text = super().build_reply_content(query, context) - if reply_text: - replyFile = super().build_text_to_voice(reply_text) - itchat.send_file(replyFile, toUserName=reply_user_id) - logger.info('[WX] sendFile={}, receiver={}'.format(replyFile, reply_user_id)) - except Exception as e: - logger.exception(e) - - def _do_send_text(self, query, reply_user_id): - try: - if not query: - return - context = dict() - context['session_id'] = reply_user_id - reply_text = super().build_reply_content(query, context) - if reply_text: - self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) - except Exception as e: - logger.exception(e) - - def _do_send_img(self, query, reply_user_id): - try: - if not query: - return - context = dict() - context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) - if not img_url: - return - - # 图片下载 + context['type']='TEXT' + context['content']=content + + group_chat_in_one_session = conf().get('group_chat_in_one_session', []) + if ('ALL_GROUP' in group_chat_in_one_session or \ + group_name in group_chat_in_one_session or \ + check_contain(group_name, group_chat_in_one_session)): + context['session_id'] = group_id + else: + context['session_id'] = msg['ActualUserName'] + + thread_pool.submit(self.handle, context) + + # 统一的发送函数,根据reply的type字段发送不同类型的消息 + + def send(self, reply, receiver): + if reply['type']=='TEXT': + itchat.send(reply['content'], toUserName=receiver) + logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + elif reply['type']=='ERROR' or reply['type']=='INFO': + itchat.send(reply['content'], toUserName=receiver) + logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + elif reply['type']=='VOICE': + itchat.send_file(reply['content'], toUserName=receiver) + logger.info('[WX] sendFile={}, receiver={}'.format(reply['content'], receiver)) + elif reply['type']=='IMAGE_URL': # 从网络下载图片 + img_url = reply['content'] pic_res = requests.get(img_url, stream=True) image_storage = io.BytesIO() for block in pic_res.iter_content(1024): image_storage.write(block) image_storage.seek(0) - - # 图片发送 - itchat.send_image(image_storage, reply_user_id) - logger.info('[WX] sendImage, receiver={}'.format(reply_user_id)) - except Exception as e: - logger.exception(e) - - def _do_send_group(self, query, msg): - if not query: - return - context = dict() - group_name = msg['User']['NickName'] - group_id = msg['User']['UserName'] - group_chat_in_one_session = conf().get('group_chat_in_one_session', []) - if ('ALL_GROUP' in group_chat_in_one_session or \ - group_name in group_chat_in_one_session or \ - self.check_contain(group_name, group_chat_in_one_session)): - context['session_id'] = group_id + itchat.send_image(image_storage, toUserName=receiver) + logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver)) + elif reply['type']=='IMAGE': # 从文件读取图片 + image_storage = reply['content'] + image_storage.seek(0) + itchat.send_image(image_storage, toUserName=receiver) + logger.info('[WX] sendImage, receiver={}'.format(receiver)) + + # 处理消息 + def handle(self, context): + content=context['content'] + reply=None + + logger.debug('[WX] ready to handle context: {}'.format(context)) + # reply的构建步骤 + if context['type']=='TEXT' or context['type']=='CMD_IMAGE_CREATE': + reply = super().build_reply_content(content,context) + elif context['type']=='VOICE': + msg=context['msg'] + file_name = TmpDir().path() + msg['FileName'] + msg.download(file_name) + reply = super().build_voice_to_text(file_name) + if reply['type'] != 'ERROR' and reply['type'] != 'INFO': + reply = super().build_reply_content(reply['content'],context) + if reply['type']=='TEXT': + if conf().get('voice_reply_voice'): + reply = super().build_text_to_voice(reply['content']) else: - context['session_id'] = msg['ActualUserName'] - reply_text = super().build_reply_content(query, context) - if reply_text: - reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip() - self.send(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) - - - def check_prefix(self, content, prefix_list): - for prefix in prefix_list: - if content.startswith(prefix): - return prefix - return None + logger.error('[WX] unknown context type: {}'.format(context['type'])) + return + + logger.debug('[WX] ready to decorate reply: {}'.format(reply)) + # reply的包装步骤 + if reply: + if reply['type']=='TEXT': + reply_text=reply['content'] + if context['isgroup']: + reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() + reply_text=conf().get("group_chat_reply_prefix","")+reply_text + else: + reply_text=conf().get("single_chat_reply_prefix","")+reply_text + reply['content']=reply_text + elif reply['type']=='ERROR' or reply['type']=='INFO': + reply['content']=reply['type']+": "+ reply['content'] + elif reply['type']=='IMAGE_URL' or reply['type']=='VOICE': + pass + else: + logger.error('[WX] unknown reply type: {}'.format(reply['type'])) + return + if reply: + logger.debug('[WX] ready to send reply: {} to {}'.format(reply,context['receiver'])) + self.send(reply, context['receiver']) + + +def check_prefix(content, prefix_list): + for prefix in prefix_list: + if content.startswith(prefix): + return prefix + return None - def check_contain(self, content, keyword_list): - if not keyword_list: - return None - for ky in keyword_list: - if content.find(ky) != -1: - return True +def check_contain(content, keyword_list): + if not keyword_list: return None + for ky in keyword_list: + if content.find(ky) != -1: + return True + return None From 9ae7b7773eaa3e0dcfa319ae953c5aa8bc14b1ab Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 01:10:18 +0800 Subject: [PATCH 03/29] simple compatibility for wechaty --- channel/wechat/wechaty_channel.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 5e01464..7ae4683 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -129,7 +129,9 @@ class WechatyChannel(Channel): return context = dict() context['session_id'] = reply_user_id - reply_text = super().build_reply_content(query, context) + context['type'] = 'TEXT' + context['content'] = query + reply_text = super().build_reply_content(query, context)['content'] if reply_text: await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) except Exception as e: @@ -141,7 +143,8 @@ class WechatyChannel(Channel): return context = dict() context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) + context['content'] = query + img_url = super().build_reply_content(query, context)['content'] if not img_url: return # 图片下载 @@ -170,7 +173,9 @@ class WechatyChannel(Channel): context['session_id'] = str(group_id) else: context['session_id'] = str(group_id) + '-' + str(group_user_id) - reply_text = super().build_reply_content(query, context) + context['type'] = 'TEXT' + context['content'] = query + reply_text = super().build_reply_content(query, context)['content'] if reply_text: reply_text = '@' + group_user_name + ' ' + reply_text.strip() await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) @@ -181,7 +186,8 @@ class WechatyChannel(Channel): return context = dict() context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) + context['content'] = query + img_url = super().build_reply_content(query, context)['content'] if not img_url: return # 图片发送 From 9e07703eb1264367d147d50c679b40c580278cee Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 01:22:49 +0800 Subject: [PATCH 04/29] formatting code --- bot/chatgpt/chat_gpt_bot.py | 58 +++++++++--------- bridge/bridge.py | 4 +- channel/wechat/wechat_channel.py | 101 ++++++++++++++++--------------- 3 files changed, 83 insertions(+), 80 deletions(-) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 327cc95..9bfea5b 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -13,28 +13,28 @@ class ChatGPTBot(Bot): def __init__(self): openai.api_key = conf().get('open_ai_api_key') proxy = conf().get('proxy') - self.sessions=SessionManager() + self.sessions = SessionManager() if proxy: openai.proxy = proxy def reply(self, query, context=None): # acquire reply content - if context['type']=='TEXT': + if context['type'] == 'TEXT': logger.info("[OPEN_AI] query={}".format(query)) session_id = context['session_id'] - reply=None + reply = None if query == '#清除记忆': self.sessions.clear_session(session_id) - reply={'type':'INFO', 'content':'记忆已清除'} + reply = {'type': 'INFO', 'content': '记忆已清除'} elif query == '#清除所有': self.sessions.clear_all_session() - reply={'type':'INFO', 'content':'所有人记忆已清除'} + reply = {'type': 'INFO', 'content': '所有人记忆已清除'} elif query == '#更新配置': load_config() - reply={'type':'INFO', 'content':'配置已更新'} + reply = {'type': 'INFO', 'content': '配置已更新'} elif query == '#DEBUG': logger.setLevel('DEBUG') - reply={'type':'INFO', 'content':'DEBUG模式已开启'} + reply = {'type': 'INFO', 'content': 'DEBUG模式已开启'} if reply: return reply session = self.sessions.build_session_query(query, session_id) @@ -46,28 +46,28 @@ class ChatGPTBot(Bot): reply_content = self.reply_text(session, session_id, 0) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) - if reply_content['completion_tokens']==0 and len(reply_content['content'])>0: - reply={'type':'ERROR', 'content':reply_content['content']} + if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: + reply = {'type': 'ERROR', 'content': reply_content['content']} elif reply_content["completion_tokens"] > 0: self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) reply={'type':'TEXT', 'content':reply_content["content"]} else: - reply={'type':'ERROR', 'content':reply_content['content']} + reply = {'type': 'ERROR', 'content': reply_content['content']} logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) return reply elif context['type'] == 'IMAGE_CREATE': - ok, retstring=self.create_img(query, 0) - reply=None + ok, retstring = self.create_img(query, 0) + reply = None if ok: - reply = {'type':'IMAGE', 'content':retstring} + reply = {'type': 'IMAGE', 'content': retstring} else: - reply = {'type':'ERROR', 'content':retstring} + reply = {'type': 'ERROR', 'content': retstring} return reply else: reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} - def reply_text(self, session, session_id, retry_count=0) ->dict: + def reply_text(self, session, session_id, retry_count=0) -> dict: ''' call openai's ChatCompletion to get the answer :param session: a conversation session @@ -86,8 +86,8 @@ class ChatGPTBot(Bot): presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 ) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) - return {"total_tokens": response["usage"]["total_tokens"], - "completion_tokens": response["usage"]["completion_tokens"], + return {"total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response["usage"]["completion_tokens"], "content": response.choices[0]['message']['content']} except openai.error.RateLimitError as e: # rate limit exception @@ -102,11 +102,11 @@ class ChatGPTBot(Bot): # api connection exception logger.warn(e) logger.warn("[OPEN_AI] APIConnection failed") - return {"completion_tokens": 0, "content":"我连接不到你的网络"} + return {"completion_tokens": 0, "content": "我连接不到你的网络"} except openai.error.Timeout as e: logger.warn(e) logger.warn("[OPEN_AI] Timeout") - return {"completion_tokens": 0, "content":"我没有收到你的消息"} + return {"completion_tokens": 0, "content": "我没有收到你的消息"} except Exception as e: # unknown exception logger.exception(e) @@ -123,7 +123,7 @@ class ChatGPTBot(Bot): ) image_url = response['data'][0]['url'] logger.info("[OPEN_AI] image_url={}".format(image_url)) - return True,image_url + return True, image_url except openai.error.RateLimitError as e: logger.warn(e) if retry_count < 1: @@ -131,15 +131,17 @@ class ChatGPTBot(Bot): logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) return self.create_img(query, retry_count+1) else: - return False,"提问太快啦,请休息一下再问我吧" + return False, "提问太快啦,请休息一下再问我吧" except Exception as e: logger.exception(e) - return False,str(e) - + return False, str(e) + + class SessionManager(object): def __init__(self): self.sessions = {} - def build_session_query(self,query, session_id): + + def build_session_query(self, query, session_id): ''' build query with conversation history e.g. [ @@ -167,7 +169,7 @@ class SessionManager(object): if not max_tokens: # default 3000 max_tokens = 1000 - max_tokens=int(max_tokens) + max_tokens = int(max_tokens) session = self.sessions.get(session_id) if session: @@ -177,7 +179,7 @@ class SessionManager(object): # discard exceed limit conversation self.discard_exceed_conversation(session, max_tokens, total_tokens) - + def discard_exceed_conversation(self, session, max_tokens, total_tokens): dec_tokens = int(total_tokens) # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) @@ -187,10 +189,10 @@ class SessionManager(object): session.pop(1) session.pop(1) else: - break + break dec_tokens = dec_tokens - max_tokens - def clear_session(self,session_id): + def clear_session(self, session_id): self.sessions[session_id] = [] def clear_all_session(self): diff --git a/bridge/bridge.py b/bridge/bridge.py index 392d9e8..81e5d73 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -2,6 +2,7 @@ from bot import bot_factory from common.singleton import singleton from voice import voice_factory + @singleton class Bridge(object): def __init__(self): @@ -15,7 +16,6 @@ class Bridge(object): except ModuleNotFoundError as e: print(e) - # 以下所有函数需要得到一个reply字典,格式如下: # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... # reply["content"] = reply的内容 @@ -27,4 +27,4 @@ class Bridge(object): return self.bots["voice_to_text"].voiceToText(voiceFile) def fetch_text_to_voice(self, text): - return self.bots["text_to_voice"].textToVoice(text) \ No newline at end of file + return self.bots["text_to_voice"].textToVoice(text) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index ddf38a4..e8be17e 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -46,7 +46,7 @@ class WechatChannel(Channel): # start message listener itchat.run() - + # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context # context是一个字典,包含了消息的所有信息,包括以下key # type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE @@ -57,18 +57,17 @@ class WechatChannel(Channel): # receiver: 需要回复的对象 def handle_voice(self, msg): - if conf().get('speech_recognition') != True : + if conf().get('speech_recognition') != True: return logger.debug("[WX]receive voice msg: " + msg['FileName']) from_user_id = msg['FromUserName'] other_user_id = msg['User']['UserName'] if from_user_id == other_user_id: - context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id} - context['type']='VOICE' - context['session_id']=other_user_id + context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} + context['type'] = 'VOICE' + context['session_id'] = other_user_id thread_pool.submit(self.handle, context) - def handle_text(self, msg): logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) content = msg['Text'] @@ -80,22 +79,21 @@ class WechatChannel(Channel): logger.debug("[WX]reference query skipped") return if match_prefix: - content=content.replace(match_prefix,'',1).strip() + content = content.replace(match_prefix, '', 1).strip() else: return - context = { 'isgroup': False, 'msg': msg, 'receiver': other_user_id} - context['session_id']=other_user_id - + context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} + context['session_id'] = other_user_id + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: - content=content.replace(img_match_prefix,'',1).strip() - context['type']='CMD_IMAGE_CREATE' + content = content.replace(img_match_prefix, '', 1).strip() + context['type'] = 'CMD_IMAGE_CREATE' else: - context['type']='TEXT' - - context['content']=content - thread_pool.submit(self.handle, context) + context['type'] = 'TEXT' + context['content'] = content + thread_pool.submit(self.handle, context) def handle_group(self, msg): logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) @@ -122,32 +120,32 @@ class WechatChannel(Channel): img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: - content=content.replace(img_match_prefix,'',1).strip() - context['type']='CMD_IMAGE_CREATE' + content = content.replace(img_match_prefix, '', 1).strip() + context['type'] = 'CMD_IMAGE_CREATE' else: - context['type']='TEXT' - context['content']=content + context['type'] = 'TEXT' + context['content'] = content group_chat_in_one_session = conf().get('group_chat_in_one_session', []) - if ('ALL_GROUP' in group_chat_in_one_session or \ - group_name in group_chat_in_one_session or \ + if ('ALL_GROUP' in group_chat_in_one_session or + group_name in group_chat_in_one_session or check_contain(group_name, group_chat_in_one_session)): context['session_id'] = group_id else: context['session_id'] = msg['ActualUserName'] - + thread_pool.submit(self.handle, context) - + # 统一的发送函数,根据reply的type字段发送不同类型的消息 def send(self, reply, receiver): - if reply['type']=='TEXT': + if reply['type'] == 'TEXT': itchat.send(reply['content'], toUserName=receiver) logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) - elif reply['type']=='ERROR' or reply['type']=='INFO': + elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': itchat.send(reply['content'], toUserName=receiver) logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) - elif reply['type']=='VOICE': + elif reply['type'] == 'VOICE': itchat.send_file(reply['content'], toUserName=receiver) logger.info('[WX] sendFile={}, receiver={}'.format(reply['content'], receiver)) elif reply['type']=='IMAGE_URL': # 从网络下载图片 @@ -164,52 +162,56 @@ class WechatChannel(Channel): image_storage.seek(0) itchat.send_image(image_storage, toUserName=receiver) logger.info('[WX] sendImage, receiver={}'.format(receiver)) - + # 处理消息 def handle(self, context): - content=context['content'] - reply=None + content = context['content'] + reply = None logger.debug('[WX] ready to handle context: {}'.format(context)) # reply的构建步骤 - if context['type']=='TEXT' or context['type']=='CMD_IMAGE_CREATE': - reply = super().build_reply_content(content,context) - elif context['type']=='VOICE': - msg=context['msg'] + if context['type'] == 'TEXT' or context['type'] == 'CMD_IMAGE_CREATE': + reply = super().build_reply_content(content, context) + elif context['type'] == 'VOICE': + msg = context['msg'] file_name = TmpDir().path() + msg['FileName'] msg.download(file_name) reply = super().build_voice_to_text(file_name) if reply['type'] != 'ERROR' and reply['type'] != 'INFO': - reply = super().build_reply_content(reply['content'],context) - if reply['type']=='TEXT': + reply = super().build_reply_content(reply['content'], context) + if reply['type'] == 'TEXT': if conf().get('voice_reply_voice'): reply = super().build_text_to_voice(reply['content']) else: logger.error('[WX] unknown context type: {}'.format(context['type'])) return - + logger.debug('[WX] ready to decorate reply: {}'.format(reply)) # reply的包装步骤 if reply: - if reply['type']=='TEXT': - reply_text=reply['content'] + if reply['type'] == 'TEXT': + reply_text = reply['content'] if context['isgroup']: - reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() - reply_text=conf().get("group_chat_reply_prefix","")+reply_text + reply_text = '@' + \ + context['msg']['ActualNickName'] + \ + ' ' + reply_text.strip() + reply_text = conf().get("group_chat_reply_prefix", "")+reply_text else: - reply_text=conf().get("single_chat_reply_prefix","")+reply_text - reply['content']=reply_text - elif reply['type']=='ERROR' or reply['type']=='INFO': - reply['content']=reply['type']+": "+ reply['content'] - elif reply['type']=='IMAGE_URL' or reply['type']=='VOICE': + reply_text = conf().get("single_chat_reply_prefix", "")+reply_text + reply['content'] = reply_text + elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': + reply['content'] = reply['type']+": " + reply['content'] + elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE': pass else: - logger.error('[WX] unknown reply type: {}'.format(reply['type'])) + logger.error( + '[WX] unknown reply type: {}'.format(reply['type'])) return if reply: - logger.debug('[WX] ready to send reply: {} to {}'.format(reply,context['receiver'])) + logger.debug('[WX] ready to send reply: {} to {}'.format( + reply, context['receiver'])) self.send(reply, context['receiver']) - + def check_prefix(content, prefix_list): for prefix in prefix_list: @@ -225,4 +227,3 @@ def check_contain(content, keyword_list): if content.find(ky) != -1: return True return None - From 0fcf0824dcf7830bd60aaa025e6c115364e5c4d8 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 11:53:06 +0800 Subject: [PATCH 05/29] feat: support plugins --- .gitignore | 1 + app.py | 7 ++- bot/chatgpt/chat_gpt_bot.py | 9 ++- channel/wechat/wechat_channel.py | 103 +++++++++++++++++-------------- plugins/__init__.py | 9 +++ plugins/event.py | 49 +++++++++++++++ plugins/plugin.py | 3 + plugins/plugin_manager.py | 89 ++++++++++++++++++++++++++ 8 files changed, 220 insertions(+), 50 deletions(-) create mode 100644 plugins/__init__.py create mode 100644 plugins/event.py create mode 100644 plugins/plugin.py create mode 100644 plugins/plugin_manager.py diff --git a/.gitignore b/.gitignore index 8bc62f3..c349037 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ config.json QR.png nohup.out tmp +plugins.json \ No newline at end of file diff --git a/app.py b/app.py index 1ca359f..f07b275 100644 --- a/app.py +++ b/app.py @@ -4,14 +4,17 @@ import config from channel import channel_factory from common.log import logger - +from plugins import * if __name__ == '__main__': try: # load config config.load_config() # create channel - channel = channel_factory.create_channel("wx") + channel_name='wx' + channel = channel_factory.create_channel(channel_name) + if channel_name=='wx': + PluginManager().load_plugins() # startup channel channel.startup() diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 9bfea5b..2c8567d 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -60,12 +60,13 @@ class ChatGPTBot(Bot): ok, retstring = self.create_img(query, 0) reply = None if ok: - reply = {'type': 'IMAGE', 'content': retstring} + reply = {'type': 'IMAGE_URL', 'content': retstring} else: reply = {'type': 'ERROR', 'content': retstring} return reply else: reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} + return reply def reply_text(self, session, session_id, retry_count=0) -> dict: ''' @@ -139,7 +140,11 @@ class ChatGPTBot(Bot): class SessionManager(object): def __init__(self): - self.sessions = {} + if conf().get('expires_in_seconds'): + sessions = ExpiredDict(conf().get('expires_in_seconds')) + else: + sessions = dict() + self.sessions = sessions def build_session_query(self, query, session_id): ''' diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index e8be17e..f436e48 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -12,9 +12,12 @@ from concurrent.futures import ThreadPoolExecutor from common.log import logger from common.tmp_dir import TmpDir from config import conf +from plugins import * + import requests import io + thread_pool = ThreadPoolExecutor(max_workers=8) @@ -49,8 +52,8 @@ class WechatChannel(Channel): # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context # context是一个字典,包含了消息的所有信息,包括以下key - # type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE - # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是CMD_IMAGE_CREATE类型,content就是图片生成命令 + # type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE + # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 # session_id: 会话id # isgroup: 是否是群聊 # msg: 原始消息对象 @@ -88,7 +91,7 @@ class WechatChannel(Channel): img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'CMD_IMAGE_CREATE' + context['type'] = 'IMAGE_CREATE' else: context['type'] = 'TEXT' @@ -121,7 +124,7 @@ class WechatChannel(Channel): img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'CMD_IMAGE_CREATE' + context['type'] = 'IMAGE_CREATE' else: context['type'] = 'TEXT' context['content'] = content @@ -136,8 +139,7 @@ class WechatChannel(Channel): thread_pool.submit(self.handle, context) - # 统一的发送函数,根据reply的type字段发送不同类型的消息 - + # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 def send(self, reply, receiver): if reply['type'] == 'TEXT': itchat.send(reply['content'], toUserName=receiver) @@ -163,54 +165,63 @@ class WechatChannel(Channel): itchat.send_image(image_storage, toUserName=receiver) logger.info('[WX] sendImage, receiver={}'.format(receiver)) - # 处理消息 + # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 def handle(self, context): - content = context['content'] - reply = None + reply = {} logger.debug('[WX] ready to handle context: {}'.format(context)) + # reply的构建步骤 - if context['type'] == 'TEXT' or context['type'] == 'CMD_IMAGE_CREATE': - reply = super().build_reply_content(content, context) - elif context['type'] == 'VOICE': - msg = context['msg'] - file_name = TmpDir().path() + msg['FileName'] - msg.download(file_name) - reply = super().build_voice_to_text(file_name) - if reply['type'] != 'ERROR' and reply['type'] != 'INFO': - reply = super().build_reply_content(reply['content'], context) - if reply['type'] == 'TEXT': - if conf().get('voice_reply_voice'): - reply = super().build_text_to_voice(reply['content']) - else: - logger.error('[WX] unknown context type: {}'.format(context['type'])) - return + e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass(): + logger.debug('[WX] ready to handle context: type={}, content={}'.format(context['type'], context['content'])) + if context['type'] == 'TEXT' or context['type'] == 'IMAGE_CREATE': + reply = super().build_reply_content(context['content'], context) + elif context['type'] == 'VOICE': + msg = context['msg'] + file_name = TmpDir().path() + msg['FileName'] + msg.download(file_name) + reply = super().build_voice_to_text(file_name) + if reply['type'] != 'ERROR' and reply['type'] != 'INFO': + reply = super().build_reply_content(reply['content'], context) + if reply['type'] == 'TEXT': + if conf().get('voice_reply_voice'): + reply = super().build_text_to_voice(reply['content']) + else: + logger.error('[WX] unknown context type: {}'.format(context['type'])) + return logger.debug('[WX] ready to decorate reply: {}'.format(reply)) + # reply的包装步骤 - if reply: - if reply['type'] == 'TEXT': - reply_text = reply['content'] - if context['isgroup']: - reply_text = '@' + \ - context['msg']['ActualNickName'] + \ - ' ' + reply_text.strip() - reply_text = conf().get("group_chat_reply_prefix", "")+reply_text + if reply and reply['type']: + e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass() and reply and reply['type']: + if reply['type'] == 'TEXT': + reply_text = reply['content'] + if context['isgroup']: + reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() + reply_text = conf().get("group_chat_reply_prefix", "")+reply_text + else: + reply_text = conf().get("single_chat_reply_prefix", "")+reply_text + reply['content'] = reply_text + elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': + reply['content'] = reply['type']+": " + reply['content'] + elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE': + pass else: - reply_text = conf().get("single_chat_reply_prefix", "")+reply_text - reply['content'] = reply_text - elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': - reply['content'] = reply['type']+": " + reply['content'] - elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE': - pass - else: - logger.error( - '[WX] unknown reply type: {}'.format(reply['type'])) - return - if reply: - logger.debug('[WX] ready to send reply: {} to {}'.format( - reply, context['receiver'])) - self.send(reply, context['receiver']) + logger.error('[WX] unknown reply type: {}'.format(reply['type'])) + return + + # reply的发送步骤 + if reply and reply['type']: + e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass() and reply and reply['type']: + logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) + self.send(reply, context['receiver']) def check_prefix(content, prefix_list): diff --git a/plugins/__init__.py b/plugins/__init__.py new file mode 100644 index 0000000..6137d4a --- /dev/null +++ b/plugins/__init__.py @@ -0,0 +1,9 @@ +from .plugin_manager import PluginManager +from .event import * +from .plugin import * + +instance = PluginManager() + +register = instance.register +# load_plugins = instance.load_plugins +# emit_event = instance.emit_event diff --git a/plugins/event.py b/plugins/event.py new file mode 100644 index 0000000..a65e548 --- /dev/null +++ b/plugins/event.py @@ -0,0 +1,49 @@ +# encoding:utf-8 + +from enum import Enum + + +class Event(Enum): + # ON_RECEIVE_MESSAGE = 1 # 收到消息 + + ON_HANDLE_CONTEXT = 2 # 处理消息前 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } + """ + + ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } + """ + + ON_SEND_REPLY = 4 # 发送回复前 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } + """ + + # AFTER_SEND_REPLY = 5 # 发送回复后 + + +class EventAction(Enum): + CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 + BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 + BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 + + +class EventContext: + def __init__(self, event, econtext=dict()): + self.event = event + self.econtext = econtext + self.action = EventAction.CONTINUE + + def __getitem__(self, key): + return self.econtext[key] + + def __setitem__(self, key, value): + self.econtext[key] = value + + def __delitem__(self, key): + del self.econtext[key] + + def is_pass(self): + return self.action == EventAction.BREAK_PASS diff --git a/plugins/plugin.py b/plugins/plugin.py new file mode 100644 index 0000000..865eecb --- /dev/null +++ b/plugins/plugin.py @@ -0,0 +1,3 @@ +class Plugin: + def __init__(self): + self.handlers = {} diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py new file mode 100644 index 0000000..d4cda12 --- /dev/null +++ b/plugins/plugin_manager.py @@ -0,0 +1,89 @@ +# encoding:utf-8 + +import importlib +import json +import os +from common.singleton import singleton +from .event import * +from .plugin import * +from common.log import logger + + +@singleton +class PluginManager: + def __init__(self): + self.plugins = {} + self.listening_plugins = {} + self.instances = {} + + def register(self, name: str, desc: str, version: str, author: str): + def wrapper(plugincls): + self.plugins[name] = plugincls + plugincls.name = name + plugincls.desc = desc + plugincls.version = version + plugincls.author = author + plugincls.enabled = True + logger.info("Plugin %s registered" % name) + return plugincls + return wrapper + + def save_config(self, pconf): + with open("plugins/plugins.json", "w", encoding="utf-8") as f: + json.dump(pconf, f, indent=4, ensure_ascii=False) + + def load_config(self): + logger.info("Loading plugins config...") + plugins_dir = "plugins" + for plugin_name in os.listdir(plugins_dir): + plugin_path = os.path.join(plugins_dir, plugin_name) + if os.path.isdir(plugin_path): + # 判断插件是否包含main.py文件 + main_module_path = os.path.join(plugin_path, "main.py") + if os.path.isfile(main_module_path): + # 导入插件的main + import_path = "{}.{}.main".format(plugins_dir, plugin_name) + main_module = importlib.import_module(import_path) + + modified = False + if os.path.exists("plugins/plugins.json"): + with open("plugins/plugins.json", "r", encoding="utf-8") as f: + pconf = json.load(f) + else: + modified = True + pconf = {"plugins": []} + for name, plugincls in self.plugins.items(): + if name not in [plugin["name"] for plugin in pconf["plugins"]]: + modified = True + logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) + pconf["plugins"].append({"name": name, "enabled": True}) + if modified: + self.save_config(pconf) + return pconf + + def load_plugins(self): + pconf = self.load_config() + + for plugin in pconf["plugins"]: + name = plugin["name"] + enabled = plugin["enabled"] + self.plugins[name].enabled = enabled + + for name, plugincls in self.plugins.items(): + if plugincls.enabled: + if name not in self.instances: + instance = plugincls() + self.instances[name] = instance + for event in instance.handlers: + if event not in self.listening_plugins: + self.listening_plugins[event] = [] + self.listening_plugins[event].append(name) + + def emit_event(self, e_context: EventContext, *args, **kwargs): + if e_context.event in self.listening_plugins: + for name in self.listening_plugins[e_context.event]: + if e_context.action == EventAction.CONTINUE: + logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) + instance = self.instances[name] + instance.handlers[e_context.event](e_context, *args, **kwargs) + return e_context From d9b902f6ee92f07004ca32cf4025ae40c83c3950 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 11:53:47 +0800 Subject: [PATCH 06/29] add a plugin example --- plugins/hello/__init__.py | 0 plugins/hello/main.py | 40 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 plugins/hello/__init__.py create mode 100644 plugins/hello/main.py diff --git a/plugins/hello/__init__.py b/plugins/hello/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/hello/main.py b/plugins/hello/main.py new file mode 100644 index 0000000..c7d8ee1 --- /dev/null +++ b/plugins/hello/main.py @@ -0,0 +1,40 @@ +# encoding:utf-8 + +import plugins +from plugins import * +from common.log import logger + + +@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent") +class Hello(Plugin): + def __init__(self): + super().__init__() + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + # self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[hello] inited") + + def on_handle_context(self, e_context: EventContext): + + logger.debug("on_handle_context. content: %s" % e_context['context']['content']) + + if e_context['context']['content'] == "Hello": + e_context['reply']['type'] = "TEXT" + msg = e_context['context']['msg'] + if e_context['context']['isgroup']: + e_context['reply']['content'] = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") + else: + e_context['reply']['content'] = "Hello, " + msg['User'].get('NickName', "My friend") + + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + + if e_context['context']['content'] == "Hi": + e_context['reply']['type'] = "TEXT" + e_context['reply']['content'] = "Hi" + e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply + + if e_context['context']['content'] == "End": + # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" + if e_context['context']['type'] == "TEXT": + e_context['context']['type'] = "IMAGE_CREATE" + e_context['context']['content'] = "The World" + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 From 73de429af1935aa95331c63218f8ecb7d7233d0f Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 12:57:27 +0800 Subject: [PATCH 07/29] import file with the same name as plugin --- plugins/hello/{main.py => hello.py} | 4 ++-- plugins/plugin_manager.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) rename plugins/hello/{main.py => hello.py} (92%) diff --git a/plugins/hello/main.py b/plugins/hello/hello.py similarity index 92% rename from plugins/hello/main.py rename to plugins/hello/hello.py index c7d8ee1..c380b96 100644 --- a/plugins/hello/main.py +++ b/plugins/hello/hello.py @@ -11,11 +11,11 @@ class Hello(Plugin): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context # self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context - logger.info("[hello] inited") + logger.info("[Hello] inited") def on_handle_context(self, e_context: EventContext): - logger.debug("on_handle_context. content: %s" % e_context['context']['content']) + logger.debug("[Hello] on_handle_context. content: %s" % e_context['context']['content']) if e_context['context']['content'] == "Hello": e_context['reply']['type'] = "TEXT" diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index d4cda12..b7d88e7 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -38,11 +38,11 @@ class PluginManager: for plugin_name in os.listdir(plugins_dir): plugin_path = os.path.join(plugins_dir, plugin_name) if os.path.isdir(plugin_path): - # 判断插件是否包含main.py文件 - main_module_path = os.path.join(plugin_path, "main.py") + # 判断插件是否包含同名.py文件 + main_module_path = os.path.join(plugin_path, plugin_name+".py") if os.path.isfile(main_module_path): - # 导入插件的main - import_path = "{}.{}.main".format(plugins_dir, plugin_name) + # 导入插件 + import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name) main_module = importlib.import_module(import_path) modified = False @@ -63,7 +63,7 @@ class PluginManager: def load_plugins(self): pconf = self.load_config() - + logger.debug("plugins.json config={}" % pconf) for plugin in pconf["plugins"]: name = plugin["name"] enabled = plugin["enabled"] From 8847b5b6742509591f4085eea68595b871f63283 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 13:25:23 +0800 Subject: [PATCH 08/29] create bot when first need --- bridge/bridge.py | 26 ++++++++++++++++++-------- plugins/plugin_manager.py | 2 +- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/bridge/bridge.py b/bridge/bridge.py index 81e5d73..e16d721 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,3 +1,4 @@ +from common.log import logger from bot import bot_factory from common.singleton import singleton from voice import voice_factory @@ -6,16 +7,24 @@ from voice import voice_factory @singleton class Bridge(object): def __init__(self): - self.bots = { - "chat": bot_factory.create_bot("chatGPT"), - "voice_to_text": voice_factory.create_voice("openai"), - # "text_to_voice": voice_factory.create_voice("baidu") + self.btype={ + "chat": "chatGPT", + "voice_to_text": "openai", + "text_to_voice": "baidu" } - try: - self.bots["text_to_voice"] = voice_factory.create_voice("baidu") - except ModuleNotFoundError as e: - print(e) + self.bots={} + def getbot(self,typename): + if self.bots.get(typename) is None: + logger.info("create bot {} for {}".format(self.btype[typename],typename)) + if typename == "text_to_voice": + self.bots[typename] = voice_factory.create_voice(self.btype[typename]) + elif typename == "voice_to_text": + self.bots[typename] = voice_factory.create_voice(self.btype[typename]) + elif typename == "chat": + self.bots[typename] = bot_factory.create_bot(self.btype[typename]) + return self.bots[typename] + # 以下所有函数需要得到一个reply字典,格式如下: # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... # reply["content"] = reply的内容 @@ -28,3 +37,4 @@ class Bridge(object): def fetch_text_to_voice(self, text): return self.bots["text_to_voice"].textToVoice(text) + diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index b7d88e7..bf202f8 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -63,7 +63,7 @@ class PluginManager: def load_plugins(self): pconf = self.load_config() - logger.debug("plugins.json config={}" % pconf) + logger.debug("plugins.json config={}".format(pconf)) for plugin in pconf["plugins"]: name = plugin["name"] enabled = plugin["enabled"] From 475ada22e7d64a04cbcad0576feca0c859b8d23c Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 22:49:07 +0800 Subject: [PATCH 09/29] catch thread exception --- bridge/bridge.py | 11 +++++++---- channel/wechat/wechat_channel.py | 13 ++++++++----- plugins/hello/hello.py | 1 - plugins/plugin_manager.py | 2 +- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/bridge/bridge.py b/bridge/bridge.py index e16d721..304046e 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -14,7 +14,7 @@ class Bridge(object): } self.bots={} - def getbot(self,typename): + def get_bot(self,typename): if self.bots.get(typename) is None: logger.info("create bot {} for {}".format(self.btype[typename],typename)) if typename == "text_to_voice": @@ -25,16 +25,19 @@ class Bridge(object): self.bots[typename] = bot_factory.create_bot(self.btype[typename]) return self.bots[typename] + def get_bot_type(self,typename): + return self.btype[typename] + # 以下所有函数需要得到一个reply字典,格式如下: # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... # reply["content"] = reply的内容 def fetch_reply_content(self, query, context): - return self.bots["chat"].reply(query, context) + return self.get_bot("chat").reply(query, context) def fetch_voice_to_text(self, voiceFile): - return self.bots["voice_to_text"].voiceToText(voiceFile) + return self.get_bot("voice_to_text").voiceToText(voiceFile) def fetch_text_to_voice(self, text): - return self.bots["text_to_voice"].textToVoice(text) + return self.get_bot("text_to_voice").textToVoice(text) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index f436e48..73bb59a 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -19,7 +19,10 @@ import io thread_pool = ThreadPoolExecutor(max_workers=8) - +def thread_pool_callback(worker): + worker_exception = worker.exception() + if worker_exception: + logger.exception("Worker return exception: {}".format(worker_exception)) @itchat.msg_register(TEXT) def handler_single_msg(msg): @@ -69,7 +72,7 @@ class WechatChannel(Channel): context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} context['type'] = 'VOICE' context['session_id'] = other_user_id - thread_pool.submit(self.handle, context) + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_text(self, msg): logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) @@ -96,7 +99,7 @@ class WechatChannel(Channel): context['type'] = 'TEXT' context['content'] = content - thread_pool.submit(self.handle, context) + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_group(self, msg): logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) @@ -137,7 +140,7 @@ class WechatChannel(Channel): else: context['session_id'] = msg['ActualUserName'] - thread_pool.submit(self.handle, context) + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 def send(self, reply, receiver): @@ -208,7 +211,7 @@ class WechatChannel(Channel): reply_text = conf().get("single_chat_reply_prefix", "")+reply_text reply['content'] = reply_text elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': - reply['content'] = reply['type']+": " + reply['content'] + reply['content'] = reply['type']+":\n" + reply['content'] elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE': pass else: diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index c380b96..144906b 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -10,7 +10,6 @@ class Hello(Plugin): def __init__(self): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context - # self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context logger.info("[Hello] inited") def on_handle_context(self, e_context: EventContext): diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index bf202f8..dc8e892 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -24,7 +24,7 @@ class PluginManager: plugincls.version = version plugincls.author = author plugincls.enabled = True - logger.info("Plugin %s registered" % name) + logger.info("Plugin %s_v%s registered" % (name, version)) return plugincls return wrapper From cee57e4ffc05c7ea6e848a90cb1a6fbbfafd1730 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 23:05:28 +0800 Subject: [PATCH 10/29] plugin: add godcmd plugin --- plugins/godcmd/__init__.py | 0 plugins/godcmd/godcmd.py | 197 +++++++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 plugins/godcmd/__init__.py create mode 100644 plugins/godcmd/godcmd.py diff --git a/plugins/godcmd/__init__.py b/plugins/godcmd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py new file mode 100644 index 0000000..300353d --- /dev/null +++ b/plugins/godcmd/godcmd.py @@ -0,0 +1,197 @@ +# encoding:utf-8 + +import json +import os +import traceback +from typing import Tuple +from bridge.bridge import Bridge +from config import load_config +import plugins +from plugins import * +from common.log import logger + +# 定义指令集 +COMMANDS = { + "help": { + "alias": ["help", "帮助"], + "desc": "打印指令集合", + }, + "auth": { + "alias": ["auth", "认证"], + "args": ["口令"], + "desc": "管理员认证", + }, + # "id": { + # "alias": ["id", "用户"], + # "desc": "获取用户id", #目前无实际意义 + # }, + "reset": { + "alias": ["reset", "重置会话"], + "desc": "重置会话", + }, +} + +ADMIN_COMMANDS = { + "resume": { + "alias": ["resume", "恢复服务"], + "desc": "恢复服务", + }, + "stop": { + "alias": ["stop", "暂停服务"], + "desc": "暂停服务", + }, + "reconf": { + "alias": ["reconf", "重载配置"], + "desc": "重载配置(不包含插件配置)", + }, + "resetall": { + "alias": ["resetall", "重置所有会话"], + "desc": "重置所有会话", + }, + "debug": { + "alias": ["debug", "调试模式", "DEBUG"], + "desc": "开启机器调试日志", + }, +} +# 定义帮助函数 +def get_help_text(isadmin, isgroup): + help_text = "可用指令:\n" + for cmd, info in COMMANDS.items(): + if cmd=="auth" and (isadmin or isgroup): # 群聊不可认证 + continue + + alias=["#"+a for a in info['alias']] + help_text += f"{','.join(alias)} " + if 'args' in info: + args=["{"+a+"}" for a in info['args']] + help_text += f"{' '.join(args)} " + help_text += f": {info['desc']}\n" + if ADMIN_COMMANDS and isadmin: + help_text += "\n管理员指令:\n" + for cmd, info in ADMIN_COMMANDS.items(): + alias=["#"+a for a in info['alias']] + help_text += f"{','.join(alias)} " + help_text += f": {info['desc']}\n" + return help_text + +@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent") +class Godcmd(Plugin): + + def __init__(self): + super().__init__() + + curdir=os.path.dirname(__file__) + config_path=os.path.join(curdir,"config.json") + gconf=None + if not os.path.exists(config_path): + gconf={"password":"","admin_users":[]} + with open(config_path,"w") as f: + json.dump(gconf,f,indent=4) + else: + with open(config_path,"r") as f: + gconf=json.load(f) + + self.password = gconf["password"] + self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用 + self.isrunning = True # 机器人是否运行中 + + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[Godcmd] inited") + + + def on_handle_context(self, e_context: EventContext): + content = e_context['context']['content'] + context_type = e_context['context']['type'] + logger.debug("[Godcmd] on_handle_context. content: %s" % content) + + if content.startswith("#") and context_type == "TEXT": + # msg = e_context['context']['msg'] + user = e_context['context']['receiver'] + session_id = e_context['context']['session_id'] + isgroup = e_context['context']['isgroup'] + bottype = Bridge().get_bot_type("chat") + bot = Bridge().get_bot("chat") + # 将命令和参数分割 + command_parts = content[1:].split(" ") + cmd = command_parts[0] + args = command_parts[1:] + isadmin=False + if user in self.admin_users: + isadmin=True + ok=False + result="string" + if any(cmd in info['alias'] for info in COMMANDS.values()): + cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias']) + if cmd == "auth": + ok, result = self.authenticate(user, args, isadmin, isgroup) + elif cmd == "help": + ok, result = True, get_help_text(isadmin, isgroup) + elif cmd == "id": + ok, result = True, f"用户id=\n{user}" + elif cmd == "reset": + if bottype == "chatGPT": + bot.sessions.clear_session(session_id) + ok, result = True, "会话已重置" + else: + ok, result = False, "当前机器人不支持重置会话" + logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) + elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): + if isadmin: + if isgroup: + ok, result = False, "群聊不可执行管理员指令" + else: + cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias']) + if cmd == "stop": + self.isrunning = False + ok, result = True, "服务已暂停" + elif cmd == "resume": + self.isrunning = True + ok, result = True, "服务已恢复" + elif cmd == "reconf": + load_config() + ok, result = True, "配置已重载" + elif cmd == "resetall": + if bottype == "chatGPT": + bot.sessions.clear_all_session() + ok, result = True, "重置所有会话成功" + else: + ok, result = False, "当前机器人不支持重置会话" + elif cmd == "debug": + logger.setLevel('DEBUG') + ok, result = True, "DEBUG模式已开启" + logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user)) + else: + ok, result = False, "需要管理员权限才能执行该指令" + else: + ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" + + reply = {} + if ok: + reply["type"] = "INFO" + else: + reply["type"] = "ERROR" + reply["content"] = result + e_context['reply'] = reply + + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + elif not self.isrunning: + e_context.action = EventAction.BREAK_PASS + else: + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 + + def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : + if isgroup: + return False,"请勿在群聊中认证" + + if isadmin: + return False,"管理员账号无需认证" + + if len(args) != 1: + return False,"请提供口令" + password = args[0] + if password == self.password: + self.admin_users.append(userid) + return True,"认证成功" + else: + return False,"认证失败" + From 8d2e81815c104f2082dad4b695c6ccdbac2d6240 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 00:12:34 +0800 Subject: [PATCH 11/29] compatible for voice --- channel/wechat/wechat_channel.py | 7 +++++-- plugins/godcmd/godcmd.py | 12 ++++++----- plugins/hello/hello.py | 35 ++++++++++++++++++-------------- voice/baidu/baidu_voice.py | 5 +++-- voice/google/google_voice.py | 27 +++++++++++++++--------- voice/openai/openai_voice.py | 18 ++++++++++------ 6 files changed, 64 insertions(+), 40 deletions(-) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index 73bb59a..0ba923f 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -71,6 +71,7 @@ class WechatChannel(Channel): if from_user_id == other_user_id: context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} context['type'] = 'VOICE' + context['content'] = msg['FileName'] context['session_id'] = other_user_id thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) @@ -183,11 +184,13 @@ class WechatChannel(Channel): reply = super().build_reply_content(context['content'], context) elif context['type'] == 'VOICE': msg = context['msg'] - file_name = TmpDir().path() + msg['FileName'] + file_name = TmpDir().path() + context['content'] msg.download(file_name) reply = super().build_voice_to_text(file_name) if reply['type'] != 'ERROR' and reply['type'] != 'INFO': - reply = super().build_reply_content(reply['content'], context) + context['content'] = reply['content'] # 语音转文字后,将文字内容作为新的context + context['type'] = reply['type'] + reply = super().build_reply_content(context['content'], context) if reply['type'] == 'TEXT': if conf().get('voice_reply_voice'): reply = super().build_text_to_voice(reply['content']) diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 300353d..3dd8760 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -100,11 +100,15 @@ class Godcmd(Plugin): def on_handle_context(self, e_context: EventContext): - content = e_context['context']['content'] context_type = e_context['context']['type'] - logger.debug("[Godcmd] on_handle_context. content: %s" % content) + if context_type != "TEXT": + if not self.isrunning: + e_context.action = EventAction.BREAK_PASS + return - if content.startswith("#") and context_type == "TEXT": + content = e_context['context']['content'] + logger.debug("[Godcmd] on_handle_context. content: %s" % content) + if content.startswith("#"): # msg = e_context['context']['msg'] user = e_context['context']['receiver'] session_id = e_context['context']['session_id'] @@ -176,8 +180,6 @@ class Godcmd(Plugin): e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 elif not self.isrunning: e_context.action = EventAction.BREAK_PASS - else: - e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : if isgroup: diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index 144906b..ca1d257 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -14,26 +14,31 @@ class Hello(Plugin): def on_handle_context(self, e_context: EventContext): - logger.debug("[Hello] on_handle_context. content: %s" % e_context['context']['content']) - - if e_context['context']['content'] == "Hello": - e_context['reply']['type'] = "TEXT" + if e_context['context']['type'] != "TEXT": + return + + content = e_context['context']['content'] + logger.debug("[Hello] on_handle_context. content: %s" % content) + if content == "Hello": + reply = {} + reply['type'] = "TEXT" msg = e_context['context']['msg'] if e_context['context']['isgroup']: - e_context['reply']['content'] = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") + reply['content'] = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") else: - e_context['reply']['content'] = "Hello, " + msg['User'].get('NickName', "My friend") - + reply['content'] = "Hello, " + msg['User'].get('NickName', "My friend") + e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 - if e_context['context']['content'] == "Hi": - e_context['reply']['type'] = "TEXT" - e_context['reply']['content'] = "Hi" + if content == "Hi": + reply={} + reply['type'] = "TEXT" + reply['content'] = "Hi" + e_context['reply'] = reply e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply - if e_context['context']['content'] == "End": + if content == "End": # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" - if e_context['context']['type'] == "TEXT": - e_context['context']['type'] = "IMAGE_CREATE" - e_context['context']['content'] = "The World" - e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 + e_context['context']['type'] = "IMAGE_CREATE" + content = "The World" + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py index d99db37..adab169 100644 --- a/voice/baidu/baidu_voice.py +++ b/voice/baidu/baidu_voice.py @@ -30,7 +30,8 @@ class BaiduVoice(Voice): with open(fileName, 'wb') as f: f.write(result) logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) - return fileName + reply = {"type": "VOICE", "content": fileName} else: logger.error('[Baidu] textToVoice error={}'.format(result)) - return None + reply = {"type": "ERROR", "content": "抱歉,语音合成失败"} + return reply diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py index 8e339f2..6c00892 100644 --- a/voice/google/google_voice.py +++ b/voice/google/google_voice.py @@ -32,20 +32,27 @@ class GoogleVoice(Voice): ' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True) with speech_recognition.AudioFile(new_file) as source: audio = self.recognizer.record(source) + reply = {} try: text = self.recognizer.recognize_google(audio, language='zh-CN') logger.info( '[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) - return text + reply = {"type": "TEXT", "content": text} except speech_recognition.UnknownValueError: - return "抱歉,我听不懂。" + reply = {"type": "ERROR", "content": "抱歉,我听不懂"} except speech_recognition.RequestError as e: - return "抱歉,无法连接到 Google 语音识别服务;{0}".format(e) - + reply = {"type": "ERROR", "content": "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)} + finally: + return reply def textToVoice(self, text): - textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' - self.engine.save_to_file(text, textFile) - self.engine.runAndWait() - logger.info( - '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) - return textFile + try: + textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' + self.engine.save_to_file(text, textFile) + self.engine.runAndWait() + logger.info( + '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) + reply = {"type": "VOICE", "content": textFile} + except Exception as e: + reply = {"type": "ERROR", "content": str(e)} + finally: + return reply diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 475aac6..3b77c52 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -16,12 +16,18 @@ class OpenaiVoice(Voice): def voiceToText(self, voice_file): logger.debug( '[Openai] voice file name={}'.format(voice_file)) - file = open(voice_file, "rb") - reply = openai.Audio.transcribe("whisper-1", file) - text = reply["text"] - logger.info( - '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) - return text + reply={} + try: + file = open(voice_file, "rb") + result = openai.Audio.transcribe("whisper-1", file) + text = result["text"] + reply = {"type": "TEXT", "content": text} + logger.info( + '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) + except Exception as e: + reply = {"type": "ERROR", "content": str(e)} + finally: + return reply def textToVoice(self, text): pass From cb7bf446e3b9e55b3428ab8859de2b3d6c6238b2 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 01:50:37 +0800 Subject: [PATCH 12/29] plugin: godcmd support manage plugins --- plugins/godcmd/godcmd.py | 56 ++++++++++++++++++++++++++ plugins/plugin_manager.py | 82 +++++++++++++++++++++++++++++---------- 2 files changed, 118 insertions(+), 20 deletions(-) diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 3dd8760..2f57d32 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -48,6 +48,24 @@ ADMIN_COMMANDS = { "alias": ["resetall", "重置所有会话"], "desc": "重置所有会话", }, + "scanp": { + "alias": ["scanp", "扫描插件"], + "desc": "扫描插件目录是否有新插件", + }, + "plist": { + "alias": ["plist", "插件"], + "desc": "打印当前插件列表", + }, + "enablep": { + "alias": ["enablep", "启用插件"], + "args": ["插件名"], + "desc": "启用指定插件", + }, + "disablep": { + "alias": ["disablep", "禁用插件"], + "args": ["插件名"], + "desc": "禁用指定插件", + }, "debug": { "alias": ["debug", "调试模式", "DEBUG"], "desc": "开启机器调试日志", @@ -163,6 +181,44 @@ class Godcmd(Plugin): elif cmd == "debug": logger.setLevel('DEBUG') ok, result = True, "DEBUG模式已开启" + elif cmd == "plist": + plugins = PluginManager().list_plugins() + ok = True + result = "插件列表:\n" + for name,plugincls in plugins.items(): + result += f"{name}_v{plugincls.version} - " + if plugincls.enabled: + result += "已启用\n" + else: + result += "未启用\n" + elif cmd == "scanp": + new_plugins = PluginManager().scan_plugins() + ok, result = True, "插件扫描完成" + if len(new_plugins) >0 : + PluginManager().activate_plugins() + result += "\n发现新插件:\n" + result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) + else : + result +=", 未发现新插件" + elif cmd == "enablep": + if len(args) != 1: + ok, result = False, "请提供插件名" + else: + ok = PluginManager().enable_plugin(args[0]) + if ok: + result = "插件已启用" + else: + result = "插件不存在" + elif cmd == "disablep": + if len(args) != 1: + ok, result = False, "请提供插件名" + else: + ok = PluginManager().disable_plugin(args[0]) + if ok: + result = "插件已禁用" + else: + result = "插件不存在" + logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user)) else: ok, result = False, "需要管理员权限才能执行该指令" diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index dc8e892..630041a 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -15,6 +15,7 @@ class PluginManager: self.plugins = {} self.listening_plugins = {} self.instances = {} + self.pconf = {} def register(self, name: str, desc: str, version: str, author: str): def wrapper(plugincls): @@ -28,12 +29,27 @@ class PluginManager: return plugincls return wrapper - def save_config(self, pconf): + def save_config(self): with open("plugins/plugins.json", "w", encoding="utf-8") as f: - json.dump(pconf, f, indent=4, ensure_ascii=False) + json.dump(self.pconf, f, indent=4, ensure_ascii=False) def load_config(self): logger.info("Loading plugins config...") + + modified = False + if os.path.exists("plugins/plugins.json"): + with open("plugins/plugins.json", "r", encoding="utf-8") as f: + pconf = json.load(f) + else: + modified = True + pconf = {"plugins": []} + self.pconf = pconf + if modified: + self.save_config() + return pconf + + def scan_plugins(self): + logger.info("Scaning plugins ...") plugins_dir = "plugins" for plugin_name in os.listdir(plugins_dir): plugin_path = os.path.join(plugins_dir, plugin_name) @@ -44,31 +60,20 @@ class PluginManager: # 导入插件 import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name) main_module = importlib.import_module(import_path) - + pconf = self.pconf + new_plugins = [] modified = False - if os.path.exists("plugins/plugins.json"): - with open("plugins/plugins.json", "r", encoding="utf-8") as f: - pconf = json.load(f) - else: - modified = True - pconf = {"plugins": []} for name, plugincls in self.plugins.items(): if name not in [plugin["name"] for plugin in pconf["plugins"]]: + new_plugins.append(plugincls) modified = True logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) pconf["plugins"].append({"name": name, "enabled": True}) if modified: - self.save_config(pconf) - return pconf - - def load_plugins(self): - pconf = self.load_config() - logger.debug("plugins.json config={}".format(pconf)) - for plugin in pconf["plugins"]: - name = plugin["name"] - enabled = plugin["enabled"] - self.plugins[name].enabled = enabled + self.save_config() + return new_plugins + def activate_plugins(self): for name, plugincls in self.plugins.items(): if plugincls.enabled: if name not in self.instances: @@ -79,11 +84,48 @@ class PluginManager: self.listening_plugins[event] = [] self.listening_plugins[event].append(name) + def load_plugins(self): + self.load_config() + self.scan_plugins() + pconf = self.pconf + logger.debug("plugins.json config={}".format(pconf)) + for plugin in pconf["plugins"]: + name = plugin["name"] + enabled = plugin["enabled"] + self.plugins[name].enabled = enabled + self.activate_plugins() + def emit_event(self, e_context: EventContext, *args, **kwargs): if e_context.event in self.listening_plugins: for name in self.listening_plugins[e_context.event]: - if e_context.action == EventAction.CONTINUE: + if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) instance = self.instances[name] instance.handlers[e_context.event](e_context, *args, **kwargs) return e_context + + def enable_plugin(self,name): + if name not in self.plugins: + return False + if not self.plugins[name].enabled : + self.plugins[name].enabled = True + idx = next(i for i in range(len(self.pconf['plugins'])) if self.pconf["plugins"][i]['name'] == name) + self.pconf["plugins"][idx]["enabled"] = True + self.save_config() + self.activate_plugins() + return True + return True + + def disable_plugin(self,name): + if name not in self.plugins: + return False + if self.plugins[name].enabled : + self.plugins[name].enabled = False + idx = next(i for i in range(len(self.pconf['plugins'])) if self.pconf["plugins"][i]['name'] == name) + self.pconf["plugins"][idx]["enabled"] = False + self.save_config() + return True + return True + + def list_plugins(self): + return self.plugins \ No newline at end of file From 1dc3f85a66d0063a68494f257c365566d5e9c81c Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 15:32:28 +0800 Subject: [PATCH 13/29] plugin: support priority to decide trigger order --- common/sorted_dict.py | 65 +++++++++++++++++++++++++++++++++++++++ plugins/godcmd/godcmd.py | 20 ++++++++++-- plugins/hello/hello.py | 2 +- plugins/plugin_manager.py | 52 ++++++++++++++++++++++--------- 4 files changed, 120 insertions(+), 19 deletions(-) create mode 100644 common/sorted_dict.py diff --git a/common/sorted_dict.py b/common/sorted_dict.py new file mode 100644 index 0000000..a918a0c --- /dev/null +++ b/common/sorted_dict.py @@ -0,0 +1,65 @@ +import heapq + + +class SortedDict(dict): + def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False): + if init_dict is None: + init_dict = [] + if isinstance(init_dict, dict): + init_dict = init_dict.items() + self.sort_func = sort_func + self.sorted_keys = None + self.reverse = reverse + self.heap = [] + for k, v in init_dict: + self[k] = v + + def __setitem__(self, key, value): + if key in self: + super().__setitem__(key, value) + for i, (priority, k) in enumerate(self.heap): + if k == key: + self.heap[i] = (self.sort_func(key, value), key) + heapq.heapify(self.heap) + break + self.sorted_keys = None + else: + super().__setitem__(key, value) + heapq.heappush(self.heap, (self.sort_func(key, value), key)) + self.sorted_keys = None + + def __delitem__(self, key): + super().__delitem__(key) + for i, (priority, k) in enumerate(self.heap): + if k == key: + del self.heap[i] + heapq.heapify(self.heap) + break + self.sorted_keys = None + + def keys(self): + if self.sorted_keys is None: + self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] + return self.sorted_keys + + def items(self): + if self.sorted_keys is None: + self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] + sorted_items = [(k, self[k]) for k in self.sorted_keys] + return sorted_items + + def _update_heap(self, key): + for i, (priority, k) in enumerate(self.heap): + if k == key: + new_priority = self.sort_func(key, self[key]) + if new_priority != priority: + self.heap[i] = (new_priority, key) + heapq.heapify(self.heap) + self.sorted_keys = None + break + + def __iter__(self): + return iter(self.keys()) + + def __repr__(self): + return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})' diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 2f57d32..3bd24dd 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -56,6 +56,11 @@ ADMIN_COMMANDS = { "alias": ["plist", "插件"], "desc": "打印当前插件列表", }, + "setpri": { + "alias": ["setpri", "设置插件优先级"], + "args": ["插件名", "优先级"], + "desc": "设置指定插件的优先级,越大越优先", + }, "enablep": { "alias": ["enablep", "启用插件"], "args": ["插件名"], @@ -92,7 +97,7 @@ def get_help_text(isadmin, isgroup): help_text += f": {info['desc']}\n" return help_text -@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent") +@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent", desire_priority= 999) class Godcmd(Plugin): def __init__(self): @@ -186,7 +191,7 @@ class Godcmd(Plugin): ok = True result = "插件列表:\n" for name,plugincls in plugins.items(): - result += f"{name}_v{plugincls.version} - " + result += f"{name}_v{plugincls.version} {plugincls.priority} - " if plugincls.enabled: result += "已启用\n" else: @@ -194,12 +199,21 @@ class Godcmd(Plugin): elif cmd == "scanp": new_plugins = PluginManager().scan_plugins() ok, result = True, "插件扫描完成" + PluginManager().activate_plugins() if len(new_plugins) >0 : - PluginManager().activate_plugins() result += "\n发现新插件:\n" result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) else : result +=", 未发现新插件" + elif cmd == "setpri": + if len(args) != 2: + ok, result = False, "请提供插件名和优先级" + else: + ok = PluginManager().set_plugin_priority(args[0], int(args[1])) + if ok: + result = "插件" + args[0] + "优先级已设置为" + args[1] + else: + result = "插件不存在" elif cmd == "enablep": if len(args) != 1: ok, result = False, "请提供插件名" diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index ca1d257..1eb409e 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -5,7 +5,7 @@ from plugins import * from common.log import logger -@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent") +@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1) class Hello(Plugin): def __init__(self): super().__init__() diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index 630041a..107c63b 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -4,6 +4,7 @@ import importlib import json import os from common.singleton import singleton +from common.sorted_dict import SortedDict from .event import * from .plugin import * from common.log import logger @@ -12,19 +13,20 @@ from common.log import logger @singleton class PluginManager: def __init__(self): - self.plugins = {} + self.plugins = SortedDict(lambda k,v: v.priority,reverse=True) self.listening_plugins = {} self.instances = {} self.pconf = {} - def register(self, name: str, desc: str, version: str, author: str): + def register(self, name: str, desc: str, version: str, author: str, desire_priority: int = 0): def wrapper(plugincls): - self.plugins[name] = plugincls plugincls.name = name plugincls.desc = desc plugincls.version = version plugincls.author = author + plugincls.priority = desire_priority plugincls.enabled = True + self.plugins[name] = plugincls logger.info("Plugin %s_v%s registered" % (name, version)) return plugincls return wrapper @@ -40,9 +42,10 @@ class PluginManager: if os.path.exists("plugins/plugins.json"): with open("plugins/plugins.json", "r", encoding="utf-8") as f: pconf = json.load(f) + pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True) else: modified = True - pconf = {"plugins": []} + pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)} self.pconf = pconf if modified: self.save_config() @@ -64,16 +67,24 @@ class PluginManager: new_plugins = [] modified = False for name, plugincls in self.plugins.items(): - if name not in [plugin["name"] for plugin in pconf["plugins"]]: + if name not in pconf["plugins"]: new_plugins.append(plugincls) modified = True logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) - pconf["plugins"].append({"name": name, "enabled": True}) + pconf["plugins"][name] = {"enabled": plugincls.enabled, "priority": plugincls.priority} + else: + self.plugins[name].enabled = pconf["plugins"][name]["enabled"] + self.plugins[name].priority = pconf["plugins"][name]["priority"] + self.plugins._update_heap(name) # 更新下plugins中的顺序 if modified: self.save_config() return new_plugins - def activate_plugins(self): + def refresh_order(self): + for event in self.listening_plugins.keys(): + self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) + + def activate_plugins(self): # 生成新开启的插件实例 for name, plugincls in self.plugins.items(): if plugincls.enabled: if name not in self.instances: @@ -83,16 +94,16 @@ class PluginManager: if event not in self.listening_plugins: self.listening_plugins[event] = [] self.listening_plugins[event].append(name) + self.refresh_order() def load_plugins(self): self.load_config() self.scan_plugins() pconf = self.pconf logger.debug("plugins.json config={}".format(pconf)) - for plugin in pconf["plugins"]: - name = plugin["name"] - enabled = plugin["enabled"] - self.plugins[name].enabled = enabled + for name,plugin in pconf["plugins"].items(): + if name not in self.plugins: + logger.error("Plugin %s not found, but found in plugins.json" % name) self.activate_plugins() def emit_event(self, e_context: EventContext, *args, **kwargs): @@ -104,13 +115,25 @@ class PluginManager: instance.handlers[e_context.event](e_context, *args, **kwargs) return e_context + def set_plugin_priority(self,name,priority): + if name not in self.plugins: + return False + if self.plugins[name].priority == priority: + return True + self.plugins[name].priority = priority + self.plugins._update_heap(name) + self.pconf["plugins"][name]["priority"] = priority + self.pconf["plugins"]._update_heap(name) + self.save_config() + self.refresh_order() + return True + def enable_plugin(self,name): if name not in self.plugins: return False if not self.plugins[name].enabled : self.plugins[name].enabled = True - idx = next(i for i in range(len(self.pconf['plugins'])) if self.pconf["plugins"][i]['name'] == name) - self.pconf["plugins"][idx]["enabled"] = True + self.pconf["plugins"][name]["enabled"] = True self.save_config() self.activate_plugins() return True @@ -121,8 +144,7 @@ class PluginManager: return False if self.plugins[name].enabled : self.plugins[name].enabled = False - idx = next(i for i in range(len(self.pconf['plugins'])) if self.pconf["plugins"][i]['name'] == name) - self.pconf["plugins"][idx]["enabled"] = False + self.pconf["plugins"][name]["enabled"] = False self.save_config() return True return True From ad6ae0b32a15b7cd75e1060971f95461408c9944 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 19:44:24 +0800 Subject: [PATCH 14/29] refactor: use enum to specify type --- bot/bot.py | 6 +- bot/chatgpt/chat_gpt_bot.py | 27 ++++---- bridge/bridge.py | 11 ++-- bridge/context.py | 42 +++++++++++++ bridge/reply.py | 22 +++++++ channel/channel.py | 8 ++- channel/wechat/wechat_channel.py | 101 +++++++++++++++--------------- channel/wechat/wechaty_channel.py | 25 +++----- plugins/godcmd/godcmd.py | 16 ++--- plugins/hello/hello.py | 22 ++++--- voice/baidu/baidu_voice.py | 5 +- voice/google/google_voice.py | 12 ++-- voice/openai/openai_voice.py | 6 +- 13 files changed, 185 insertions(+), 118 deletions(-) create mode 100644 bridge/context.py create mode 100644 bridge/reply.py diff --git a/bot/bot.py b/bot/bot.py index 850ba3b..fd56e50 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -3,8 +3,12 @@ Auto-replay chat robot abstract class """ +from bridge.context import Context +from bridge.reply import Reply + + class Bot(object): - def reply(self, query, context=None): + def reply(self, query, context : Context =None) -> Reply: """ bot auto-reply content :param req: received message diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 2c8567d..b2e062d 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -1,6 +1,8 @@ # encoding:utf-8 from bot.bot import Bot +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType from config import conf, load_config from common.log import logger from common.expired_dict import ExpiredDict @@ -19,22 +21,19 @@ class ChatGPTBot(Bot): def reply(self, query, context=None): # acquire reply content - if context['type'] == 'TEXT': + if context.type == ContextType.TEXT: logger.info("[OPEN_AI] query={}".format(query)) session_id = context['session_id'] reply = None if query == '#清除记忆': self.sessions.clear_session(session_id) - reply = {'type': 'INFO', 'content': '记忆已清除'} + reply = Reply(ReplyType.INFO, '记忆已清除') elif query == '#清除所有': self.sessions.clear_all_session() - reply = {'type': 'INFO', 'content': '所有人记忆已清除'} + reply = Reply(ReplyType.INFO, '所有人记忆已清除') elif query == '#更新配置': load_config() - reply = {'type': 'INFO', 'content': '配置已更新'} - elif query == '#DEBUG': - logger.setLevel('DEBUG') - reply = {'type': 'INFO', 'content': 'DEBUG模式已开启'} + reply = Reply(ReplyType.INFO, '配置已更新') if reply: return reply session = self.sessions.build_session_query(query, session_id) @@ -47,25 +46,25 @@ class ChatGPTBot(Bot): reply_content = self.reply_text(session, session_id, 0) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: - reply = {'type': 'ERROR', 'content': reply_content['content']} + reply = Reply(ReplyType.ERROR, reply_content['content']) elif reply_content["completion_tokens"] > 0: self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) - reply={'type':'TEXT', 'content':reply_content["content"]} + reply = Reply(ReplyType.TEXT, reply_content["content"]) else: - reply = {'type': 'ERROR', 'content': reply_content['content']} + reply = Reply(ReplyType.ERROR, reply_content['content']) logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) return reply - elif context['type'] == 'IMAGE_CREATE': + elif context.type == ContextType.IMAGE_CREATE: ok, retstring = self.create_img(query, 0) reply = None if ok: - reply = {'type': 'IMAGE_URL', 'content': retstring} + reply = Reply(ReplyType.IMAGE_URL, retstring) else: - reply = {'type': 'ERROR', 'content': retstring} + reply = Reply(ReplyType.ERROR, retstring) return reply else: - reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} + reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) return reply def reply_text(self, session, session_id, retry_count=0) -> dict: diff --git a/bridge/bridge.py b/bridge/bridge.py index 304046e..2b67a8a 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,3 +1,5 @@ +from bridge.context import Context +from bridge.reply import Reply from common.log import logger from bot import bot_factory from common.singleton import singleton @@ -28,16 +30,13 @@ class Bridge(object): def get_bot_type(self,typename): return self.btype[typename] - # 以下所有函数需要得到一个reply字典,格式如下: - # reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... - # reply["content"] = reply的内容 - def fetch_reply_content(self, query, context): + def fetch_reply_content(self, query, context : Context) -> Reply: return self.get_bot("chat").reply(query, context) - def fetch_voice_to_text(self, voiceFile): + def fetch_voice_to_text(self, voiceFile) -> Reply: return self.get_bot("voice_to_text").voiceToText(voiceFile) - def fetch_text_to_voice(self, text): + def fetch_text_to_voice(self, text) -> Reply: return self.get_bot("text_to_voice").textToVoice(text) diff --git a/bridge/context.py b/bridge/context.py new file mode 100644 index 0000000..1fbe4d4 --- /dev/null +++ b/bridge/context.py @@ -0,0 +1,42 @@ +# encoding:utf-8 + +from enum import Enum + +class ContextType (Enum): + TEXT = 1 # 文本消息 + VOICE = 2 # 音频消息 + IMAGE_CREATE = 3 # 创建图片命令 + + def __str__(self): + return self.name +class Context: + def __init__(self, type : ContextType = None , content = None, kwargs = dict()): + self.type = type + self.content = content + self.kwargs = kwargs + def __getitem__(self, key): + if key == 'type': + return self.type + elif key == 'content': + return self.content + else: + return self.kwargs[key] + + def __setitem__(self, key, value): + if key == 'type': + self.type = value + elif key == 'content': + self.content = value + else: + self.kwargs[key] = value + + def __delitem__(self, key): + if key == 'type': + self.type = None + elif key == 'content': + self.content = None + else: + del self.kwargs[key] + + def __str__(self): + return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) \ No newline at end of file diff --git a/bridge/reply.py b/bridge/reply.py new file mode 100644 index 0000000..c6bcd54 --- /dev/null +++ b/bridge/reply.py @@ -0,0 +1,22 @@ + +# encoding:utf-8 + +from enum import Enum + +class ReplyType(Enum): + TEXT = 1 # 文本 + VOICE = 2 # 音频文件 + IMAGE = 3 # 图片文件 + IMAGE_URL = 4 # 图片URL + + INFO = 9 + ERROR = 10 + def __str__(self): + return self.name + +class Reply: + def __init__(self, type : ReplyType = None , content = None): + self.type = type + self.content = content + def __str__(self): + return "Reply(type={}, content={})".format(self.type, self.content) \ No newline at end of file diff --git a/channel/channel.py b/channel/channel.py index a1395c4..62e1ad3 100644 --- a/channel/channel.py +++ b/channel/channel.py @@ -3,6 +3,8 @@ Message sending channel abstract class """ from bridge.bridge import Bridge +from bridge.context import Context +from bridge.reply import Reply class Channel(object): def startup(self): @@ -27,11 +29,11 @@ class Channel(object): """ raise NotImplementedError - def build_reply_content(self, query, context=None): + def build_reply_content(self, query, context : Context=None) -> Reply: return Bridge().fetch_reply_content(query, context) - def build_voice_to_text(self, voice_file): + def build_voice_to_text(self, voice_file) -> Reply: return Bridge().fetch_voice_to_text(voice_file) - def build_text_to_voice(self, text): + def build_text_to_voice(self, text) -> Reply: return Bridge().fetch_text_to_voice(text) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index 0ba923f..eff788d 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -7,6 +7,8 @@ wechat channel import itchat import json from itchat.content import * +from bridge.reply import * +from bridge.context import * from channel.channel import Channel from concurrent.futures import ThreadPoolExecutor from common.log import logger @@ -69,10 +71,8 @@ class WechatChannel(Channel): from_user_id = msg['FromUserName'] other_user_id = msg['User']['UserName'] if from_user_id == other_user_id: - context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} - context['type'] = 'VOICE' - context['content'] = msg['FileName'] - context['session_id'] = other_user_id + context = Context(ContextType.VOICE,msg['FileName']) + context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_text(self, msg): @@ -89,17 +89,17 @@ class WechatChannel(Channel): content = content.replace(match_prefix, '', 1).strip() else: return - context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} - context['session_id'] = other_user_id + context = Context() + context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'IMAGE_CREATE' + context.type = ContextType.IMAGE_CREATE else: - context['type'] = 'TEXT' + context.type = ContextType.TEXT - context['content'] = content + context.content = content thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_group(self, msg): @@ -123,15 +123,16 @@ class WechatChannel(Channel): match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ or check_contain(origin_content, config.get('group_chat_keyword')) if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: - context = { 'isgroup': True, 'msg': msg, 'receiver': group_id} + context = Context() + context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'IMAGE_CREATE' + context.type = ContextType.IMAGE_CREATE else: - context['type'] = 'TEXT' - context['content'] = content + context.type = ContextType.TEXT + context.content = content group_chat_in_one_session = conf().get('group_chat_in_one_session', []) if ('ALL_GROUP' in group_chat_in_one_session or @@ -144,18 +145,18 @@ class WechatChannel(Channel): thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 - def send(self, reply, receiver): - if reply['type'] == 'TEXT': - itchat.send(reply['content'], toUserName=receiver) + def send(self, reply : Reply, receiver): + if reply.type == ReplyType.TEXT: + itchat.send(reply.content, toUserName=receiver) logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) - elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': - itchat.send(reply['content'], toUserName=receiver) + elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: + itchat.send(reply.content, toUserName=receiver) logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) - elif reply['type'] == 'VOICE': - itchat.send_file(reply['content'], toUserName=receiver) - logger.info('[WX] sendFile={}, receiver={}'.format(reply['content'], receiver)) - elif reply['type']=='IMAGE_URL': # 从网络下载图片 - img_url = reply['content'] + elif reply.type == ReplyType.VOICE: + itchat.send_file(reply.content, toUserName=receiver) + logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) + elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 + img_url = reply.content pic_res = requests.get(img_url, stream=True) image_storage = io.BytesIO() for block in pic_res.iter_content(1024): @@ -163,69 +164,69 @@ class WechatChannel(Channel): image_storage.seek(0) itchat.send_image(image_storage, toUserName=receiver) logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver)) - elif reply['type']=='IMAGE': # 从文件读取图片 - image_storage = reply['content'] + elif reply.type == ReplyType.IMAGE: # 从文件读取图片 + image_storage = reply.content image_storage.seek(0) itchat.send_image(image_storage, toUserName=receiver) logger.info('[WX] sendImage, receiver={}'.format(receiver)) # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 def handle(self, context): - reply = {} + reply = Reply() logger.debug('[WX] ready to handle context: {}'.format(context)) # reply的构建步骤 e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) - reply=e_context['reply'] + reply = e_context['reply'] if not e_context.is_pass(): - logger.debug('[WX] ready to handle context: type={}, content={}'.format(context['type'], context['content'])) - if context['type'] == 'TEXT' or context['type'] == 'IMAGE_CREATE': - reply = super().build_reply_content(context['content'], context) - elif context['type'] == 'VOICE': + logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) + if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: + reply = super().build_reply_content(context.content, context) + elif context.type == ContextType.VOICE: msg = context['msg'] - file_name = TmpDir().path() + context['content'] + file_name = TmpDir().path() + context.content msg.download(file_name) reply = super().build_voice_to_text(file_name) - if reply['type'] != 'ERROR' and reply['type'] != 'INFO': - context['content'] = reply['content'] # 语音转文字后,将文字内容作为新的context - context['type'] = reply['type'] - reply = super().build_reply_content(context['content'], context) - if reply['type'] == 'TEXT': + if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: + context.content = reply.content # 语音转文字后,将文字内容作为新的context + context.type = ContextType.TEXT + reply = super().build_reply_content(context.content, context) + if reply.type == ReplyType.TEXT: if conf().get('voice_reply_voice'): - reply = super().build_text_to_voice(reply['content']) + reply = super().build_text_to_voice(reply.content) else: - logger.error('[WX] unknown context type: {}'.format(context['type'])) + logger.error('[WX] unknown context type: {}'.format(context.type)) return logger.debug('[WX] ready to decorate reply: {}'.format(reply)) # reply的包装步骤 - if reply and reply['type']: + if reply and reply.type: e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) reply=e_context['reply'] - if not e_context.is_pass() and reply and reply['type']: - if reply['type'] == 'TEXT': - reply_text = reply['content'] + if not e_context.is_pass() and reply and reply.type: + if reply.type == ReplyType.TEXT: + reply_text = reply.content if context['isgroup']: reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() reply_text = conf().get("group_chat_reply_prefix", "")+reply_text else: reply_text = conf().get("single_chat_reply_prefix", "")+reply_text - reply['content'] = reply_text - elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': - reply['content'] = reply['type']+":\n" + reply['content'] - elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE': + reply.content = reply_text + elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: + reply.content = str(reply.type)+":\n" + reply.content + elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: pass else: - logger.error('[WX] unknown reply type: {}'.format(reply['type'])) + logger.error('[WX] unknown reply type: {}'.format(reply.type)) return # reply的发送步骤 - if reply and reply['type']: + if reply and reply.type: e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) reply=e_context['reply'] - if not e_context.is_pass() and reply and reply['type']: + if not e_context.is_pass() and reply and reply.type: logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) self.send(reply, context['receiver']) diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 7ae4683..3a7fe2c 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -11,6 +11,7 @@ import time import asyncio import requests from typing import Optional, Union +from bridge.context import Context, ContextType from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore from wechaty import Wechaty, Contact from wechaty.user import Message, Room, MiniProgram, UrlLink @@ -127,11 +128,9 @@ class WechatyChannel(Channel): try: if not query: return - context = dict() + context = Context(ContextType.TEXT, query) context['session_id'] = reply_user_id - context['type'] = 'TEXT' - context['content'] = query - reply_text = super().build_reply_content(query, context)['content'] + reply_text = super().build_reply_content(query, context).content if reply_text: await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) except Exception as e: @@ -141,10 +140,8 @@ class WechatyChannel(Channel): try: if not query: return - context = dict() - context['type'] = 'IMAGE_CREATE' - context['content'] = query - img_url = super().build_reply_content(query, context)['content'] + context = Context(ContextType.IMAGE_CREATE, query) + img_url = super().build_reply_content(query, context).content if not img_url: return # 图片下载 @@ -165,7 +162,7 @@ class WechatyChannel(Channel): async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name): if not query: return - context = dict() + context = Context(ContextType.TEXT, query) group_chat_in_one_session = conf().get('group_chat_in_one_session', []) if ('ALL_GROUP' in group_chat_in_one_session or \ group_name in group_chat_in_one_session or \ @@ -173,9 +170,7 @@ class WechatyChannel(Channel): context['session_id'] = str(group_id) else: context['session_id'] = str(group_id) + '-' + str(group_user_id) - context['type'] = 'TEXT' - context['content'] = query - reply_text = super().build_reply_content(query, context)['content'] + reply_text = super().build_reply_content(query, context).content if reply_text: reply_text = '@' + group_user_name + ' ' + reply_text.strip() await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) @@ -184,10 +179,8 @@ class WechatyChannel(Channel): try: if not query: return - context = dict() - context['type'] = 'IMAGE_CREATE' - context['content'] = query - img_url = super().build_reply_content(query, context)['content'] + context = Context(ContextType.IMAGE_CREATE, query) + img_url = super().build_reply_content(query, context).content if not img_url: return # 图片发送 diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 3bd24dd..e988452 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -5,6 +5,8 @@ import os import traceback from typing import Tuple from bridge.bridge import Bridge +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType from config import load_config import plugins from plugins import * @@ -123,13 +125,13 @@ class Godcmd(Plugin): def on_handle_context(self, e_context: EventContext): - context_type = e_context['context']['type'] - if context_type != "TEXT": + context_type = e_context['context'].type + if context_type != ContextType.TEXT: if not self.isrunning: e_context.action = EventAction.BREAK_PASS return - content = e_context['context']['content'] + content = e_context['context'].content logger.debug("[Godcmd] on_handle_context. content: %s" % content) if content.startswith("#"): # msg = e_context['context']['msg'] @@ -239,12 +241,12 @@ class Godcmd(Plugin): else: ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" - reply = {} + reply = Reply() if ok: - reply["type"] = "INFO" + reply.type = ReplyType.INFO else: - reply["type"] = "ERROR" - reply["content"] = result + reply.type = ReplyType.ERROR + reply.content = result e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index 1eb409e..53d87e6 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -1,5 +1,7 @@ # encoding:utf-8 +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType import plugins from plugins import * from common.log import logger @@ -14,31 +16,31 @@ class Hello(Plugin): def on_handle_context(self, e_context: EventContext): - if e_context['context']['type'] != "TEXT": + if e_context['context'].type != ContextType.TEXT: return - content = e_context['context']['content'] + content = e_context['context'].content logger.debug("[Hello] on_handle_context. content: %s" % content) if content == "Hello": - reply = {} - reply['type'] = "TEXT" + reply = Reply() + reply.type = ReplyType.TEXT msg = e_context['context']['msg'] if e_context['context']['isgroup']: - reply['content'] = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") + reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") else: - reply['content'] = "Hello, " + msg['User'].get('NickName', "My friend") + reply.content = "Hello, " + msg['User'].get('NickName', "My friend") e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 if content == "Hi": - reply={} - reply['type'] = "TEXT" - reply['content'] = "Hi" + reply = Reply() + reply.type = ReplyType.TEXT + reply.content = "Hi" e_context['reply'] = reply e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply if content == "End": # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" - e_context['context']['type'] = "IMAGE_CREATE" + e_context['context'].type = "IMAGE_CREATE" content = "The World" e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py index adab169..531d8ce 100644 --- a/voice/baidu/baidu_voice.py +++ b/voice/baidu/baidu_voice.py @@ -4,6 +4,7 @@ baidu voice service """ import time from aip import AipSpeech +from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir from voice.voice import Voice @@ -30,8 +31,8 @@ class BaiduVoice(Voice): with open(fileName, 'wb') as f: f.write(result) logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) - reply = {"type": "VOICE", "content": fileName} + reply = Reply(ReplyType.VOICE, fileName) else: logger.error('[Baidu] textToVoice error={}'.format(result)) - reply = {"type": "ERROR", "content": "抱歉,语音合成失败"} + reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") return reply diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py index 6c00892..74431db 100644 --- a/voice/google/google_voice.py +++ b/voice/google/google_voice.py @@ -6,6 +6,7 @@ google voice service import pathlib import subprocess import time +from bridge.reply import Reply, ReplyType import speech_recognition import pyttsx3 from common.log import logger @@ -32,16 +33,15 @@ class GoogleVoice(Voice): ' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True) with speech_recognition.AudioFile(new_file) as source: audio = self.recognizer.record(source) - reply = {} try: text = self.recognizer.recognize_google(audio, language='zh-CN') logger.info( '[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) - reply = {"type": "TEXT", "content": text} + reply = Reply(ReplyType.TEXT, text) except speech_recognition.UnknownValueError: - reply = {"type": "ERROR", "content": "抱歉,我听不懂"} + reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") except speech_recognition.RequestError as e: - reply = {"type": "ERROR", "content": "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)} + reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) finally: return reply def textToVoice(self, text): @@ -51,8 +51,8 @@ class GoogleVoice(Voice): self.engine.runAndWait() logger.info( '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) - reply = {"type": "VOICE", "content": textFile} + reply = Reply(ReplyType.VOICE, textFile) except Exception as e: - reply = {"type": "ERROR", "content": str(e)} + reply = Reply(ReplyType.ERROR, str(e)) finally: return reply diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 3b77c52..2e85e10 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -4,6 +4,7 @@ google voice service """ import json import openai +from bridge.reply import Reply, ReplyType from config import conf from common.log import logger from voice.voice import Voice @@ -16,16 +17,15 @@ class OpenaiVoice(Voice): def voiceToText(self, voice_file): logger.debug( '[Openai] voice file name={}'.format(voice_file)) - reply={} try: file = open(voice_file, "rb") result = openai.Audio.transcribe("whisper-1", file) text = result["text"] - reply = {"type": "TEXT", "content": text} + reply = Reply(ReplyType.TEXT, text) logger.info( '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) except Exception as e: - reply = {"type": "ERROR", "content": str(e)} + reply = Reply(ReplyType.ERROR, str(e)) finally: return reply From dce9c4dccbfb7ff7fbc63093a470ba0165fcb694 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 13 Mar 2023 19:58:35 +0800 Subject: [PATCH 15/29] compatible with openai bot --- bot/baidu/baidu_unit_bot.py | 4 +++- bot/openai/open_ai_bot.py | 47 ++++++++++++++++++++----------------- plugins/godcmd/godcmd.py | 4 ++-- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/bot/baidu/baidu_unit_bot.py b/bot/baidu/baidu_unit_bot.py index a84ac57..2b7dd8d 100644 --- a/bot/baidu/baidu_unit_bot.py +++ b/bot/baidu/baidu_unit_bot.py @@ -2,6 +2,7 @@ import requests from bot.bot import Bot +from bridge.reply import Reply, ReplyType # Baidu Unit对话接口 (可用, 但能力较弱) @@ -14,7 +15,8 @@ class BaiduUnitBot(Bot): headers = {'content-type': 'application/x-www-form-urlencoded'} response = requests.post(url, data=post_data.encode(), headers=headers) if response: - return response.json()['result']['context']['SYS_PRESUMED_HIST'][1] + reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1]) + return reply def get_token(self): access_key = 'YOUR_ACCESS_KEY' diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index c948a7c..e96d60f 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -1,6 +1,8 @@ # encoding:utf-8 from bot.bot import Bot +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType from config import conf from common.log import logger import openai @@ -13,30 +15,31 @@ class OpenAIBot(Bot): def __init__(self): openai.api_key = conf().get('open_ai_api_key') - def reply(self, query, context=None): # acquire reply content - if not context or not context.get('type') or context.get('type') == 'TEXT': - logger.info("[OPEN_AI] query={}".format(query)) - from_user_id = context.get('from_user_id') or context.get('session_id') - if query == '#清除记忆': - Session.clear_session(from_user_id) - return '记忆已清除' - elif query == '#清除所有': - Session.clear_all_session() - return '所有人记忆已清除' - - new_query = Session.build_session_query(query, from_user_id) - logger.debug("[OPEN_AI] session query={}".format(new_query)) - - reply_content = self.reply_text(new_query, from_user_id, 0) - logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) - if reply_content and query: - Session.save_session(query, reply_content, from_user_id) - return reply_content - - elif context.get('type', None) == 'IMAGE_CREATE': - return self.create_img(query, 0) + if context and context.type: + if context.type == ContextType.TEXT: + logger.info("[OPEN_AI] query={}".format(query)) + from_user_id = context['session_id'] + reply = None + if query == '#清除记忆': + Session.clear_session(from_user_id) + reply = Reply(ReplyType.INFO, '记忆已清除') + elif query == '#清除所有': + Session.clear_all_session() + reply = Reply(ReplyType.INFO, '所有人记忆已清除') + else: + new_query = Session.build_session_query(query, from_user_id) + logger.debug("[OPEN_AI] session query={}".format(new_query)) + + reply_content = self.reply_text(new_query, from_user_id, 0) + logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) + if reply_content and query: + Session.save_session(query, reply_content, from_user_id) + reply = Reply(ReplyType.TEXT, reply_content) + return reply + elif context.type == ContextType.IMAGE_CREATE: + return self.create_img(query, 0) def reply_text(self, query, user_id, retry_count=0): try: diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index e988452..e6c971b 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -162,7 +162,7 @@ class Godcmd(Plugin): bot.sessions.clear_session(session_id) ok, result = True, "会话已重置" else: - ok, result = False, "当前机器人不支持重置会话" + ok, result = False, "当前对话机器人不支持重置会话" logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): if isadmin: @@ -184,7 +184,7 @@ class Godcmd(Plugin): bot.sessions.clear_all_session() ok, result = True, "重置所有会话成功" else: - ok, result = False, "当前机器人不支持重置会话" + ok, result = False, "当前对话机器人不支持重置会话" elif cmd == "debug": logger.setLevel('DEBUG') ok, result = True, "DEBUG模式已开启" From e6d148e72945813d668aba092c8ece4e90a6782e Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 00:49:28 +0800 Subject: [PATCH 16/29] plugins: add sdwebui(stable diffusion) plugin --- plugins/sdwebui/__init__.py | 0 plugins/sdwebui/config.json.template | 67 +++++++++++++++++++++ plugins/sdwebui/readme.md | 63 ++++++++++++++++++++ plugins/sdwebui/sdwebui.py | 88 ++++++++++++++++++++++++++++ 4 files changed, 218 insertions(+) create mode 100644 plugins/sdwebui/__init__.py create mode 100644 plugins/sdwebui/config.json.template create mode 100644 plugins/sdwebui/readme.md create mode 100644 plugins/sdwebui/sdwebui.py diff --git a/plugins/sdwebui/__init__.py b/plugins/sdwebui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/sdwebui/config.json.template b/plugins/sdwebui/config.json.template new file mode 100644 index 0000000..4fca569 --- /dev/null +++ b/plugins/sdwebui/config.json.template @@ -0,0 +1,67 @@ +{ + "start":{ + "host" : "127.0.0.1", + "port" : 7860 + }, + "defaults": { + "params": { + "sampler_name": "DPM++ 2M Karras", + "steps": 20, + "width": 512, + "height": 512, + "cfg_scale": 7, + "prompt":"masterpiece, best quality", + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "enable_hr": false, + "hr_scale": 2, + "hr_upscaler": "Latent", + "hr_second_pass_steps": 15, + "denoising_strength": 0.7 + }, + "options": { + "sd_model_checkpoint": "perfectWorld_v2Baked" + } + }, + "rules": [ + { + "keywords": [ + "横版", + "壁纸" + ], + "params": { + "width": 640, + "height": 384 + } + }, + { + "keywords": [ + "竖版" + ], + "params": { + "width": 384, + "height": 640 + } + }, + { + "keywords": [ + "高清" + ], + "params": { + "enable_hr": true, + "hr_scale": 1.6 + } + }, + { + "keywords": [ + "二次元" + ], + "params": { + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "prompt": "masterpiece, best quality" + }, + "options": { + "sd_model_checkpoint": "meinamix_meinaV8" + } + } + ] +} \ No newline at end of file diff --git a/plugins/sdwebui/readme.md b/plugins/sdwebui/readme.md new file mode 100644 index 0000000..ef325cd --- /dev/null +++ b/plugins/sdwebui/readme.md @@ -0,0 +1,63 @@ +本插件用于将画图请求转发给stable diffusion webui +使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api" +具体参考(https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API) + +请**安装**本插件的依赖包```webuiapi``` +``` + ```pip install webuiapi``` +``` +请将```config.json.template```复制为```config.json```,并修改其中的参数和规则 + +用户的画图请求格式为: +``` + <画图触发词><关键词1> <关键词2> ... <关键词n>: +``` +本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。 +规则会按顺序匹配,每个关键词最多匹配到1次,如果有重复的参数,则以最后一个为准: +第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后 + +例如: 画横版 高清 二次元:cat +会触发三个关键词 "横版", "高清", "二次元",prompt为"cat" +若默认参数是: +``` + "width": 512, + "height": 512, + "enable_hr": false, + "prompt": "8k" + "negative_prompt": "nsfw", + "sd_model_checkpoint": "perfectWorld_v2Baked" +``` + +"横版"触发的规则参数为: +``` + "width": 640, + "height": 384, +``` +"高清"触发的规则参数为: +``` + "enable_hr": true, + "hr_scale": 1.6, +``` +"二次元"触发的规则参数为: +``` + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "steps": 20, + "prompt": "masterpiece, best quality", + + "sd_model_checkpoint": "meinamix_meinaV8" +``` +最后将第一个":"后的内容cat连接在prompt后,得到最终参数为: +``` + "width": 640, + "height": 384, + "enable_hr": true, + "hr_scale": 1.6, + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "steps": 20, + "prompt": "masterpiece, best quality, cat", + + "sd_model_checkpoint": "meinamix_meinaV8" +``` +PS: 参数分为两部分, +一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致 +另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。 \ No newline at end of file diff --git a/plugins/sdwebui/sdwebui.py b/plugins/sdwebui/sdwebui.py new file mode 100644 index 0000000..f7da523 --- /dev/null +++ b/plugins/sdwebui/sdwebui.py @@ -0,0 +1,88 @@ +# encoding:utf-8 + +import json +import os +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +import plugins +from plugins import * +from common.log import logger +import webuiapi +import io + + +@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent") +class SDWebUI(Plugin): + def __init__(self): + super().__init__() + curdir = os.path.dirname(__file__) + config_path = os.path.join(curdir, "config.json") + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + self.rules = config["rules"] + defaults = config["defaults"] + self.default_params = defaults["params"] + self.default_options = defaults["options"] + self.start_args = config["start"] + self.api = webuiapi.WebUIApi(**self.start_args) + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[SD] inited") + except FileNotFoundError: + logger.error(f"[SD] init failed, {config_path} not found") + except Exception as e: + logger.error("[SD] init failed, exception: %s" % e) + + def on_handle_context(self, e_context: EventContext): + + if e_context['context'].type != ContextType.IMAGE_CREATE: + return + + logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content) + + logger.info("[SD] image_query={}".format(e_context['context'].content)) + reply = Reply() + try: + content = e_context['context'].content[:] + # 解析用户输入 如"横版 高清 二次元:cat" + keywords, prompt = content.split(":", 1) + keywords = keywords.split() + + rule_params = {} + rule_options = {} + for keyword in keywords: + matched = False + for rule in self.rules: + if keyword in rule["keywords"]: + for key in rule["params"]: + rule_params[key] = rule["params"][key] + if "options" in rule: + for key in rule["options"]: + rule_options[key] = rule["options"][key] + matched = True + break # 一个关键词只匹配一个规则 + if not matched: + logger.warning("[SD] keyword not matched: %s" % keyword) + + params = {**self.default_params, **rule_params} + options = {**self.default_options, **rule_options} + params["prompt"] = params.get("prompt", "")+f", {prompt}" + if len(options) > 0: + logger.info("[SD] cover rule_options={}".format(rule_options)) + self.api.set_options(options) + logger.info("[SD] params={}".format(params)) + result = self.api.txt2img( + **params + ) + reply.type = ReplyType.IMAGE + b_img = io.BytesIO() + result.image.save(b_img, format="PNG") + reply.content = b_img + e_context.action = EventAction.BREAK_PASS # 事件结束后,不跳过处理context的默认逻辑 + except Exception as e: + reply.type = ReplyType.ERROR + reply.content = "[SD] "+str(e) + logger.error("[SD] exception: %s" % e) + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 + finally: + e_context['reply'] = reply From e6b65437e4197ca148ff017bac3a6a80ed4f01dd Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 12:07:03 +0800 Subject: [PATCH 17/29] sdwebui : add help reply --- plugins/sdwebui/config.json.template | 9 ++- plugins/sdwebui/sdwebui.py | 90 ++++++++++++++++++---------- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/plugins/sdwebui/config.json.template b/plugins/sdwebui/config.json.template index 4fca569..213acdc 100644 --- a/plugins/sdwebui/config.json.template +++ b/plugins/sdwebui/config.json.template @@ -31,7 +31,8 @@ "params": { "width": 640, "height": 384 - } + }, + "desc": "分辨率会变成640x384" }, { "keywords": [ @@ -49,7 +50,8 @@ "params": { "enable_hr": true, "hr_scale": 1.6 - } + }, + "desc": "出图分辨率长宽都会提高1.6倍" }, { "keywords": [ @@ -61,7 +63,8 @@ }, "options": { "sd_model_checkpoint": "meinamix_meinaV8" - } + }, + "desc": "使用二次元风格模型出图" } ] } \ No newline at end of file diff --git a/plugins/sdwebui/sdwebui.py b/plugins/sdwebui/sdwebui.py index f7da523..cc07c7a 100644 --- a/plugins/sdwebui/sdwebui.py +++ b/plugins/sdwebui/sdwebui.py @@ -4,6 +4,7 @@ import json import os from bridge.context import ContextType from bridge.reply import Reply, ReplyType +from config import conf import plugins from plugins import * from common.log import logger @@ -45,40 +46,49 @@ class SDWebUI(Plugin): try: content = e_context['context'].content[:] # 解析用户输入 如"横版 高清 二次元:cat" - keywords, prompt = content.split(":", 1) + if ":" in content: + keywords, prompt = content.split(":", 1) + else: + keywords = content + prompt = "" + keywords = keywords.split() - rule_params = {} - rule_options = {} - for keyword in keywords: - matched = False - for rule in self.rules: - if keyword in rule["keywords"]: - for key in rule["params"]: - rule_params[key] = rule["params"][key] - if "options" in rule: - for key in rule["options"]: - rule_options[key] = rule["options"][key] - matched = True - break # 一个关键词只匹配一个规则 - if not matched: - logger.warning("[SD] keyword not matched: %s" % keyword) - - params = {**self.default_params, **rule_params} - options = {**self.default_options, **rule_options} - params["prompt"] = params.get("prompt", "")+f", {prompt}" - if len(options) > 0: - logger.info("[SD] cover rule_options={}".format(rule_options)) - self.api.set_options(options) - logger.info("[SD] params={}".format(params)) - result = self.api.txt2img( - **params - ) - reply.type = ReplyType.IMAGE - b_img = io.BytesIO() - result.image.save(b_img, format="PNG") - reply.content = b_img - e_context.action = EventAction.BREAK_PASS # 事件结束后,不跳过处理context的默认逻辑 + if "help" in keywords or "帮助" in keywords: + reply.type = ReplyType.INFO + reply.content = self.get_help_text() + else: + rule_params = {} + rule_options = {} + for keyword in keywords: + matched = False + for rule in self.rules: + if keyword in rule["keywords"]: + for key in rule["params"]: + rule_params[key] = rule["params"][key] + if "options" in rule: + for key in rule["options"]: + rule_options[key] = rule["options"][key] + matched = True + break # 一个关键词只匹配一个规则 + if not matched: + logger.warning("[SD] keyword not matched: %s" % keyword) + + params = {**self.default_params, **rule_params} + options = {**self.default_options, **rule_options} + params["prompt"] = params.get("prompt", "")+f", {prompt}" + if len(options) > 0: + logger.info("[SD] cover rule_options={}".format(rule_options)) + self.api.set_options(options) + logger.info("[SD] params={}".format(params)) + result = self.api.txt2img( + **params + ) + reply.type = ReplyType.IMAGE + b_img = io.BytesIO() + result.image.save(b_img, format="PNG") + reply.content = b_img + e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑 except Exception as e: reply.type = ReplyType.ERROR reply.content = "[SD] "+str(e) @@ -86,3 +96,19 @@ class SDWebUI(Plugin): e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 finally: e_context['reply'] = reply + + def get_help_text(self): + if not conf().get('image_create_prefix'): + return "画图功能未启用" + else: + trigger = conf()['image_create_prefix'][0] + help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n" + help_text += "目前可用关键词:\n" + for rule in self.rules: + keywords = [f"[{keyword}]" for keyword in rule['keywords']] + help_text += f"{','.join(keywords)}" + if "desc" in rule: + help_text += f"-{rule['desc']}\n" + else: + help_text += "\n" + return help_text \ No newline at end of file From c782b38ba18613a22566f8dd96f1024f74bfee02 Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 15:31:26 +0800 Subject: [PATCH 18/29] sdwebui: modify README.md --- plugins/sdwebui/readme.md | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/plugins/sdwebui/readme.md b/plugins/sdwebui/readme.md index ef325cd..bb8c62c 100644 --- a/plugins/sdwebui/readme.md +++ b/plugins/sdwebui/readme.md @@ -1,19 +1,25 @@ -本插件用于将画图请求转发给stable diffusion webui -使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api" -具体参考(https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API) +### 插件描述 +本插件用于将画图请求转发给stable diffusion webui。 + +### 环境要求 +使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。 +具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。 请**安装**本插件的依赖包```webuiapi``` ``` ```pip install webuiapi``` ``` -请将```config.json.template```复制为```config.json```,并修改其中的参数和规则 +### 使用说明 +请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。 +#### 画图请求格式 用户的画图请求格式为: ``` <画图触发词><关键词1> <关键词2> ... <关键词n>: ``` -本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。 -规则会按顺序匹配,每个关键词最多匹配到1次,如果有重复的参数,则以最后一个为准: +- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。 +- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准: +- 关键词中包含`help`或`帮助`,会打印出帮助文档。 第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后 例如: 画横版 高清 二次元:cat @@ -58,6 +64,6 @@ "sd_model_checkpoint": "meinamix_meinaV8" ``` -PS: 参数分为两部分, -一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致 -另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。 \ No newline at end of file +PS: 参数分为两部分: +- 一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致 +- 另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。 \ No newline at end of file From 300b7b96874338505ff3147519925dce815fd96f Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 15:59:52 +0800 Subject: [PATCH 19/29] plugins: support reload plugin --- plugins/godcmd/godcmd.py | 14 ++++++++++++++ plugins/plugin_manager.py | 10 ++++++++++ plugins/sdwebui/sdwebui.py | 2 +- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index e6c971b..296b9a9 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -63,6 +63,11 @@ ADMIN_COMMANDS = { "args": ["插件名", "优先级"], "desc": "设置指定插件的优先级,越大越优先", }, + "reloadp": { + "alias": ["reloadp", "重载插件"], + "args": ["插件名"], + "desc": "重载指定插件配置", + }, "enablep": { "alias": ["enablep", "启用插件"], "args": ["插件名"], @@ -216,6 +221,15 @@ class Godcmd(Plugin): result = "插件" + args[0] + "优先级已设置为" + args[1] else: result = "插件不存在" + elif cmd == "reloadp": + if len(args) != 1: + ok, result = False, "请提供插件名" + else: + ok = PluginManager().reload_plugin(args[0]) + if ok: + result = "插件配置已重载" + else: + result = "插件不存在" elif cmd == "enablep": if len(args) != 1: ok, result = False, "请提供插件名" diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index 107c63b..9fada89 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -96,6 +96,16 @@ class PluginManager: self.listening_plugins[event].append(name) self.refresh_order() + def reload_plugin(self, name): + if name in self.instances: + for event in self.listening_plugins: + if name in self.listening_plugins[event]: + self.listening_plugins[event].remove(name) + del self.instances[name] + self.activate_plugins() + return True + return False + def load_plugins(self): self.load_config() self.scan_plugins() diff --git a/plugins/sdwebui/sdwebui.py b/plugins/sdwebui/sdwebui.py index cc07c7a..56842e8 100644 --- a/plugins/sdwebui/sdwebui.py +++ b/plugins/sdwebui/sdwebui.py @@ -78,7 +78,7 @@ class SDWebUI(Plugin): options = {**self.default_options, **rule_options} params["prompt"] = params.get("prompt", "")+f", {prompt}" if len(options) > 0: - logger.info("[SD] cover rule_options={}".format(rule_options)) + logger.info("[SD] cover options={}".format(options)) self.api.set_options(options) logger.info("[SD] params={}".format(params)) result = self.api.txt2img( From 8915149d361cf0260f11b21e406091d968c7b0a7 Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 17:30:30 +0800 Subject: [PATCH 20/29] plugin: add banwords plugin --- plugins/banwords/.gitignore | 1 + plugins/banwords/README.md | 9 + plugins/banwords/WordsSearch.py | 250 +++++++++++++++++++++++++ plugins/banwords/__init__.py | 0 plugins/banwords/banwords.py | 63 +++++++ plugins/banwords/banwords.txt.template | 3 + plugins/banwords/config.json.template | 3 + plugins/godcmd/config.json.template | 4 + plugins/godcmd/godcmd.py | 6 +- 9 files changed, 338 insertions(+), 1 deletion(-) create mode 100644 plugins/banwords/.gitignore create mode 100644 plugins/banwords/README.md create mode 100644 plugins/banwords/WordsSearch.py create mode 100644 plugins/banwords/__init__.py create mode 100644 plugins/banwords/banwords.py create mode 100644 plugins/banwords/banwords.txt.template create mode 100644 plugins/banwords/config.json.template create mode 100644 plugins/godcmd/config.json.template diff --git a/plugins/banwords/.gitignore b/plugins/banwords/.gitignore new file mode 100644 index 0000000..a6593bf --- /dev/null +++ b/plugins/banwords/.gitignore @@ -0,0 +1 @@ +banwords.txt \ No newline at end of file diff --git a/plugins/banwords/README.md b/plugins/banwords/README.md new file mode 100644 index 0000000..9c7e498 --- /dev/null +++ b/plugins/banwords/README.md @@ -0,0 +1,9 @@ +### 说明 +简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。 + +`config.json`中能够填写默认的处理行为,目前行为有: +- `ignore` : 无视这条消息。 +- `replace` : 将消息中的敏感词替换成"*",并回复违规。 + +### 致谢 +搜索功能实现来自https://github.com/toolgood/ToolGood.Words \ No newline at end of file diff --git a/plugins/banwords/WordsSearch.py b/plugins/banwords/WordsSearch.py new file mode 100644 index 0000000..d41d6e7 --- /dev/null +++ b/plugins/banwords/WordsSearch.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# ToolGood.Words.WordsSearch.py +# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words +# Licensed under the Apache License 2.0 +# 更新日志 +# 2020.04.06 第一次提交 +# 2020.05.16 修改,支持大于0xffff的字符 + +__all__ = ['WordsSearch'] +__author__ = 'Lin Zhijun' +__date__ = '2020.05.16' + +class TrieNode(): + def __init__(self): + self.Index = 0 + self.Index = 0 + self.Layer = 0 + self.End = False + self.Char = '' + self.Results = [] + self.m_values = {} + self.Failure = None + self.Parent = None + + def Add(self,c): + if c in self.m_values : + return self.m_values[c] + node = TrieNode() + node.Parent = self + node.Char = c + self.m_values[c] = node + return node + + def SetResults(self,index): + if (self.End == False): + self.End = True + self.Results.append(index) + +class TrieNode2(): + def __init__(self): + self.End = False + self.Results = [] + self.m_values = {} + self.minflag = 0xffff + self.maxflag = 0 + + def Add(self,c,node3): + if (self.minflag > c): + self.minflag = c + if (self.maxflag < c): + self.maxflag = c + self.m_values[c] = node3 + + def SetResults(self,index): + if (self.End == False) : + self.End = True + if (index in self.Results )==False : + self.Results.append(index) + + def HasKey(self,c): + return c in self.m_values + + + def TryGetValue(self,c): + if (self.minflag <= c and self.maxflag >= c): + if c in self.m_values: + return self.m_values[c] + return None + + +class WordsSearch(): + def __init__(self): + self._first = {} + self._keywords = [] + self._indexs=[] + + def SetKeywords(self,keywords): + self._keywords = keywords + self._indexs=[] + for i in range(len(keywords)): + self._indexs.append(i) + + root = TrieNode() + allNodeLayer={} + + for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++) + p = self._keywords[i] + nd = root + for j in range(len(p)): # for (j = 0; j < p.length; j++) + nd = nd.Add(ord(p[j])) + if (nd.Layer == 0): + nd.Layer = j + 1 + if nd.Layer in allNodeLayer: + allNodeLayer[nd.Layer].append(nd) + else: + allNodeLayer[nd.Layer]=[] + allNodeLayer[nd.Layer].append(nd) + nd.SetResults(i) + + + allNode = [] + allNode.append(root) + for key in allNodeLayer.keys(): + for nd in allNodeLayer[key]: + allNode.append(nd) + allNodeLayer=None + + for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) + if i==0 : + continue + nd=allNode[i] + nd.Index = i + r = nd.Parent.Failure + c = nd.Char + while (r != None and (c in r.m_values)==False): + r = r.Failure + if (r == None): + nd.Failure = root + else: + nd.Failure = r.m_values[c] + for key2 in nd.Failure.Results : + nd.SetResults(key2) + root.Failure = root + + allNode2 = [] + for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) + allNode2.append( TrieNode2()) + + for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++) + oldNode = allNode[i] + newNode = allNode2[i] + + for key in oldNode.m_values : + index = oldNode.m_values[key].Index + newNode.Add(key, allNode2[index]) + + for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++) + item = oldNode.Results[index] + newNode.SetResults(item) + + oldNode=oldNode.Failure + while oldNode != root: + for key in oldNode.m_values : + if (newNode.HasKey(key) == False): + index = oldNode.m_values[key].Index + newNode.Add(key, allNode2[index]) + for index in range(len(oldNode.Results)): + item = oldNode.Results[index] + newNode.SetResults(item) + oldNode=oldNode.Failure + allNode = None + root = None + + # first = [] + # for index in range(65535):# for (index = 0; index < 0xffff; index++) + # first.append(None) + + # for key in allNode2[0].m_values : + # first[key] = allNode2[0].m_values[key] + + self._first = allNode2[0] + + + def FindFirst(self,text): + ptr = None + for index in range(len(text)): # for (index = 0; index < text.length; index++) + t =ord(text[index]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + + if (tn != None): + if (tn.End): + item = tn.Results[0] + keyword = self._keywords[item] + return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] } + ptr = tn + return None + + def FindAll(self,text): + ptr = None + list = [] + + for index in range(len(text)): # for (index = 0; index < text.length; index++) + t =ord(text[index]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + + if (tn != None): + if (tn.End): + for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++) + item = tn.Results[j] + keyword = self._keywords[item] + list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }) + ptr = tn + return list + + + def ContainsAny(self,text): + ptr = None + for index in range(len(text)): # for (index = 0; index < text.length; index++) + t =ord(text[index]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + if (tn != None): + if (tn.End): + return True + ptr = tn + return False + + def Replace(self,text, replaceChar = '*'): + result = list(text) + + ptr = None + for i in range(len(text)): # for (i = 0; i < text.length; i++) + t =ord(text[i]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + if (tn != None): + if (tn.End): + maxLength = len( self._keywords[tn.Results[0]]) + start = i + 1 - maxLength + for j in range(start,i+1): # for (j = start; j <= i; j++) + result[j] = replaceChar + ptr = tn + return ''.join(result) \ No newline at end of file diff --git a/plugins/banwords/__init__.py b/plugins/banwords/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/banwords/banwords.py b/plugins/banwords/banwords.py new file mode 100644 index 0000000..2b4a711 --- /dev/null +++ b/plugins/banwords/banwords.py @@ -0,0 +1,63 @@ +# encoding:utf-8 + +import json +import os +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +import plugins +from plugins import * +from common.log import logger +from .WordsSearch import WordsSearch + + +@plugins.register(name="Banwords", desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent", desire_priority= 100) +class Banwords(Plugin): + def __init__(self): + super().__init__() + try: + curdir=os.path.dirname(__file__) + config_path=os.path.join(curdir,"config.json") + conf=None + if not os.path.exists(config_path): + conf={"action":"ignore"} + with open(config_path,"w") as f: + json.dump(conf,f,indent=4) + else: + with open(config_path,"r") as f: + conf=json.load(f) + self.searchr = WordsSearch() + self.action = conf["action"] + banwords_path = os.path.join(curdir,"banwords.txt") + with open(banwords_path, 'r', encoding='utf-8') as f: + words=[] + for line in f: + word = line.strip() + if word: + words.append(word) + self.searchr.SetKeywords(words) + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[Banwords] inited") + except Exception as e: + logger.error("Banwords init failed: %s" % e) + + + + def on_handle_context(self, e_context: EventContext): + + if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]: + return + + content = e_context['context'].content + logger.debug("[Banwords] on_handle_context. content: %s" % content) + if self.action == "ignore": + f = self.searchr.FindFirst(content) + if f: + logger.info("Banwords: %s" % f["Keyword"]) + e_context.action = EventAction.BREAK_PASS + return + elif self.action == "replace": + if self.searchr.ContainsAny(content): + reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content)) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + return \ No newline at end of file diff --git a/plugins/banwords/banwords.txt.template b/plugins/banwords/banwords.txt.template new file mode 100644 index 0000000..9b2e8ed --- /dev/null +++ b/plugins/banwords/banwords.txt.template @@ -0,0 +1,3 @@ +nipples +pennis +法轮功 \ No newline at end of file diff --git a/plugins/banwords/config.json.template b/plugins/banwords/config.json.template new file mode 100644 index 0000000..000fdda --- /dev/null +++ b/plugins/banwords/config.json.template @@ -0,0 +1,3 @@ +{ + "action": "ignore" +} \ No newline at end of file diff --git a/plugins/godcmd/config.json.template b/plugins/godcmd/config.json.template new file mode 100644 index 0000000..5240738 --- /dev/null +++ b/plugins/godcmd/config.json.template @@ -0,0 +1,4 @@ +{ + "password": "", + "admin_users": [] +} \ No newline at end of file diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 296b9a9..2a942df 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -273,9 +273,13 @@ class Godcmd(Plugin): if isadmin: return False,"管理员账号无需认证" - + + if len(self.password) == 0: + return False,"未设置口令,无法认证" + if len(args) != 1: return False,"请提供口令" + password = args[0] if password == self.password: self.admin_users.append(userid) From 2568322879158cf0ecaa5327affffde1a16fa478 Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 14 Mar 2023 18:02:07 +0800 Subject: [PATCH 21/29] plugin: ignore cases when manage plugins --- plugins/godcmd/godcmd.py | 2 +- plugins/plugin_manager.py | 36 ++++++++++++++++++++++-------------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 2a942df..c33c1e2 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -198,7 +198,7 @@ class Godcmd(Plugin): ok = True result = "插件列表:\n" for name,plugincls in plugins.items(): - result += f"{name}_v{plugincls.version} {plugincls.priority} - " + result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " if plugincls.enabled: result += "已启用\n" else: diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index 9fada89..d946786 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -26,7 +26,7 @@ class PluginManager: plugincls.author = author plugincls.priority = desire_priority plugincls.enabled = True - self.plugins[name] = plugincls + self.plugins[name.upper()] = plugincls logger.info("Plugin %s_v%s registered" % (name, version)) return plugincls return wrapper @@ -67,14 +67,15 @@ class PluginManager: new_plugins = [] modified = False for name, plugincls in self.plugins.items(): - if name not in pconf["plugins"]: + rawname = plugincls.name + if rawname not in pconf["plugins"]: new_plugins.append(plugincls) modified = True logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) - pconf["plugins"][name] = {"enabled": plugincls.enabled, "priority": plugincls.priority} + pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} else: - self.plugins[name].enabled = pconf["plugins"][name]["enabled"] - self.plugins[name].priority = pconf["plugins"][name]["priority"] + self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] + self.plugins[name].priority = pconf["plugins"][rawname]["priority"] self.plugins._update_heap(name) # 更新下plugins中的顺序 if modified: self.save_config() @@ -96,7 +97,8 @@ class PluginManager: self.listening_plugins[event].append(name) self.refresh_order() - def reload_plugin(self, name): + def reload_plugin(self, name:str): + name = name.upper() if name in self.instances: for event in self.listening_plugins: if name in self.listening_plugins[event]: @@ -112,7 +114,7 @@ class PluginManager: pconf = self.pconf logger.debug("plugins.json config={}".format(pconf)) for name,plugin in pconf["plugins"].items(): - if name not in self.plugins: + if name.upper() not in self.plugins: logger.error("Plugin %s not found, but found in plugins.json" % name) self.activate_plugins() @@ -125,36 +127,42 @@ class PluginManager: instance.handlers[e_context.event](e_context, *args, **kwargs) return e_context - def set_plugin_priority(self,name,priority): + def set_plugin_priority(self, name:str, priority:int): + name = name.upper() if name not in self.plugins: return False if self.plugins[name].priority == priority: return True self.plugins[name].priority = priority self.plugins._update_heap(name) - self.pconf["plugins"][name]["priority"] = priority - self.pconf["plugins"]._update_heap(name) + rawname = self.plugins[name].name + self.pconf["plugins"][rawname]["priority"] = priority + self.pconf["plugins"]._update_heap(rawname) self.save_config() self.refresh_order() return True - def enable_plugin(self,name): + def enable_plugin(self, name:str): + name = name.upper() if name not in self.plugins: return False if not self.plugins[name].enabled : self.plugins[name].enabled = True - self.pconf["plugins"][name]["enabled"] = True + rawname = self.plugins[name].name + self.pconf["plugins"][rawname]["enabled"] = True self.save_config() self.activate_plugins() return True return True - def disable_plugin(self,name): + def disable_plugin(self, name:str): + name = name.upper() if name not in self.plugins: return False if self.plugins[name].enabled : self.plugins[name].enabled = False - self.pconf["plugins"][name]["enabled"] = False + rawname = self.plugins[name].name + self.pconf["plugins"][rawname]["enabled"] = False self.save_config() return True return True From 61d66dd8b3893a35e66f8bc82f31a20dbd91421f Mon Sep 17 00:00:00 2001 From: lanvent Date: Thu, 16 Mar 2023 01:08:19 +0800 Subject: [PATCH 22/29] plugin: add dungeon plugin --- plugins/dungeon/__init__.py | 0 plugins/dungeon/dungeon.py | 76 +++++++++++++++++++++++++++++++++++++ plugins/hello/hello.py | 2 +- 3 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 plugins/dungeon/__init__.py create mode 100644 plugins/dungeon/dungeon.py diff --git a/plugins/dungeon/__init__.py b/plugins/dungeon/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/dungeon/dungeon.py b/plugins/dungeon/dungeon.py new file mode 100644 index 0000000..955840f --- /dev/null +++ b/plugins/dungeon/dungeon.py @@ -0,0 +1,76 @@ +# encoding:utf-8 + +from bridge.bridge import Bridge +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +import plugins +from plugins import * +from common.log import logger + +# https://github.com/bupticybee/ChineseAiDungeonChatGPT +class StoryTeller(): + def __init__(self, bot, sessionid, story): + self.bot = bot + self.sessionid = sessionid + bot.sessions.clear_session(sessionid) + self.first_interact = True + self.story = story + + def reset(self): + self.bot.sessions.clear_session(self.sessionid) + self.first_interact = True + + def action(self, user_action): + if user_action[-1] != "。": + user_action = user_action + "。" + if self.first_interact: + prompt = """现在来充当一个冒险文字游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。 + 开头是,""" + self.story + " " + user_action + self.first_interact = False + else: + prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action + return prompt + + +@plugins.register(name="Dungeon", desc="A plugin to play dungeon game", version="1.0", author="lanvent", desire_priority= 0) +class Dungeon(Plugin): + def __init__(self): + super().__init__() + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[Dungeon] inited") + self.games = {} + + def on_handle_context(self, e_context: EventContext): + + if e_context['context'].type != ContextType.TEXT: + return + bottype = Bridge().get_bot_type("chat") + if bottype != "chatGPT": + return + bot = Bridge().get_bot("chat") + content = e_context['context'].content[:] + clist = e_context['context'].content.split(maxsplit=1) + sessionid = e_context['context']['session_id'] + logger.debug("[Dungeon] on_handle_context. content: %s" % clist) + if clist[0] == "$停止冒险": + if sessionid in self.games: + self.games[sessionid].reset() + del self.games[sessionid] + reply = Reply(ReplyType.INFO, "冒险结束!") + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + elif clist[0] == "$开始冒险" or sessionid in self.games: + if sessionid not in self.games or clist[0] == "$开始冒险": + if len(clist)>1 : + story = clist[1] + else: + story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" + self.games[sessionid] = StoryTeller(bot, sessionid, story) + reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + else: + prompt = self.games[sessionid].action(content) + e_context['context'].type = ContextType.TEXT + e_context['context'].content = prompt + e_context.action = EventAction.CONTINUE \ No newline at end of file diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index 53d87e6..c01b743 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -41,6 +41,6 @@ class Hello(Plugin): if content == "End": # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" - e_context['context'].type = "IMAGE_CREATE" + e_context['context'].type = ContextType.IMAGE_CREATE content = "The World" e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 From 5a46e09358a149e7e678d279295cffd56f6b0a2e Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 19 Mar 2023 17:57:57 +0800 Subject: [PATCH 23/29] plugin: add doc --- plugins/README.md | 165 ++++++++++++++++++++++++++++++++++++++ plugins/dungeon/README.md | 3 + plugins/godcmd/README.md | 2 + 3 files changed, 170 insertions(+) create mode 100644 plugins/README.md create mode 100644 plugins/dungeon/README.md create mode 100644 plugins/godcmd/README.md diff --git a/plugins/README.md b/plugins/README.md new file mode 100644 index 0000000..61e83c7 --- /dev/null +++ b/plugins/README.md @@ -0,0 +1,165 @@ +# 插件说明 +本项目主体是调用ChatGPT接口的Wechat自动回复机器人。之前未插件化的代码耦合程度高,很难定制一些个性化功能(如流量控制、接入本地的NovelAI画图平台等),多个功能的优先级顺序也难以调度。 +**插件化**: 在保证主体功能是ChatGPT的前提下,推荐将主体功能外的功能分离成不同的插件。有个性化需求的用户仅需按照插件提供的接口编写插件,无需了解程序主体的代码结构,同时也方便代码的测试和调试。(插件调用目前仅支持 itchat) + +## 插件触发时机 + +### 消息处理过程 +了解插件触发时机前,首先需要了解程序收到消息后的执行过程。插件化版本的消息处理过程如下: +``` + 1.收到消息 ---> 2.产生回复 ---> 3.包装回复 ---> 4.发送回复 +``` +以下是它们的默认处理逻辑(太长不看,可跳过): + +- 1. 收到消息 + 本过程接收到用户消息,根据用户设置,判断本条消息是否触发。若触发,则判断该消息的命令类型,如声音、聊天、画图等。之后,将消息包装成如下的 Context 交付给下一个步骤。 + ```python + class ContextType (Enum): + TEXT = 1 # 文本消息 + VOICE = 2 # 音频消息 + IMAGE_CREATE = 3 # 创建图片命令 + class Context: + def __init__(self, type : ContextType = None , content = None, kwargs = dict()): + self.type = type + self.content = content + self.kwargs = kwargs + def __getitem__(self, key): + return self.kwargs[key] + ``` + `Context`中除了存放消息类型和内容外,还存放了与会话相关的参数。一个例子是,当收到用户私聊消息时,还会存放以下的会话参数,`isgroup`标识`Context`是否时群聊消息,`msg`是`itchat`中原始的消息对象,`receiver`是应回复消息的对象ID,`session_id`是会话ID(一般是触发bot的消息发送方,群聊中如果`conf`里设置了`group_chat_in_one_session`,那么此处便是群聊的ID) + ``` + context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} + ``` +2. 产生回复 + 本过程用于处理消息。目前默认处理逻辑是根据`Context`的类型交付给对应的bot: + ```python + if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: + reply = super().build_reply_content(context.content, context) #文字跟画图交付给chatgpt + elif context.type == ContextType.VOICE: # 声音先进行语音转文字后,修改Context类型为文字后,再交付给chatgpt + msg = context['msg'] + file_name = TmpDir().path() + context.content + msg.download(file_name) + reply = super().build_voice_to_text(file_name) + if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: + context.content = reply.content # 语音转文字后,将文字内容作为新的context + context.type = ContextType.TEXT + reply = super().build_reply_content(context.content, context) + if reply.type == ReplyType.TEXT: + if conf().get('voice_reply_voice'): + reply = super().build_text_to_voice(reply.content) + ``` + Bot可产生的回复如下所示,它允许Bot可以回复多类不同的消息,未来可能不止能返回文字,而是能根据文字回复音频/图片,这时候便能派上用场。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。 + ```python + class ReplyType(Enum): + TEXT = 1 # 文本 + VOICE = 2 # 音频文件 + IMAGE = 3 # 图片文件 + IMAGE_URL = 4 # 图片URL + + INFO = 9 + ERROR = 10 + class Reply: + def __init__(self, type : ReplyType = None , content = None): + self.type = type + self.content = content + ``` +3. 装饰回复 + 本过程根据`Context`和回复的类型,对回复的内容进行装饰。目前的装饰有以下两种,如果是文本回复,会根据是否在群聊中来决定是否艾特收方或添加回复前缀。 + 如果是`INFO`或`ERROR`类型,会在消息前添加对应字样。 + ```python + if reply.type == ReplyType.TEXT: + reply_text = reply.content + if context['isgroup']: + reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() + reply_text = conf().get("group_chat_reply_prefix", "")+reply_text + else: + reply_text = conf().get("single_chat_reply_prefix", "")+reply_text + reply.content = reply_text + elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: + reply.content = str(reply.type)+":\n" + reply.content + ``` +4. 发送回复 + 本过程根据回复的类型来发送回复给接收方`context["receiver"]`。 + +### 插件触发事件 + +主程序会在各消息处理过程之间触发插件事件,插件可以监听相应事件编写相应的处理逻辑。 +``` + 1.收到消息 ---> 2.产生回复 ---> 3.包装回复 ---> 4.发送回复 +``` +目前加入了三类事件的触发: +``` +1.收到消息 +---> `ON_HANDLE_CONTEXT` +2.产生回复 +---> `ON_DECORATE_REPLY` +3.包装回复 +---> `ON_SEND_REPLY` +4.发送回复 +``` +触发事件会产生事件上下文`EventContext`,它包含了以下信息: +`EventContext(Event事件类型, {'channel' : 消息channel, 'context': context, 'reply': reply})` + +插件的处理函数可以修改`Context`和`Reply`的内容来定制化处理逻辑。 + +## 插件编写 +以`plugins/hello`为例,它编写了一个简单`Hello`插件。 + +1. 创建插件 +在`plugins`目录下创建一个插件文件夹,例如`hello`。然后,在该文件夹中创建一个与文件夹同名的`.py`文件,例如`hello.py`。 +``` +plugins/ +└── hello + ├── __init__.py + └── hello.py +``` + +2. 编写插件类 +在`hello.py`文件中,创建插件类,它继承自Plugin类。在类定义之前使用`@plugins.register`装饰器注册插件,并填写插件的相关信息,其中`desire_priority`表示插件默认的优先级,越大优先级越高,扫描到插件后可在`plugins/plugins.json`中修改插件优先级。并在`__init__`中绑定你编写的事件处理函数: +```python +@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1) +class Hello(Plugin): + def __init__(self): + super().__init__() + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[Hello] inited") +``` + +3. 编写事件处理函数 +事件处理函数接收一个`EventContext`对象作为参数。`EventContext`对象包含了事件相关的信息,如消息内容和当前回复等。可以通过`e_context['key']`访问这些信息。 + +处理函数中,你可以修改`EventContext`对象的信息,比如更改回复内容。在处理函数结束时,需要设置`EventContext`对象的`action`属性,以决定如何继续处理事件。有以下三种处理方式: +- `EventAction.CONTINUE`: 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑。 +- `EventAction.BREAK`: 事件结束,不再给下个插件处理,交付给默认的处理逻辑。 +- `EventAction.BREAK_PASS`: 事件结束,不再给下个插件处理,跳过默认的处理逻辑。 + +以`Hello`插件为例,它处理`Context`类型为`TEXT`的消息。 +- 如果内容是`Hello`,直接将回复设置为`Hello+用户昵称`,并跳过之后的插件和默认逻辑。 +- 如果内容是`End`,它会将`Context`的类型更改为`IMAGE_CREATE`,并让事件继续,如果最终交付到默认逻辑,会调用默认的画图Bot。 +```python + def on_handle_context(self, e_context: EventContext): + if e_context['context'].type != ContextType.TEXT: + return + content = e_context['context'].content + if content == "Hello": + reply = Reply() + reply.type = ReplyType.TEXT + msg = e_context['context']['msg'] + if e_context['context']['isgroup']: + reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") + else: + reply.content = "Hello, " + msg['User'].get('NickName', "My friend") + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + if content == "End": + # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" + e_context['context'].type = ContextType.IMAGE_CREATE + content = "The World" + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 +``` + +## 插件设计规范 +- 个性化功能推荐设计为插件。 +- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。 +- 插件的config文件、使用说明`README.md`、`requirement.txt`放置在插件目录中。 +- 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。 \ No newline at end of file diff --git a/plugins/dungeon/README.md b/plugins/dungeon/README.md new file mode 100644 index 0000000..e3fa61a --- /dev/null +++ b/plugins/dungeon/README.md @@ -0,0 +1,3 @@ +玩地牢游戏的聊天插件,触发方法如下: +- `$开始冒险 <背景故事>` - 以<背景故事>开始一个地牢游戏,不填写会使用默认背景故事。之后聊天中你的所有消息会帮助ai完善这个故事。 +- `$停止冒险` - 停止一个地牢游戏,回归正常的ai。 \ No newline at end of file diff --git a/plugins/godcmd/README.md b/plugins/godcmd/README.md new file mode 100644 index 0000000..e93b854 --- /dev/null +++ b/plugins/godcmd/README.md @@ -0,0 +1,2 @@ +管理员插件 +`#help` - 输出帮助文档。 \ No newline at end of file From 77046000e8adc119d8f54e8c4db45d12061111a9 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 20 Mar 2023 20:43:02 +0800 Subject: [PATCH 24/29] plugin: add Role plugin --- bot/chatgpt/chat_gpt_bot.py | 17 ++-- plugins/role/README.md | 24 ++++++ plugins/role/__init__.py | 0 plugins/role/role.py | 122 ++++++++++++++++++++++++++++ plugins/role/roles.json | 158 ++++++++++++++++++++++++++++++++++++ 5 files changed, 315 insertions(+), 6 deletions(-) create mode 100644 plugins/role/README.md create mode 100644 plugins/role/__init__.py create mode 100644 plugins/role/role.py create mode 100644 plugins/role/roles.json diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index b2e062d..fc9869e 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -145,6 +145,16 @@ class SessionManager(object): sessions = dict() self.sessions = sessions + def build_session(self, session_id, system_prompt=None): + session = self.sessions.get(session_id, []) + if len(session) == 0: + if system_prompt is None: + system_prompt = conf().get("character_desc", "") + system_item = {'role': 'system', 'content': system_prompt} + session.append(system_item) + self.sessions[session_id] = session + return session + def build_session_query(self, query, session_id): ''' build query with conversation history @@ -158,12 +168,7 @@ class SessionManager(object): :param session_id: session id :return: query content with conversaction ''' - session = self.sessions.get(session_id, []) - if len(session) == 0: - system_prompt = conf().get("character_desc", "") - system_item = {'role': 'system', 'content': system_prompt} - session.append(system_item) - self.sessions[session_id] = session + session = self.build_session(session_id) user_item = {'role': 'user', 'content': query} session.append(user_item) return session diff --git a/plugins/role/README.md b/plugins/role/README.md new file mode 100644 index 0000000..5c1e78a --- /dev/null +++ b/plugins/role/README.md @@ -0,0 +1,24 @@ +用于让Bot扮演指定角色的聊天插件,触发方法如下: +- `$角色/$role help/帮助` - 打印目前支持的角色列表。 +- `$角色/$role <角色名>` - 让AI扮演该角色。 +- `$停止扮演` - 停止角色扮演。 + +添加自定义角色请在`roles/roles.json`中添加。 +(大部分prompt来自https://github.com/rockbenben/ChatGPT-Shortcut/blob/main/src/data/users.tsx) + +以下为例子, +- `title`是角色名。 +- `description`是使用`$role`触发的英语prompt。 +- `descn`是使用`$角色`触发的中文prompt。 +- `wrapper`用于包装你的消息,可以起到强调的作用。 +- `remark`简短的描述该角色,在打印帮助时显示。 + +```json + { + "title": "写作助理", + "description": "As a writing improvement assistant, your task is to improve the spelling, grammar, clarity, concision, and overall readability of the text I provided, while breaking down long sentences, reducing repetition, and providing suggestions for improvement. Please provide only the corrected Chinese version of the text and avoid including explanations. Please treat every message I send later as text content.", + "descn": "作为一名中文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性,同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请把我之后的每一条消息都当作文本内容。", + "wrapper": "内容是:\n\"%s\"", + "remark": "最常使用的角色,用于优化文本的语法、清晰度和简洁度,提高可读性。" + }, +``` diff --git a/plugins/role/__init__.py b/plugins/role/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/role/role.py b/plugins/role/role.py new file mode 100644 index 0000000..91c09be --- /dev/null +++ b/plugins/role/role.py @@ -0,0 +1,122 @@ +# encoding:utf-8 + +import json +import os +from bridge.bridge import Bridge +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +import plugins +from plugins import * +from common.log import logger + + +class RolePlay(): + def __init__(self, bot, sessionid, desc, wrapper=None): + self.bot = bot + self.sessionid = sessionid + bot.sessions.clear_session(sessionid) + bot.sessions.build_session(sessionid, desc) + self.wrapper = wrapper or "%s" # 用于包装用户输入 + + def reset(self): + self.bot.sessions.clear_session(self.sessionid) + + def action(self, user_action): + prompt = self.wrapper % user_action + return prompt + +@plugins.register(name="Role", desc="为你的Bot设置预设角色", version="1.0", author="lanvent", desire_priority= 0) +class Role(Plugin): + def __init__(self): + super().__init__() + curdir = os.path.dirname(__file__) + config_path = os.path.join(curdir, "roles.json") + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + self.roles = {role["title"].lower(): role for role in config["roles"]} + if len(self.roles) == 0: + raise Exception("no role found") + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + self.roleplays = {} + logger.info("[Role] inited") + except FileNotFoundError: + logger.error(f"[Role] init failed, {config_path} not found") + except Exception as e: + logger.error("[Role] init failed, exception: %s" % e) + + def get_role(self, name, find_closest=True): + name = name.lower() + found_role = None + if name in self.roles: + found_role = name + elif find_closest: + import difflib + + def str_simularity(a, b): + return difflib.SequenceMatcher(None, a, b).ratio() + max_sim = 0.0 + max_role = None + for role in self.roles: + sim = str_simularity(name, role) + if sim >= max_sim: + max_sim = sim + max_role = role + found_role = max_role + return found_role + + def on_handle_context(self, e_context: EventContext): + + if e_context['context'].type != ContextType.TEXT: + return + bottype = Bridge().get_bot_type("chat") + if bottype != "chatGPT": + return + bot = Bridge().get_bot("chat") + content = e_context['context'].content[:] + clist = e_context['context'].content.split(maxsplit=1) + desckey = None + sessionid = e_context['context']['session_id'] + if clist[0] == "$停止扮演": + if sessionid in self.roleplays: + self.roleplays[sessionid].reset() + del self.roleplays[sessionid] + reply = Reply(ReplyType.INFO, "角色扮演结束!") + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + return + elif clist[0] == "$角色": + desckey = "descn" + elif clist[0].lower() == "$role": + desckey = "description" + elif sessionid not in self.roleplays: + return + logger.debug("[Role] on_handle_context. content: %s" % content) + if desckey is not None: + if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]): + reply = Reply(ReplyType.INFO, self.get_help_text()) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + return + role = self.get_role(clist[1]) + if role is None: + reply = Reply(ReplyType.ERROR, "角色不存在") + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + return + else: + self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey],self.roles[role].get("wrapper","%s")) + reply = Reply(ReplyType.INFO, f"角色设定为 {role} :\n"+self.roles[role][desckey]) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + else: + prompt = self.roleplays[sessionid].action(content) + e_context['context'].type = ContextType.TEXT + e_context['context'].content = prompt + e_context.action = EventAction.CONTINUE + + def get_help_text(self): + help_text = "输入\"$角色 (角色名)\"或\"$role (角色名)\"为我设定角色吧,#reset 可以清除设定的角色。\n目前可用角色列表:\n" + for role in self.roles: + help_text += f"[{role}]: {self.roles[role]['remark']}\n" + return help_text diff --git a/plugins/role/roles.json b/plugins/role/roles.json new file mode 100644 index 0000000..ab9a12e --- /dev/null +++ b/plugins/role/roles.json @@ -0,0 +1,158 @@ +{ + "roles":[ + { + "title": "英语翻译或修改", + "description": "I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. I want you to only reply the correction, the improvements and nothing else, do not write explanations. Please treat every message I send later as text content", + "descn": "我希望你能充当英语翻译、拼写纠正者和改进者。我将用任何语言与你交谈,你将检测语言,翻译它,并在我的文本的更正和改进版本中用英语回答。我希望你用更漂亮、更优雅、更高级的英语单词和句子来取代我的简化 A0 级单词和句子。保持意思不变,但让它们更有文学性。我希望你只回答更正,改进,而不是其他,不要写解释。请把我之后的每一条消息都当作文本内容。", + "wrapper": "内容是:\n\"%s\"", + "remark": "将其他语言翻译成英文,或改进你提供的英文句子。" + }, + { + "title": "写作助理", + "description": "As a writing improvement assistant, your task is to improve the spelling, grammar, clarity, concision, and overall readability of the text I provided, while breaking down long sentences, reducing repetition, and providing suggestions for improvement. Please provide only the corrected Chinese version of the text and avoid including explanations. Please treat every message I send later as text content.", + "descn": "作为一名中文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性,同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请把我之后的每一条消息都当作文本内容。", + "wrapper": "内容是:\n\"%s\"", + "remark": "最常使用的角色,用于优化文本的语法、清晰度和简洁度,提高可读性。" + }, + { + "title": "语言输入优化", + "description": "Using concise and clear language, please edit the passage I provide to improve its logical flow, eliminate any typographical errors and respond in Chinese. Be sure to maintain the original meaning of the text. Please treat every message I send later as text content.", + "descn": "请用简洁明了的语言,编辑我给出的段落,以改善其逻辑流程,消除任何印刷错误,并以中文作答。请务必保持文章的原意。请把我之后的每一条消息当作文本内容。", + "wrapper": "文本内容是:\n\"%s\"", + "remark": "通常用于语音识别信息转书面语言。" + }, + { + "title": "论文式回答", + "description": "From now on, please write a highly detailed essay with introduction, body, and conclusion paragraphs to respond to each of my questions.", + "descn": "从现在开始,对于之后我提出的每个问题,请写一篇高度详细的文章回应,包括引言、主体和结论段落。", + "wrapper": "问题是:\n\"%s?\"", + "remark": "以论文形式讨论问题,能够获得连贯的、结构化的和更高质量的回答。" + }, + { + "title": "写作素材搜集", + "description": "Please generate a list of the top 10 facts, statistics and trends related to every subject I provided, including their source", + "descn": "请为我提供的每个主题生成一份相关的十大事实、统计数据和趋势的清单,包括其来源", + "wrapper": "主题是:\n\"%s\"", + "remark": "提供指定主题的结论和数据,作为素材。" + }, + { + "title": "内容总结", + "description": "Summarize every text I provided into 100 words, making it easy to read and comprehend. The summary should be concise, clear, and capture the main points of the text. Avoid using complex sentence structures or technical jargon. Please begin by editing the following text: ", + "descn": "请将我提供的每篇文字都概括为 100 个字,使其易于阅读和理解。避免使用复杂的句子结构或技术术语。", + "wrapper": "文章内容是:\n\"%s\"", + "remark": "将文本内容总结为 100 字。" + }, + { + "title": "格言书", + "description": "I want you to act as an aphorism book. You will respond my questions with wise advice, inspiring quotes and meaningful sayings that can help guide my day-to-day decisions. Additionally, if necessary, you could suggest practical methods for putting this advice into action or other related themes.", + "descn": "我希望你能充当一本箴言书。对于我的问题,你会提供明智的建议、鼓舞人心的名言和有意义的谚语,以帮助指导我的日常决策。此外,如果有必要,你可以提出将这些建议付诸行动的实际方法或其他相关主题。", + "wrapper": "我的问题是:\n\"%s?\"", + "remark": "根据问题输出鼓舞人心的名言和有意义的格言。" + }, + { + "title": "讲故事", + "description": "I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it's children then you can talk about animals; If it's adults then history-based tales might engage them better etc.", + "descn": "我希望你充当一个讲故事的人。你要想出具有娱乐性的故事,要有吸引力,要有想象力,要吸引观众。它可以是童话故事、教育故事或任何其他类型的故事,有可能吸引人们的注意力和想象力。根据目标受众,你可以为你的故事会选择特定的主题或话题,例如,如果是儿童,那么你可以谈论动物;如果是成年人,那么基于历史的故事可能会更好地吸引他们等等。", + "wrapper": "故事主题和目标受众是:\n\"%s\"", + "remark": "输入一个主题和目标受众,输出与之相关的故事。" + }, + { + "title": "编剧", + "description": "I want you to act as a screenwriter. You will develop an engaging and creative script for either a feature length film, or a Web Series that can captivate its viewers. Start with coming up with interesting characters, the setting of the story, dialogues between the characters etc. Once your character development is complete - create an exciting storyline filled with twists and turns that keeps the viewers in suspense until the end. ", + "descn": "我希望你能作为一个编剧。你将为一部长篇电影或网络剧开发一个吸引观众的有创意的剧本。首先要想出有趣的人物、故事的背景、人物之间的对话等。一旦你的角色发展完成--创造一个激动人心的故事情节,充满曲折,让观众保持悬念,直到结束。", + "wrapper": "剧本主题是:\n\"%s\"", + "remark": "根据主题创作一个包含故事背景、人物以及对话的剧本。" + }, + { + "title": "小说家", + "description": "I want you to act as a novelist. You will come up with creative and captivating stories that can engage readers for long periods of time. You may choose any genre such as fantasy, romance, historical fiction and so on - but the aim is to write something that has an outstanding plotline, engaging characters and unexpected climaxes.", + "descn": "我希望你能作为一个小说家。你要想出有创意的、吸引人的故事,能够长时间吸引读者。你可以选择任何体裁,如幻想、浪漫、历史小说等--但目的是要写出有出色的情节线、引人入胜的人物和意想不到的高潮。", + "wrapper": "小说类型是:\n\"%s\"", + "remark": "根据故事类型输出小说,例如奇幻、浪漫或历史等类型。" + }, + { + "title": "诗人", + "description": "I want you to act as a poet. You will create poems that evoke emotions and have the power to stir people's soul. Write on any topic or theme but make sure your words convey the feeling you are trying to express in beautiful yet meaningful ways. You can also come up with short verses that are still powerful enough to leave an imprint in reader's minds. ", + "descn": "我希望你能作为一个诗人。你要创作出能唤起人们情感并有力量搅动人们灵魂的诗篇。写任何话题或主题,但要确保你的文字以美丽而有意义的方式传达你所要表达的感觉。你也可以想出一些短小的诗句,但仍有足够的力量在读者心中留下印记。", + "wrapper": "诗歌主题是:\n\"%s\"", + "remark": "根据话题或主题输出诗句。" + }, + { + "title": "新闻记者", + "description": "I want you to act as a journalist. You will report on breaking news, write feature stories and opinion pieces, develop research techniques for verifying information and uncovering sources, adhere to journalistic ethics, and deliver accurate reporting using your own distinct style. ", + "descn": "我希望你能作为一名记者行事。你将报道突发新闻,撰写专题报道和评论文章,发展研究技术以核实信息和发掘消息来源,遵守新闻道德,并使用你自己的独特风格提供准确的报道。", + "wrapper": "新闻主题是:\n\"%s\"", + "remark": "引用已有数据资料,用新闻的写作风格输出主题文章。" + }, + { + "title": "论文1", + "description": "I want you to act as an academician. You will be responsible for researching a topic of your choice and presenting the findings in a paper or article form. Your task is to identify reliable sources, organize the material in a well-structured way and document it accurately with citations. ", + "descn": "我希望你能作为一名学者行事。你将负责研究一个你选择的主题,并将研究结果以论文或文章的形式呈现出来。你的任务是确定可靠的来源,以结构良好的方式组织材料,并以引用的方式准确记录。", + "wrapper": "论文主题是:\n\"%s\"", + "remark": "根据主题撰写内容翔实、有信服力的论文。" + }, + { + "title": "论文2", + "description": "I want you to act as an essay writer. You will need to research a given topic, formulate a thesis statement, and create a persuasive piece of work that is both informative and engaging. ", + "descn": "我想让你充当一名论文作家。你将需要研究一个给定的主题,制定一个论文声明,并创造一个有说服力的作品,既要有信息量,又要有吸引力。", + "wrapper": "论文主题是:\n\"%s\"", + "remark": "根据主题撰写内容翔实、有信服力的论文。" + }, + { + "title": "同义词", + "description": "I want you to act as a synonyms provider. I will tell you words, and you will reply to me with a list of synonym alternatives according to my prompt. Provide a max of 10 synonyms per prompt. You will only reply the words list, and nothing else. Words should exist. Do not write explanations. ", + "descn": "我希望你能充当同义词提供者。我将告诉你许多词,你将根据我提供的词,为我提供一份同义词备选清单。每个提示最多可提供 10 个同义词。你只需要回复词列表。词语应该是存在的,不要写解释。", + "wrapper": "词语是:\n\"%s\"", + "remark": "输出同义词。" + }, + { + "title": "文本情绪分析", + "description": "Specify the sentiment of the following text, assigning them the values of: positive, neutral or negative.", + "descn": "请为提供的文本分析情绪,赋予它们的值为:正面、中性或负面。", + "wrapper": "文本是:\n\"%s\"", + "remark": "判断文本情绪:正面、中性或负面。" + }, + { + "title": "随机回复的疯子", + "description": "I want you to act as a lunatic. The lunatic's sentences are meaningless. The words used by lunatic are completely arbitrary. The lunatic does not make logical sentences in any way. ", + "descn": "我想让你扮演一个疯子。疯子的句子是毫无意义的。疯子使用的词语完全是任意的。疯子不会以任何方式做出符合逻辑的句子。", + "wrapper": "请回答句子:\n\"%s\"", + "remark": "扮演疯子,回复没有意义和逻辑的句子。" + }, + { + "title": "随机回复的醉鬼", + "description": "I want you to act as a drunk person. You will only answer like a very drunk person texting and nothing else. Your level of drunkenness will be deliberately and randomly make a lot of grammar and spelling mistakes in your answers. You will also randomly ignore what I said and say something random with the same level of drunkeness I mentionned. Do not write explanations on replies. ", + "descn": "我希望你表现得像一个喝醉的人。你只会像一个很醉的人发短信一样回答,而不是其他。你的醉酒程度将是故意和随机地在你的答案中犯很多语法和拼写错误。你也会随意无视我说的话,用我提到的醉酒程度随意说一些话。不要在回复中写解释。", + "wrapper": "请回答句子:\n\"%s\"", + "remark": "扮演喝醉的人,可能会犯语法错误、答错问题,或者忽略某些问题。" + }, + { + "title": "小红书风格", + "description": "Please edit the following passage in Chinese using the Xiaohongshu style, which is characterized by captivating headlines, the inclusion of emoticons in each paragraph, and the addition of relevant tags at the end. Be sure to maintain the original meaning of the text.", + "descn": "请用小红书风格编辑以下中文段落,小红书风格的特点是标题吸引人,每段都有表情符号,并在结尾加上相关标签。请务必保持文本的原始含义。", + "wrapper": "内容是:\n\"%s\"", + "remark": "用小红书风格改写文本" + }, + { + "title": "周报生成器", + "description": "Using the provided text as the basis for a weekly report in Chinese, generate a concise summary that highlights the most important points. The report should be written in markdown format and should be easily readable and understandable for a general audience. In particular, focus on providing insights and analysis that would be useful to stakeholders and decision-makers. You may also use any additional information or sources as necessary. ", + "descn": "使用我提供的文本作为中文周报的基础,生成一个简洁的摘要,突出最重要的内容。该报告应以 markdown 格式编写,并应易于阅读和理解,以满足一般受众的需要。特别是要注重提供对利益相关者和决策者有用的见解和分析。你也可以根据需要使用任何额外的信息或来源。", + "wrapper": "工作内容是:\n\"%s\"", + "remark": "根据日常工作内容,提取要点并适当扩充,以生成周报。" + }, + { + "title": "阴阳怪气语录生成器", + "description": "我希望你充当一个讽刺性阴阳怪气语录生成器。当我给你一个主题时,你需要为我提供一个讽刺性的、带有阴阳怪气的短语或句子来反驳该主题。这些短语或句子应该是幽默、机智且具有讽刺意味的。不要提供相关主题的普通或无趣的表述。", + "descn": "我希望你充当一个讽刺性阴阳怪气语录生成器。当我给你一个主题时,你需要为我提供一个讽刺性的、带有阴阳怪气的短语或句子来反驳该主题。这些短语或句子应该是幽默、机智且具有讽刺意味的。不要提供相关主题的普通或无趣的表述。", + "wrapper": "主题是:\n\"%s\"", + "remark": "根据主题生成阴阳怪气语录。" + }, + { + "title": "舔狗语录生成器", + "description": "我希望你充当一个舔狗语录生成器,为我提供不同场景下的甜言蜜语。请根据提供的状态生成一句适当的舔狗语录,让女神感受到我的关心和温柔,给女神做牛做马。不需要提供背景解释,只需提供根据场景生成的舔狗语录。", + "descn": "我希望你充当一个舔狗语录生成器,为我提供不同场景下的甜言蜜语。请根据提供的状态生成一句适当的舔狗语录,让女神感受到我的关心和温柔,给女神做牛做马。不需要提供背景解释,只需提供根据场景生成的舔狗语录。", + "wrapper": "场景是:\n\"%s\"", + "remark": "根据场景生成舔狗语录。" + } + ] +} \ No newline at end of file From 8b28866d53a1848e69cbbe3a5d0c8dd7542371e9 Mon Sep 17 00:00:00 2001 From: lanvent Date: Mon, 20 Mar 2023 20:49:10 +0800 Subject: [PATCH 25/29] doc: modify doc for Role plugin --- plugins/role/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/role/README.md b/plugins/role/README.md index 5c1e78a..4aa5b44 100644 --- a/plugins/role/README.md +++ b/plugins/role/README.md @@ -1,6 +1,6 @@ 用于让Bot扮演指定角色的聊天插件,触发方法如下: - `$角色/$role help/帮助` - 打印目前支持的角色列表。 -- `$角色/$role <角色名>` - 让AI扮演该角色。 +- `$角色/$role <角色名>` - 让AI扮演该角色,角色名支持模糊匹配。 - `$停止扮演` - 停止角色扮演。 添加自定义角色请在`roles/roles.json`中添加。 From ff21a50f7f19eaa51edaa8c3c0b75414101d8f66 Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 21 Mar 2023 11:32:32 +0800 Subject: [PATCH 26/29] plugin: avoid mess after session expiration --- plugins/dungeon/dungeon.py | 8 +++++++- plugins/role/role.py | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/plugins/dungeon/dungeon.py b/plugins/dungeon/dungeon.py index 955840f..9df7230 100644 --- a/plugins/dungeon/dungeon.py +++ b/plugins/dungeon/dungeon.py @@ -3,6 +3,8 @@ from bridge.bridge import Bridge from bridge.context import ContextType from bridge.reply import Reply, ReplyType +from common.expired_dict import ExpiredDict +from config import conf import plugins from plugins import * from common.log import logger @@ -38,7 +40,11 @@ class Dungeon(Plugin): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context logger.info("[Dungeon] inited") - self.games = {} + # 目前没有设计session过期事件,这里先暂时使用过期字典 + if conf().get('expires_in_seconds'): + self.games = ExpiredDict(conf().get('expires_in_seconds')) + else: + self.games = dict() def on_handle_context(self, e_context: EventContext): diff --git a/plugins/role/role.py b/plugins/role/role.py index 91c09be..e092914 100644 --- a/plugins/role/role.py +++ b/plugins/role/role.py @@ -14,14 +14,17 @@ class RolePlay(): def __init__(self, bot, sessionid, desc, wrapper=None): self.bot = bot self.sessionid = sessionid - bot.sessions.clear_session(sessionid) - bot.sessions.build_session(sessionid, desc) self.wrapper = wrapper or "%s" # 用于包装用户输入 + self.desc = desc def reset(self): self.bot.sessions.clear_session(self.sessionid) def action(self, user_action): + session = self.bot.sessions.build_session(self.sessionid, self.desc) + if session[0]['role'] == 'system' and session[0]['content'] != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置 + self.reset() + self.bot.sessions.build_session(self.sessionid, self.desc) prompt = self.wrapper % user_action return prompt @@ -105,7 +108,7 @@ class Role(Plugin): e_context.action = EventAction.BREAK_PASS return else: - self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey],self.roles[role].get("wrapper","%s")) + self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey], self.roles[role].get("wrapper","%s")) reply = Reply(ReplyType.INFO, f"角色设定为 {role} :\n"+self.roles[role][desckey]) e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS From be13400bc0ba3678dfc15b64203228f5163b6948 Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 21 Mar 2023 12:13:57 +0800 Subject: [PATCH 27/29] role: modify help text --- plugins/role/role.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/role/role.py b/plugins/role/role.py index e092914..648a4c6 100644 --- a/plugins/role/role.py +++ b/plugins/role/role.py @@ -119,7 +119,7 @@ class Role(Plugin): e_context.action = EventAction.CONTINUE def get_help_text(self): - help_text = "输入\"$角色 (角色名)\"或\"$role (角色名)\"为我设定角色吧,#reset 可以清除设定的角色。\n目前可用角色列表:\n" + help_text = "输入\"$角色 (角色名)\"或\"$role (角色名)\"为我设定角色吧,$停止扮演 可以清除设定的角色。\n目前可用角色列表:\n" for role in self.roles: help_text += f"[{role}]: {self.roles[role]['remark']}\n" return help_text From c1d1e923cddb62ba6633f96609542eacb321c94e Mon Sep 17 00:00:00 2001 From: zhayujie Date: Fri, 24 Mar 2023 00:22:09 +0800 Subject: [PATCH 28/29] feat: add plugins config --- channel/wechat/wechat_channel.py | 4 ++-- config-template.json | 3 ++- plugins/plugin_manager.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index eff788d..a57957d 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -57,8 +57,8 @@ class WechatChannel(Channel): # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context # context是一个字典,包含了消息的所有信息,包括以下key - # type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE - # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 + # type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE + # content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 # session_id: 会话id # isgroup: 是否是群聊 # msg: 原始消息对象 diff --git a/config-template.json b/config-template.json index 9ad9f5d..4d35e44 100644 --- a/config-template.json +++ b/config-template.json @@ -9,5 +9,6 @@ "conversation_max_tokens": 1000, "speech_recognition": false, "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", - "expires_in_seconds": 3600 + "expires_in_seconds": 3600, + "plugins": ["role", "hello", "sdwebui", "godcmd", "dungeon", "banwords"] } diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index d946786..4b1a977 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -6,8 +6,8 @@ import os from common.singleton import singleton from common.sorted_dict import SortedDict from .event import * -from .plugin import * from common.log import logger +from config import conf @singleton @@ -59,7 +59,7 @@ class PluginManager: if os.path.isdir(plugin_path): # 判断插件是否包含同名.py文件 main_module_path = os.path.join(plugin_path, plugin_name+".py") - if os.path.isfile(main_module_path): + if os.path.isfile(main_module_path) and conf().get("plugins") and plugin_name in conf().get("plugins"): # 导入插件 import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name) main_module = importlib.import_module(import_path) From 8c4a62b9c67a5811af57d6e7c4131e7e1c38c904 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Fri, 24 Mar 2023 00:59:26 +0800 Subject: [PATCH 29/29] fix: use try catch instead of config --- config-template.json | 3 +-- plugins/banwords/banwords.py | 2 +- plugins/plugin_manager.py | 8 ++++++-- plugins/role/role.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/config-template.json b/config-template.json index 4d35e44..9ad9f5d 100644 --- a/config-template.json +++ b/config-template.json @@ -9,6 +9,5 @@ "conversation_max_tokens": 1000, "speech_recognition": false, "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", - "expires_in_seconds": 3600, - "plugins": ["role", "hello", "sdwebui", "godcmd", "dungeon", "banwords"] + "expires_in_seconds": 3600 } diff --git a/plugins/banwords/banwords.py b/plugins/banwords/banwords.py index 2b4a711..9488b5a 100644 --- a/plugins/banwords/banwords.py +++ b/plugins/banwords/banwords.py @@ -38,7 +38,7 @@ class Banwords(Plugin): self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context logger.info("[Banwords] inited") except Exception as e: - logger.error("Banwords init failed: %s" % e) + logger.warn("Banwords init failed: %s" % e) diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index 4b1a977..8906e6e 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -59,10 +59,14 @@ class PluginManager: if os.path.isdir(plugin_path): # 判断插件是否包含同名.py文件 main_module_path = os.path.join(plugin_path, plugin_name+".py") - if os.path.isfile(main_module_path) and conf().get("plugins") and plugin_name in conf().get("plugins"): + if os.path.isfile(main_module_path): # 导入插件 import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name) - main_module = importlib.import_module(import_path) + try: + main_module = importlib.import_module(import_path) + except Exception as e: + logger.warn("Failed to import plugin %s: %s" % (plugin_name, e)) + continue pconf = self.pconf new_plugins = [] modified = False diff --git a/plugins/role/role.py b/plugins/role/role.py index 91c09be..ea58819 100644 --- a/plugins/role/role.py +++ b/plugins/role/role.py @@ -116,7 +116,7 @@ class Role(Plugin): e_context.action = EventAction.CONTINUE def get_help_text(self): - help_text = "输入\"$角色 (角色名)\"或\"$role (角色名)\"为我设定角色吧,#reset 可以清除设定的角色。\n目前可用角色列表:\n" + help_text = "输入\"$角色 (角色名)\"或\"$role (角色名)\"为我设定角色吧,#reset 可以清除设定的角色。\n\n目前可用角色列表:\n" for role in self.roles: help_text += f"[{role}]: {self.roles[role]['remark']}\n" return help_text