diff --git a/app.py b/app.py index 7f766d3..ff2a6c7 100644 --- a/app.py +++ b/app.py @@ -3,6 +3,7 @@ import os import signal import sys +import time from channel import channel_factory from common import const @@ -24,6 +25,21 @@ def sigterm_handler_wrap(_signo): signal.signal(_signo, func) +def start_channel(channel_name: str): + channel = channel_factory.create_channel(channel_name) + if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", + const.FEISHU, const.DINGTALK]: + PluginManager().load_plugins() + + if conf().get("use_linkai"): + try: + from common import linkai_client + threading.Thread(target=linkai_client.start, args=(channel,)).start() + except Exception as e: + pass + channel.startup() + + def run(): try: # load config @@ -41,22 +57,11 @@ def run(): if channel_name == "wxy": os.environ["WECHATY_LOG"] = "warn" - # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' - - channel = channel_factory.create_channel(channel_name) - if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]: - PluginManager().load_plugins() - - if conf().get("use_linkai"): - try: - from common import linkai_client - threading.Thread(target=linkai_client.start, args=(channel, )).start() - except Exception as e: - pass - # startup channel - channel.startup() + start_channel(channel_name) + while True: + time.sleep(1) except Exception as e: logger.error("App startup failed!") logger.exception(e) diff --git a/bot/linkai/link_ai_bot.py b/bot/linkai/link_ai_bot.py index a0b92c1..d37d82a 100644 --- a/bot/linkai/link_ai_bot.py +++ b/bot/linkai/link_ai_bot.py @@ -400,7 +400,7 @@ class LinkAIBot(Bot): i += 1 if url.endswith(".mp4"): reply_type = ReplyType.VIDEO_URL - elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx"): + elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"): reply_type = ReplyType.FILE url = _download_file(url) if not url: diff --git a/channel/chat_channel.py b/channel/chat_channel.py index ba017af..fe71207 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -4,6 +4,7 @@ import threading import time from asyncio import CancelledError from concurrent.futures import Future, ThreadPoolExecutor +from concurrent import futures from bridge.context import * from bridge.reply import * @@ -17,6 +18,8 @@ try: except Exception as e: pass +handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 + # 抽象类, 它包含了与消息通道无关的通用处理逻辑 class ChatChannel(Channel): @@ -25,7 +28,6 @@ class ChatChannel(Channel): futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 lock = threading.Lock() # 用于控制对sessions的访问 - handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 def __init__(self): _thread = threading.Thread(target=self.consume) @@ -339,7 +341,7 @@ class ChatChannel(Channel): if not context_queue.empty(): context = context_queue.get() logger.debug("[WX] consume context: {}".format(context)) - future: Future = self.handler_pool.submit(self._handle, context) + future: Future = handler_pool.submit(self._handle, context) future.add_done_callback(self._thread_pool_callback(session_id, context=context)) if session_id not in self.futures: self.futures[session_id] = [] diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index fe14daa..717b068 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -15,6 +15,7 @@ import requests from bridge.context import * from bridge.reply import * from channel.chat_channel import ChatChannel +from channel import chat_channel from channel.wechat.wechat_message import * from common.expired_dict import ExpiredDict from common.log import logger @@ -112,30 +113,39 @@ class WechatChannel(ChatChannel): self.auto_login_times = 0 def startup(self): - itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 - # login by scan QRCode - hotReload = conf().get("hot_reload", False) - status_path = os.path.join(get_appdata_dir(), "itchat.pkl") - itchat.auto_login( - enableCmdQR=2, - hotReload=hotReload, - statusStorageDir=status_path, - qrCallback=qrCallback, - exitCallback=self.exitCallback, - loginCallback=self.loginCallback - ) - self.user_id = itchat.instance.storageClass.userName - self.name = itchat.instance.storageClass.nickName - logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) - # start message listener - itchat.run() + try: + itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 + # login by scan QRCode + hotReload = conf().get("hot_reload", False) + status_path = os.path.join(get_appdata_dir(), "itchat.pkl") + itchat.auto_login( + enableCmdQR=2, + hotReload=hotReload, + statusStorageDir=status_path, + qrCallback=qrCallback, + exitCallback=self.exitCallback, + loginCallback=self.loginCallback + ) + self.user_id = itchat.instance.storageClass.userName + self.name = itchat.instance.storageClass.nickName + logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) + # start message listener + itchat.run() + except Exception as e: + logger.error(e) def exitCallback(self): - _send_logout() - time.sleep(3) - self.auto_login_times += 1 - if self.auto_login_times < 100: - self.startup() + try: + from common.linkai_client import chat_client + if chat_client.client_id and conf().get("use_linkai"): + _send_logout() + time.sleep(2) + self.auto_login_times += 1 + if self.auto_login_times < 100: + chat_channel.handler_pool._shutdown = False + self.startup() + except Exception as e: + pass def loginCallback(self): logger.debug("Login success") @@ -259,7 +269,6 @@ def _send_login_success(): def _send_logout(): try: from common.linkai_client import chat_client - time.sleep(2) if chat_client.client_id: chat_client.send_logout() except Exception as e: @@ -268,7 +277,6 @@ def _send_logout(): def _send_qr_code(qrcode_list: list): try: from common.linkai_client import chat_client - time.sleep(2) if chat_client.client_id: chat_client.send_qrcode(qrcode_list) except Exception as e: diff --git a/common/linkai_client.py b/common/linkai_client.py index 80f3168..ad7d213 100644 --- a/common/linkai_client.py +++ b/common/linkai_client.py @@ -2,7 +2,9 @@ from bridge.context import Context, ContextType from bridge.reply import Reply, ReplyType from common.log import logger from linkai import LinkAIClient, PushMsg -from config import conf +from config import conf, pconf, plugin_config +from plugins import PluginManager + chat_client: LinkAIClient @@ -22,6 +24,29 @@ class ChatClient(LinkAIClient): context["isgroup"] = push_msg.is_group self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context) + def on_config(self, config: dict): + if not self.client_id: + return + logger.info(f"从控制台加载配置: {config}") + local_config = conf() + for key in local_config.keys(): + if config.get(key) is not None: + local_config[key] = config.get(key) + if config.get("reply_voice_mode"): + if config.get("reply_voice_mode") == "voice_reply_voice": + local_config["voice_reply_voice"] = True + elif config.get("reply_voice_mode") == "always_reply_voice": + local_config["always_reply_voice"] = True + # if config.get("admin_password") and plugin_config["Godcmd"]: + # plugin_config["Godcmd"]["password"] = config.get("admin_password") + # PluginManager().instances["Godcmd"].reload() + # if config.get("group_app_map") and pconf("linkai"): + # local_group_map = {} + # for mapping in config.get("group_app_map"): + # local_group_map[mapping.get("group_name")] = mapping.get("app_code") + # pconf("linkai")["group_app_map"] = local_group_map + # PluginManager().instances["linkai"].reload() + def start(channel): global chat_client diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 1c8b09c..a965a68 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -475,3 +475,11 @@ class Godcmd(Plugin): if model == "gpt-4-turbo": return const.GPT4_TURBO_PREVIEW return model + + def reload(self): + gconf = plugin_config[self.name] + if gconf: + if gconf.get("password"): + self.password = gconf["password"] + if gconf.get("admin_users"): + self.admin_users = gconf["admin_users"] diff --git a/plugins/plugin.py b/plugins/plugin.py index 801997b..f4c9618 100644 --- a/plugins/plugin.py +++ b/plugins/plugin.py @@ -46,3 +46,6 @@ class Plugin: def get_help_text(self, **kwargs): return "暂无帮助信息" + + def reload(self): + pass