diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 24d9cca..78b2775 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -1,9 +1,13 @@ - +from asyncio import CancelledError +import queue +from concurrent.futures import Future, ThreadPoolExecutor import os import re +import threading import time +from channel.chat_message import ChatMessage from common.expired_dict import ExpiredDict from channel.channel import Channel from bridge.reply import * @@ -20,8 +24,16 @@ except Exception as e: class ChatChannel(Channel): name = None # 登录的用户名 user_id = None # 登录的用户id + futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 + sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 + lock = threading.Lock() # 用于控制对sessions的访问 + handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 + def __init__(self): - pass + _thread = threading.Thread(target=self.consume) + _thread.setDaemon(True) + _thread.start() + # 根据消息构造context,消息内容相关的触发项写在这里 def _compose_context(self, ctype: ContextType, content, **kwargs): @@ -215,6 +227,57 @@ class ChatChannel(Channel): time.sleep(3+3*retry_cnt) self._send(reply, context, retry_cnt+1) + def thread_pool_callback(self, session_id): + def func(worker:Future): + try: + worker_exception = worker.exception() + if worker_exception: + logger.exception("Worker return exception: {}".format(worker_exception)) + except CancelledError as e: + logger.info("Worker cancelled, session_id = {}".format(session_id)) + except Exception as e: + logger.exception("Worker raise exception: {}".format(e)) + with self.lock: + self.sessions[session_id][1].release() + return func + + def produce(self, context: Context): + session_id = context['session_id'] + with self.lock: + if session_id not in self.sessions: + self.sessions[session_id] = (queue.Queue(), threading.BoundedSemaphore(1)) + self.sessions[session_id][0].put(context) + + # 消费者函数,单独线程,用于从消息队列中取出消息并处理 + def consume(self): + while True: + with self.lock: + session_ids = list(self.sessions.keys()) + for session_id in session_ids: + context_queue, semaphore = self.sessions[session_id] + if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除 + if not context_queue.empty(): + context = context_queue.get() + logger.debug("[WX] consume context: {}".format(context)) + future:Future = self.handler_pool.submit(self._handle, context) + future.add_done_callback(self.thread_pool_callback(session_id)) + if session_id not in self.futures: + self.futures[session_id] = [] + self.futures[session_id].append(future) + elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 + self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()] + assert len(self.futures[session_id]) == 0, "thread pool error" + del self.sessions[session_id] + else: + semaphore.release() + time.sleep(0.1) + + def cancel(self, session_id): + with self.lock: + if session_id in self.sessions: + for future in self.futures[session_id]: + future.cancel() + self.sessions[session_id][0]=queue.Queue() def check_prefix(content, prefix_list): diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index 3c90c2f..c70e056 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -5,6 +5,7 @@ wechat channel """ import os +import threading import requests import io import time @@ -17,18 +18,10 @@ from lib import itchat from lib.itchat.content import * from bridge.reply import * from bridge.context import * -from concurrent.futures import ThreadPoolExecutor from config import conf from common.time_check import time_checker from common.expired_dict import ExpiredDict from plugins import * -thread_pool = ThreadPoolExecutor(max_workers=8) - -def thread_pool_callback(worker): - worker_exception = worker.exception() - if worker_exception: - logger.exception("Worker return exception: {}".format(worker_exception)) - @itchat.msg_register(TEXT) def handler_single_msg(msg): @@ -73,7 +66,9 @@ def qrCallback(uuid,status,qrcode): try: from PIL import Image img = Image.open(io.BytesIO(qrcode)) - thread_pool.submit(img.show,"QRCode") + _thread = threading.Thread(target=img.show, args=("QRCode",)) + _thread.setDaemon(True) + _thread.start() except Exception as e: pass @@ -142,7 +137,7 @@ class WechatChannel(ChatChannel): logger.debug("[WX]receive voice msg: {}".format(cmsg.content)) context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg) if context: - thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) + self.produce(context) @time_checker @_check @@ -150,7 +145,7 @@ class WechatChannel(ChatChannel): logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg) if context: - thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) + self.produce(context) @time_checker @_check @@ -158,7 +153,7 @@ class WechatChannel(ChatChannel): logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg) if context: - thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) + self.produce(context) @time_checker @_check @@ -168,7 +163,7 @@ class WechatChannel(ChatChannel): logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg) if context: - thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) + self.produce(context) # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 def send(self, reply: Reply, context: Context): diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 85742bd..6478202 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -5,7 +5,6 @@ wechaty channel Python Wechaty - https://github.com/wechaty/python-wechaty """ import base64 -from concurrent.futures import ThreadPoolExecutor import os import time import asyncio @@ -18,21 +17,18 @@ from bridge.context import * from channel.chat_channel import ChatChannel from channel.wechat.wechaty_message import WechatyMessage from common.log import logger +from common.singleton import singleton from config import conf try: from voice.audio_convert import any_to_sil except Exception as e: pass -thread_pool = ThreadPoolExecutor(max_workers=8) -def thread_pool_callback(worker): - worker_exception = worker.exception() - if worker_exception: - logger.exception("Worker return exception: {}".format(worker_exception)) +@singleton class WechatyChannel(ChatChannel): def __init__(self): - pass + super().__init__() def startup(self): config = conf() @@ -41,6 +37,10 @@ class WechatyChannel(ChatChannel): asyncio.run(self.main()) async def main(self): + + loop = asyncio.get_event_loop() + #将asyncio的loop传入处理线程 + self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop) self.bot = Wechaty() self.bot.on('login', self.on_login) self.bot.on('message', self.on_message) @@ -122,8 +122,4 @@ class WechatyChannel(ChatChannel): context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg) if context: logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context)) - thread_pool.submit(self._handle_loop, context, asyncio.get_event_loop()).add_done_callback(thread_pool_callback) - - def _handle_loop(self,context,loop): - asyncio.set_event_loop(loop) - self._handle(context) \ No newline at end of file + self.produce(context) \ No newline at end of file