@@ -3,6 +3,7 @@ | |||
import os | |||
import signal | |||
import sys | |||
import time | |||
from channel import channel_factory | |||
from common import const | |||
@@ -24,6 +25,21 @@ def sigterm_handler_wrap(_signo): | |||
signal.signal(_signo, func) | |||
def start_channel(channel_name: str): | |||
channel = channel_factory.create_channel(channel_name) | |||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", | |||
const.FEISHU, const.DINGTALK]: | |||
PluginManager().load_plugins() | |||
if conf().get("use_linkai"): | |||
try: | |||
from common import linkai_client | |||
threading.Thread(target=linkai_client.start, args=(channel,)).start() | |||
except Exception as e: | |||
pass | |||
channel.startup() | |||
def run(): | |||
try: | |||
# load config | |||
@@ -41,22 +57,11 @@ def run(): | |||
if channel_name == "wxy": | |||
os.environ["WECHATY_LOG"] = "warn" | |||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' | |||
channel = channel_factory.create_channel(channel_name) | |||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]: | |||
PluginManager().load_plugins() | |||
if conf().get("use_linkai"): | |||
try: | |||
from common import linkai_client | |||
threading.Thread(target=linkai_client.start, args=(channel, )).start() | |||
except Exception as e: | |||
pass | |||
# startup channel | |||
channel.startup() | |||
start_channel(channel_name) | |||
while True: | |||
time.sleep(1) | |||
except Exception as e: | |||
logger.error("App startup failed!") | |||
logger.exception(e) | |||
@@ -400,7 +400,7 @@ class LinkAIBot(Bot): | |||
i += 1 | |||
if url.endswith(".mp4"): | |||
reply_type = ReplyType.VIDEO_URL | |||
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx"): | |||
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"): | |||
reply_type = ReplyType.FILE | |||
url = _download_file(url) | |||
if not url: | |||
@@ -4,6 +4,7 @@ import threading | |||
import time | |||
from asyncio import CancelledError | |||
from concurrent.futures import Future, ThreadPoolExecutor | |||
from concurrent import futures | |||
from bridge.context import * | |||
from bridge.reply import * | |||
@@ -17,6 +18,8 @@ try: | |||
except Exception as e: | |||
pass | |||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 | |||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑 | |||
class ChatChannel(Channel): | |||
@@ -25,7 +28,6 @@ class ChatChannel(Channel): | |||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 | |||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 | |||
lock = threading.Lock() # 用于控制对sessions的访问 | |||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 | |||
def __init__(self): | |||
_thread = threading.Thread(target=self.consume) | |||
@@ -339,7 +341,7 @@ class ChatChannel(Channel): | |||
if not context_queue.empty(): | |||
context = context_queue.get() | |||
logger.debug("[WX] consume context: {}".format(context)) | |||
future: Future = self.handler_pool.submit(self._handle, context) | |||
future: Future = handler_pool.submit(self._handle, context) | |||
future.add_done_callback(self._thread_pool_callback(session_id, context=context)) | |||
if session_id not in self.futures: | |||
self.futures[session_id] = [] | |||
@@ -15,6 +15,7 @@ import requests | |||
from bridge.context import * | |||
from bridge.reply import * | |||
from channel.chat_channel import ChatChannel | |||
from channel import chat_channel | |||
from channel.wechat.wechat_message import * | |||
from common.expired_dict import ExpiredDict | |||
from common.log import logger | |||
@@ -112,30 +113,39 @@ class WechatChannel(ChatChannel): | |||
self.auto_login_times = 0 | |||
def startup(self): | |||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||
# login by scan QRCode | |||
hotReload = conf().get("hot_reload", False) | |||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl") | |||
itchat.auto_login( | |||
enableCmdQR=2, | |||
hotReload=hotReload, | |||
statusStorageDir=status_path, | |||
qrCallback=qrCallback, | |||
exitCallback=self.exitCallback, | |||
loginCallback=self.loginCallback | |||
) | |||
self.user_id = itchat.instance.storageClass.userName | |||
self.name = itchat.instance.storageClass.nickName | |||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) | |||
# start message listener | |||
itchat.run() | |||
try: | |||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||
# login by scan QRCode | |||
hotReload = conf().get("hot_reload", False) | |||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl") | |||
itchat.auto_login( | |||
enableCmdQR=2, | |||
hotReload=hotReload, | |||
statusStorageDir=status_path, | |||
qrCallback=qrCallback, | |||
exitCallback=self.exitCallback, | |||
loginCallback=self.loginCallback | |||
) | |||
self.user_id = itchat.instance.storageClass.userName | |||
self.name = itchat.instance.storageClass.nickName | |||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) | |||
# start message listener | |||
itchat.run() | |||
except Exception as e: | |||
logger.error(e) | |||
def exitCallback(self): | |||
_send_logout() | |||
time.sleep(3) | |||
self.auto_login_times += 1 | |||
if self.auto_login_times < 100: | |||
self.startup() | |||
try: | |||
from common.linkai_client import chat_client | |||
if chat_client.client_id and conf().get("use_linkai"): | |||
_send_logout() | |||
time.sleep(2) | |||
self.auto_login_times += 1 | |||
if self.auto_login_times < 100: | |||
chat_channel.handler_pool._shutdown = False | |||
self.startup() | |||
except Exception as e: | |||
pass | |||
def loginCallback(self): | |||
logger.debug("Login success") | |||
@@ -259,7 +269,6 @@ def _send_login_success(): | |||
def _send_logout(): | |||
try: | |||
from common.linkai_client import chat_client | |||
time.sleep(2) | |||
if chat_client.client_id: | |||
chat_client.send_logout() | |||
except Exception as e: | |||
@@ -268,7 +277,6 @@ def _send_logout(): | |||
def _send_qr_code(qrcode_list: list): | |||
try: | |||
from common.linkai_client import chat_client | |||
time.sleep(2) | |||
if chat_client.client_id: | |||
chat_client.send_qrcode(qrcode_list) | |||
except Exception as e: | |||
@@ -2,7 +2,9 @@ from bridge.context import Context, ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from linkai import LinkAIClient, PushMsg | |||
from config import conf | |||
from config import conf, pconf, plugin_config | |||
from plugins import PluginManager | |||
chat_client: LinkAIClient | |||
@@ -22,6 +24,29 @@ class ChatClient(LinkAIClient): | |||
context["isgroup"] = push_msg.is_group | |||
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context) | |||
def on_config(self, config: dict): | |||
if not self.client_id: | |||
return | |||
logger.info(f"从控制台加载配置: {config}") | |||
local_config = conf() | |||
for key in local_config.keys(): | |||
if config.get(key) is not None: | |||
local_config[key] = config.get(key) | |||
if config.get("reply_voice_mode"): | |||
if config.get("reply_voice_mode") == "voice_reply_voice": | |||
local_config["voice_reply_voice"] = True | |||
elif config.get("reply_voice_mode") == "always_reply_voice": | |||
local_config["always_reply_voice"] = True | |||
# if config.get("admin_password") and plugin_config["Godcmd"]: | |||
# plugin_config["Godcmd"]["password"] = config.get("admin_password") | |||
# PluginManager().instances["Godcmd"].reload() | |||
# if config.get("group_app_map") and pconf("linkai"): | |||
# local_group_map = {} | |||
# for mapping in config.get("group_app_map"): | |||
# local_group_map[mapping.get("group_name")] = mapping.get("app_code") | |||
# pconf("linkai")["group_app_map"] = local_group_map | |||
# PluginManager().instances["linkai"].reload() | |||
def start(channel): | |||
global chat_client | |||
@@ -475,3 +475,11 @@ class Godcmd(Plugin): | |||
if model == "gpt-4-turbo": | |||
return const.GPT4_TURBO_PREVIEW | |||
return model | |||
def reload(self): | |||
gconf = plugin_config[self.name] | |||
if gconf: | |||
if gconf.get("password"): | |||
self.password = gconf["password"] | |||
if gconf.get("admin_users"): | |||
self.admin_users = gconf["admin_users"] |
@@ -46,3 +46,6 @@ class Plugin: | |||
def get_help_text(self, **kwargs): | |||
return "暂无帮助信息" | |||
def reload(self): | |||
pass |