From 28eb67bc24b8f4f90a60a9c94a9e701bd978b00d Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 4 Apr 2023 14:57:38 +0800 Subject: [PATCH] feat: reset will cancel unprocessed messages --- channel/chat_channel.py | 22 ++++++++++++++++++---- plugins/godcmd/godcmd.py | 3 +++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/channel/chat_channel.py b/channel/chat_channel.py index af2c299..f178130 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -243,9 +243,9 @@ class ChatChannel(Channel): session_id = context['session_id'] with self.lock: if session_id not in self.sessions: - self.sessions[session_id] = (Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))) + self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))] if context.type == ContextType.TEXT and context.content.startswith("#"): - self.sessions[session_id][0].putleft(context) # 优先处理命令 + self.sessions[session_id][0].putleft(context) # 优先处理管理命令 else: self.sessions[session_id][0].put(context) @@ -273,12 +273,26 @@ class ChatChannel(Channel): semaphore.release() time.sleep(0.1) - def cancel(self, session_id): + # 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务 + def cancel_session(self, session_id): with self.lock: if session_id in self.sessions: for future in self.futures[session_id]: future.cancel() - self.sessions[session_id][0]=Dequeue() + cnt = self.sessions[session_id][0].qsize() + if cnt>0: + logger.info("Cancel {} messages in session {}".format(cnt, session_id)) + self.sessions[session_id][0] = Dequeue() + + def cancel_all_session(self): + with self.lock: + for session_id in self.sessions: + for future in self.futures[session_id]: + future.cancel() + cnt = self.sessions[session_id][0].qsize() + if cnt>0: + logger.info("Cancel {} messages in session {}".format(cnt, session_id)) + self.sessions[session_id][0] = Dequeue() def check_prefix(content, prefix_list): diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 0d574d0..d29e7fc 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -146,6 +146,7 @@ class Godcmd(Plugin): logger.debug("[Godcmd] on_handle_context. content: %s" % content) if content.startswith("#"): # msg = e_context['context']['msg'] + channel = e_context['channel'] user = e_context['context']['receiver'] session_id = e_context['context']['session_id'] isgroup = e_context['context']['isgroup'] @@ -181,6 +182,7 @@ class Godcmd(Plugin): elif cmd == "reset": if bottype in (const.CHATGPT, const.OPEN_AI): bot.sessions.clear_session(session_id) + channel.cancel_session(session_id) ok, result = True, "会话已重置" else: ok, result = False, "当前对话机器人不支持重置会话" @@ -202,6 +204,7 @@ class Godcmd(Plugin): ok, result = True, "配置已重载" elif cmd == "resetall": if bottype in (const.CHATGPT, const.OPEN_AI): + channel.cancel_all_session() bot.sessions.clear_all_session() ok, result = True, "重置所有会话成功" else: