From a086f1989f6b0c43bc063168f22281d1d05d64d2 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Fri, 25 Aug 2023 16:06:55 +0800 Subject: [PATCH] feat: add xunfei spark bot --- README.md | 15 +- bot/bot_factory.py | 10 +- bot/chatgpt/chat_gpt_session.py | 15 +- bot/xunfei/xunfei_spark_bot.py | 246 ++++++++++++++++++++++++++++++++ bridge/bridge.py | 2 + common/const.py | 1 + config.py | 12 +- plugins/godcmd/godcmd.py | 5 +- 8 files changed, 290 insertions(+), 16 deletions(-) create mode 100644 bot/xunfei/xunfei_spark_bot.py diff --git a/README.md b/README.md index 8bbd9c0..934a944 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ 最新版本支持的功能如下: - [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信,微信公众号和企业微信应用等部署方式 -- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3, GPT-3.5, GPT-4, 文心一言模型 +- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3, GPT-3.5, GPT-4, 文心一言, 讯飞星火 - [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai等多种语音模型 - [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dell-E, stable diffusion, replicate, midjourney模型 - [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结等插件 @@ -113,7 +113,7 @@ pip3 install azure-cognitiveservices-speech # config.json文件内容示例 { "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY - "model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 + "model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei "proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890" "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 @@ -129,7 +129,10 @@ pip3 install azure-cognitiveservices-speech "azure_api_version": "", # 采用Azure ChatGPT时,API版本 "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 # 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。 - "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。" + "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。", + "use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ + "linkai_api_key": "", # LinkAI Api Key + "linkai_app_code": "" # LinkAI 应用code } ``` **配置说明:** @@ -166,6 +169,12 @@ pip3 install azure-cognitiveservices-speech + `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43)) + `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。 +**5.LinkAI配置 (可选)** + ++ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat) ++ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://chat.link-ai.tech/console/interface) 创建 ++ `linkai_app_code`: LinkAI 应用code,选填 + **本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。** ## 运行 diff --git a/bot/bot_factory.py b/bot/bot_factory.py index e0e07e4..513eb78 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -14,31 +14,29 @@ def create_bot(bot_type): # 替换Baidu Unit为Baidu文心千帆对话接口 # from bot.baidu.baidu_unit_bot import BaiduUnitBot # return BaiduUnitBot() - from bot.baidu.baidu_wenxin import BaiduWenxinBot - return BaiduWenxinBot() elif bot_type == const.CHATGPT: # ChatGPT 网页端web接口 from bot.chatgpt.chat_gpt_bot import ChatGPTBot - return ChatGPTBot() elif bot_type == const.OPEN_AI: # OpenAI 官方对话模型API from bot.openai.open_ai_bot import OpenAIBot - return OpenAIBot() elif bot_type == const.CHATGPTONAZURE: # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/ from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot - return AzureChatGPTBot() + elif bot_type == const.XUNFEI: + from bot.xunfei.xunfei_spark_bot import XunFeiBot + return XunFeiBot() + elif bot_type == const.LINKAI: from bot.linkai.link_ai_bot import LinkAIBot return LinkAIBot() - raise RuntimeError diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py index 299ae19..794dac3 100644 --- a/bot/chatgpt/chat_gpt_session.py +++ b/bot/chatgpt/chat_gpt_session.py @@ -55,11 +55,16 @@ class ChatGPTSession(Session): # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def num_tokens_from_messages(messages, model): """Returns the number of tokens used by a list of messages.""" + + if model in ["wenxin", "xunfei"]: + return num_tokens_by_character(messages) + import tiktoken if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo"]: return num_tokens_from_messages(messages, model="gpt-3.5-turbo") - elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]: + elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]: return num_tokens_from_messages(messages, model="gpt-4") try: @@ -85,3 +90,11 @@ def num_tokens_from_messages(messages, model): num_tokens += tokens_per_name num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens + + +def num_tokens_by_character(messages): + """Returns the number of tokens used by a list of messages.""" + tokens = 0 + for msg in messages: + tokens += len(msg["content"]) + return tokens diff --git a/bot/xunfei/xunfei_spark_bot.py b/bot/xunfei/xunfei_spark_bot.py new file mode 100644 index 0000000..c1564b2 --- /dev/null +++ b/bot/xunfei/xunfei_spark_bot.py @@ -0,0 +1,246 @@ +# encoding:utf-8 + +import requests, json +from bot.bot import Bot +from bot.session_manager import SessionManager +from bot.baidu.baidu_wenxin_session import BaiduWenxinSession +from bridge.context import ContextType, Context +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf +from common import const +import time +import _thread as thread +import datetime +from datetime import datetime +from wsgiref.handlers import format_date_time +from urllib.parse import urlencode +import base64 +import ssl +import hashlib +import hmac +import json +from time import mktime +from urllib.parse import urlparse +import websocket +import queue +import threading +import random + +# 消息队列 map +queue_map = dict() + + +class XunFeiBot(Bot): + def __init__(self): + super().__init__() + self.app_id = conf().get("xunfei_app_id") + self.api_key = conf().get("xunfei_api_key") + self.api_secret = conf().get("xunfei_api_secret") + # 默认使用v2.0版本,1.5版本可设置为 general + self.domain = "generalv2" + # 默认使用v2.0版本,1.5版本可设置为 "ws://spark-api.xf-yun.com/v1.1/chat" + self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" + self.host = urlparse(self.spark_url).netloc + self.path = urlparse(self.spark_url).path + self.answer = "" + # 和wenxin使用相同的session机制 + self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI) + + def reply(self, query, context: Context = None) -> Reply: + if context.type == ContextType.TEXT: + logger.info("[XunFei] query={}".format(query)) + session_id = context["session_id"] + request_id = self.gen_request_id(session_id) + session = self.sessions.session_query(query, session_id) + threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start() + depth = 0 + time.sleep(0.1) + t1 = time.time() + usage = {} + while depth <= 300: + try: + data_queue = queue_map.get(request_id) + if not data_queue: + depth += 1 + time.sleep(0.1) + continue + data_item = data_queue.get(block=True, timeout=0.1) + if data_item.is_end: + # 请求结束 + del queue_map[request_id] + if data_item.reply: + self.answer += data_item.reply + usage = data_item.usage + break + + self.answer += data_item.reply + depth += 1 + except Exception as e: + depth += 1 + continue + t2 = time.time() + logger.info(f"[XunFei-API] response={self.answer}, time={t2 - t1}s, usage={usage}") + self.sessions.session_reply(self.answer, session_id, usage.get("total_tokens")) + reply = Reply(ReplyType.TEXT, self.answer) + return reply + else: + reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) + return reply + + def create_web_socket(self, prompt, session_id, temperature=0.5): + logger.info(f"[XunFei] start connect, prompt={prompt}") + websocket.enableTrace(False) + wsUrl = self.create_url() + ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, + on_open=on_open) + data_queue = queue.Queue(1000) + queue_map[session_id] = data_queue + ws.appid = self.app_id + ws.question = prompt + ws.domain = self.domain + ws.session_id = session_id + ws.temperature = temperature + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) + + def gen_request_id(self, session_id: str): + return session_id + "_" + str(int(time.time())) + "" + str(random.randint(0, 100)) + + # 生成url + def create_url(self): + # 生成RFC1123格式的时间戳 + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + # 拼接字符串 + signature_origin = "host: " + self.host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + self.path + " HTTP/1.1" + + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \ + f'signature="{signature_sha_base64}"' + + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": self.host + } + # 拼接鉴权参数,生成url + url = self.spark_url + '?' + urlencode(v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + return url + + def gen_params(self, appid, domain, question): + """ + 通过appid和用户的提问来生成请参数 + """ + data = { + "header": { + "app_id": appid, + "uid": "1234" + }, + "parameter": { + "chat": { + "domain": domain, + "random_threshold": 0.5, + "max_tokens": 2048, + "auditing": "default" + } + }, + "payload": { + "message": { + "text": question + } + } + } + return data + + +class ReplyItem: + def __init__(self, reply, usage=None, is_end=False): + self.is_end = is_end + self.reply = reply + self.usage = usage + + +# 收到websocket错误的处理 +def on_error(ws, error): + logger.error("[XunFei] error:", error) + + +# 收到websocket关闭的处理 +def on_close(ws, one, two): + data_queue = queue_map.get(ws.session_id) + data_queue.put("END") + + +# 收到websocket连接建立的处理 +def on_open(ws): + logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}") + thread.start_new_thread(run, (ws,)) + + +def run(ws, *args): + data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question, temperature=ws.temperature)) + ws.send(data) + + +# Websocket 操作 +# 收到websocket消息的处理 +def on_message(ws, message): + data = json.loads(message) + code = data['header']['code'] + if code != 0: + logger.error(f'请求错误: {code}, {data}') + ws.close() + else: + choices = data["payload"]["choices"] + status = choices["status"] + content = choices["text"][0]["content"] + data_queue = queue_map.get(ws.session_id) + if not data_queue: + logger.error(f"[XunFei] can't find data queue, session_id={ws.session_id}") + return + reply_item = ReplyItem(content) + if status == 2: + usage = data["payload"].get("usage") + reply_item = ReplyItem(content, usage) + reply_item.is_end = True + ws.close() + data_queue.put(reply_item) + + +def gen_params(appid, domain, question, temperature=0.5): + """ + 通过appid和用户的提问来生成请参数 + """ + data = { + "header": { + "app_id": appid, + "uid": "1234" + }, + "parameter": { + "chat": { + "domain": domain, + "temperature": temperature, + "random_threshold": 0.5, + "max_tokens": 2048, + "auditing": "default" + } + }, + "payload": { + "message": { + "text": question + } + } + } + return data diff --git a/bridge/bridge.py b/bridge/bridge.py index c58b975..2022438 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -25,6 +25,8 @@ class Bridge(object): self.btype["chat"] = const.CHATGPTONAZURE if model_type in ["wenxin"]: self.btype["chat"] = const.BAIDU + if model_type in ["xunfei"]: + self.btype["chat"] = const.XUNFEI if conf().get("use_linkai") and conf().get("linkai_api_key"): self.btype["chat"] = const.LINKAI self.bots = {} diff --git a/common/const.py b/common/const.py index 481741d..4e87034 100644 --- a/common/const.py +++ b/common/const.py @@ -2,6 +2,7 @@ OPEN_AI = "openAI" CHATGPT = "chatGPT" BAIDU = "baidu" +XUNFEI = "xunfei" CHATGPTONAZURE = "chatGPTOnAzure" LINKAI = "linkai" diff --git a/config.py b/config.py index 373dd42..048f667 100644 --- a/config.py +++ b/config.py @@ -16,7 +16,7 @@ available_setting = { "open_ai_api_base": "https://api.openai.com/v1", "proxy": "", # openai使用的代理 # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 - "model": "gpt-3.5-turbo", # 还支持 gpt-3.5-turbo-16k, gpt-4, wenxin + "model": "gpt-3.5-turbo", # 还支持 gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei "use_azure_chatgpt": False, # 是否使用azure的chatgpt "azure_deployment_id": "", # azure 模型部署名称 "azure_api_version": "", # azure api版本 @@ -52,9 +52,13 @@ available_setting = { "request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 "timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 # Baidu 文心一言参数 - "baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型 - "baidu_wenxin_api_key": "", # Baidu api key - "baidu_wenxin_secret_key": "", # Baidu secret key + "baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型 + "baidu_wenxin_api_key": "", # Baidu api key + "baidu_wenxin_secret_key": "", # Baidu secret key + # 讯飞星火API + "xunfei_app_id": "", # 讯飞应用ID + "xunfei_api_key": "", # 讯飞 API key + "xunfei_api_secret": "", # 讯飞 API secret # 语音设置 "speech_recognition": False, # 是否开启语音识别 "group_speech_recognition": False, # 是否开启群组语音识别 diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 0ff204e..61830ba 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -294,7 +294,7 @@ class Godcmd(Plugin): except Exception as e: ok, result = False, "你没有设置私有GPT模型" elif cmd == "reset": - if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]: + if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI]: bot.sessions.clear_session(session_id) channel.cancel_session(session_id) ok, result = True, "会话已重置" @@ -317,7 +317,8 @@ class Godcmd(Plugin): load_config() ok, result = True, "配置已重载" elif cmd == "resetall": - if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]: + if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, + const.BAIDU, const.XUNFEI]: channel.cancel_all_session() bot.sessions.clear_all_session() ok, result = True, "重置所有会话成功"