@@ -233,6 +233,9 @@ class ChatChannel(Channel): | |||||
time.sleep(3+3*retry_cnt) | time.sleep(3+3*retry_cnt) | ||||
self._send(reply, context, retry_cnt+1) | self._send(reply, context, retry_cnt+1) | ||||
def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数 | |||||
logger.debug("Worker return success, session_id = {}".format(session_id)) | |||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | ||||
logger.exception("Worker return exception: {}".format(exception)) | logger.exception("Worker return exception: {}".format(exception)) | ||||
@@ -242,6 +245,8 @@ class ChatChannel(Channel): | |||||
worker_exception = worker.exception() | worker_exception = worker.exception() | ||||
if worker_exception: | if worker_exception: | ||||
self._fail_callback(session_id, exception = worker_exception, **kwargs) | self._fail_callback(session_id, exception = worker_exception, **kwargs) | ||||
else: | |||||
self._success_callback(session_id, **kwargs) | |||||
except CancelledError as e: | except CancelledError as e: | ||||
logger.info("Worker cancelled, session_id = {}".format(session_id)) | logger.info("Worker cancelled, session_id = {}".format(session_id)) | ||||
except Exception as e: | except Exception as e: | ||||
@@ -16,7 +16,7 @@ class Query(): | |||||
def POST(self): | def POST(self): | ||||
# Make sure to return the instance that first created, @singleton will do that. | # Make sure to return the instance that first created, @singleton will do that. | ||||
channel_instance = WechatMPChannel() | |||||
channel = WechatMPChannel() | |||||
try: | try: | ||||
webData = web.data() | webData = web.data() | ||||
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) | # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) | ||||
@@ -27,14 +27,14 @@ class Query(): | |||||
message_id = wechatmp_msg.msg_id | message_id = wechatmp_msg.msg_id | ||||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) | logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) | ||||
context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | |||||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | |||||
if context: | if context: | ||||
# set private openai_api_key | # set private openai_api_key | ||||
# if from_user is not changed in itchat, this can be placed at chat_channel | # if from_user is not changed in itchat, this can be placed at chat_channel | ||||
user_data = conf().get_user_data(from_user) | user_data = conf().get_user_data(from_user) | ||||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key | context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key | ||||
channel_instance.produce(context) | |||||
# The reply will be sent by channel_instance.send() in another thread | |||||
channel.produce(context) | |||||
# The reply will be sent by channel.send() in another thread | |||||
return "success" | return "success" | ||||
elif wechatmp_msg.msg_type == 'event': | elif wechatmp_msg.msg_type == 'event': | ||||
@@ -41,7 +41,8 @@ class Query(): | |||||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | ||||
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg)) | logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg)) | ||||
if message_id in channel.received_msgs: # received and finished | if message_id in channel.received_msgs: # received and finished | ||||
return | |||||
# no return because of bandwords or other reasons | |||||
return "success" | |||||
if supported and context: | if supported and context: | ||||
# set private openai_api_key | # set private openai_api_key | ||||
# if from_user is not changed in itchat, this can be placed at chat_channel | # if from_user is not changed in itchat, this can be placed at chat_channel | ||||
@@ -71,11 +72,12 @@ class Query(): | |||||
channel.query1[cache_key] = False | channel.query1[cache_key] = False | ||||
channel.query2[cache_key] = False | channel.query2[cache_key] = False | ||||
channel.query3[cache_key] = False | channel.query3[cache_key] = False | ||||
# Request again | |||||
# User request again, and the answer is not ready | |||||
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True: | elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True: | ||||
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True. | channel.query1[cache_key] = False #To improve waiting experience, this can be set to True. | ||||
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True. | channel.query2[cache_key] = False #To improve waiting experience, this can be set to True. | ||||
channel.query3[cache_key] = False | channel.query3[cache_key] = False | ||||
# User request again, and the answer is ready | |||||
elif cache_key in channel.cache_dict: | elif cache_key in channel.cache_dict: | ||||
# Skip the waiting phase | # Skip the waiting phase | ||||
channel.query1[cache_key] = True | channel.query1[cache_key] = True | ||||
@@ -89,7 +91,7 @@ class Query(): | |||||
logger.debug("[wechatmp] query1 {}".format(cache_key)) | logger.debug("[wechatmp] query1 {}".format(cache_key)) | ||||
channel.query1[cache_key] = True | channel.query1[cache_key] = True | ||||
cnt = 0 | cnt = 0 | ||||
while cache_key not in channel.cache_dict and cnt < 45: | |||||
while cache_key in channel.running and cnt < 45: | |||||
cnt = cnt + 1 | cnt = cnt + 1 | ||||
time.sleep(0.1) | time.sleep(0.1) | ||||
if cnt == 45: | if cnt == 45: | ||||
@@ -104,7 +106,7 @@ class Query(): | |||||
logger.debug("[wechatmp] query2 {}".format(cache_key)) | logger.debug("[wechatmp] query2 {}".format(cache_key)) | ||||
channel.query2[cache_key] = True | channel.query2[cache_key] = True | ||||
cnt = 0 | cnt = 0 | ||||
while cache_key not in channel.cache_dict and cnt < 45: | |||||
while cache_key in channel.running and cnt < 45: | |||||
cnt = cnt + 1 | cnt = cnt + 1 | ||||
time.sleep(0.1) | time.sleep(0.1) | ||||
if cnt == 45: | if cnt == 45: | ||||
@@ -119,7 +121,7 @@ class Query(): | |||||
logger.debug("[wechatmp] query3 {}".format(cache_key)) | logger.debug("[wechatmp] query3 {}".format(cache_key)) | ||||
channel.query3[cache_key] = True | channel.query3[cache_key] = True | ||||
cnt = 0 | cnt = 0 | ||||
while cache_key not in channel.cache_dict and cnt < 40: | |||||
while cache_key in channel.running and cnt < 40: | |||||
cnt = cnt + 1 | cnt = cnt + 1 | ||||
time.sleep(0.1) | time.sleep(0.1) | ||||
if cnt == 40: | if cnt == 40: | ||||
@@ -132,12 +134,17 @@ class Query(): | |||||
else: | else: | ||||
pass | pass | ||||
if float(time.time()) - float(query_time) > 4.8: | |||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】" | |||||
logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id)) | |||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | |||||
return replyPost | |||||
if cache_key not in channel.cache_dict and cache_key not in channel.running: | |||||
# no return because of bandwords or other reasons | |||||
return "success" | |||||
# if float(time.time()) - float(query_time) > 4.8: | |||||
# reply_text = "【正在思考中,回复任意文字尝试获取回复】" | |||||
# logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id)) | |||||
# replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | |||||
# return replyPost | |||||
if cache_key in channel.cache_dict: | if cache_key in channel.cache_dict: | ||||
content = channel.cache_dict[cache_key] | content = channel.cache_dict[cache_key] | ||||
if len(content.encode('utf8'))<=MAX_UTF8_LEN: | if len(content.encode('utf8'))<=MAX_UTF8_LEN: | ||||
@@ -97,8 +97,7 @@ class WechatMPChannel(ChatChannel): | |||||
if self.passive_reply: | if self.passive_reply: | ||||
receiver = context["receiver"] | receiver = context["receiver"] | ||||
self.cache_dict[receiver] = reply.content | self.cache_dict[receiver] = reply.content | ||||
self.running.remove(receiver) | |||||
logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply)) | |||||
logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply)) | |||||
else: | else: | ||||
receiver = context["receiver"] | receiver = context["receiver"] | ||||
reply_text = reply.content | reply_text = reply.content | ||||
@@ -116,10 +115,15 @@ class WechatMPChannel(ChatChannel): | |||||
return | return | ||||
def _fail_callback(self, session_id, exception, context, **kwargs): | |||||
logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) | |||||
assert session_id not in self.cache_dict | |||||
self.running.remove(session_id) | |||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 | |||||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id)) | |||||
if self.passive_reply: | |||||
self.running.remove(session_id) | |||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | |||||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) | |||||
if self.passive_reply: | |||||
assert session_id not in self.cache_dict | |||||
self.running.remove(session_id) | |||||