From 0fcf0824dcf7830bd60aaa025e6c115364e5c4d8 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 12 Mar 2023 11:53:06 +0800 Subject: [PATCH] feat: support plugins --- .gitignore | 1 + app.py | 7 ++- bot/chatgpt/chat_gpt_bot.py | 9 ++- channel/wechat/wechat_channel.py | 103 +++++++++++++++++-------------- plugins/__init__.py | 9 +++ plugins/event.py | 49 +++++++++++++++ plugins/plugin.py | 3 + plugins/plugin_manager.py | 89 ++++++++++++++++++++++++++ 8 files changed, 220 insertions(+), 50 deletions(-) create mode 100644 plugins/__init__.py create mode 100644 plugins/event.py create mode 100644 plugins/plugin.py create mode 100644 plugins/plugin_manager.py diff --git a/.gitignore b/.gitignore index 8bc62f3..c349037 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ config.json QR.png nohup.out tmp +plugins.json \ No newline at end of file diff --git a/app.py b/app.py index 1ca359f..f07b275 100644 --- a/app.py +++ b/app.py @@ -4,14 +4,17 @@ import config from channel import channel_factory from common.log import logger - +from plugins import * if __name__ == '__main__': try: # load config config.load_config() # create channel - channel = channel_factory.create_channel("wx") + channel_name='wx' + channel = channel_factory.create_channel(channel_name) + if channel_name=='wx': + PluginManager().load_plugins() # startup channel channel.startup() diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 9bfea5b..2c8567d 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -60,12 +60,13 @@ class ChatGPTBot(Bot): ok, retstring = self.create_img(query, 0) reply = None if ok: - reply = {'type': 'IMAGE', 'content': retstring} + reply = {'type': 'IMAGE_URL', 'content': retstring} else: reply = {'type': 'ERROR', 'content': retstring} return reply else: reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} + return reply def reply_text(self, session, session_id, retry_count=0) -> dict: ''' @@ -139,7 +140,11 @@ class ChatGPTBot(Bot): class SessionManager(object): def __init__(self): - self.sessions = {} + if conf().get('expires_in_seconds'): + sessions = ExpiredDict(conf().get('expires_in_seconds')) + else: + sessions = dict() + self.sessions = sessions def build_session_query(self, query, session_id): ''' diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index e8be17e..f436e48 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -12,9 +12,12 @@ from concurrent.futures import ThreadPoolExecutor from common.log import logger from common.tmp_dir import TmpDir from config import conf +from plugins import * + import requests import io + thread_pool = ThreadPoolExecutor(max_workers=8) @@ -49,8 +52,8 @@ class WechatChannel(Channel): # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context # context是一个字典,包含了消息的所有信息,包括以下key - # type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE - # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是CMD_IMAGE_CREATE类型,content就是图片生成命令 + # type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE + # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 # session_id: 会话id # isgroup: 是否是群聊 # msg: 原始消息对象 @@ -88,7 +91,7 @@ class WechatChannel(Channel): img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'CMD_IMAGE_CREATE' + context['type'] = 'IMAGE_CREATE' else: context['type'] = 'TEXT' @@ -121,7 +124,7 @@ class WechatChannel(Channel): img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() - context['type'] = 'CMD_IMAGE_CREATE' + context['type'] = 'IMAGE_CREATE' else: context['type'] = 'TEXT' context['content'] = content @@ -136,8 +139,7 @@ class WechatChannel(Channel): thread_pool.submit(self.handle, context) - # 统一的发送函数,根据reply的type字段发送不同类型的消息 - + # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 def send(self, reply, receiver): if reply['type'] == 'TEXT': itchat.send(reply['content'], toUserName=receiver) @@ -163,54 +165,63 @@ class WechatChannel(Channel): itchat.send_image(image_storage, toUserName=receiver) logger.info('[WX] sendImage, receiver={}'.format(receiver)) - # 处理消息 + # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 def handle(self, context): - content = context['content'] - reply = None + reply = {} logger.debug('[WX] ready to handle context: {}'.format(context)) + # reply的构建步骤 - if context['type'] == 'TEXT' or context['type'] == 'CMD_IMAGE_CREATE': - reply = super().build_reply_content(content, context) - elif context['type'] == 'VOICE': - msg = context['msg'] - file_name = TmpDir().path() + msg['FileName'] - msg.download(file_name) - reply = super().build_voice_to_text(file_name) - if reply['type'] != 'ERROR' and reply['type'] != 'INFO': - reply = super().build_reply_content(reply['content'], context) - if reply['type'] == 'TEXT': - if conf().get('voice_reply_voice'): - reply = super().build_text_to_voice(reply['content']) - else: - logger.error('[WX] unknown context type: {}'.format(context['type'])) - return + e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass(): + logger.debug('[WX] ready to handle context: type={}, content={}'.format(context['type'], context['content'])) + if context['type'] == 'TEXT' or context['type'] == 'IMAGE_CREATE': + reply = super().build_reply_content(context['content'], context) + elif context['type'] == 'VOICE': + msg = context['msg'] + file_name = TmpDir().path() + msg['FileName'] + msg.download(file_name) + reply = super().build_voice_to_text(file_name) + if reply['type'] != 'ERROR' and reply['type'] != 'INFO': + reply = super().build_reply_content(reply['content'], context) + if reply['type'] == 'TEXT': + if conf().get('voice_reply_voice'): + reply = super().build_text_to_voice(reply['content']) + else: + logger.error('[WX] unknown context type: {}'.format(context['type'])) + return logger.debug('[WX] ready to decorate reply: {}'.format(reply)) + # reply的包装步骤 - if reply: - if reply['type'] == 'TEXT': - reply_text = reply['content'] - if context['isgroup']: - reply_text = '@' + \ - context['msg']['ActualNickName'] + \ - ' ' + reply_text.strip() - reply_text = conf().get("group_chat_reply_prefix", "")+reply_text + if reply and reply['type']: + e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass() and reply and reply['type']: + if reply['type'] == 'TEXT': + reply_text = reply['content'] + if context['isgroup']: + reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() + reply_text = conf().get("group_chat_reply_prefix", "")+reply_text + else: + reply_text = conf().get("single_chat_reply_prefix", "")+reply_text + reply['content'] = reply_text + elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': + reply['content'] = reply['type']+": " + reply['content'] + elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE': + pass else: - reply_text = conf().get("single_chat_reply_prefix", "")+reply_text - reply['content'] = reply_text - elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': - reply['content'] = reply['type']+": " + reply['content'] - elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE': - pass - else: - logger.error( - '[WX] unknown reply type: {}'.format(reply['type'])) - return - if reply: - logger.debug('[WX] ready to send reply: {} to {}'.format( - reply, context['receiver'])) - self.send(reply, context['receiver']) + logger.error('[WX] unknown reply type: {}'.format(reply['type'])) + return + + # reply的发送步骤 + if reply and reply['type']: + e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass() and reply and reply['type']: + logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) + self.send(reply, context['receiver']) def check_prefix(content, prefix_list): diff --git a/plugins/__init__.py b/plugins/__init__.py new file mode 100644 index 0000000..6137d4a --- /dev/null +++ b/plugins/__init__.py @@ -0,0 +1,9 @@ +from .plugin_manager import PluginManager +from .event import * +from .plugin import * + +instance = PluginManager() + +register = instance.register +# load_plugins = instance.load_plugins +# emit_event = instance.emit_event diff --git a/plugins/event.py b/plugins/event.py new file mode 100644 index 0000000..a65e548 --- /dev/null +++ b/plugins/event.py @@ -0,0 +1,49 @@ +# encoding:utf-8 + +from enum import Enum + + +class Event(Enum): + # ON_RECEIVE_MESSAGE = 1 # 收到消息 + + ON_HANDLE_CONTEXT = 2 # 处理消息前 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } + """ + + ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } + """ + + ON_SEND_REPLY = 4 # 发送回复前 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } + """ + + # AFTER_SEND_REPLY = 5 # 发送回复后 + + +class EventAction(Enum): + CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 + BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 + BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 + + +class EventContext: + def __init__(self, event, econtext=dict()): + self.event = event + self.econtext = econtext + self.action = EventAction.CONTINUE + + def __getitem__(self, key): + return self.econtext[key] + + def __setitem__(self, key, value): + self.econtext[key] = value + + def __delitem__(self, key): + del self.econtext[key] + + def is_pass(self): + return self.action == EventAction.BREAK_PASS diff --git a/plugins/plugin.py b/plugins/plugin.py new file mode 100644 index 0000000..865eecb --- /dev/null +++ b/plugins/plugin.py @@ -0,0 +1,3 @@ +class Plugin: + def __init__(self): + self.handlers = {} diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py new file mode 100644 index 0000000..d4cda12 --- /dev/null +++ b/plugins/plugin_manager.py @@ -0,0 +1,89 @@ +# encoding:utf-8 + +import importlib +import json +import os +from common.singleton import singleton +from .event import * +from .plugin import * +from common.log import logger + + +@singleton +class PluginManager: + def __init__(self): + self.plugins = {} + self.listening_plugins = {} + self.instances = {} + + def register(self, name: str, desc: str, version: str, author: str): + def wrapper(plugincls): + self.plugins[name] = plugincls + plugincls.name = name + plugincls.desc = desc + plugincls.version = version + plugincls.author = author + plugincls.enabled = True + logger.info("Plugin %s registered" % name) + return plugincls + return wrapper + + def save_config(self, pconf): + with open("plugins/plugins.json", "w", encoding="utf-8") as f: + json.dump(pconf, f, indent=4, ensure_ascii=False) + + def load_config(self): + logger.info("Loading plugins config...") + plugins_dir = "plugins" + for plugin_name in os.listdir(plugins_dir): + plugin_path = os.path.join(plugins_dir, plugin_name) + if os.path.isdir(plugin_path): + # 判断插件是否包含main.py文件 + main_module_path = os.path.join(plugin_path, "main.py") + if os.path.isfile(main_module_path): + # 导入插件的main + import_path = "{}.{}.main".format(plugins_dir, plugin_name) + main_module = importlib.import_module(import_path) + + modified = False + if os.path.exists("plugins/plugins.json"): + with open("plugins/plugins.json", "r", encoding="utf-8") as f: + pconf = json.load(f) + else: + modified = True + pconf = {"plugins": []} + for name, plugincls in self.plugins.items(): + if name not in [plugin["name"] for plugin in pconf["plugins"]]: + modified = True + logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) + pconf["plugins"].append({"name": name, "enabled": True}) + if modified: + self.save_config(pconf) + return pconf + + def load_plugins(self): + pconf = self.load_config() + + for plugin in pconf["plugins"]: + name = plugin["name"] + enabled = plugin["enabled"] + self.plugins[name].enabled = enabled + + for name, plugincls in self.plugins.items(): + if plugincls.enabled: + if name not in self.instances: + instance = plugincls() + self.instances[name] = instance + for event in instance.handlers: + if event not in self.listening_plugins: + self.listening_plugins[event] = [] + self.listening_plugins[event].append(name) + + def emit_event(self, e_context: EventContext, *args, **kwargs): + if e_context.event in self.listening_plugins: + for name in self.listening_plugins[e_context.event]: + if e_context.action == EventAction.CONTINUE: + logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) + instance = self.instances[name] + instance.handlers[e_context.event](e_context, *args, **kwargs) + return e_context