@@ -233,6 +233,9 @@ class ChatChannel(Channel): | |||
time.sleep(3+3*retry_cnt) | |||
self._send(reply, context, retry_cnt+1) | |||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 | |||
pass | |||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | |||
logger.exception("Worker return exception: {}".format(exception)) | |||
@@ -242,6 +245,8 @@ class ChatChannel(Channel): | |||
worker_exception = worker.exception() | |||
if worker_exception: | |||
self._fail_callback(session_id, exception = worker_exception, **kwargs) | |||
else: | |||
self._success_callback(session_id, **kwargs) | |||
except CancelledError as e: | |||
logger.info("Worker cancelled, session_id = {}".format(session_id)) | |||
except Exception as e: | |||
@@ -16,7 +16,7 @@ class Query(): | |||
def POST(self): | |||
# Make sure to return the instance that first created, @singleton will do that. | |||
channel_instance = WechatMPChannel() | |||
channel = WechatMPChannel() | |||
try: | |||
webData = web.data() | |||
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) | |||
@@ -27,14 +27,15 @@ class Query(): | |||
message_id = wechatmp_msg.msg_id | |||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) | |||
context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | |||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | |||
if context: | |||
# set private openai_api_key | |||
# if from_user is not changed in itchat, this can be placed at chat_channel | |||
user_data = conf().get_user_data(from_user) | |||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key | |||
channel_instance.produce(context) | |||
# The reply will be sent by channel_instance.send() in another thread | |||
channel.produce(context) | |||
channel.running.add(from_user) | |||
# The reply will be sent by channel.send() in another thread | |||
return "success" | |||
elif wechatmp_msg.msg_type == 'event': | |||
@@ -41,7 +41,8 @@ class Query(): | |||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | |||
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg)) | |||
if message_id in channel.received_msgs: # received and finished | |||
return | |||
# no return because of bandwords or other reasons | |||
return "success" | |||
if supported and context: | |||
# set private openai_api_key | |||
# if from_user is not changed in itchat, this can be placed at chat_channel | |||
@@ -71,11 +72,12 @@ class Query(): | |||
channel.query1[cache_key] = False | |||
channel.query2[cache_key] = False | |||
channel.query3[cache_key] = False | |||
# Request again | |||
# User request again, and the answer is not ready | |||
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True: | |||
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True. | |||
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True. | |||
channel.query3[cache_key] = False | |||
# User request again, and the answer is ready | |||
elif cache_key in channel.cache_dict: | |||
# Skip the waiting phase | |||
channel.query1[cache_key] = True | |||
@@ -89,7 +91,7 @@ class Query(): | |||
logger.debug("[wechatmp] query1 {}".format(cache_key)) | |||
channel.query1[cache_key] = True | |||
cnt = 0 | |||
while cache_key not in channel.cache_dict and cnt < 45: | |||
while cache_key in channel.running and cnt < 45: | |||
cnt = cnt + 1 | |||
time.sleep(0.1) | |||
if cnt == 45: | |||
@@ -104,7 +106,7 @@ class Query(): | |||
logger.debug("[wechatmp] query2 {}".format(cache_key)) | |||
channel.query2[cache_key] = True | |||
cnt = 0 | |||
while cache_key not in channel.cache_dict and cnt < 45: | |||
while cache_key in channel.running and cnt < 45: | |||
cnt = cnt + 1 | |||
time.sleep(0.1) | |||
if cnt == 45: | |||
@@ -119,7 +121,7 @@ class Query(): | |||
logger.debug("[wechatmp] query3 {}".format(cache_key)) | |||
channel.query3[cache_key] = True | |||
cnt = 0 | |||
while cache_key not in channel.cache_dict and cnt < 40: | |||
while cache_key in channel.running and cnt < 40: | |||
cnt = cnt + 1 | |||
time.sleep(0.1) | |||
if cnt == 40: | |||
@@ -132,12 +134,17 @@ class Query(): | |||
else: | |||
pass | |||
if float(time.time()) - float(query_time) > 4.8: | |||
reply_text = "【正在思考中,回复任意文字尝试获取回复】" | |||
logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id)) | |||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | |||
return replyPost | |||
if cache_key not in channel.cache_dict and cache_key not in channel.running: | |||
# no return because of bandwords or other reasons | |||
return "success" | |||
# if float(time.time()) - float(query_time) > 4.8: | |||
# reply_text = "【正在思考中,回复任意文字尝试获取回复】" | |||
# logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id)) | |||
# replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | |||
# return replyPost | |||
if cache_key in channel.cache_dict: | |||
content = channel.cache_dict[cache_key] | |||
if len(content.encode('utf8'))<=MAX_UTF8_LEN: | |||
@@ -97,8 +97,7 @@ class WechatMPChannel(ChatChannel): | |||
if self.passive_reply: | |||
receiver = context["receiver"] | |||
self.cache_dict[receiver] = reply.content | |||
self.running.remove(receiver) | |||
logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply)) | |||
logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply)) | |||
else: | |||
receiver = context["receiver"] | |||
reply_text = reply.content | |||
@@ -115,11 +114,12 @@ class WechatMPChannel(ChatChannel): | |||
logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) | |||
return | |||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 | |||
self.running.remove(session_id) | |||
def _fail_callback(self, session_id, exception, context, **kwargs): | |||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | |||
logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) | |||
assert session_id not in self.cache_dict | |||
if self.passive_reply: | |||
assert session_id not in self.cache_dict | |||
self.running.remove(session_id) | |||