@@ -1,9 +1,13 @@ | |||||
from asyncio import CancelledError | |||||
import queue | |||||
from concurrent.futures import Future, ThreadPoolExecutor | |||||
import os | import os | ||||
import re | import re | ||||
import threading | |||||
import time | import time | ||||
from channel.chat_message import ChatMessage | |||||
from common.expired_dict import ExpiredDict | from common.expired_dict import ExpiredDict | ||||
from channel.channel import Channel | from channel.channel import Channel | ||||
from bridge.reply import * | from bridge.reply import * | ||||
@@ -20,8 +24,16 @@ except Exception as e: | |||||
class ChatChannel(Channel): | class ChatChannel(Channel): | ||||
name = None # 登录的用户名 | name = None # 登录的用户名 | ||||
user_id = None # 登录的用户id | user_id = None # 登录的用户id | ||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 | |||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 | |||||
lock = threading.Lock() # 用于控制对sessions的访问 | |||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 | |||||
def __init__(self): | def __init__(self): | ||||
pass | |||||
_thread = threading.Thread(target=self.consume) | |||||
_thread.setDaemon(True) | |||||
_thread.start() | |||||
# 根据消息构造context,消息内容相关的触发项写在这里 | # 根据消息构造context,消息内容相关的触发项写在这里 | ||||
def _compose_context(self, ctype: ContextType, content, **kwargs): | def _compose_context(self, ctype: ContextType, content, **kwargs): | ||||
@@ -215,6 +227,57 @@ class ChatChannel(Channel): | |||||
time.sleep(3+3*retry_cnt) | time.sleep(3+3*retry_cnt) | ||||
self._send(reply, context, retry_cnt+1) | self._send(reply, context, retry_cnt+1) | ||||
def thread_pool_callback(self, session_id): | |||||
def func(worker:Future): | |||||
try: | |||||
worker_exception = worker.exception() | |||||
if worker_exception: | |||||
logger.exception("Worker return exception: {}".format(worker_exception)) | |||||
except CancelledError as e: | |||||
logger.info("Worker cancelled, session_id = {}".format(session_id)) | |||||
except Exception as e: | |||||
logger.exception("Worker raise exception: {}".format(e)) | |||||
with self.lock: | |||||
self.sessions[session_id][1].release() | |||||
return func | |||||
def produce(self, context: Context): | |||||
session_id = context['session_id'] | |||||
with self.lock: | |||||
if session_id not in self.sessions: | |||||
self.sessions[session_id] = (queue.Queue(), threading.BoundedSemaphore(1)) | |||||
self.sessions[session_id][0].put(context) | |||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理 | |||||
def consume(self): | |||||
while True: | |||||
with self.lock: | |||||
session_ids = list(self.sessions.keys()) | |||||
for session_id in session_ids: | |||||
context_queue, semaphore = self.sessions[session_id] | |||||
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除 | |||||
if not context_queue.empty(): | |||||
context = context_queue.get() | |||||
logger.debug("[WX] consume context: {}".format(context)) | |||||
future:Future = self.handler_pool.submit(self._handle, context) | |||||
future.add_done_callback(self.thread_pool_callback(session_id)) | |||||
if session_id not in self.futures: | |||||
self.futures[session_id] = [] | |||||
self.futures[session_id].append(future) | |||||
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 | |||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()] | |||||
assert len(self.futures[session_id]) == 0, "thread pool error" | |||||
del self.sessions[session_id] | |||||
else: | |||||
semaphore.release() | |||||
time.sleep(0.1) | |||||
def cancel(self, session_id): | |||||
with self.lock: | |||||
if session_id in self.sessions: | |||||
for future in self.futures[session_id]: | |||||
future.cancel() | |||||
self.sessions[session_id][0]=queue.Queue() | |||||
def check_prefix(content, prefix_list): | def check_prefix(content, prefix_list): | ||||
@@ -5,6 +5,7 @@ wechat channel | |||||
""" | """ | ||||
import os | import os | ||||
import threading | |||||
import requests | import requests | ||||
import io | import io | ||||
import time | import time | ||||
@@ -17,18 +18,10 @@ from lib import itchat | |||||
from lib.itchat.content import * | from lib.itchat.content import * | ||||
from bridge.reply import * | from bridge.reply import * | ||||
from bridge.context import * | from bridge.context import * | ||||
from concurrent.futures import ThreadPoolExecutor | |||||
from config import conf | from config import conf | ||||
from common.time_check import time_checker | from common.time_check import time_checker | ||||
from common.expired_dict import ExpiredDict | from common.expired_dict import ExpiredDict | ||||
from plugins import * | from plugins import * | ||||
thread_pool = ThreadPoolExecutor(max_workers=8) | |||||
def thread_pool_callback(worker): | |||||
worker_exception = worker.exception() | |||||
if worker_exception: | |||||
logger.exception("Worker return exception: {}".format(worker_exception)) | |||||
@itchat.msg_register(TEXT) | @itchat.msg_register(TEXT) | ||||
def handler_single_msg(msg): | def handler_single_msg(msg): | ||||
@@ -73,7 +66,9 @@ def qrCallback(uuid,status,qrcode): | |||||
try: | try: | ||||
from PIL import Image | from PIL import Image | ||||
img = Image.open(io.BytesIO(qrcode)) | img = Image.open(io.BytesIO(qrcode)) | ||||
thread_pool.submit(img.show,"QRCode") | |||||
_thread = threading.Thread(target=img.show, args=("QRCode",)) | |||||
_thread.setDaemon(True) | |||||
_thread.start() | |||||
except Exception as e: | except Exception as e: | ||||
pass | pass | ||||
@@ -142,7 +137,7 @@ class WechatChannel(ChatChannel): | |||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content)) | logger.debug("[WX]receive voice msg: {}".format(cmsg.content)) | ||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg) | context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg) | ||||
if context: | if context: | ||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) | |||||
self.produce(context) | |||||
@time_checker | @time_checker | ||||
@_check | @_check | ||||
@@ -150,7 +145,7 @@ class WechatChannel(ChatChannel): | |||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | ||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg) | context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg) | ||||
if context: | if context: | ||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) | |||||
self.produce(context) | |||||
@time_checker | @time_checker | ||||
@_check | @_check | ||||
@@ -158,7 +153,7 @@ class WechatChannel(ChatChannel): | |||||
logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | ||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg) | context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg) | ||||
if context: | if context: | ||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) | |||||
self.produce(context) | |||||
@time_checker | @time_checker | ||||
@_check | @_check | ||||
@@ -168,7 +163,7 @@ class WechatChannel(ChatChannel): | |||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) | logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) | ||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg) | context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg) | ||||
if context: | if context: | ||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) | |||||
self.produce(context) | |||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | ||||
def send(self, reply: Reply, context: Context): | def send(self, reply: Reply, context: Context): | ||||
@@ -5,7 +5,6 @@ wechaty channel | |||||
Python Wechaty - https://github.com/wechaty/python-wechaty | Python Wechaty - https://github.com/wechaty/python-wechaty | ||||
""" | """ | ||||
import base64 | import base64 | ||||
from concurrent.futures import ThreadPoolExecutor | |||||
import os | import os | ||||
import time | import time | ||||
import asyncio | import asyncio | ||||
@@ -18,21 +17,18 @@ from bridge.context import * | |||||
from channel.chat_channel import ChatChannel | from channel.chat_channel import ChatChannel | ||||
from channel.wechat.wechaty_message import WechatyMessage | from channel.wechat.wechaty_message import WechatyMessage | ||||
from common.log import logger | from common.log import logger | ||||
from common.singleton import singleton | |||||
from config import conf | from config import conf | ||||
try: | try: | ||||
from voice.audio_convert import any_to_sil | from voice.audio_convert import any_to_sil | ||||
except Exception as e: | except Exception as e: | ||||
pass | pass | ||||
thread_pool = ThreadPoolExecutor(max_workers=8) | |||||
def thread_pool_callback(worker): | |||||
worker_exception = worker.exception() | |||||
if worker_exception: | |||||
logger.exception("Worker return exception: {}".format(worker_exception)) | |||||
@singleton | |||||
class WechatyChannel(ChatChannel): | class WechatyChannel(ChatChannel): | ||||
def __init__(self): | def __init__(self): | ||||
pass | |||||
super().__init__() | |||||
def startup(self): | def startup(self): | ||||
config = conf() | config = conf() | ||||
@@ -41,6 +37,10 @@ class WechatyChannel(ChatChannel): | |||||
asyncio.run(self.main()) | asyncio.run(self.main()) | ||||
async def main(self): | async def main(self): | ||||
loop = asyncio.get_event_loop() | |||||
#将asyncio的loop传入处理线程 | |||||
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop) | |||||
self.bot = Wechaty() | self.bot = Wechaty() | ||||
self.bot.on('login', self.on_login) | self.bot.on('login', self.on_login) | ||||
self.bot.on('message', self.on_message) | self.bot.on('message', self.on_message) | ||||
@@ -122,8 +122,4 @@ class WechatyChannel(ChatChannel): | |||||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg) | context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg) | ||||
if context: | if context: | ||||
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context)) | logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context)) | ||||
thread_pool.submit(self._handle_loop, context, asyncio.get_event_loop()).add_done_callback(thread_pool_callback) | |||||
def _handle_loop(self,context,loop): | |||||
asyncio.set_event_loop(loop) | |||||
self._handle(context) | |||||
self.produce(context) |