Browse Source

feat: optimize consumer thread pool

master
zhayujie 9 months ago
parent
commit
af5bc73dc0
7 changed files with 93 additions and 42 deletions
  1. +19
    -14
      app.py
  2. +1
    -1
      bot/linkai/link_ai_bot.py
  3. +4
    -2
      channel/chat_channel.py
  4. +32
    -24
      channel/wechat/wechat_channel.py
  5. +26
    -1
      common/linkai_client.py
  6. +8
    -0
      plugins/godcmd/godcmd.py
  7. +3
    -0
      plugins/plugin.py

+ 19
- 14
app.py View File

@@ -3,6 +3,7 @@
import os import os
import signal import signal
import sys import sys
import time


from channel import channel_factory from channel import channel_factory
from common import const from common import const
@@ -24,6 +25,21 @@ def sigterm_handler_wrap(_signo):
signal.signal(_signo, func) signal.signal(_signo, func)




def start_channel(channel_name: str):
channel = channel_factory.create_channel(channel_name)
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
const.FEISHU, const.DINGTALK]:
PluginManager().load_plugins()

if conf().get("use_linkai"):
try:
from common import linkai_client
threading.Thread(target=linkai_client.start, args=(channel,)).start()
except Exception as e:
pass
channel.startup()


def run(): def run():
try: try:
# load config # load config
@@ -41,22 +57,11 @@ def run():


if channel_name == "wxy": if channel_name == "wxy":
os.environ["WECHATY_LOG"] = "warn" os.environ["WECHATY_LOG"] = "warn"
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'

channel = channel_factory.create_channel(channel_name)
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]:
PluginManager().load_plugins()

if conf().get("use_linkai"):
try:
from common import linkai_client
threading.Thread(target=linkai_client.start, args=(channel, )).start()
except Exception as e:
pass


# startup channel
channel.startup()
start_channel(channel_name)


while True:
time.sleep(1)
except Exception as e: except Exception as e:
logger.error("App startup failed!") logger.error("App startup failed!")
logger.exception(e) logger.exception(e)


+ 1
- 1
bot/linkai/link_ai_bot.py View File

@@ -400,7 +400,7 @@ class LinkAIBot(Bot):
i += 1 i += 1
if url.endswith(".mp4"): if url.endswith(".mp4"):
reply_type = ReplyType.VIDEO_URL reply_type = ReplyType.VIDEO_URL
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx"):
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"):
reply_type = ReplyType.FILE reply_type = ReplyType.FILE
url = _download_file(url) url = _download_file(url)
if not url: if not url:


+ 4
- 2
channel/chat_channel.py View File

@@ -4,6 +4,7 @@ import threading
import time import time
from asyncio import CancelledError from asyncio import CancelledError
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from concurrent import futures


from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
@@ -17,6 +18,8 @@ try:
except Exception as e: except Exception as e:
pass pass


handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池



# 抽象类, 它包含了与消息通道无关的通用处理逻辑 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
class ChatChannel(Channel): class ChatChannel(Channel):
@@ -25,7 +28,6 @@ class ChatChannel(Channel):
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
lock = threading.Lock() # 用于控制对sessions的访问 lock = threading.Lock() # 用于控制对sessions的访问
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池


def __init__(self): def __init__(self):
_thread = threading.Thread(target=self.consume) _thread = threading.Thread(target=self.consume)
@@ -339,7 +341,7 @@ class ChatChannel(Channel):
if not context_queue.empty(): if not context_queue.empty():
context = context_queue.get() context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context)) logger.debug("[WX] consume context: {}".format(context))
future: Future = self.handler_pool.submit(self._handle, context)
future: Future = handler_pool.submit(self._handle, context)
future.add_done_callback(self._thread_pool_callback(session_id, context=context)) future.add_done_callback(self._thread_pool_callback(session_id, context=context))
if session_id not in self.futures: if session_id not in self.futures:
self.futures[session_id] = [] self.futures[session_id] = []


+ 32
- 24
channel/wechat/wechat_channel.py View File

@@ -15,6 +15,7 @@ import requests
from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
from channel.chat_channel import ChatChannel from channel.chat_channel import ChatChannel
from channel import chat_channel
from channel.wechat.wechat_message import * from channel.wechat.wechat_message import *
from common.expired_dict import ExpiredDict from common.expired_dict import ExpiredDict
from common.log import logger from common.log import logger
@@ -112,30 +113,39 @@ class WechatChannel(ChatChannel):
self.auto_login_times = 0 self.auto_login_times = 0


def startup(self): def startup(self):
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
# login by scan QRCode
hotReload = conf().get("hot_reload", False)
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
itchat.auto_login(
enableCmdQR=2,
hotReload=hotReload,
statusStorageDir=status_path,
qrCallback=qrCallback,
exitCallback=self.exitCallback,
loginCallback=self.loginCallback
)
self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
# start message listener
itchat.run()
try:
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
# login by scan QRCode
hotReload = conf().get("hot_reload", False)
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
itchat.auto_login(
enableCmdQR=2,
hotReload=hotReload,
statusStorageDir=status_path,
qrCallback=qrCallback,
exitCallback=self.exitCallback,
loginCallback=self.loginCallback
)
self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
# start message listener
itchat.run()
except Exception as e:
logger.error(e)


def exitCallback(self): def exitCallback(self):
_send_logout()
time.sleep(3)
self.auto_login_times += 1
if self.auto_login_times < 100:
self.startup()
try:
from common.linkai_client import chat_client
if chat_client.client_id and conf().get("use_linkai"):
_send_logout()
time.sleep(2)
self.auto_login_times += 1
if self.auto_login_times < 100:
chat_channel.handler_pool._shutdown = False
self.startup()
except Exception as e:
pass


def loginCallback(self): def loginCallback(self):
logger.debug("Login success") logger.debug("Login success")
@@ -259,7 +269,6 @@ def _send_login_success():
def _send_logout(): def _send_logout():
try: try:
from common.linkai_client import chat_client from common.linkai_client import chat_client
time.sleep(2)
if chat_client.client_id: if chat_client.client_id:
chat_client.send_logout() chat_client.send_logout()
except Exception as e: except Exception as e:
@@ -268,7 +277,6 @@ def _send_logout():
def _send_qr_code(qrcode_list: list): def _send_qr_code(qrcode_list: list):
try: try:
from common.linkai_client import chat_client from common.linkai_client import chat_client
time.sleep(2)
if chat_client.client_id: if chat_client.client_id:
chat_client.send_qrcode(qrcode_list) chat_client.send_qrcode(qrcode_list)
except Exception as e: except Exception as e:


+ 26
- 1
common/linkai_client.py View File

@@ -2,7 +2,9 @@ from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from common.log import logger from common.log import logger
from linkai import LinkAIClient, PushMsg from linkai import LinkAIClient, PushMsg
from config import conf
from config import conf, pconf, plugin_config
from plugins import PluginManager



chat_client: LinkAIClient chat_client: LinkAIClient


@@ -22,6 +24,29 @@ class ChatClient(LinkAIClient):
context["isgroup"] = push_msg.is_group context["isgroup"] = push_msg.is_group
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context) self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)


def on_config(self, config: dict):
if not self.client_id:
return
logger.info(f"从控制台加载配置: {config}")
local_config = conf()
for key in local_config.keys():
if config.get(key) is not None:
local_config[key] = config.get(key)
if config.get("reply_voice_mode"):
if config.get("reply_voice_mode") == "voice_reply_voice":
local_config["voice_reply_voice"] = True
elif config.get("reply_voice_mode") == "always_reply_voice":
local_config["always_reply_voice"] = True
# if config.get("admin_password") and plugin_config["Godcmd"]:
# plugin_config["Godcmd"]["password"] = config.get("admin_password")
# PluginManager().instances["Godcmd"].reload()
# if config.get("group_app_map") and pconf("linkai"):
# local_group_map = {}
# for mapping in config.get("group_app_map"):
# local_group_map[mapping.get("group_name")] = mapping.get("app_code")
# pconf("linkai")["group_app_map"] = local_group_map
# PluginManager().instances["linkai"].reload()



def start(channel): def start(channel):
global chat_client global chat_client


+ 8
- 0
plugins/godcmd/godcmd.py View File

@@ -475,3 +475,11 @@ class Godcmd(Plugin):
if model == "gpt-4-turbo": if model == "gpt-4-turbo":
return const.GPT4_TURBO_PREVIEW return const.GPT4_TURBO_PREVIEW
return model return model

def reload(self):
gconf = plugin_config[self.name]
if gconf:
if gconf.get("password"):
self.password = gconf["password"]
if gconf.get("admin_users"):
self.admin_users = gconf["admin_users"]

+ 3
- 0
plugins/plugin.py View File

@@ -46,3 +46,6 @@ class Plugin:


def get_help_text(self, **kwargs): def get_help_text(self, **kwargs):
return "暂无帮助信息" return "暂无帮助信息"

def reload(self):
pass

Loading…
Cancel
Save