@@ -233,9 +233,6 @@ class ChatChannel(Channel): | |||||
time.sleep(3+3*retry_cnt) | time.sleep(3+3*retry_cnt) | ||||
self._send(reply, context, retry_cnt+1) | self._send(reply, context, retry_cnt+1) | ||||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 | |||||
pass | |||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | ||||
logger.exception("Worker return exception: {}".format(exception)) | logger.exception("Worker return exception: {}".format(exception)) | ||||
@@ -245,8 +242,6 @@ class ChatChannel(Channel): | |||||
worker_exception = worker.exception() | worker_exception = worker.exception() | ||||
if worker_exception: | if worker_exception: | ||||
self._fail_callback(session_id, exception = worker_exception, **kwargs) | self._fail_callback(session_id, exception = worker_exception, **kwargs) | ||||
else: | |||||
self._success_callback(session_id, **kwargs) | |||||
except CancelledError as e: | except CancelledError as e: | ||||
logger.info("Worker cancelled, session_id = {}".format(session_id)) | logger.info("Worker cancelled, session_id = {}".format(session_id)) | ||||
except Exception as e: | except Exception as e: | ||||
@@ -34,7 +34,6 @@ class Query(): | |||||
user_data = conf().get_user_data(from_user) | user_data = conf().get_user_data(from_user) | ||||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key | context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key | ||||
channel.produce(context) | channel.produce(context) | ||||
channel.running.add(from_user) | |||||
# The reply will be sent by channel.send() in another thread | # The reply will be sent by channel.send() in another thread | ||||
return "success" | return "success" | ||||
@@ -97,6 +97,7 @@ class WechatMPChannel(ChatChannel): | |||||
if self.passive_reply: | if self.passive_reply: | ||||
receiver = context["receiver"] | receiver = context["receiver"] | ||||
self.cache_dict[receiver] = reply.content | self.cache_dict[receiver] = reply.content | ||||
self.running.remove(receiver) | |||||
logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply)) | logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply)) | ||||
else: | else: | ||||
receiver = context["receiver"] | receiver = context["receiver"] | ||||
@@ -114,12 +115,10 @@ class WechatMPChannel(ChatChannel): | |||||
logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) | logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) | ||||
return | return | ||||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 | |||||
self.running.remove(session_id) | |||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | ||||
logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) | logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) | ||||
if self.passive_reply: | if self.passive_reply: | ||||
assert session_id not in self.cache_dict | assert session_id not in self.cache_dict | ||||
self.running.remove(session_id) | |||||
self.running.remove(session_id) | |||||