diff --git a/.gitignore b/.gitignore index 8bc62f3..c349037 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ config.json QR.png nohup.out tmp +plugins.json \ No newline at end of file diff --git a/app.py b/app.py index 1ca359f..f07b275 100644 --- a/app.py +++ b/app.py @@ -4,14 +4,17 @@ import config from channel import channel_factory from common.log import logger - +from plugins import * if __name__ == '__main__': try: # load config config.load_config() # create channel - channel = channel_factory.create_channel("wx") + channel_name='wx' + channel = channel_factory.create_channel(channel_name) + if channel_name=='wx': + PluginManager().load_plugins() # startup channel channel.startup() diff --git a/bot/baidu/baidu_unit_bot.py b/bot/baidu/baidu_unit_bot.py index a84ac57..2b7dd8d 100644 --- a/bot/baidu/baidu_unit_bot.py +++ b/bot/baidu/baidu_unit_bot.py @@ -2,6 +2,7 @@ import requests from bot.bot import Bot +from bridge.reply import Reply, ReplyType # Baidu Unit对话接口 (可用, 但能力较弱) @@ -14,7 +15,8 @@ class BaiduUnitBot(Bot): headers = {'content-type': 'application/x-www-form-urlencoded'} response = requests.post(url, data=post_data.encode(), headers=headers) if response: - return response.json()['result']['context']['SYS_PRESUMED_HIST'][1] + reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1]) + return reply def get_token(self): access_key = 'YOUR_ACCESS_KEY' diff --git a/bot/bot.py b/bot/bot.py index 850ba3b..fd56e50 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -3,8 +3,12 @@ Auto-replay chat robot abstract class """ +from bridge.context import Context +from bridge.reply import Reply + + class Bot(object): - def reply(self, query, context=None): + def reply(self, query, context : Context =None) -> Reply: """ bot auto-reply content :param req: received message diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 78e344f..b2e062d 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -1,41 +1,42 @@ # encoding:utf-8 from bot.bot import Bot +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType from config import conf, load_config from common.log import logger from common.expired_dict import ExpiredDict import openai import time -if conf().get('expires_in_seconds'): - all_sessions = ExpiredDict(conf().get('expires_in_seconds')) -else: - all_sessions = dict() # OpenAI对话模型API (可用) class ChatGPTBot(Bot): def __init__(self): openai.api_key = conf().get('open_ai_api_key') proxy = conf().get('proxy') + self.sessions = SessionManager() if proxy: openai.proxy = proxy def reply(self, query, context=None): # acquire reply content - if not context or not context.get('type') or context.get('type') == 'TEXT': + if context.type == ContextType.TEXT: logger.info("[OPEN_AI] query={}".format(query)) - session_id = context.get('session_id') or context.get('from_user_id') + session_id = context['session_id'] + reply = None if query == '#清除记忆': - Session.clear_session(session_id) - return '记忆已清除' + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, '记忆已清除') elif query == '#清除所有': - Session.clear_all_session() - return '所有人记忆已清除' + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, '所有人记忆已清除') elif query == '#更新配置': load_config() - return '配置已更新' - - session = Session.build_session_query(query, session_id) + reply = Reply(ReplyType.INFO, '配置已更新') + if reply: + return reply + session = self.sessions.build_session_query(query, session_id) logger.debug("[OPEN_AI] session query={}".format(session)) # if context.get('stream'): @@ -44,14 +45,29 @@ class ChatGPTBot(Bot): reply_content = self.reply_text(session, session_id, 0) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) - if reply_content["completion_tokens"] > 0: - Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) - return reply_content["content"] - - elif context.get('type', None) == 'IMAGE_CREATE': - return self.create_img(query, 0) + if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: + reply = Reply(ReplyType.ERROR, reply_content['content']) + elif reply_content["completion_tokens"] > 0: + self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) + reply = Reply(ReplyType.TEXT, reply_content["content"]) + else: + reply = Reply(ReplyType.ERROR, reply_content['content']) + logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) + return reply + + elif context.type == ContextType.IMAGE_CREATE: + ok, retstring = self.create_img(query, 0) + reply = None + if ok: + reply = Reply(ReplyType.IMAGE_URL, retstring) + else: + reply = Reply(ReplyType.ERROR, retstring) + return reply + else: + reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) + return reply - def reply_text(self, session, session_id, retry_count=0) ->dict: + def reply_text(self, session, session_id, retry_count=0) -> dict: ''' call openai's ChatCompletion to get the answer :param session: a conversation session @@ -70,8 +86,8 @@ class ChatGPTBot(Bot): presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 ) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) - return {"total_tokens": response["usage"]["total_tokens"], - "completion_tokens": response["usage"]["completion_tokens"], + return {"total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response["usage"]["completion_tokens"], "content": response.choices[0]['message']['content']} except openai.error.RateLimitError as e: # rate limit exception @@ -86,15 +102,15 @@ class ChatGPTBot(Bot): # api connection exception logger.warn(e) logger.warn("[OPEN_AI] APIConnection failed") - return {"completion_tokens": 0, "content":"我连接不到你的网络"} + return {"completion_tokens": 0, "content": "我连接不到你的网络"} except openai.error.Timeout as e: logger.warn(e) logger.warn("[OPEN_AI] Timeout") - return {"completion_tokens": 0, "content":"我没有收到你的消息"} + return {"completion_tokens": 0, "content": "我没有收到你的消息"} except Exception as e: # unknown exception logger.exception(e) - Session.clear_session(session_id) + self.sessions.clear_session(session_id) return {"completion_tokens": 0, "content": "请再问我一次吧"} def create_img(self, query, retry_count=0): @@ -107,7 +123,7 @@ class ChatGPTBot(Bot): ) image_url = response['data'][0]['url'] logger.info("[OPEN_AI] image_url={}".format(image_url)) - return image_url + return True, image_url except openai.error.RateLimitError as e: logger.warn(e) if retry_count < 1: @@ -115,14 +131,21 @@ class ChatGPTBot(Bot): logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) return self.create_img(query, retry_count+1) else: - return "提问太快啦,请休息一下再问我吧" + return False, "提问太快啦,请休息一下再问我吧" except Exception as e: logger.exception(e) - return None + return False, str(e) + + +class SessionManager(object): + def __init__(self): + if conf().get('expires_in_seconds'): + sessions = ExpiredDict(conf().get('expires_in_seconds')) + else: + sessions = dict() + self.sessions = sessions -class Session(object): - @staticmethod - def build_session_query(query, session_id): + def build_session_query(self, query, session_id): ''' build query with conversation history e.g. [ @@ -135,36 +158,33 @@ class Session(object): :param session_id: session id :return: query content with conversaction ''' - session = all_sessions.get(session_id, []) + session = self.sessions.get(session_id, []) if len(session) == 0: system_prompt = conf().get("character_desc", "") system_item = {'role': 'system', 'content': system_prompt} session.append(system_item) - all_sessions[session_id] = session + self.sessions[session_id] = session user_item = {'role': 'user', 'content': query} session.append(user_item) return session - @staticmethod - def save_session(answer, session_id, total_tokens): + def save_session(self, answer, session_id, total_tokens): max_tokens = conf().get("conversation_max_tokens") if not max_tokens: # default 3000 max_tokens = 1000 - max_tokens=int(max_tokens) + max_tokens = int(max_tokens) - session = all_sessions.get(session_id) + session = self.sessions.get(session_id) if session: # append conversation gpt_item = {'role': 'assistant', 'content': answer} session.append(gpt_item) # discard exceed limit conversation - Session.discard_exceed_conversation(session, max_tokens, total_tokens) - + self.discard_exceed_conversation(session, max_tokens, total_tokens) - @staticmethod - def discard_exceed_conversation(session, max_tokens, total_tokens): + def discard_exceed_conversation(self, session, max_tokens, total_tokens): dec_tokens = int(total_tokens) # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) while dec_tokens > max_tokens: @@ -173,13 +193,11 @@ class Session(object): session.pop(1) session.pop(1) else: - break + break dec_tokens = dec_tokens - max_tokens - @staticmethod - def clear_session(session_id): - all_sessions[session_id] = [] + def clear_session(self, session_id): + self.sessions[session_id] = [] - @staticmethod - def clear_all_session(): - all_sessions.clear() + def clear_all_session(self): + self.sessions.clear() diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index c948a7c..e96d60f 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -1,6 +1,8 @@ # encoding:utf-8 from bot.bot import Bot +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType from config import conf from common.log import logger import openai @@ -13,30 +15,31 @@ class OpenAIBot(Bot): def __init__(self): openai.api_key = conf().get('open_ai_api_key') - def reply(self, query, context=None): # acquire reply content - if not context or not context.get('type') or context.get('type') == 'TEXT': - logger.info("[OPEN_AI] query={}".format(query)) - from_user_id = context.get('from_user_id') or context.get('session_id') - if query == '#清除记忆': - Session.clear_session(from_user_id) - return '记忆已清除' - elif query == '#清除所有': - Session.clear_all_session() - return '所有人记忆已清除' - - new_query = Session.build_session_query(query, from_user_id) - logger.debug("[OPEN_AI] session query={}".format(new_query)) - - reply_content = self.reply_text(new_query, from_user_id, 0) - logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) - if reply_content and query: - Session.save_session(query, reply_content, from_user_id) - return reply_content - - elif context.get('type', None) == 'IMAGE_CREATE': - return self.create_img(query, 0) + if context and context.type: + if context.type == ContextType.TEXT: + logger.info("[OPEN_AI] query={}".format(query)) + from_user_id = context['session_id'] + reply = None + if query == '#清除记忆': + Session.clear_session(from_user_id) + reply = Reply(ReplyType.INFO, '记忆已清除') + elif query == '#清除所有': + Session.clear_all_session() + reply = Reply(ReplyType.INFO, '所有人记忆已清除') + else: + new_query = Session.build_session_query(query, from_user_id) + logger.debug("[OPEN_AI] session query={}".format(new_query)) + + reply_content = self.reply_text(new_query, from_user_id, 0) + logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) + if reply_content and query: + Session.save_session(query, reply_content, from_user_id) + reply = Reply(ReplyType.TEXT, reply_content) + return reply + elif context.type == ContextType.IMAGE_CREATE: + return self.create_img(query, 0) def reply_text(self, query, user_id, retry_count=0): try: diff --git a/bridge/bridge.py b/bridge/bridge.py index e739a7f..2b67a8a 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,16 +1,42 @@ +from bridge.context import Context +from bridge.reply import Reply +from common.log import logger from bot import bot_factory +from common.singleton import singleton from voice import voice_factory +@singleton class Bridge(object): def __init__(self): - pass + self.btype={ + "chat": "chatGPT", + "voice_to_text": "openai", + "text_to_voice": "baidu" + } + self.bots={} - def fetch_reply_content(self, query, context): - return bot_factory.create_bot("chatGPT").reply(query, context) + def get_bot(self,typename): + if self.bots.get(typename) is None: + logger.info("create bot {} for {}".format(self.btype[typename],typename)) + if typename == "text_to_voice": + self.bots[typename] = voice_factory.create_voice(self.btype[typename]) + elif typename == "voice_to_text": + self.bots[typename] = voice_factory.create_voice(self.btype[typename]) + elif typename == "chat": + self.bots[typename] = bot_factory.create_bot(self.btype[typename]) + return self.bots[typename] + + def get_bot_type(self,typename): + return self.btype[typename] - def fetch_voice_to_text(self, voiceFile): - return voice_factory.create_voice("openai").voiceToText(voiceFile) - def fetch_text_to_voice(self, text): - return voice_factory.create_voice("baidu").textToVoice(text) \ No newline at end of file + def fetch_reply_content(self, query, context : Context) -> Reply: + return self.get_bot("chat").reply(query, context) + + def fetch_voice_to_text(self, voiceFile) -> Reply: + return self.get_bot("voice_to_text").voiceToText(voiceFile) + + def fetch_text_to_voice(self, text) -> Reply: + return self.get_bot("text_to_voice").textToVoice(text) + diff --git a/bridge/context.py b/bridge/context.py new file mode 100644 index 0000000..1fbe4d4 --- /dev/null +++ b/bridge/context.py @@ -0,0 +1,42 @@ +# encoding:utf-8 + +from enum import Enum + +class ContextType (Enum): + TEXT = 1 # 文本消息 + VOICE = 2 # 音频消息 + IMAGE_CREATE = 3 # 创建图片命令 + + def __str__(self): + return self.name +class Context: + def __init__(self, type : ContextType = None , content = None, kwargs = dict()): + self.type = type + self.content = content + self.kwargs = kwargs + def __getitem__(self, key): + if key == 'type': + return self.type + elif key == 'content': + return self.content + else: + return self.kwargs[key] + + def __setitem__(self, key, value): + if key == 'type': + self.type = value + elif key == 'content': + self.content = value + else: + self.kwargs[key] = value + + def __delitem__(self, key): + if key == 'type': + self.type = None + elif key == 'content': + self.content = None + else: + del self.kwargs[key] + + def __str__(self): + return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) \ No newline at end of file diff --git a/bridge/reply.py b/bridge/reply.py new file mode 100644 index 0000000..c6bcd54 --- /dev/null +++ b/bridge/reply.py @@ -0,0 +1,22 @@ + +# encoding:utf-8 + +from enum import Enum + +class ReplyType(Enum): + TEXT = 1 # 文本 + VOICE = 2 # 音频文件 + IMAGE = 3 # 图片文件 + IMAGE_URL = 4 # 图片URL + + INFO = 9 + ERROR = 10 + def __str__(self): + return self.name + +class Reply: + def __init__(self, type : ReplyType = None , content = None): + self.type = type + self.content = content + def __str__(self): + return "Reply(type={}, content={})".format(self.type, self.content) \ No newline at end of file diff --git a/channel/channel.py b/channel/channel.py index a1395c4..62e1ad3 100644 --- a/channel/channel.py +++ b/channel/channel.py @@ -3,6 +3,8 @@ Message sending channel abstract class """ from bridge.bridge import Bridge +from bridge.context import Context +from bridge.reply import Reply class Channel(object): def startup(self): @@ -27,11 +29,11 @@ class Channel(object): """ raise NotImplementedError - def build_reply_content(self, query, context=None): + def build_reply_content(self, query, context : Context=None) -> Reply: return Bridge().fetch_reply_content(query, context) - def build_voice_to_text(self, voice_file): + def build_voice_to_text(self, voice_file) -> Reply: return Bridge().fetch_voice_to_text(voice_file) - def build_text_to_voice(self, text): + def build_text_to_voice(self, text) -> Reply: return Bridge().fetch_text_to_voice(text) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index d43a0a3..eff788d 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -7,16 +7,24 @@ wechat channel import itchat import json from itchat.content import * +from bridge.reply import * +from bridge.context import * from channel.channel import Channel from concurrent.futures import ThreadPoolExecutor from common.log import logger from common.tmp_dir import TmpDir from config import conf +from plugins import * + import requests import io -thread_pool = ThreadPoolExecutor(max_workers=8) +thread_pool = ThreadPoolExecutor(max_workers=8) +def thread_pool_callback(worker): + worker_exception = worker.exception() + if worker_exception: + logger.exception("Worker return exception: {}".format(worker_exception)) @itchat.msg_register(TEXT) def handler_single_msg(msg): @@ -47,62 +55,52 @@ class WechatChannel(Channel): # start message listener itchat.run() + # handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context + # context是一个字典,包含了消息的所有信息,包括以下key + # type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE + # content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 + # session_id: 会话id + # isgroup: 是否是群聊 + # msg: 原始消息对象 + # receiver: 需要回复的对象 + def handle_voice(self, msg): - if conf().get('speech_recognition') != True : + if conf().get('speech_recognition') != True: return logger.debug("[WX]receive voice msg: " + msg['FileName']) - thread_pool.submit(self._do_handle_voice, msg) - - def _do_handle_voice(self, msg): from_user_id = msg['FromUserName'] other_user_id = msg['User']['UserName'] if from_user_id == other_user_id: - file_name = TmpDir().path() + msg['FileName'] - msg.download(file_name) - query = super().build_voice_to_text(file_name) - if conf().get('voice_reply_voice'): - self._do_send_voice(query, from_user_id) - else: - self._do_send_text(query, from_user_id) + context = Context(ContextType.VOICE,msg['FileName']) + context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_text(self, msg): logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) content = msg['Text'] - self._handle_single_msg(msg, content) - - def _handle_single_msg(self, msg, content): from_user_id = msg['FromUserName'] to_user_id = msg['ToUserName'] # 接收人id other_user_id = msg['User']['UserName'] # 对手方id - match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) + match_prefix = check_prefix(content, conf().get('single_chat_prefix')) if "」\n- - - - - - - - - - - - - - -" in content: logger.debug("[WX]reference query skipped") return - if from_user_id == other_user_id and match_prefix is not None: - # 好友向自己发送消息 - if match_prefix != '': - str_list = content.split(match_prefix, 1) - if len(str_list) == 2: - content = str_list[1].strip() - - img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) - if img_match_prefix: - content = content.split(img_match_prefix, 1)[1].strip() - thread_pool.submit(self._do_send_img, content, from_user_id) - else : - thread_pool.submit(self._do_send_text, content, from_user_id) - elif to_user_id == other_user_id and match_prefix: - # 自己给好友发送消息 - str_list = content.split(match_prefix, 1) - if len(str_list) == 2: - content = str_list[1].strip() - img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) - if img_match_prefix: - content = content.split(img_match_prefix, 1)[1].strip() - thread_pool.submit(self._do_send_img, content, to_user_id) - else: - thread_pool.submit(self._do_send_text, content, to_user_id) + if match_prefix: + content = content.replace(match_prefix, '', 1).strip() + else: + return + context = Context() + context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} + + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) + if img_match_prefix: + content = content.replace(img_match_prefix, '', 1).strip() + context.type = ContextType.IMAGE_CREATE + else: + context.type = ContextType.TEXT + context.content = content + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_group(self, msg): logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) @@ -122,100 +120,128 @@ class WechatChannel(Channel): logger.debug("[WX]reference query skipped") return "" config = conf() - match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \ - or self.check_contain(origin_content, config.get('group_chat_keyword')) - if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: - img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) + match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ + or check_contain(origin_content, config.get('group_chat_keyword')) + if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: + context = Context() + context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} + + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: - content = content.split(img_match_prefix, 1)[1].strip() - thread_pool.submit(self._do_send_img, content, group_id) + content = content.replace(img_match_prefix, '', 1).strip() + context.type = ContextType.IMAGE_CREATE else: - thread_pool.submit(self._do_send_group, content, msg) - - def send(self, msg, receiver): - itchat.send(msg, toUserName=receiver) - logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver)) - - def _do_send_voice(self, query, reply_user_id): - try: - if not query: - return - context = dict() - context['from_user_id'] = reply_user_id - reply_text = super().build_reply_content(query, context) - if reply_text: - replyFile = super().build_text_to_voice(reply_text) - itchat.send_file(replyFile, toUserName=reply_user_id) - logger.info('[WX] sendFile={}, receiver={}'.format(replyFile, reply_user_id)) - except Exception as e: - logger.exception(e) - - def _do_send_text(self, query, reply_user_id): - try: - if not query: - return - context = dict() - context['session_id'] = reply_user_id - reply_text = super().build_reply_content(query, context) - if reply_text: - self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) - except Exception as e: - logger.exception(e) - - def _do_send_img(self, query, reply_user_id): - try: - if not query: - return - context = dict() - context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) - if not img_url: - return - - # 图片下载 + context.type = ContextType.TEXT + context.content = content + + group_chat_in_one_session = conf().get('group_chat_in_one_session', []) + if ('ALL_GROUP' in group_chat_in_one_session or + group_name in group_chat_in_one_session or + check_contain(group_name, group_chat_in_one_session)): + context['session_id'] = group_id + else: + context['session_id'] = msg['ActualUserName'] + + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) + + # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 + def send(self, reply : Reply, receiver): + if reply.type == ReplyType.TEXT: + itchat.send(reply.content, toUserName=receiver) + logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: + itchat.send(reply.content, toUserName=receiver) + logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + elif reply.type == ReplyType.VOICE: + itchat.send_file(reply.content, toUserName=receiver) + logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) + elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 + img_url = reply.content pic_res = requests.get(img_url, stream=True) image_storage = io.BytesIO() for block in pic_res.iter_content(1024): image_storage.write(block) image_storage.seek(0) + itchat.send_image(image_storage, toUserName=receiver) + logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver)) + elif reply.type == ReplyType.IMAGE: # 从文件读取图片 + image_storage = reply.content + image_storage.seek(0) + itchat.send_image(image_storage, toUserName=receiver) + logger.info('[WX] sendImage, receiver={}'.format(receiver)) + + # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 + def handle(self, context): + reply = Reply() + + logger.debug('[WX] ready to handle context: {}'.format(context)) + + # reply的构建步骤 + e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) + reply = e_context['reply'] + if not e_context.is_pass(): + logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) + if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: + reply = super().build_reply_content(context.content, context) + elif context.type == ContextType.VOICE: + msg = context['msg'] + file_name = TmpDir().path() + context.content + msg.download(file_name) + reply = super().build_voice_to_text(file_name) + if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: + context.content = reply.content # 语音转文字后,将文字内容作为新的context + context.type = ContextType.TEXT + reply = super().build_reply_content(context.content, context) + if reply.type == ReplyType.TEXT: + if conf().get('voice_reply_voice'): + reply = super().build_text_to_voice(reply.content) + else: + logger.error('[WX] unknown context type: {}'.format(context.type)) + return - # 图片发送 - itchat.send_image(image_storage, reply_user_id) - logger.info('[WX] sendImage, receiver={}'.format(reply_user_id)) - except Exception as e: - logger.exception(e) - - def _do_send_group(self, query, msg): - if not query: - return - context = dict() - group_name = msg['User']['NickName'] - group_id = msg['User']['UserName'] - group_chat_in_one_session = conf().get('group_chat_in_one_session', []) - if ('ALL_GROUP' in group_chat_in_one_session or \ - group_name in group_chat_in_one_session or \ - self.check_contain(group_name, group_chat_in_one_session)): - context['session_id'] = group_id - else: - context['session_id'] = msg['ActualUserName'] - reply_text = super().build_reply_content(query, context) - if reply_text: - reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip() - self.send(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) - - - def check_prefix(self, content, prefix_list): - for prefix in prefix_list: - if content.startswith(prefix): - return prefix - return None + logger.debug('[WX] ready to decorate reply: {}'.format(reply)) + + # reply的包装步骤 + if reply and reply.type: + e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass() and reply and reply.type: + if reply.type == ReplyType.TEXT: + reply_text = reply.content + if context['isgroup']: + reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() + reply_text = conf().get("group_chat_reply_prefix", "")+reply_text + else: + reply_text = conf().get("single_chat_reply_prefix", "")+reply_text + reply.content = reply_text + elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: + reply.content = str(reply.type)+":\n" + reply.content + elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: + pass + else: + logger.error('[WX] unknown reply type: {}'.format(reply.type)) + return + + # reply的发送步骤 + if reply and reply.type: + e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) + reply=e_context['reply'] + if not e_context.is_pass() and reply and reply.type: + logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) + self.send(reply, context['receiver']) + + +def check_prefix(content, prefix_list): + for prefix in prefix_list: + if content.startswith(prefix): + return prefix + return None - def check_contain(self, content, keyword_list): - if not keyword_list: - return None - for ky in keyword_list: - if content.find(ky) != -1: - return True +def check_contain(content, keyword_list): + if not keyword_list: return None - + for ky in keyword_list: + if content.find(ky) != -1: + return True + return None diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 5e01464..3a7fe2c 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -11,6 +11,7 @@ import time import asyncio import requests from typing import Optional, Union +from bridge.context import Context, ContextType from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore from wechaty import Wechaty, Contact from wechaty.user import Message, Room, MiniProgram, UrlLink @@ -127,9 +128,9 @@ class WechatyChannel(Channel): try: if not query: return - context = dict() + context = Context(ContextType.TEXT, query) context['session_id'] = reply_user_id - reply_text = super().build_reply_content(query, context) + reply_text = super().build_reply_content(query, context).content if reply_text: await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) except Exception as e: @@ -139,9 +140,8 @@ class WechatyChannel(Channel): try: if not query: return - context = dict() - context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) + context = Context(ContextType.IMAGE_CREATE, query) + img_url = super().build_reply_content(query, context).content if not img_url: return # 图片下载 @@ -162,7 +162,7 @@ class WechatyChannel(Channel): async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name): if not query: return - context = dict() + context = Context(ContextType.TEXT, query) group_chat_in_one_session = conf().get('group_chat_in_one_session', []) if ('ALL_GROUP' in group_chat_in_one_session or \ group_name in group_chat_in_one_session or \ @@ -170,7 +170,7 @@ class WechatyChannel(Channel): context['session_id'] = str(group_id) else: context['session_id'] = str(group_id) + '-' + str(group_user_id) - reply_text = super().build_reply_content(query, context) + reply_text = super().build_reply_content(query, context).content if reply_text: reply_text = '@' + group_user_name + ' ' + reply_text.strip() await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) @@ -179,9 +179,8 @@ class WechatyChannel(Channel): try: if not query: return - context = dict() - context['type'] = 'IMAGE_CREATE' - img_url = super().build_reply_content(query, context) + context = Context(ContextType.IMAGE_CREATE, query) + img_url = super().build_reply_content(query, context).content if not img_url: return # 图片发送 diff --git a/common/singleton.py b/common/singleton.py new file mode 100644 index 0000000..b46095c --- /dev/null +++ b/common/singleton.py @@ -0,0 +1,9 @@ +def singleton(cls): + instances = {} + + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + + return get_instance diff --git a/common/sorted_dict.py b/common/sorted_dict.py new file mode 100644 index 0000000..a918a0c --- /dev/null +++ b/common/sorted_dict.py @@ -0,0 +1,65 @@ +import heapq + + +class SortedDict(dict): + def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False): + if init_dict is None: + init_dict = [] + if isinstance(init_dict, dict): + init_dict = init_dict.items() + self.sort_func = sort_func + self.sorted_keys = None + self.reverse = reverse + self.heap = [] + for k, v in init_dict: + self[k] = v + + def __setitem__(self, key, value): + if key in self: + super().__setitem__(key, value) + for i, (priority, k) in enumerate(self.heap): + if k == key: + self.heap[i] = (self.sort_func(key, value), key) + heapq.heapify(self.heap) + break + self.sorted_keys = None + else: + super().__setitem__(key, value) + heapq.heappush(self.heap, (self.sort_func(key, value), key)) + self.sorted_keys = None + + def __delitem__(self, key): + super().__delitem__(key) + for i, (priority, k) in enumerate(self.heap): + if k == key: + del self.heap[i] + heapq.heapify(self.heap) + break + self.sorted_keys = None + + def keys(self): + if self.sorted_keys is None: + self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] + return self.sorted_keys + + def items(self): + if self.sorted_keys is None: + self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] + sorted_items = [(k, self[k]) for k in self.sorted_keys] + return sorted_items + + def _update_heap(self, key): + for i, (priority, k) in enumerate(self.heap): + if k == key: + new_priority = self.sort_func(key, self[key]) + if new_priority != priority: + self.heap[i] = (new_priority, key) + heapq.heapify(self.heap) + self.sorted_keys = None + break + + def __iter__(self): + return iter(self.keys()) + + def __repr__(self): + return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})' diff --git a/plugins/__init__.py b/plugins/__init__.py new file mode 100644 index 0000000..6137d4a --- /dev/null +++ b/plugins/__init__.py @@ -0,0 +1,9 @@ +from .plugin_manager import PluginManager +from .event import * +from .plugin import * + +instance = PluginManager() + +register = instance.register +# load_plugins = instance.load_plugins +# emit_event = instance.emit_event diff --git a/plugins/banwords/.gitignore b/plugins/banwords/.gitignore new file mode 100644 index 0000000..a6593bf --- /dev/null +++ b/plugins/banwords/.gitignore @@ -0,0 +1 @@ +banwords.txt \ No newline at end of file diff --git a/plugins/banwords/README.md b/plugins/banwords/README.md new file mode 100644 index 0000000..9c7e498 --- /dev/null +++ b/plugins/banwords/README.md @@ -0,0 +1,9 @@ +### 说明 +简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。 + +`config.json`中能够填写默认的处理行为,目前行为有: +- `ignore` : 无视这条消息。 +- `replace` : 将消息中的敏感词替换成"*",并回复违规。 + +### 致谢 +搜索功能实现来自https://github.com/toolgood/ToolGood.Words \ No newline at end of file diff --git a/plugins/banwords/WordsSearch.py b/plugins/banwords/WordsSearch.py new file mode 100644 index 0000000..d41d6e7 --- /dev/null +++ b/plugins/banwords/WordsSearch.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# ToolGood.Words.WordsSearch.py +# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words +# Licensed under the Apache License 2.0 +# 更新日志 +# 2020.04.06 第一次提交 +# 2020.05.16 修改,支持大于0xffff的字符 + +__all__ = ['WordsSearch'] +__author__ = 'Lin Zhijun' +__date__ = '2020.05.16' + +class TrieNode(): + def __init__(self): + self.Index = 0 + self.Index = 0 + self.Layer = 0 + self.End = False + self.Char = '' + self.Results = [] + self.m_values = {} + self.Failure = None + self.Parent = None + + def Add(self,c): + if c in self.m_values : + return self.m_values[c] + node = TrieNode() + node.Parent = self + node.Char = c + self.m_values[c] = node + return node + + def SetResults(self,index): + if (self.End == False): + self.End = True + self.Results.append(index) + +class TrieNode2(): + def __init__(self): + self.End = False + self.Results = [] + self.m_values = {} + self.minflag = 0xffff + self.maxflag = 0 + + def Add(self,c,node3): + if (self.minflag > c): + self.minflag = c + if (self.maxflag < c): + self.maxflag = c + self.m_values[c] = node3 + + def SetResults(self,index): + if (self.End == False) : + self.End = True + if (index in self.Results )==False : + self.Results.append(index) + + def HasKey(self,c): + return c in self.m_values + + + def TryGetValue(self,c): + if (self.minflag <= c and self.maxflag >= c): + if c in self.m_values: + return self.m_values[c] + return None + + +class WordsSearch(): + def __init__(self): + self._first = {} + self._keywords = [] + self._indexs=[] + + def SetKeywords(self,keywords): + self._keywords = keywords + self._indexs=[] + for i in range(len(keywords)): + self._indexs.append(i) + + root = TrieNode() + allNodeLayer={} + + for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++) + p = self._keywords[i] + nd = root + for j in range(len(p)): # for (j = 0; j < p.length; j++) + nd = nd.Add(ord(p[j])) + if (nd.Layer == 0): + nd.Layer = j + 1 + if nd.Layer in allNodeLayer: + allNodeLayer[nd.Layer].append(nd) + else: + allNodeLayer[nd.Layer]=[] + allNodeLayer[nd.Layer].append(nd) + nd.SetResults(i) + + + allNode = [] + allNode.append(root) + for key in allNodeLayer.keys(): + for nd in allNodeLayer[key]: + allNode.append(nd) + allNodeLayer=None + + for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) + if i==0 : + continue + nd=allNode[i] + nd.Index = i + r = nd.Parent.Failure + c = nd.Char + while (r != None and (c in r.m_values)==False): + r = r.Failure + if (r == None): + nd.Failure = root + else: + nd.Failure = r.m_values[c] + for key2 in nd.Failure.Results : + nd.SetResults(key2) + root.Failure = root + + allNode2 = [] + for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) + allNode2.append( TrieNode2()) + + for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++) + oldNode = allNode[i] + newNode = allNode2[i] + + for key in oldNode.m_values : + index = oldNode.m_values[key].Index + newNode.Add(key, allNode2[index]) + + for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++) + item = oldNode.Results[index] + newNode.SetResults(item) + + oldNode=oldNode.Failure + while oldNode != root: + for key in oldNode.m_values : + if (newNode.HasKey(key) == False): + index = oldNode.m_values[key].Index + newNode.Add(key, allNode2[index]) + for index in range(len(oldNode.Results)): + item = oldNode.Results[index] + newNode.SetResults(item) + oldNode=oldNode.Failure + allNode = None + root = None + + # first = [] + # for index in range(65535):# for (index = 0; index < 0xffff; index++) + # first.append(None) + + # for key in allNode2[0].m_values : + # first[key] = allNode2[0].m_values[key] + + self._first = allNode2[0] + + + def FindFirst(self,text): + ptr = None + for index in range(len(text)): # for (index = 0; index < text.length; index++) + t =ord(text[index]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + + if (tn != None): + if (tn.End): + item = tn.Results[0] + keyword = self._keywords[item] + return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] } + ptr = tn + return None + + def FindAll(self,text): + ptr = None + list = [] + + for index in range(len(text)): # for (index = 0; index < text.length; index++) + t =ord(text[index]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + + if (tn != None): + if (tn.End): + for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++) + item = tn.Results[j] + keyword = self._keywords[item] + list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }) + ptr = tn + return list + + + def ContainsAny(self,text): + ptr = None + for index in range(len(text)): # for (index = 0; index < text.length; index++) + t =ord(text[index]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + if (tn != None): + if (tn.End): + return True + ptr = tn + return False + + def Replace(self,text, replaceChar = '*'): + result = list(text) + + ptr = None + for i in range(len(text)): # for (i = 0; i < text.length; i++) + t =ord(text[i]) # text.charCodeAt(index) + tn = None + if (ptr == None): + tn = self._first.TryGetValue(t) + else: + tn = ptr.TryGetValue(t) + if (tn==None): + tn = self._first.TryGetValue(t) + + if (tn != None): + if (tn.End): + maxLength = len( self._keywords[tn.Results[0]]) + start = i + 1 - maxLength + for j in range(start,i+1): # for (j = start; j <= i; j++) + result[j] = replaceChar + ptr = tn + return ''.join(result) \ No newline at end of file diff --git a/plugins/banwords/__init__.py b/plugins/banwords/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/banwords/banwords.py b/plugins/banwords/banwords.py new file mode 100644 index 0000000..2b4a711 --- /dev/null +++ b/plugins/banwords/banwords.py @@ -0,0 +1,63 @@ +# encoding:utf-8 + +import json +import os +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +import plugins +from plugins import * +from common.log import logger +from .WordsSearch import WordsSearch + + +@plugins.register(name="Banwords", desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent", desire_priority= 100) +class Banwords(Plugin): + def __init__(self): + super().__init__() + try: + curdir=os.path.dirname(__file__) + config_path=os.path.join(curdir,"config.json") + conf=None + if not os.path.exists(config_path): + conf={"action":"ignore"} + with open(config_path,"w") as f: + json.dump(conf,f,indent=4) + else: + with open(config_path,"r") as f: + conf=json.load(f) + self.searchr = WordsSearch() + self.action = conf["action"] + banwords_path = os.path.join(curdir,"banwords.txt") + with open(banwords_path, 'r', encoding='utf-8') as f: + words=[] + for line in f: + word = line.strip() + if word: + words.append(word) + self.searchr.SetKeywords(words) + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[Banwords] inited") + except Exception as e: + logger.error("Banwords init failed: %s" % e) + + + + def on_handle_context(self, e_context: EventContext): + + if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]: + return + + content = e_context['context'].content + logger.debug("[Banwords] on_handle_context. content: %s" % content) + if self.action == "ignore": + f = self.searchr.FindFirst(content) + if f: + logger.info("Banwords: %s" % f["Keyword"]) + e_context.action = EventAction.BREAK_PASS + return + elif self.action == "replace": + if self.searchr.ContainsAny(content): + reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content)) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + return \ No newline at end of file diff --git a/plugins/banwords/banwords.txt.template b/plugins/banwords/banwords.txt.template new file mode 100644 index 0000000..9b2e8ed --- /dev/null +++ b/plugins/banwords/banwords.txt.template @@ -0,0 +1,3 @@ +nipples +pennis +法轮功 \ No newline at end of file diff --git a/plugins/banwords/config.json.template b/plugins/banwords/config.json.template new file mode 100644 index 0000000..000fdda --- /dev/null +++ b/plugins/banwords/config.json.template @@ -0,0 +1,3 @@ +{ + "action": "ignore" +} \ No newline at end of file diff --git a/plugins/event.py b/plugins/event.py new file mode 100644 index 0000000..a65e548 --- /dev/null +++ b/plugins/event.py @@ -0,0 +1,49 @@ +# encoding:utf-8 + +from enum import Enum + + +class Event(Enum): + # ON_RECEIVE_MESSAGE = 1 # 收到消息 + + ON_HANDLE_CONTEXT = 2 # 处理消息前 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } + """ + + ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } + """ + + ON_SEND_REPLY = 4 # 发送回复前 + """ + e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } + """ + + # AFTER_SEND_REPLY = 5 # 发送回复后 + + +class EventAction(Enum): + CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 + BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 + BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 + + +class EventContext: + def __init__(self, event, econtext=dict()): + self.event = event + self.econtext = econtext + self.action = EventAction.CONTINUE + + def __getitem__(self, key): + return self.econtext[key] + + def __setitem__(self, key, value): + self.econtext[key] = value + + def __delitem__(self, key): + del self.econtext[key] + + def is_pass(self): + return self.action == EventAction.BREAK_PASS diff --git a/plugins/godcmd/__init__.py b/plugins/godcmd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/godcmd/config.json.template b/plugins/godcmd/config.json.template new file mode 100644 index 0000000..5240738 --- /dev/null +++ b/plugins/godcmd/config.json.template @@ -0,0 +1,4 @@ +{ + "password": "", + "admin_users": [] +} \ No newline at end of file diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py new file mode 100644 index 0000000..c33c1e2 --- /dev/null +++ b/plugins/godcmd/godcmd.py @@ -0,0 +1,289 @@ +# encoding:utf-8 + +import json +import os +import traceback +from typing import Tuple +from bridge.bridge import Bridge +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from config import load_config +import plugins +from plugins import * +from common.log import logger + +# 定义指令集 +COMMANDS = { + "help": { + "alias": ["help", "帮助"], + "desc": "打印指令集合", + }, + "auth": { + "alias": ["auth", "认证"], + "args": ["口令"], + "desc": "管理员认证", + }, + # "id": { + # "alias": ["id", "用户"], + # "desc": "获取用户id", #目前无实际意义 + # }, + "reset": { + "alias": ["reset", "重置会话"], + "desc": "重置会话", + }, +} + +ADMIN_COMMANDS = { + "resume": { + "alias": ["resume", "恢复服务"], + "desc": "恢复服务", + }, + "stop": { + "alias": ["stop", "暂停服务"], + "desc": "暂停服务", + }, + "reconf": { + "alias": ["reconf", "重载配置"], + "desc": "重载配置(不包含插件配置)", + }, + "resetall": { + "alias": ["resetall", "重置所有会话"], + "desc": "重置所有会话", + }, + "scanp": { + "alias": ["scanp", "扫描插件"], + "desc": "扫描插件目录是否有新插件", + }, + "plist": { + "alias": ["plist", "插件"], + "desc": "打印当前插件列表", + }, + "setpri": { + "alias": ["setpri", "设置插件优先级"], + "args": ["插件名", "优先级"], + "desc": "设置指定插件的优先级,越大越优先", + }, + "reloadp": { + "alias": ["reloadp", "重载插件"], + "args": ["插件名"], + "desc": "重载指定插件配置", + }, + "enablep": { + "alias": ["enablep", "启用插件"], + "args": ["插件名"], + "desc": "启用指定插件", + }, + "disablep": { + "alias": ["disablep", "禁用插件"], + "args": ["插件名"], + "desc": "禁用指定插件", + }, + "debug": { + "alias": ["debug", "调试模式", "DEBUG"], + "desc": "开启机器调试日志", + }, +} +# 定义帮助函数 +def get_help_text(isadmin, isgroup): + help_text = "可用指令:\n" + for cmd, info in COMMANDS.items(): + if cmd=="auth" and (isadmin or isgroup): # 群聊不可认证 + continue + + alias=["#"+a for a in info['alias']] + help_text += f"{','.join(alias)} " + if 'args' in info: + args=["{"+a+"}" for a in info['args']] + help_text += f"{' '.join(args)} " + help_text += f": {info['desc']}\n" + if ADMIN_COMMANDS and isadmin: + help_text += "\n管理员指令:\n" + for cmd, info in ADMIN_COMMANDS.items(): + alias=["#"+a for a in info['alias']] + help_text += f"{','.join(alias)} " + help_text += f": {info['desc']}\n" + return help_text + +@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent", desire_priority= 999) +class Godcmd(Plugin): + + def __init__(self): + super().__init__() + + curdir=os.path.dirname(__file__) + config_path=os.path.join(curdir,"config.json") + gconf=None + if not os.path.exists(config_path): + gconf={"password":"","admin_users":[]} + with open(config_path,"w") as f: + json.dump(gconf,f,indent=4) + else: + with open(config_path,"r") as f: + gconf=json.load(f) + + self.password = gconf["password"] + self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用 + self.isrunning = True # 机器人是否运行中 + + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[Godcmd] inited") + + + def on_handle_context(self, e_context: EventContext): + context_type = e_context['context'].type + if context_type != ContextType.TEXT: + if not self.isrunning: + e_context.action = EventAction.BREAK_PASS + return + + content = e_context['context'].content + logger.debug("[Godcmd] on_handle_context. content: %s" % content) + if content.startswith("#"): + # msg = e_context['context']['msg'] + user = e_context['context']['receiver'] + session_id = e_context['context']['session_id'] + isgroup = e_context['context']['isgroup'] + bottype = Bridge().get_bot_type("chat") + bot = Bridge().get_bot("chat") + # 将命令和参数分割 + command_parts = content[1:].split(" ") + cmd = command_parts[0] + args = command_parts[1:] + isadmin=False + if user in self.admin_users: + isadmin=True + ok=False + result="string" + if any(cmd in info['alias'] for info in COMMANDS.values()): + cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias']) + if cmd == "auth": + ok, result = self.authenticate(user, args, isadmin, isgroup) + elif cmd == "help": + ok, result = True, get_help_text(isadmin, isgroup) + elif cmd == "id": + ok, result = True, f"用户id=\n{user}" + elif cmd == "reset": + if bottype == "chatGPT": + bot.sessions.clear_session(session_id) + ok, result = True, "会话已重置" + else: + ok, result = False, "当前对话机器人不支持重置会话" + logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) + elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): + if isadmin: + if isgroup: + ok, result = False, "群聊不可执行管理员指令" + else: + cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias']) + if cmd == "stop": + self.isrunning = False + ok, result = True, "服务已暂停" + elif cmd == "resume": + self.isrunning = True + ok, result = True, "服务已恢复" + elif cmd == "reconf": + load_config() + ok, result = True, "配置已重载" + elif cmd == "resetall": + if bottype == "chatGPT": + bot.sessions.clear_all_session() + ok, result = True, "重置所有会话成功" + else: + ok, result = False, "当前对话机器人不支持重置会话" + elif cmd == "debug": + logger.setLevel('DEBUG') + ok, result = True, "DEBUG模式已开启" + elif cmd == "plist": + plugins = PluginManager().list_plugins() + ok = True + result = "插件列表:\n" + for name,plugincls in plugins.items(): + result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " + if plugincls.enabled: + result += "已启用\n" + else: + result += "未启用\n" + elif cmd == "scanp": + new_plugins = PluginManager().scan_plugins() + ok, result = True, "插件扫描完成" + PluginManager().activate_plugins() + if len(new_plugins) >0 : + result += "\n发现新插件:\n" + result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) + else : + result +=", 未发现新插件" + elif cmd == "setpri": + if len(args) != 2: + ok, result = False, "请提供插件名和优先级" + else: + ok = PluginManager().set_plugin_priority(args[0], int(args[1])) + if ok: + result = "插件" + args[0] + "优先级已设置为" + args[1] + else: + result = "插件不存在" + elif cmd == "reloadp": + if len(args) != 1: + ok, result = False, "请提供插件名" + else: + ok = PluginManager().reload_plugin(args[0]) + if ok: + result = "插件配置已重载" + else: + result = "插件不存在" + elif cmd == "enablep": + if len(args) != 1: + ok, result = False, "请提供插件名" + else: + ok = PluginManager().enable_plugin(args[0]) + if ok: + result = "插件已启用" + else: + result = "插件不存在" + elif cmd == "disablep": + if len(args) != 1: + ok, result = False, "请提供插件名" + else: + ok = PluginManager().disable_plugin(args[0]) + if ok: + result = "插件已禁用" + else: + result = "插件不存在" + + logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user)) + else: + ok, result = False, "需要管理员权限才能执行该指令" + else: + ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" + + reply = Reply() + if ok: + reply.type = ReplyType.INFO + else: + reply.type = ReplyType.ERROR + reply.content = result + e_context['reply'] = reply + + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + elif not self.isrunning: + e_context.action = EventAction.BREAK_PASS + + def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : + if isgroup: + return False,"请勿在群聊中认证" + + if isadmin: + return False,"管理员账号无需认证" + + if len(self.password) == 0: + return False,"未设置口令,无法认证" + + if len(args) != 1: + return False,"请提供口令" + + password = args[0] + if password == self.password: + self.admin_users.append(userid) + return True,"认证成功" + else: + return False,"认证失败" + diff --git a/plugins/hello/__init__.py b/plugins/hello/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py new file mode 100644 index 0000000..53d87e6 --- /dev/null +++ b/plugins/hello/hello.py @@ -0,0 +1,46 @@ +# encoding:utf-8 + +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +import plugins +from plugins import * +from common.log import logger + + +@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1) +class Hello(Plugin): + def __init__(self): + super().__init__() + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[Hello] inited") + + def on_handle_context(self, e_context: EventContext): + + if e_context['context'].type != ContextType.TEXT: + return + + content = e_context['context'].content + logger.debug("[Hello] on_handle_context. content: %s" % content) + if content == "Hello": + reply = Reply() + reply.type = ReplyType.TEXT + msg = e_context['context']['msg'] + if e_context['context']['isgroup']: + reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") + else: + reply.content = "Hello, " + msg['User'].get('NickName', "My friend") + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + + if content == "Hi": + reply = Reply() + reply.type = ReplyType.TEXT + reply.content = "Hi" + e_context['reply'] = reply + e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply + + if content == "End": + # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" + e_context['context'].type = "IMAGE_CREATE" + content = "The World" + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 diff --git a/plugins/plugin.py b/plugins/plugin.py new file mode 100644 index 0000000..865eecb --- /dev/null +++ b/plugins/plugin.py @@ -0,0 +1,3 @@ +class Plugin: + def __init__(self): + self.handlers = {} diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py new file mode 100644 index 0000000..d946786 --- /dev/null +++ b/plugins/plugin_manager.py @@ -0,0 +1,171 @@ +# encoding:utf-8 + +import importlib +import json +import os +from common.singleton import singleton +from common.sorted_dict import SortedDict +from .event import * +from .plugin import * +from common.log import logger + + +@singleton +class PluginManager: + def __init__(self): + self.plugins = SortedDict(lambda k,v: v.priority,reverse=True) + self.listening_plugins = {} + self.instances = {} + self.pconf = {} + + def register(self, name: str, desc: str, version: str, author: str, desire_priority: int = 0): + def wrapper(plugincls): + plugincls.name = name + plugincls.desc = desc + plugincls.version = version + plugincls.author = author + plugincls.priority = desire_priority + plugincls.enabled = True + self.plugins[name.upper()] = plugincls + logger.info("Plugin %s_v%s registered" % (name, version)) + return plugincls + return wrapper + + def save_config(self): + with open("plugins/plugins.json", "w", encoding="utf-8") as f: + json.dump(self.pconf, f, indent=4, ensure_ascii=False) + + def load_config(self): + logger.info("Loading plugins config...") + + modified = False + if os.path.exists("plugins/plugins.json"): + with open("plugins/plugins.json", "r", encoding="utf-8") as f: + pconf = json.load(f) + pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True) + else: + modified = True + pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)} + self.pconf = pconf + if modified: + self.save_config() + return pconf + + def scan_plugins(self): + logger.info("Scaning plugins ...") + plugins_dir = "plugins" + for plugin_name in os.listdir(plugins_dir): + plugin_path = os.path.join(plugins_dir, plugin_name) + if os.path.isdir(plugin_path): + # 判断插件是否包含同名.py文件 + main_module_path = os.path.join(plugin_path, plugin_name+".py") + if os.path.isfile(main_module_path): + # 导入插件 + import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name) + main_module = importlib.import_module(import_path) + pconf = self.pconf + new_plugins = [] + modified = False + for name, plugincls in self.plugins.items(): + rawname = plugincls.name + if rawname not in pconf["plugins"]: + new_plugins.append(plugincls) + modified = True + logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) + pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} + else: + self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] + self.plugins[name].priority = pconf["plugins"][rawname]["priority"] + self.plugins._update_heap(name) # 更新下plugins中的顺序 + if modified: + self.save_config() + return new_plugins + + def refresh_order(self): + for event in self.listening_plugins.keys(): + self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) + + def activate_plugins(self): # 生成新开启的插件实例 + for name, plugincls in self.plugins.items(): + if plugincls.enabled: + if name not in self.instances: + instance = plugincls() + self.instances[name] = instance + for event in instance.handlers: + if event not in self.listening_plugins: + self.listening_plugins[event] = [] + self.listening_plugins[event].append(name) + self.refresh_order() + + def reload_plugin(self, name:str): + name = name.upper() + if name in self.instances: + for event in self.listening_plugins: + if name in self.listening_plugins[event]: + self.listening_plugins[event].remove(name) + del self.instances[name] + self.activate_plugins() + return True + return False + + def load_plugins(self): + self.load_config() + self.scan_plugins() + pconf = self.pconf + logger.debug("plugins.json config={}".format(pconf)) + for name,plugin in pconf["plugins"].items(): + if name.upper() not in self.plugins: + logger.error("Plugin %s not found, but found in plugins.json" % name) + self.activate_plugins() + + def emit_event(self, e_context: EventContext, *args, **kwargs): + if e_context.event in self.listening_plugins: + for name in self.listening_plugins[e_context.event]: + if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: + logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) + instance = self.instances[name] + instance.handlers[e_context.event](e_context, *args, **kwargs) + return e_context + + def set_plugin_priority(self, name:str, priority:int): + name = name.upper() + if name not in self.plugins: + return False + if self.plugins[name].priority == priority: + return True + self.plugins[name].priority = priority + self.plugins._update_heap(name) + rawname = self.plugins[name].name + self.pconf["plugins"][rawname]["priority"] = priority + self.pconf["plugins"]._update_heap(rawname) + self.save_config() + self.refresh_order() + return True + + def enable_plugin(self, name:str): + name = name.upper() + if name not in self.plugins: + return False + if not self.plugins[name].enabled : + self.plugins[name].enabled = True + rawname = self.plugins[name].name + self.pconf["plugins"][rawname]["enabled"] = True + self.save_config() + self.activate_plugins() + return True + return True + + def disable_plugin(self, name:str): + name = name.upper() + if name not in self.plugins: + return False + if self.plugins[name].enabled : + self.plugins[name].enabled = False + rawname = self.plugins[name].name + self.pconf["plugins"][rawname]["enabled"] = False + self.save_config() + return True + return True + + def list_plugins(self): + return self.plugins \ No newline at end of file diff --git a/plugins/sdwebui/__init__.py b/plugins/sdwebui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/sdwebui/config.json.template b/plugins/sdwebui/config.json.template new file mode 100644 index 0000000..213acdc --- /dev/null +++ b/plugins/sdwebui/config.json.template @@ -0,0 +1,70 @@ +{ + "start":{ + "host" : "127.0.0.1", + "port" : 7860 + }, + "defaults": { + "params": { + "sampler_name": "DPM++ 2M Karras", + "steps": 20, + "width": 512, + "height": 512, + "cfg_scale": 7, + "prompt":"masterpiece, best quality", + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "enable_hr": false, + "hr_scale": 2, + "hr_upscaler": "Latent", + "hr_second_pass_steps": 15, + "denoising_strength": 0.7 + }, + "options": { + "sd_model_checkpoint": "perfectWorld_v2Baked" + } + }, + "rules": [ + { + "keywords": [ + "横版", + "壁纸" + ], + "params": { + "width": 640, + "height": 384 + }, + "desc": "分辨率会变成640x384" + }, + { + "keywords": [ + "竖版" + ], + "params": { + "width": 384, + "height": 640 + } + }, + { + "keywords": [ + "高清" + ], + "params": { + "enable_hr": true, + "hr_scale": 1.6 + }, + "desc": "出图分辨率长宽都会提高1.6倍" + }, + { + "keywords": [ + "二次元" + ], + "params": { + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "prompt": "masterpiece, best quality" + }, + "options": { + "sd_model_checkpoint": "meinamix_meinaV8" + }, + "desc": "使用二次元风格模型出图" + } + ] +} \ No newline at end of file diff --git a/plugins/sdwebui/readme.md b/plugins/sdwebui/readme.md new file mode 100644 index 0000000..bb8c62c --- /dev/null +++ b/plugins/sdwebui/readme.md @@ -0,0 +1,69 @@ +### 插件描述 +本插件用于将画图请求转发给stable diffusion webui。 + +### 环境要求 +使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。 +具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。 + +请**安装**本插件的依赖包```webuiapi``` +``` + ```pip install webuiapi``` +``` +### 使用说明 +请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。 + +#### 画图请求格式 +用户的画图请求格式为: +``` + <画图触发词><关键词1> <关键词2> ... <关键词n>: +``` +- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。 +- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准: +- 关键词中包含`help`或`帮助`,会打印出帮助文档。 +第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后 + +例如: 画横版 高清 二次元:cat +会触发三个关键词 "横版", "高清", "二次元",prompt为"cat" +若默认参数是: +``` + "width": 512, + "height": 512, + "enable_hr": false, + "prompt": "8k" + "negative_prompt": "nsfw", + "sd_model_checkpoint": "perfectWorld_v2Baked" +``` + +"横版"触发的规则参数为: +``` + "width": 640, + "height": 384, +``` +"高清"触发的规则参数为: +``` + "enable_hr": true, + "hr_scale": 1.6, +``` +"二次元"触发的规则参数为: +``` + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "steps": 20, + "prompt": "masterpiece, best quality", + + "sd_model_checkpoint": "meinamix_meinaV8" +``` +最后将第一个":"后的内容cat连接在prompt后,得到最终参数为: +``` + "width": 640, + "height": 384, + "enable_hr": true, + "hr_scale": 1.6, + "negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", + "steps": 20, + "prompt": "masterpiece, best quality, cat", + + "sd_model_checkpoint": "meinamix_meinaV8" +``` +PS: 参数分为两部分: +- 一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致 +- 另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。 \ No newline at end of file diff --git a/plugins/sdwebui/sdwebui.py b/plugins/sdwebui/sdwebui.py new file mode 100644 index 0000000..56842e8 --- /dev/null +++ b/plugins/sdwebui/sdwebui.py @@ -0,0 +1,114 @@ +# encoding:utf-8 + +import json +import os +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from config import conf +import plugins +from plugins import * +from common.log import logger +import webuiapi +import io + + +@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent") +class SDWebUI(Plugin): + def __init__(self): + super().__init__() + curdir = os.path.dirname(__file__) + config_path = os.path.join(curdir, "config.json") + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + self.rules = config["rules"] + defaults = config["defaults"] + self.default_params = defaults["params"] + self.default_options = defaults["options"] + self.start_args = config["start"] + self.api = webuiapi.WebUIApi(**self.start_args) + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + logger.info("[SD] inited") + except FileNotFoundError: + logger.error(f"[SD] init failed, {config_path} not found") + except Exception as e: + logger.error("[SD] init failed, exception: %s" % e) + + def on_handle_context(self, e_context: EventContext): + + if e_context['context'].type != ContextType.IMAGE_CREATE: + return + + logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content) + + logger.info("[SD] image_query={}".format(e_context['context'].content)) + reply = Reply() + try: + content = e_context['context'].content[:] + # 解析用户输入 如"横版 高清 二次元:cat" + if ":" in content: + keywords, prompt = content.split(":", 1) + else: + keywords = content + prompt = "" + + keywords = keywords.split() + + if "help" in keywords or "帮助" in keywords: + reply.type = ReplyType.INFO + reply.content = self.get_help_text() + else: + rule_params = {} + rule_options = {} + for keyword in keywords: + matched = False + for rule in self.rules: + if keyword in rule["keywords"]: + for key in rule["params"]: + rule_params[key] = rule["params"][key] + if "options" in rule: + for key in rule["options"]: + rule_options[key] = rule["options"][key] + matched = True + break # 一个关键词只匹配一个规则 + if not matched: + logger.warning("[SD] keyword not matched: %s" % keyword) + + params = {**self.default_params, **rule_params} + options = {**self.default_options, **rule_options} + params["prompt"] = params.get("prompt", "")+f", {prompt}" + if len(options) > 0: + logger.info("[SD] cover options={}".format(options)) + self.api.set_options(options) + logger.info("[SD] params={}".format(params)) + result = self.api.txt2img( + **params + ) + reply.type = ReplyType.IMAGE + b_img = io.BytesIO() + result.image.save(b_img, format="PNG") + reply.content = b_img + e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑 + except Exception as e: + reply.type = ReplyType.ERROR + reply.content = "[SD] "+str(e) + logger.error("[SD] exception: %s" % e) + e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 + finally: + e_context['reply'] = reply + + def get_help_text(self): + if not conf().get('image_create_prefix'): + return "画图功能未启用" + else: + trigger = conf()['image_create_prefix'][0] + help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n" + help_text += "目前可用关键词:\n" + for rule in self.rules: + keywords = [f"[{keyword}]" for keyword in rule['keywords']] + help_text += f"{','.join(keywords)}" + if "desc" in rule: + help_text += f"-{rule['desc']}\n" + else: + help_text += "\n" + return help_text \ No newline at end of file diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py index d99db37..531d8ce 100644 --- a/voice/baidu/baidu_voice.py +++ b/voice/baidu/baidu_voice.py @@ -4,6 +4,7 @@ baidu voice service """ import time from aip import AipSpeech +from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir from voice.voice import Voice @@ -30,7 +31,8 @@ class BaiduVoice(Voice): with open(fileName, 'wb') as f: f.write(result) logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) - return fileName + reply = Reply(ReplyType.VOICE, fileName) else: logger.error('[Baidu] textToVoice error={}'.format(result)) - return None + reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") + return reply diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py index 8e339f2..74431db 100644 --- a/voice/google/google_voice.py +++ b/voice/google/google_voice.py @@ -6,6 +6,7 @@ google voice service import pathlib import subprocess import time +from bridge.reply import Reply, ReplyType import speech_recognition import pyttsx3 from common.log import logger @@ -36,16 +37,22 @@ class GoogleVoice(Voice): text = self.recognizer.recognize_google(audio, language='zh-CN') logger.info( '[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) - return text + reply = Reply(ReplyType.TEXT, text) except speech_recognition.UnknownValueError: - return "抱歉,我听不懂。" + reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") except speech_recognition.RequestError as e: - return "抱歉,无法连接到 Google 语音识别服务;{0}".format(e) - + reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) + finally: + return reply def textToVoice(self, text): - textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' - self.engine.save_to_file(text, textFile) - self.engine.runAndWait() - logger.info( - '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) - return textFile + try: + textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' + self.engine.save_to_file(text, textFile) + self.engine.runAndWait() + logger.info( + '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) + reply = Reply(ReplyType.VOICE, textFile) + except Exception as e: + reply = Reply(ReplyType.ERROR, str(e)) + finally: + return reply diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 475aac6..2e85e10 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -4,6 +4,7 @@ google voice service """ import json import openai +from bridge.reply import Reply, ReplyType from config import conf from common.log import logger from voice.voice import Voice @@ -16,12 +17,17 @@ class OpenaiVoice(Voice): def voiceToText(self, voice_file): logger.debug( '[Openai] voice file name={}'.format(voice_file)) - file = open(voice_file, "rb") - reply = openai.Audio.transcribe("whisper-1", file) - text = reply["text"] - logger.info( - '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) - return text + try: + file = open(voice_file, "rb") + result = openai.Audio.transcribe("whisper-1", file) + text = result["text"] + reply = Reply(ReplyType.TEXT, text) + logger.info( + '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) + except Exception as e: + reply = Reply(ReplyType.ERROR, str(e)) + finally: + return reply def textToVoice(self, text): pass