diff --git a/.gitignore b/.gitignore index 9b3bcdf..560e615 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,5 @@ plugins/banwords/lib/__pycache__ !plugins/hello !plugins/role !plugins/keyword -!plugins/linkai \ No newline at end of file +!plugins/linkai +client_config.json diff --git a/app.py b/app.py index 19acdcd..3c337b6 100644 --- a/app.py +++ b/app.py @@ -8,6 +8,7 @@ from channel import channel_factory from common import const from config import load_config from plugins import * +import threading def sigterm_handler_wrap(_signo): @@ -46,8 +47,16 @@ def run(): if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU]: PluginManager().load_plugins() + if conf().get("use_linkai"): + try: + from common import linkai_client + threading.Thread(target=linkai_client.start, args=(channel, )).start() + except Exception as e: + pass + # startup channel channel.startup() + except Exception as e: logger.error("App startup failed!") logger.exception(e) diff --git a/bot/linkai/link_ai_bot.py b/bot/linkai/link_ai_bot.py index c9afa71..fc22ba8 100644 --- a/bot/linkai/link_ai_bot.py +++ b/bot/linkai/link_ai_bot.py @@ -17,7 +17,6 @@ import threading from common import memory, utils import base64 - class LinkAIBot(Bot): # authentication failed AUTH_FAILED_CODE = 401 @@ -84,7 +83,6 @@ class LinkAIBot(Bot): if session_message[0].get("role") == "system": if app_code or model == "wenxin": session_message.pop(0) - body = { "app_code": app_code, "messages": session_message, @@ -93,7 +91,25 @@ class LinkAIBot(Bot): "top_p": conf().get("top_p", 1), "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "session_id": session_id, + "channel_type": conf().get("channel_type") } + try: + from linkai import LinkAIClient + client_id = LinkAIClient.fetch_client_id() + if client_id: + body["client_id"] = client_id + # start: client info deliver + if context.kwargs.get("msg"): + body["session_id"] = context.kwargs.get("msg").from_user_id + if context.kwargs.get("msg").is_group: + body["is_group"] = True + body["group_name"] = context.kwargs.get("msg").from_user_nickname + body["sender_name"] = context.kwargs.get("msg").actual_user_nickname + else: + body["sender_name"] = context.kwargs.get("msg").from_user_nickname + except Exception as e: + pass file_id = context.kwargs.get("file_id") if file_id: body["file_id"] = file_id @@ -230,7 +246,7 @@ class LinkAIBot(Bot): } if self.args.get("max_tokens"): body["max_tokens"] = self.args.get("max_tokens") - headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} + headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} # do http request base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") diff --git a/channel/channel.py b/channel/channel.py index 6464d77..c225342 100644 --- a/channel/channel.py +++ b/channel/channel.py @@ -8,6 +8,7 @@ from bridge.reply import * class Channel(object): + channel_type = "" NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE] def startup(self): diff --git a/channel/channel_factory.py b/channel/channel_factory.py index 7044b9a..5a49eea 100644 --- a/channel/channel_factory.py +++ b/channel/channel_factory.py @@ -2,43 +2,41 @@ channel factory """ from common import const +from .channel import Channel -def create_channel(channel_type): + +def create_channel(channel_type) -> Channel: """ create a channel instance :param channel_type: channel type code :return: channel instance """ + ch = Channel() if channel_type == "wx": from channel.wechat.wechat_channel import WechatChannel - - return WechatChannel() + ch = WechatChannel() elif channel_type == "wxy": from channel.wechat.wechaty_channel import WechatyChannel - - return WechatyChannel() + ch = WechatyChannel() elif channel_type == "terminal": from channel.terminal.terminal_channel import TerminalChannel - - return TerminalChannel() + ch = TerminalChannel() elif channel_type == "wechatmp": from channel.wechatmp.wechatmp_channel import WechatMPChannel - - return WechatMPChannel(passive_reply=True) + ch = WechatMPChannel(passive_reply=True) elif channel_type == "wechatmp_service": from channel.wechatmp.wechatmp_channel import WechatMPChannel - - return WechatMPChannel(passive_reply=False) + ch = WechatMPChannel(passive_reply=False) elif channel_type == "wechatcom_app": from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel - - return WechatComAppChannel() + ch = WechatComAppChannel() elif channel_type == "wework": from channel.wework.wework_channel import WeworkChannel - return WeworkChannel() - + ch = WeworkChannel() elif channel_type == const.FEISHU: from channel.feishu.feishu_channel import FeiShuChanel - return FeiShuChanel() - - raise RuntimeError + ch = FeiShuChanel() + else: + raise RuntimeError + ch.channel_type = channel_type + return ch diff --git a/channel/feishu/feishu_channel.py b/channel/feishu/feishu_channel.py index 85e40d7..76fbbf1 100644 --- a/channel/feishu/feishu_channel.py +++ b/channel/feishu/feishu_channel.py @@ -51,10 +51,14 @@ class FeiShuChanel(ChatChannel): web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) def send(self, reply: Reply, context: Context): - msg = context["msg"] + msg = context.get("msg") is_group = context["isgroup"] + if msg: + access_token = msg.access_token + else: + access_token = self.fetch_access_token() headers = { - "Authorization": "Bearer " + msg.access_token, + "Authorization": "Bearer " + access_token, "Content-Type": "application/json", } msg_type = "text" @@ -63,7 +67,7 @@ class FeiShuChanel(ChatChannel): content_key = "text" if reply.type == ReplyType.IMAGE_URL: # 图片上传 - reply_content = self._upload_image_url(reply.content, msg.access_token) + reply_content = self._upload_image_url(reply.content, access_token) if not reply_content: logger.warning("[FeiShu] upload file failed") return @@ -79,7 +83,7 @@ class FeiShuChanel(ChatChannel): res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10)) else: url = "https://open.feishu.cn/open-apis/im/v1/messages" - params = {"receive_id_type": context.get("receive_id_type")} + params = {"receive_id_type": context.get("receive_id_type") or "open_id"} data = { "receive_id": context.get("receiver"), "msg_type": msg_type, diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index db77d83..5e616f3 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -109,6 +109,7 @@ class WechatChannel(ChatChannel): def __init__(self): super().__init__() self.receivedMsgs = ExpiredDict(60 * 60) + self.auto_login_times = 0 def startup(self): itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 @@ -120,6 +121,8 @@ class WechatChannel(ChatChannel): hotReload=hotReload, statusStorageDir=status_path, qrCallback=qrCallback, + exitCallback=self.exitCallback, + loginCallback=self.loginCallback ) self.user_id = itchat.instance.storageClass.userName self.name = itchat.instance.storageClass.nickName @@ -127,6 +130,14 @@ class WechatChannel(ChatChannel): # start message listener itchat.run() + def exitCallback(self): + self.auto_login_times += 1 + if self.auto_login_times < 100: + self.startup() + + def loginCallback(self): + pass + # handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复 # Context包含了消息的所有信息,包括以下属性 # type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE diff --git a/common/linkai_client.py b/common/linkai_client.py new file mode 100644 index 0000000..a6469f7 --- /dev/null +++ b/common/linkai_client.py @@ -0,0 +1,28 @@ +from bridge.context import Context, ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from linkai import LinkAIClient, PushMsg +from config import conf + + +class ChatClient(LinkAIClient): + def __init__(self, api_key, host, channel): + super().__init__(api_key, host) + self.channel = channel + self.client_type = channel.channel_type + + def on_message(self, push_msg: PushMsg): + session_id = push_msg.session_id + msg_content = push_msg.msg_content + logger.info(f"receive msg push, session_id={session_id}, msg_content={msg_content}") + context = Context() + context.type = ContextType.TEXT + context["receiver"] = session_id + context["isgroup"] = push_msg.is_group + self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context) + + +def start(channel): + client = ChatClient(api_key=conf().get("linkai_api_key"), + host="link-ai.chat", channel=channel) + client.start()