@@ -243,9 +243,9 @@ class ChatChannel(Channel): | |||||
session_id = context['session_id'] | session_id = context['session_id'] | ||||
with self.lock: | with self.lock: | ||||
if session_id not in self.sessions: | if session_id not in self.sessions: | ||||
self.sessions[session_id] = (Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))) | |||||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))] | |||||
if context.type == ContextType.TEXT and context.content.startswith("#"): | if context.type == ContextType.TEXT and context.content.startswith("#"): | ||||
self.sessions[session_id][0].putleft(context) # 优先处理命令 | |||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令 | |||||
else: | else: | ||||
self.sessions[session_id][0].put(context) | self.sessions[session_id][0].put(context) | ||||
@@ -273,12 +273,26 @@ class ChatChannel(Channel): | |||||
semaphore.release() | semaphore.release() | ||||
time.sleep(0.1) | time.sleep(0.1) | ||||
def cancel(self, session_id): | |||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务 | |||||
def cancel_session(self, session_id): | |||||
with self.lock: | with self.lock: | ||||
if session_id in self.sessions: | if session_id in self.sessions: | ||||
for future in self.futures[session_id]: | for future in self.futures[session_id]: | ||||
future.cancel() | future.cancel() | ||||
self.sessions[session_id][0]=Dequeue() | |||||
cnt = self.sessions[session_id][0].qsize() | |||||
if cnt>0: | |||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) | |||||
self.sessions[session_id][0] = Dequeue() | |||||
def cancel_all_session(self): | |||||
with self.lock: | |||||
for session_id in self.sessions: | |||||
for future in self.futures[session_id]: | |||||
future.cancel() | |||||
cnt = self.sessions[session_id][0].qsize() | |||||
if cnt>0: | |||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) | |||||
self.sessions[session_id][0] = Dequeue() | |||||
def check_prefix(content, prefix_list): | def check_prefix(content, prefix_list): | ||||
@@ -146,6 +146,7 @@ class Godcmd(Plugin): | |||||
logger.debug("[Godcmd] on_handle_context. content: %s" % content) | logger.debug("[Godcmd] on_handle_context. content: %s" % content) | ||||
if content.startswith("#"): | if content.startswith("#"): | ||||
# msg = e_context['context']['msg'] | # msg = e_context['context']['msg'] | ||||
channel = e_context['channel'] | |||||
user = e_context['context']['receiver'] | user = e_context['context']['receiver'] | ||||
session_id = e_context['context']['session_id'] | session_id = e_context['context']['session_id'] | ||||
isgroup = e_context['context']['isgroup'] | isgroup = e_context['context']['isgroup'] | ||||
@@ -181,6 +182,7 @@ class Godcmd(Plugin): | |||||
elif cmd == "reset": | elif cmd == "reset": | ||||
if bottype in (const.CHATGPT, const.OPEN_AI): | if bottype in (const.CHATGPT, const.OPEN_AI): | ||||
bot.sessions.clear_session(session_id) | bot.sessions.clear_session(session_id) | ||||
channel.cancel_session(session_id) | |||||
ok, result = True, "会话已重置" | ok, result = True, "会话已重置" | ||||
else: | else: | ||||
ok, result = False, "当前对话机器人不支持重置会话" | ok, result = False, "当前对话机器人不支持重置会话" | ||||
@@ -202,6 +204,7 @@ class Godcmd(Plugin): | |||||
ok, result = True, "配置已重载" | ok, result = True, "配置已重载" | ||||
elif cmd == "resetall": | elif cmd == "resetall": | ||||
if bottype in (const.CHATGPT, const.OPEN_AI): | if bottype in (const.CHATGPT, const.OPEN_AI): | ||||
channel.cancel_all_session() | |||||
bot.sessions.clear_all_session() | bot.sessions.clear_all_session() | ||||
ok, result = True, "重置所有会话成功" | ok, result = True, "重置所有会话成功" | ||||
else: | else: | ||||