From d8b75206fea1703f8e07d72a80828d82ae8695e7 Mon Sep 17 00:00:00 2001 From: lanvent Date: Fri, 7 Apr 2023 12:15:29 +0800 Subject: [PATCH] feat: maxmize message length --- channel/chat_channel.py | 9 ++- channel/wechatmp/wechatmp_channel.py | 104 ++++++++++++++++----------- 2 files changed, 68 insertions(+), 45 deletions(-) diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 6ab60e9..a7f8694 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -231,12 +231,15 @@ class ChatChannel(Channel): time.sleep(3+3*retry_cnt) self._send(reply, context, retry_cnt+1) - def thread_pool_callback(self, session_id): + def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 + logger.exception("Worker return exception: {}".format(exception)) + + def _thread_pool_callback(self, session_id, **kwargs): def func(worker:Future): try: worker_exception = worker.exception() if worker_exception: - logger.exception("Worker return exception: {}".format(worker_exception)) + self._fail_callback(session_id, exception = worker_exception, **kwargs) except CancelledError as e: logger.info("Worker cancelled, session_id = {}".format(session_id)) except Exception as e: @@ -267,7 +270,7 @@ class ChatChannel(Channel): context = context_queue.get() logger.debug("[WX] consume context: {}".format(context)) future:Future = self.handler_pool.submit(self._handle, context) - future.add_done_callback(self.thread_pool_callback(session_id)) + future.add_done_callback(self._thread_pool_callback(session_id, context = context)) if session_id not in self.futures: self.futures[session_id] = [] self.futures[session_id].append(future) diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py index 1967245..3bf2df4 100644 --- a/channel/wechatmp/wechatmp_channel.py +++ b/channel/wechatmp/wechatmp_channel.py @@ -26,12 +26,14 @@ import traceback # from concurrent.futures import ThreadPoolExecutor # thread_pool = ThreadPoolExecutor(max_workers=8) +MAX_UTF8_LEN = 2048 @singleton class WechatMPChannel(ChatChannel): NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] def __init__(self): super().__init__() self.cache_dict = dict() + self.running = set() self.query1 = dict() self.query2 = dict() self.query3 = dict() @@ -47,11 +49,16 @@ class WechatMPChannel(ChatChannel): def send(self, reply: Reply, context: Context): - reply_cnt = math.ceil(len(reply.content) / 600) receiver = context["receiver"] - self.cache_dict[receiver] = (reply_cnt, reply.content) + self.cache_dict[receiver] = reply.content + self.running.remove(receiver) logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply)) + def _fail_callback(self, session_id, exception, context, **kwargs): + logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) + assert session_id not in self.cache_dict + self.running.remove(session_id) + def verify_server(): try: @@ -86,11 +93,11 @@ class SubsribeAccountQuery(): return verify_server() def POST(self): - channel_instance = WechatMPChannel() + channel = WechatMPChannel() try: query_time = time.time() webData = web.data() - # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) + logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) wechat_msg = receive.parse_xml(webData) if wechat_msg.msg_type == 'text': from_user = wechat_msg.from_user_id @@ -101,21 +108,20 @@ class SubsribeAccountQuery(): logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) cache_key = from_user - cache = channel_instance.cache_dict.get(cache_key) reply_text = "" # New request - if cache == None: + if cache_key not in channel.cache_dict and cache_key not in channel.running: # The first query begin, reset the cache - context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechat_msg) + context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechat_msg) logger.debug("[wechatmp] context: {} {}".format(context, wechat_msg)) if context: # set private openai_api_key # if from_user is not changed in itchat, this can be placed at chat_channel user_data = conf().get_user_data(from_user) context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key - channel_instance.cache_dict[cache_key] = (0, "") - channel_instance.produce(context) + channel.running.add(cache_key) + channel.produce(context) else: trigger_prefix = conf().get('single_chat_prefix',[''])[0] if trigger_prefix: @@ -129,31 +135,28 @@ class SubsribeAccountQuery(): 未知错误,请稍后再试""") replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content) return replyMsg.send() - channel_instance.query1[cache_key] = False - channel_instance.query2[cache_key] = False - channel_instance.query3[cache_key] = False + channel.query1[cache_key] = False + channel.query2[cache_key] = False + channel.query3[cache_key] = False # Request again - elif cache[0] == 0 and channel_instance.query1.get(cache_key) == True and channel_instance.query2.get(cache_key) == True and channel_instance.query3.get(cache_key) == True: - channel_instance.query1[cache_key] = False #To improve waiting experience, this can be set to True. - channel_instance.query2[cache_key] = False #To improve waiting experience, this can be set to True. - channel_instance.query3[cache_key] = False - elif cache[0] >= 1: + elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True: + channel.query1[cache_key] = False #To improve waiting experience, this can be set to True. + channel.query2[cache_key] = False #To improve waiting experience, this can be set to True. + channel.query3[cache_key] = False + elif cache_key in channel.cache_dict: # Skip the waiting phase - channel_instance.query1[cache_key] = True - channel_instance.query2[cache_key] = True - channel_instance.query3[cache_key] = True - + channel.query1[cache_key] = True + channel.query2[cache_key] = True + channel.query3[cache_key] = True - cache = channel_instance.cache_dict.get(cache_key) - if channel_instance.query1.get(cache_key) == False: + if channel.query1.get(cache_key) == False: # The first query from wechat official server logger.debug("[wechatmp] query1 {}".format(cache_key)) - channel_instance.query1[cache_key] = True + channel.query1[cache_key] = True cnt = 0 - while cache[0] == 0 and cnt < 45: + while cache_key not in channel.cache_dict and cnt < 45: cnt = cnt + 1 time.sleep(0.1) - cache = channel_instance.cache_dict.get(cache_key) if cnt == 45: # waiting for timeout (the POST query will be closed by wechat official server) time.sleep(5) @@ -161,15 +164,14 @@ class SubsribeAccountQuery(): return else: pass - elif channel_instance.query2.get(cache_key) == False: + elif channel.query2.get(cache_key) == False: # The second query from wechat official server logger.debug("[wechatmp] query2 {}".format(cache_key)) - channel_instance.query2[cache_key] = True + channel.query2[cache_key] = True cnt = 0 - while cache[0] == 0 and cnt < 45: + while cache_key not in channel.cache_dict and cnt < 45: cnt = cnt + 1 time.sleep(0.1) - cache = channel_instance.cache_dict.get(cache_key) if cnt == 45: # waiting for timeout (the POST query will be closed by wechat official server) time.sleep(5) @@ -177,15 +179,14 @@ class SubsribeAccountQuery(): return else: pass - elif channel_instance.query3.get(cache_key) == False: + elif channel.query3.get(cache_key) == False: # The third query from wechat official server logger.debug("[wechatmp] query3 {}".format(cache_key)) - channel_instance.query3[cache_key] = True + channel.query3[cache_key] = True cnt = 0 - while cache[0] == 0 and cnt < 40: + while cache_key not in channel.cache_dict and cnt < 40: cnt = cnt + 1 time.sleep(0.1) - cache = channel_instance.cache_dict.get(cache_key) if cnt == 40: # Have waiting for 3x5 seconds # return timeout message @@ -198,15 +199,19 @@ class SubsribeAccountQuery(): if float(time.time()) - float(query_time) > 4.8: logger.info("[wechatmp] Timeout for {} {}".format(from_user, message_id)) + time.sleep(1) return - - - if cache[0] > 1: - reply_text = cache[1][:600] + "\n【未完待续,回复任意文字以继续】" #wechatmp auto_reply length limit - channel_instance.cache_dict[cache_key] = (cache[0] - 1, cache[1][600:]) - elif cache[0] == 1: - reply_text = cache[1] - channel_instance.cache_dict.pop(cache_key) + + if cache_key in channel.cache_dict: + content = channel.cache_dict[cache_key] + if len(content.encode('utf8'))<=MAX_UTF8_LEN: + reply_text = channel.cache_dict[cache_key] + channel.cache_dict.pop(cache_key) + else: + continue_text = "\n【未完待续,回复任意文字以继续】" + splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8'))) + reply_text = splits[0] + continue_text + channel.cache_dict[cache_key] = splits[1] logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text)) replyPost = reply.TextMsg(from_user, to_user, reply_text).send() return replyPost @@ -232,3 +237,18 @@ class SubsribeAccountQuery(): logger.exception(exc) return exc +def split_string_by_utf8_length(string, max_length, max_split=0): + encoded = string.encode('utf-8') + start, end = 0, 0 + result = [] + while end < len(encoded): + if max_split > 0 and len(result) >= max_split: + result.append(encoded[start:].decode('utf-8')) + break + end = start + max_length + # 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止 + while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000: + end -= 1 + result.append(encoded[start:end].decode('utf-8')) + start = end + return result \ No newline at end of file