@@ -3,6 +3,7 @@ | |||||
import os | import os | ||||
import signal | import signal | ||||
import sys | import sys | ||||
import time | |||||
from channel import channel_factory | from channel import channel_factory | ||||
from common import const | from common import const | ||||
@@ -24,6 +25,21 @@ def sigterm_handler_wrap(_signo): | |||||
signal.signal(_signo, func) | signal.signal(_signo, func) | ||||
def start_channel(channel_name: str): | |||||
channel = channel_factory.create_channel(channel_name) | |||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", | |||||
const.FEISHU, const.DINGTALK]: | |||||
PluginManager().load_plugins() | |||||
if conf().get("use_linkai"): | |||||
try: | |||||
from common import linkai_client | |||||
threading.Thread(target=linkai_client.start, args=(channel,)).start() | |||||
except Exception as e: | |||||
pass | |||||
channel.startup() | |||||
def run(): | def run(): | ||||
try: | try: | ||||
# load config | # load config | ||||
@@ -41,22 +57,11 @@ def run(): | |||||
if channel_name == "wxy": | if channel_name == "wxy": | ||||
os.environ["WECHATY_LOG"] = "warn" | os.environ["WECHATY_LOG"] = "warn" | ||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' | |||||
channel = channel_factory.create_channel(channel_name) | |||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]: | |||||
PluginManager().load_plugins() | |||||
if conf().get("use_linkai"): | |||||
try: | |||||
from common import linkai_client | |||||
threading.Thread(target=linkai_client.start, args=(channel, )).start() | |||||
except Exception as e: | |||||
pass | |||||
# startup channel | |||||
channel.startup() | |||||
start_channel(channel_name) | |||||
while True: | |||||
time.sleep(1) | |||||
except Exception as e: | except Exception as e: | ||||
logger.error("App startup failed!") | logger.error("App startup failed!") | ||||
logger.exception(e) | logger.exception(e) | ||||
@@ -400,7 +400,7 @@ class LinkAIBot(Bot): | |||||
i += 1 | i += 1 | ||||
if url.endswith(".mp4"): | if url.endswith(".mp4"): | ||||
reply_type = ReplyType.VIDEO_URL | reply_type = ReplyType.VIDEO_URL | ||||
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx"): | |||||
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"): | |||||
reply_type = ReplyType.FILE | reply_type = ReplyType.FILE | ||||
url = _download_file(url) | url = _download_file(url) | ||||
if not url: | if not url: | ||||
@@ -4,6 +4,7 @@ import threading | |||||
import time | import time | ||||
from asyncio import CancelledError | from asyncio import CancelledError | ||||
from concurrent.futures import Future, ThreadPoolExecutor | from concurrent.futures import Future, ThreadPoolExecutor | ||||
from concurrent import futures | |||||
from bridge.context import * | from bridge.context import * | ||||
from bridge.reply import * | from bridge.reply import * | ||||
@@ -17,6 +18,8 @@ try: | |||||
except Exception as e: | except Exception as e: | ||||
pass | pass | ||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 | |||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑 | # 抽象类, 它包含了与消息通道无关的通用处理逻辑 | ||||
class ChatChannel(Channel): | class ChatChannel(Channel): | ||||
@@ -25,7 +28,6 @@ class ChatChannel(Channel): | |||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 | futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 | ||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 | sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 | ||||
lock = threading.Lock() # 用于控制对sessions的访问 | lock = threading.Lock() # 用于控制对sessions的访问 | ||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 | |||||
def __init__(self): | def __init__(self): | ||||
_thread = threading.Thread(target=self.consume) | _thread = threading.Thread(target=self.consume) | ||||
@@ -339,7 +341,7 @@ class ChatChannel(Channel): | |||||
if not context_queue.empty(): | if not context_queue.empty(): | ||||
context = context_queue.get() | context = context_queue.get() | ||||
logger.debug("[WX] consume context: {}".format(context)) | logger.debug("[WX] consume context: {}".format(context)) | ||||
future: Future = self.handler_pool.submit(self._handle, context) | |||||
future: Future = handler_pool.submit(self._handle, context) | |||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context)) | future.add_done_callback(self._thread_pool_callback(session_id, context=context)) | ||||
if session_id not in self.futures: | if session_id not in self.futures: | ||||
self.futures[session_id] = [] | self.futures[session_id] = [] | ||||
@@ -15,6 +15,7 @@ import requests | |||||
from bridge.context import * | from bridge.context import * | ||||
from bridge.reply import * | from bridge.reply import * | ||||
from channel.chat_channel import ChatChannel | from channel.chat_channel import ChatChannel | ||||
from channel import chat_channel | |||||
from channel.wechat.wechat_message import * | from channel.wechat.wechat_message import * | ||||
from common.expired_dict import ExpiredDict | from common.expired_dict import ExpiredDict | ||||
from common.log import logger | from common.log import logger | ||||
@@ -112,30 +113,39 @@ class WechatChannel(ChatChannel): | |||||
self.auto_login_times = 0 | self.auto_login_times = 0 | ||||
def startup(self): | def startup(self): | ||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||||
# login by scan QRCode | |||||
hotReload = conf().get("hot_reload", False) | |||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl") | |||||
itchat.auto_login( | |||||
enableCmdQR=2, | |||||
hotReload=hotReload, | |||||
statusStorageDir=status_path, | |||||
qrCallback=qrCallback, | |||||
exitCallback=self.exitCallback, | |||||
loginCallback=self.loginCallback | |||||
) | |||||
self.user_id = itchat.instance.storageClass.userName | |||||
self.name = itchat.instance.storageClass.nickName | |||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) | |||||
# start message listener | |||||
itchat.run() | |||||
try: | |||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||||
# login by scan QRCode | |||||
hotReload = conf().get("hot_reload", False) | |||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl") | |||||
itchat.auto_login( | |||||
enableCmdQR=2, | |||||
hotReload=hotReload, | |||||
statusStorageDir=status_path, | |||||
qrCallback=qrCallback, | |||||
exitCallback=self.exitCallback, | |||||
loginCallback=self.loginCallback | |||||
) | |||||
self.user_id = itchat.instance.storageClass.userName | |||||
self.name = itchat.instance.storageClass.nickName | |||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) | |||||
# start message listener | |||||
itchat.run() | |||||
except Exception as e: | |||||
logger.error(e) | |||||
def exitCallback(self): | def exitCallback(self): | ||||
_send_logout() | |||||
time.sleep(3) | |||||
self.auto_login_times += 1 | |||||
if self.auto_login_times < 100: | |||||
self.startup() | |||||
try: | |||||
from common.linkai_client import chat_client | |||||
if chat_client.client_id and conf().get("use_linkai"): | |||||
_send_logout() | |||||
time.sleep(2) | |||||
self.auto_login_times += 1 | |||||
if self.auto_login_times < 100: | |||||
chat_channel.handler_pool._shutdown = False | |||||
self.startup() | |||||
except Exception as e: | |||||
pass | |||||
def loginCallback(self): | def loginCallback(self): | ||||
logger.debug("Login success") | logger.debug("Login success") | ||||
@@ -259,7 +269,6 @@ def _send_login_success(): | |||||
def _send_logout(): | def _send_logout(): | ||||
try: | try: | ||||
from common.linkai_client import chat_client | from common.linkai_client import chat_client | ||||
time.sleep(2) | |||||
if chat_client.client_id: | if chat_client.client_id: | ||||
chat_client.send_logout() | chat_client.send_logout() | ||||
except Exception as e: | except Exception as e: | ||||
@@ -268,7 +277,6 @@ def _send_logout(): | |||||
def _send_qr_code(qrcode_list: list): | def _send_qr_code(qrcode_list: list): | ||||
try: | try: | ||||
from common.linkai_client import chat_client | from common.linkai_client import chat_client | ||||
time.sleep(2) | |||||
if chat_client.client_id: | if chat_client.client_id: | ||||
chat_client.send_qrcode(qrcode_list) | chat_client.send_qrcode(qrcode_list) | ||||
except Exception as e: | except Exception as e: | ||||
@@ -2,7 +2,9 @@ from bridge.context import Context, ContextType | |||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common.log import logger | from common.log import logger | ||||
from linkai import LinkAIClient, PushMsg | from linkai import LinkAIClient, PushMsg | ||||
from config import conf | |||||
from config import conf, pconf, plugin_config | |||||
from plugins import PluginManager | |||||
chat_client: LinkAIClient | chat_client: LinkAIClient | ||||
@@ -22,6 +24,29 @@ class ChatClient(LinkAIClient): | |||||
context["isgroup"] = push_msg.is_group | context["isgroup"] = push_msg.is_group | ||||
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context) | self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context) | ||||
def on_config(self, config: dict): | |||||
if not self.client_id: | |||||
return | |||||
logger.info(f"从控制台加载配置: {config}") | |||||
local_config = conf() | |||||
for key in local_config.keys(): | |||||
if config.get(key) is not None: | |||||
local_config[key] = config.get(key) | |||||
if config.get("reply_voice_mode"): | |||||
if config.get("reply_voice_mode") == "voice_reply_voice": | |||||
local_config["voice_reply_voice"] = True | |||||
elif config.get("reply_voice_mode") == "always_reply_voice": | |||||
local_config["always_reply_voice"] = True | |||||
# if config.get("admin_password") and plugin_config["Godcmd"]: | |||||
# plugin_config["Godcmd"]["password"] = config.get("admin_password") | |||||
# PluginManager().instances["Godcmd"].reload() | |||||
# if config.get("group_app_map") and pconf("linkai"): | |||||
# local_group_map = {} | |||||
# for mapping in config.get("group_app_map"): | |||||
# local_group_map[mapping.get("group_name")] = mapping.get("app_code") | |||||
# pconf("linkai")["group_app_map"] = local_group_map | |||||
# PluginManager().instances["linkai"].reload() | |||||
def start(channel): | def start(channel): | ||||
global chat_client | global chat_client | ||||
@@ -475,3 +475,11 @@ class Godcmd(Plugin): | |||||
if model == "gpt-4-turbo": | if model == "gpt-4-turbo": | ||||
return const.GPT4_TURBO_PREVIEW | return const.GPT4_TURBO_PREVIEW | ||||
return model | return model | ||||
def reload(self): | |||||
gconf = plugin_config[self.name] | |||||
if gconf: | |||||
if gconf.get("password"): | |||||
self.password = gconf["password"] | |||||
if gconf.get("admin_users"): | |||||
self.admin_users = gconf["admin_users"] |
@@ -46,3 +46,6 @@ class Plugin: | |||||
def get_help_text(self, **kwargs): | def get_help_text(self, **kwargs): | ||||
return "暂无帮助信息" | return "暂无帮助信息" | ||||
def reload(self): | |||||
pass |