@@ -29,4 +29,5 @@ plugins/banwords/lib/__pycache__ | |||
!plugins/hello | |||
!plugins/role | |||
!plugins/keyword | |||
!plugins/linkai | |||
!plugins/linkai | |||
client_config.json |
@@ -8,6 +8,7 @@ from channel import channel_factory | |||
from common import const | |||
from config import load_config | |||
from plugins import * | |||
import threading | |||
def sigterm_handler_wrap(_signo): | |||
@@ -46,8 +47,16 @@ def run(): | |||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU]: | |||
PluginManager().load_plugins() | |||
if conf().get("use_linkai"): | |||
try: | |||
from common import linkai_client | |||
threading.Thread(target=linkai_client.start, args=(channel, )).start() | |||
except Exception as e: | |||
pass | |||
# startup channel | |||
channel.startup() | |||
except Exception as e: | |||
logger.error("App startup failed!") | |||
logger.exception(e) | |||
@@ -17,7 +17,6 @@ import threading | |||
from common import memory, utils | |||
import base64 | |||
class LinkAIBot(Bot): | |||
# authentication failed | |||
AUTH_FAILED_CODE = 401 | |||
@@ -84,7 +83,6 @@ class LinkAIBot(Bot): | |||
if session_message[0].get("role") == "system": | |||
if app_code or model == "wenxin": | |||
session_message.pop(0) | |||
body = { | |||
"app_code": app_code, | |||
"messages": session_message, | |||
@@ -93,7 +91,25 @@ class LinkAIBot(Bot): | |||
"top_p": conf().get("top_p", 1), | |||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"session_id": session_id, | |||
"channel_type": conf().get("channel_type") | |||
} | |||
try: | |||
from linkai import LinkAIClient | |||
client_id = LinkAIClient.fetch_client_id() | |||
if client_id: | |||
body["client_id"] = client_id | |||
# start: client info deliver | |||
if context.kwargs.get("msg"): | |||
body["session_id"] = context.kwargs.get("msg").from_user_id | |||
if context.kwargs.get("msg").is_group: | |||
body["is_group"] = True | |||
body["group_name"] = context.kwargs.get("msg").from_user_nickname | |||
body["sender_name"] = context.kwargs.get("msg").actual_user_nickname | |||
else: | |||
body["sender_name"] = context.kwargs.get("msg").from_user_nickname | |||
except Exception as e: | |||
pass | |||
file_id = context.kwargs.get("file_id") | |||
if file_id: | |||
body["file_id"] = file_id | |||
@@ -230,7 +246,7 @@ class LinkAIBot(Bot): | |||
} | |||
if self.args.get("max_tokens"): | |||
body["max_tokens"] = self.args.get("max_tokens") | |||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} | |||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} | |||
# do http request | |||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") | |||
@@ -8,6 +8,7 @@ from bridge.reply import * | |||
class Channel(object): | |||
channel_type = "" | |||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE] | |||
def startup(self): | |||
@@ -2,43 +2,41 @@ | |||
channel factory | |||
""" | |||
from common import const | |||
from .channel import Channel | |||
def create_channel(channel_type): | |||
def create_channel(channel_type) -> Channel: | |||
""" | |||
create a channel instance | |||
:param channel_type: channel type code | |||
:return: channel instance | |||
""" | |||
ch = Channel() | |||
if channel_type == "wx": | |||
from channel.wechat.wechat_channel import WechatChannel | |||
return WechatChannel() | |||
ch = WechatChannel() | |||
elif channel_type == "wxy": | |||
from channel.wechat.wechaty_channel import WechatyChannel | |||
return WechatyChannel() | |||
ch = WechatyChannel() | |||
elif channel_type == "terminal": | |||
from channel.terminal.terminal_channel import TerminalChannel | |||
return TerminalChannel() | |||
ch = TerminalChannel() | |||
elif channel_type == "wechatmp": | |||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | |||
return WechatMPChannel(passive_reply=True) | |||
ch = WechatMPChannel(passive_reply=True) | |||
elif channel_type == "wechatmp_service": | |||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | |||
return WechatMPChannel(passive_reply=False) | |||
ch = WechatMPChannel(passive_reply=False) | |||
elif channel_type == "wechatcom_app": | |||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel | |||
return WechatComAppChannel() | |||
ch = WechatComAppChannel() | |||
elif channel_type == "wework": | |||
from channel.wework.wework_channel import WeworkChannel | |||
return WeworkChannel() | |||
ch = WeworkChannel() | |||
elif channel_type == const.FEISHU: | |||
from channel.feishu.feishu_channel import FeiShuChanel | |||
return FeiShuChanel() | |||
raise RuntimeError | |||
ch = FeiShuChanel() | |||
else: | |||
raise RuntimeError | |||
ch.channel_type = channel_type | |||
return ch |
@@ -51,10 +51,14 @@ class FeiShuChanel(ChatChannel): | |||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) | |||
def send(self, reply: Reply, context: Context): | |||
msg = context["msg"] | |||
msg = context.get("msg") | |||
is_group = context["isgroup"] | |||
if msg: | |||
access_token = msg.access_token | |||
else: | |||
access_token = self.fetch_access_token() | |||
headers = { | |||
"Authorization": "Bearer " + msg.access_token, | |||
"Authorization": "Bearer " + access_token, | |||
"Content-Type": "application/json", | |||
} | |||
msg_type = "text" | |||
@@ -63,7 +67,7 @@ class FeiShuChanel(ChatChannel): | |||
content_key = "text" | |||
if reply.type == ReplyType.IMAGE_URL: | |||
# 图片上传 | |||
reply_content = self._upload_image_url(reply.content, msg.access_token) | |||
reply_content = self._upload_image_url(reply.content, access_token) | |||
if not reply_content: | |||
logger.warning("[FeiShu] upload file failed") | |||
return | |||
@@ -79,7 +83,7 @@ class FeiShuChanel(ChatChannel): | |||
res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10)) | |||
else: | |||
url = "https://open.feishu.cn/open-apis/im/v1/messages" | |||
params = {"receive_id_type": context.get("receive_id_type")} | |||
params = {"receive_id_type": context.get("receive_id_type") or "open_id"} | |||
data = { | |||
"receive_id": context.get("receiver"), | |||
"msg_type": msg_type, | |||
@@ -109,6 +109,7 @@ class WechatChannel(ChatChannel): | |||
def __init__(self): | |||
super().__init__() | |||
self.receivedMsgs = ExpiredDict(60 * 60) | |||
self.auto_login_times = 0 | |||
def startup(self): | |||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||
@@ -120,6 +121,8 @@ class WechatChannel(ChatChannel): | |||
hotReload=hotReload, | |||
statusStorageDir=status_path, | |||
qrCallback=qrCallback, | |||
exitCallback=self.exitCallback, | |||
loginCallback=self.loginCallback | |||
) | |||
self.user_id = itchat.instance.storageClass.userName | |||
self.name = itchat.instance.storageClass.nickName | |||
@@ -127,6 +130,14 @@ class WechatChannel(ChatChannel): | |||
# start message listener | |||
itchat.run() | |||
def exitCallback(self): | |||
self.auto_login_times += 1 | |||
if self.auto_login_times < 100: | |||
self.startup() | |||
def loginCallback(self): | |||
pass | |||
# handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复 | |||
# Context包含了消息的所有信息,包括以下属性 | |||
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE | |||
@@ -0,0 +1,28 @@ | |||
from bridge.context import Context, ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from linkai import LinkAIClient, PushMsg | |||
from config import conf | |||
class ChatClient(LinkAIClient): | |||
def __init__(self, api_key, host, channel): | |||
super().__init__(api_key, host) | |||
self.channel = channel | |||
self.client_type = channel.channel_type | |||
def on_message(self, push_msg: PushMsg): | |||
session_id = push_msg.session_id | |||
msg_content = push_msg.msg_content | |||
logger.info(f"receive msg push, session_id={session_id}, msg_content={msg_content}") | |||
context = Context() | |||
context.type = ContextType.TEXT | |||
context["receiver"] = session_id | |||
context["isgroup"] = push_msg.is_group | |||
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context) | |||
def start(channel): | |||
client = ChatClient(api_key=conf().get("linkai_api_key"), | |||
host="link-ai.chat", channel=channel) | |||
client.start() |