ソースを参照

fix: wechatmp's deadloop when reply is None from @JS00000 #789

master
lanvent 1年前
コミット
fcfafb05f1
4個のファイルの変更37行の追加21行の削除
  1. +5
    -0
      channel/chat_channel.py
  2. +4
    -4
      channel/wechatmp/ServiceAccount.py
  3. +18
    -11
      channel/wechatmp/SubscribeAccount.py
  4. +10
    -6
      channel/wechatmp/wechatmp_channel.py

+ 5
- 0
channel/chat_channel.py ファイルの表示

@@ -233,6 +233,9 @@ class ChatChannel(Channel):
time.sleep(3+3*retry_cnt)
self._send(reply, context, retry_cnt+1)

def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数
logger.debug("Worker return success, session_id = {}".format(session_id))

def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
logger.exception("Worker return exception: {}".format(exception))

@@ -242,6 +245,8 @@ class ChatChannel(Channel):
worker_exception = worker.exception()
if worker_exception:
self._fail_callback(session_id, exception = worker_exception, **kwargs)
else:
self._success_callback(session_id, **kwargs)
except CancelledError as e:
logger.info("Worker cancelled, session_id = {}".format(session_id))
except Exception as e:


+ 4
- 4
channel/wechatmp/ServiceAccount.py ファイルの表示

@@ -16,7 +16,7 @@ class Query():

def POST(self):
# Make sure to return the instance that first created, @singleton will do that.
channel_instance = WechatMPChannel()
channel = WechatMPChannel()
try:
webData = web.data()
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
@@ -27,14 +27,14 @@ class Query():
message_id = wechatmp_msg.msg_id

logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
if context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
channel_instance.produce(context)
# The reply will be sent by channel_instance.send() in another thread
channel.produce(context)
# The reply will be sent by channel.send() in another thread
return "success"

elif wechatmp_msg.msg_type == 'event':


+ 18
- 11
channel/wechatmp/SubscribeAccount.py ファイルの表示

@@ -41,7 +41,8 @@ class Query():
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg))
if message_id in channel.received_msgs: # received and finished
return
# no return because of bandwords or other reasons
return "success"
if supported and context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
@@ -71,11 +72,12 @@ class Query():
channel.query1[cache_key] = False
channel.query2[cache_key] = False
channel.query3[cache_key] = False
# Request again
# User request again, and the answer is not ready
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True:
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True.
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True.
channel.query3[cache_key] = False
# User request again, and the answer is ready
elif cache_key in channel.cache_dict:
# Skip the waiting phase
channel.query1[cache_key] = True
@@ -89,7 +91,7 @@ class Query():
logger.debug("[wechatmp] query1 {}".format(cache_key))
channel.query1[cache_key] = True
cnt = 0
while cache_key not in channel.cache_dict and cnt < 45:
while cache_key in channel.running and cnt < 45:
cnt = cnt + 1
time.sleep(0.1)
if cnt == 45:
@@ -104,7 +106,7 @@ class Query():
logger.debug("[wechatmp] query2 {}".format(cache_key))
channel.query2[cache_key] = True
cnt = 0
while cache_key not in channel.cache_dict and cnt < 45:
while cache_key in channel.running and cnt < 45:
cnt = cnt + 1
time.sleep(0.1)
if cnt == 45:
@@ -119,7 +121,7 @@ class Query():
logger.debug("[wechatmp] query3 {}".format(cache_key))
channel.query3[cache_key] = True
cnt = 0
while cache_key not in channel.cache_dict and cnt < 40:
while cache_key in channel.running and cnt < 40:
cnt = cnt + 1
time.sleep(0.1)
if cnt == 40:
@@ -132,12 +134,17 @@ class Query():
else:
pass

if float(time.time()) - float(query_time) > 4.8:
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost

if cache_key not in channel.cache_dict and cache_key not in channel.running:
# no return because of bandwords or other reasons
return "success"

# if float(time.time()) - float(query_time) > 4.8:
# reply_text = "【正在思考中,回复任意文字尝试获取回复】"
# logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
# replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
# return replyPost

if cache_key in channel.cache_dict:
content = channel.cache_dict[cache_key]
if len(content.encode('utf8'))<=MAX_UTF8_LEN:


+ 10
- 6
channel/wechatmp/wechatmp_channel.py ファイルの表示

@@ -97,8 +97,7 @@ class WechatMPChannel(ChatChannel):
if self.passive_reply:
receiver = context["receiver"]
self.cache_dict[receiver] = reply.content
self.running.remove(receiver)
logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply))
logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply))
else:
receiver = context["receiver"]
reply_text = reply.content
@@ -116,10 +115,15 @@ class WechatMPChannel(ChatChannel):
return


def _fail_callback(self, session_id, exception, context, **kwargs):
logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
assert session_id not in self.cache_dict
self.running.remove(session_id)
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id))
if self.passive_reply:
self.running.remove(session_id)


def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
if self.passive_reply:
assert session_id not in self.cache_dict
self.running.remove(session_id)


読み込み中…
キャンセル
保存