From 1a981ea97095d12f2a6000a14e005f77e52359e3 Mon Sep 17 00:00:00 2001 From: JS00000 Date: Wed, 5 Apr 2023 20:55:24 +0800 Subject: [PATCH] Refactor: inherit ChatChannel --- channel/channel_factory.py | 4 +- channel/wechatmp/receive.py | 68 +++---- channel/wechatmp/wechatmp_channel.py | 284 ++++++++++----------------- 3 files changed, 135 insertions(+), 221 deletions(-) diff --git a/channel/channel_factory.py b/channel/channel_factory.py index 3d06154..3303ded 100644 --- a/channel/channel_factory.py +++ b/channel/channel_factory.py @@ -18,6 +18,6 @@ def create_channel(channel_type): from channel.terminal.terminal_channel import TerminalChannel return TerminalChannel() elif channel_type == 'wechatmp': - from channel.wechatmp.wechatmp_channel import WechatMPServer - return WechatMPServer() + from channel.wechatmp.wechatmp_channel import WechatMPChannel + return WechatMPChannel() raise RuntimeError diff --git a/channel/wechatmp/receive.py b/channel/wechatmp/receive.py index 40fc35f..64d2106 100644 --- a/channel/wechatmp/receive.py +++ b/channel/wechatmp/receive.py @@ -1,47 +1,43 @@ # -*- coding: utf-8 -*-# # filename: receive.py import xml.etree.ElementTree as ET +from bridge.context import ContextType +from channel.chat_message import ChatMessage +from common.tmp_dir import TmpDir +from common.log import logger def parse_xml(web_data): if len(web_data) == 0: return None xmlData = ET.fromstring(web_data) - msg_type = xmlData.find('MsgType').text - if msg_type == 'text': - return TextMsg(xmlData) - elif msg_type == 'image': - return ImageMsg(xmlData) - elif msg_type == 'event': - return Event(xmlData) + return WeChatMPMessage(xmlData) - -class Msg(object): - def __init__(self, xmlData): - self.ToUserName = xmlData.find('ToUserName').text - self.FromUserName = xmlData.find('FromUserName').text - self.CreateTime = xmlData.find('CreateTime').text - self.MsgType = xmlData.find('MsgType').text - self.MsgId = xmlData.find('MsgId').text - - -class TextMsg(Msg): - def __init__(self, xmlData): - Msg.__init__(self, xmlData) - self.Content = xmlData.find('Content').text.encode("utf-8") - - -class ImageMsg(Msg): - def __init__(self, xmlData): - Msg.__init__(self, xmlData) - self.PicUrl = xmlData.find('PicUrl').text - self.MediaId = xmlData.find('MediaId').text - - -class Event(object): +class WeChatMPMessage(ChatMessage): def __init__(self, xmlData): - self.ToUserName = xmlData.find('ToUserName').text - self.FromUserName = xmlData.find('FromUserName').text - self.CreateTime = xmlData.find('CreateTime').text - self.MsgType = xmlData.find('MsgType').text - self.Event = xmlData.find('Event').text + super().__init__(xmlData) + self.to_user_id = xmlData.find('ToUserName').text + self.from_user_id = xmlData.find('FromUserName').text + self.create_time = xmlData.find('CreateTime').text + self.msg_type = xmlData.find('MsgType').text + self.msg_id = xmlData.find('MsgId').text + self.is_group = False + + # reply to other_user_id + self.other_user_id = self.from_user_id + + if self.msg_type == 'text': + self.ctype = ContextType.TEXT + self.content = xmlData.find('Content').text.encode("utf-8") + elif self.msg_type == 'voice': + self.ctype = ContextType.TEXT + self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果 + elif self.msg_type == 'image': + # not implemented + self.pic_url = xmlData.find('PicUrl').text + self.media_id = xmlData.find('MediaId').text + elif self.msg_type == 'event': + self.event = xmlData.find('Event').text + else: # video, shortvideo, location, link + # not implemented + pass \ No newline at end of file diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py index 0da9085..cdd8673 100644 --- a/channel/wechatmp/wechatmp_channel.py +++ b/channel/wechatmp/wechatmp_channel.py @@ -4,9 +4,10 @@ import time import math import hashlib import textwrap -from channel.channel import Channel +from channel.chat_channel import ChatChannel import channel.wechatmp.reply as reply import channel.wechatmp.receive as receive +from common.singleton import singleton from common.log import logger from config import conf from bridge.reply import * @@ -21,202 +22,125 @@ import traceback # certificate='/ssl/cert.pem', # private_key='/ssl/cert.key') -class WechatMPServer(): + +# from concurrent.futures import ThreadPoolExecutor +# thread_pool = ThreadPoolExecutor(max_workers=8) + +@singleton +class WechatMPChannel(ChatChannel): def __init__(self): - pass + super().__init__() + self.cache_dict = dict() + self.query1 = dict() + self.query2 = dict() + self.query3 = dict() + - def startup(self): + def startup(self): urls = ( - '/wx', 'WechatMPChannel', + '/wx', 'SubsribeAccountQuery', ) app = web.application(urls, globals()) app.run() -cache_dict = dict() -query1 = dict() -query2 = dict() -query3 = dict() - -from concurrent.futures import ThreadPoolExecutor -thread_pool = ThreadPoolExecutor(max_workers=8) - -class WechatMPChannel(Channel): - def GET(self): - try: - data = web.input() - if len(data) == 0: - return "hello, this is handle view" - signature = data.signature - timestamp = data.timestamp - nonce = data.nonce - echostr = data.echostr - token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写 - - data_list = [token, timestamp, nonce] - data_list.sort() - sha1 = hashlib.sha1() - # map(sha1.update, data_list) #python2 - sha1.update("".join(data_list).encode('utf-8')) - hashcode = sha1.hexdigest() - print("handle/GET func: hashcode, signature: ", hashcode, signature) - if hashcode == signature: - return echostr - else: - return "" - except Exception as Argument: - return Argument - - - def _do_build_reply(self, cache_key, fromUser, message): - context = dict() - context['session_id'] = fromUser - reply_text = super().build_reply_content(message, context) - # The query is done, record the cache - logger.info("[threaded] Get reply for {}: {} \nA: {}".format(fromUser, message, reply_text)) - global cache_dict - reply_cnt = math.ceil(len(reply_text) / 600) - cache_dict[cache_key] = (reply_cnt, reply_text) - - - def send(self, reply : Reply, cache_key): - global cache_dict + def send(self, reply: Reply, context: Context): reply_cnt = math.ceil(len(reply.content) / 600) - cache_dict[cache_key] = (reply_cnt, reply.content) - - - def handle(self, context): - global cache_dict - try: - reply = Reply() - logger.debug('[wechatmp] ready to handle context: {}'.format(context)) - - # reply的构建步骤 - e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) - reply = e_context['reply'] - if not e_context.is_pass(): - logger.debug('[wechatmp] ready to handle context: type={}, content={}'.format(context.type, context.content)) - if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: - reply = super().build_reply_content(context.content, context) - # elif context.type == ContextType.VOICE: - # msg = context['msg'] - # file_name = TmpDir().path() + context.content - # msg.download(file_name) - # reply = super().build_voice_to_text(file_name) - # if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: - # context.content = reply.content # 语音转文字后,将文字内容作为新的context - # context.type = ContextType.TEXT - # reply = super().build_reply_content(context.content, context) - # if reply.type == ReplyType.TEXT: - # if conf().get('voice_reply_voice'): - # reply = super().build_text_to_voice(reply.content) - else: - logger.error('[wechatmp] unknown context type: {}'.format(context.type)) - return - - logger.debug('[wechatmp] ready to decorate reply: {}'.format(reply)) - - # reply的包装步骤 - if reply and reply.type: - e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) - reply=e_context['reply'] - if not e_context.is_pass() and reply and reply.type: - if reply.type == ReplyType.TEXT: - pass - elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: - reply.content = str(reply.type)+":\n" + reply.content - elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: - pass - else: - logger.error('[wechatmp] unknown reply type: {}'.format(reply.type)) - return - - # reply的发送步骤 - if reply and reply.type: - e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) - reply=e_context['reply'] - if not e_context.is_pass() and reply and reply.type: - logger.debug('[wechatmp] ready to send reply: {} to {}'.format(reply, context['receiver'])) - self.send(reply, context['receiver']) - else: - cache_dict[context['receiver']] = (1, "No reply") - - logger.info("[threaded] Get reply for {}: {} \nA: {}".format(context['receiver'], context.content, reply.content)) - except Exception as exc: - print(traceback.format_exc()) - cache_dict[context['receiver']] = (1, "ERROR") - + receiver = context["receiver"] + self.cache_dict[receiver] = (reply_cnt, reply.content) + logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply)) + + +def verify_server(): + try: + data = web.input() + if len(data) == 0: + return "None" + signature = data.signature + timestamp = data.timestamp + nonce = data.nonce + echostr = data.echostr + token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写 + + data_list = [token, timestamp, nonce] + data_list.sort() + sha1 = hashlib.sha1() + # map(sha1.update, data_list) #python2 + sha1.update("".join(data_list).encode('utf-8')) + hashcode = sha1.hexdigest() + print("handle/GET func: hashcode, signature: ", hashcode, signature) + if hashcode == signature: + return echostr + else: + return "" + except Exception as Argument: + return Argument + + +# This class is instantiated once per query +class SubsribeAccountQuery(): + def GET(self): + return verify_server() def POST(self): + channel_instance = WechatMPChannel() try: - queryTime = time.time() + query_time = time.time() webData = web.data() # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) - recMsg = receive.parse_xml(webData) - if isinstance(recMsg, receive.Msg) and recMsg.MsgType == 'text': - fromUser = recMsg.FromUserName - toUser = recMsg.ToUserName - createTime = recMsg.CreateTime - message = recMsg.Content.decode("utf-8") - message_id = recMsg.MsgId + wechat_msg = receive.parse_xml(webData) + if wechat_msg.msg_type == 'text': + from_user = wechat_msg.from_user_id + to_user = wechat_msg.to_user_id + message = wechat_msg.content.decode("utf-8") + message_id = wechat_msg.msg_id - logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), fromUser, message_id, message)) + logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) - global cache_dict - global query1 - global query2 - global query3 - cache_key = fromUser - cache = cache_dict.get(cache_key) + cache_key = from_user + cache = channel_instance.cache_dict.get(cache_key) reply_text = "" # New request if cache == None: # The first query begin, reset the cache - cache_dict[cache_key] = (0, "") - # thread_pool.submit(self._do_build_reply, cache_key, fromUser, message) - - context = Context() - context.kwargs = {'isgroup': False, 'receiver': fromUser, 'session_id': fromUser} + channel_instance.cache_dict[cache_key] = (0, "") - user_data = conf().get_user_data(fromUser) - context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key + context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechat_msg) + if context: + # set private openai_api_key + # if from_user is not changed in itchat, this can be placed at chat_channel + user_data = conf().get_user_data(from_user) + context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key + channel_instance.produce(context) - img_match_prefix = check_prefix(message, conf().get('image_create_prefix')) - if img_match_prefix: - message = message.replace(img_match_prefix, '', 1).strip() - context.type = ContextType.IMAGE_CREATE - else: - context.type = ContextType.TEXT - context.content = message - thread_pool.submit(self.handle, context) - query1[cache_key] = False - query2[cache_key] = False - query3[cache_key] = False + channel_instance.query1[cache_key] = False + channel_instance.query2[cache_key] = False + channel_instance.query3[cache_key] = False # Request again - elif cache[0] == 0 and query1.get(cache_key) == True and query2.get(cache_key) == True and query3.get(cache_key) == True: - query1[cache_key] = False #To improve waiting experience, this can be set to True. - query2[cache_key] = False #To improve waiting experience, this can be set to True. - query3[cache_key] = False + elif cache[0] == 0 and channel_instance.query1.get(cache_key) == True and channel_instance.query2.get(cache_key) == True and channel_instance.query3.get(cache_key) == True: + channel_instance.query1[cache_key] = False #To improve waiting experience, this can be set to True. + channel_instance.query2[cache_key] = False #To improve waiting experience, this can be set to True. + channel_instance.query3[cache_key] = False elif cache[0] >= 1: # Skip the waiting phase - query1[cache_key] = True - query2[cache_key] = True - query3[cache_key] = True + channel_instance.query1[cache_key] = True + channel_instance.query2[cache_key] = True + channel_instance.query3[cache_key] = True - cache = cache_dict.get(cache_key) - if query1.get(cache_key) == False: + cache = channel_instance.cache_dict.get(cache_key) + if channel_instance.query1.get(cache_key) == False: # The first query from wechat official server logger.debug("[wechatmp] query1 {}".format(cache_key)) - query1[cache_key] = True + channel_instance.query1[cache_key] = True cnt = 0 while cache[0] == 0 and cnt < 45: cnt = cnt + 1 time.sleep(0.1) - cache = cache_dict.get(cache_key) + cache = channel_instance.cache_dict.get(cache_key) if cnt == 45: # waiting for timeout (the POST query will be closed by wechat official server) time.sleep(5) @@ -224,15 +148,15 @@ class WechatMPChannel(Channel): return else: pass - elif query2.get(cache_key) == False: + elif channel_instance.query2.get(cache_key) == False: # The second query from wechat official server logger.debug("[wechatmp] query2 {}".format(cache_key)) - query2[cache_key] = True + channel_instance.query2[cache_key] = True cnt = 0 while cache[0] == 0 and cnt < 45: cnt = cnt + 1 time.sleep(0.1) - cache = cache_dict.get(cache_key) + cache = channel_instance.cache_dict.get(cache_key) if cnt == 45: # waiting for timeout (the POST query will be closed by wechat official server) time.sleep(5) @@ -240,42 +164,42 @@ class WechatMPChannel(Channel): return else: pass - elif query3.get(cache_key) == False: + elif channel_instance.query3.get(cache_key) == False: # The third query from wechat official server logger.debug("[wechatmp] query3 {}".format(cache_key)) - query3[cache_key] = True + channel_instance.query3[cache_key] = True cnt = 0 while cache[0] == 0 and cnt < 45: cnt = cnt + 1 time.sleep(0.1) - cache = cache_dict.get(cache_key) + cache = channel_instance.cache_dict.get(cache_key) if cnt == 45: # Have waiting for 3x5 seconds # return timeout message reply_text = "【正在响应中,回复任意文字尝试获取回复】" - logger.info("[wechatmp] Three queries has finished For {}: {}".format(fromUser, message_id)) - replyPost = reply.TextMsg(fromUser, toUser, reply_text).send() + logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id)) + replyPost = reply.TextMsg(from_user, to_user, reply_text).send() return replyPost else: pass - if float(time.time()) - float(queryTime) > 4.8: - logger.info("[wechatmp] Timeout for {} {}".format(fromUser, message_id)) + if float(time.time()) - float(query_time) > 4.8: + logger.info("[wechatmp] Timeout for {} {}".format(from_user, message_id)) return if cache[0] > 1: reply_text = cache[1][:600] + "\n【未完待续,回复任意文字以继续】" #wechatmp auto_reply length limit - cache_dict[cache_key] = (cache[0] - 1, cache[1][600:]) + channel_instance.cache_dict[cache_key] = (cache[0] - 1, cache[1][600:]) elif cache[0] == 1: reply_text = cache[1] - cache_dict.pop(cache_key) + channel_instance.cache_dict.pop(cache_key) logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text)) - replyPost = reply.TextMsg(fromUser, toUser, reply_text).send() + replyPost = reply.TextMsg(from_user, to_user, reply_text).send() return replyPost - elif isinstance(recMsg, receive.Event) and recMsg.MsgType == 'event': - logger.info("[wechatmp] Event {} from {}".format(recMsg.Event, recMsg.FromUserName)) + elif wechat_msg.msg_type == 'event': + logger.info("[wechatmp] Event {} from {}".format(wechat_msg.Event, wechat_msg.from_user_id)) content = textwrap.dedent("""\ 感谢您的关注! 这里是ChatGPT,可以自由对话。 @@ -285,7 +209,7 @@ class WechatMPChannel(Channel): 支持图片输出,画字开头的问题将回复图片链接。 支持角色扮演和文字冒险两种定制模式对话。 输入'#帮助' 查看详细指令。""") - replyMsg = reply.TextMsg(recMsg.FromUserName, recMsg.ToUserName, content) + replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content) return replyMsg.send() else: logger.info("暂且不处理") @@ -294,9 +218,3 @@ class WechatMPChannel(Channel): logger.exception(exc) return exc - -def check_prefix(content, prefix_list): - for prefix in prefix_list: - if content.startswith(prefix): - return prefix - return None