From 1817a972c635f939d47e8a4c88d2437353ba653a Mon Sep 17 00:00:00 2001 From: Kevin Li Date: Tue, 25 Jul 2023 09:52:47 +0800 Subject: [PATCH] Add Baidu Wenxin Bot --- .gitignore | 4 ++ bot/baidu/baidu_wenxin.py | 97 +++++++++++++++++++++++++++++++ bot/baidu/baidu_wenxin_session.py | 87 +++++++++++++++++++++++++++ bot/bot_factory.py | 9 ++- bridge/bridge.py | 2 + config.py | 5 ++ 6 files changed, 201 insertions(+), 3 deletions(-) create mode 100644 bot/baidu/baidu_wenxin.py create mode 100644 bot/baidu/baidu_wenxin_session.py diff --git a/.gitignore b/.gitignore index 4eb71e5..ca76cab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ .DS_Store .idea .vscode +.venv +.vs .wechaty/ __pycache__/ venv* @@ -22,6 +24,8 @@ plugins/**/ !plugins/tool !plugins/banwords !plugins/banwords/**/ +plugins/banwords/__pycache__ +plugins/banwords/lib/__pycache__ !plugins/hello !plugins/role !plugins/keyword \ No newline at end of file diff --git a/bot/baidu/baidu_wenxin.py b/bot/baidu/baidu_wenxin.py new file mode 100644 index 0000000..8835f7a --- /dev/null +++ b/bot/baidu/baidu_wenxin.py @@ -0,0 +1,97 @@ +# encoding:utf-8 + +import requests, json +import pdb +from bot.bot import Bot +from bridge.reply import Reply, ReplyType +from bot.session_manager import SessionManager +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf +from bot.baidu.baidu_wenxin_session import BaiduWenxinSession + +BAIDU_API_KEY = conf().get("baidu_wenxin_api_key") +BAIDU_SECRET_KEY = conf().get("baidu_wenxin_api_key") + +class BaiduWenxinBot(Bot): + + def __init__(self): + super().__init__() + self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("baidu_wenxin_model") or "eb-instant") + + def reply(self, query, context=None): + # acquire reply content + if context and context.type: + if context.type == ContextType.TEXT: + logger.info("[BAIDU] query={}".format(query)) + session_id = context["session_id"] + reply = None + if query == "#清除记忆": + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + else: + session = self.sessions.session_query(query, session_id) + result = self.reply_text(session) + total_tokens, completion_tokens, reply_content = ( + result["total_tokens"], + result["completion_tokens"], + result["content"], + ) + logger.debug( + "[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens) + ) + + if total_tokens == 0: + reply = Reply(ReplyType.ERROR, reply_content) + else: + self.sessions.session_reply(reply_content, session_id, total_tokens) + reply = Reply(ReplyType.TEXT, reply_content) + return reply + elif context.type == ContextType.IMAGE_CREATE: + ok, retstring = self.create_img(query, 0) + reply = None + if ok: + reply = Reply(ReplyType.IMAGE_URL, retstring) + else: + reply = Reply(ReplyType.ERROR, retstring) + return reply + + def reply_text(self, session: BaiduWenxinSession, retry_count=0): + try: + access_token = self.get_access_token() + url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token + headers = { + 'Content-Type': 'application/json' + } + payload = {'messages': session.messages} + response = requests.request("POST", url, headers=headers, data=json.dumps(payload)) + response_text = json.loads(response.text) + res_content = response_text["result"] + total_tokens = response_text["usage"]["total_tokens"] + completion_tokens = response_text["usage"]["completion_tokens"] + logger.info("[BAIDU] reply={}".format(res_content)) + return { + "total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "content": res_content, + } + except Exception as e: + need_retry = retry_count < 2 + logger.warn("[BAIDU] Exception: {}".format(e)) + need_retry = False + self.sessions.clear_session(session.session_id) + result = {"completion_tokens": 0, "content": "出错了: {}".format(e)} + return result + + def get_access_token(self): + """ + 使用 AK,SK 生成鉴权签名(Access Token) + :return: access_token,或是None(如果错误) + """ + url = "https://aip.baidubce.com/oauth/2.0/token" + params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY} + return str(requests.post(url, params=params).json().get("access_token")) diff --git a/bot/baidu/baidu_wenxin_session.py b/bot/baidu/baidu_wenxin_session.py new file mode 100644 index 0000000..aad8f71 --- /dev/null +++ b/bot/baidu/baidu_wenxin_session.py @@ -0,0 +1,87 @@ +from bot.session_manager import Session +from common.log import logger + +""" + e.g. [ + {"role": "user", "content": "Who won the world series in 2020?"}, + {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, + {"role": "user", "content": "Where was it played?"} + ] +""" + +class BaiduWenxinSession(Session): + def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"): + super().__init__(session_id, system_prompt) + self.model = model + # 百度文心不支持system prompt + # self.reset() + + def discard_exceeding(self, max_tokens, cur_tokens=None): + # pdb.set_trace() + precise = True + try: + cur_tokens = self.calc_tokens() + except Exception as e: + precise = False + if cur_tokens is None: + raise e + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + while cur_tokens > max_tokens: + if len(self.messages) > 2: + self.messages.pop(1) + elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": + self.messages.pop(1) + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + break + elif len(self.messages) == 2 and self.messages[1]["role"] == "user": + logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) + break + else: + logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) + break + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = cur_tokens - max_tokens + return cur_tokens + + def calc_tokens(self): + return num_tokens_from_messages(self.messages, self.model) + + +# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +def num_tokens_from_messages(messages, model): + """Returns the number of tokens used by a list of messages.""" + import tiktoken + + if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo"]: + return num_tokens_from_messages(messages, model="gpt-3.5-turbo") + elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]: + return num_tokens_from_messages(messages, model="gpt-4") + + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + logger.debug("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + if model == "gpt-3.5-turbo": + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = -1 # if there's a name, the role is omitted + elif model == "gpt-4": + tokens_per_message = 3 + tokens_per_name = 1 + else: + logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.") + return num_tokens_from_messages(messages, model="gpt-3.5-turbo") + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens diff --git a/bot/bot_factory.py b/bot/bot_factory.py index 2e9cb2d..e0e07e4 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -11,10 +11,13 @@ def create_bot(bot_type): :return: bot instance """ if bot_type == const.BAIDU: - # Baidu Unit对话接口 - from bot.baidu.baidu_unit_bot import BaiduUnitBot + # 替换Baidu Unit为Baidu文心千帆对话接口 + # from bot.baidu.baidu_unit_bot import BaiduUnitBot + # return BaiduUnitBot() - return BaiduUnitBot() + from bot.baidu.baidu_wenxin import BaiduWenxinBot + + return BaiduWenxinBot() elif bot_type == const.CHATGPT: # ChatGPT 网页端web接口 diff --git a/bridge/bridge.py b/bridge/bridge.py index d3fbd95..3e01511 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -23,6 +23,8 @@ class Bridge(object): self.btype["chat"] = const.OPEN_AI if conf().get("use_azure_chatgpt", False): self.btype["chat"] = const.CHATGPTONAZURE + if conf().get("use_baidu_wenxin", False): + self.btype["chat"] = const.BAIDU if conf().get("use_linkai") and conf().get("linkai_api_key"): self.btype["chat"] = const.LINKAI self.bots = {} diff --git a/config.py b/config.py index 85c5436..dc32f74 100644 --- a/config.py +++ b/config.py @@ -19,6 +19,7 @@ available_setting = { "model": "gpt-3.5-turbo", "use_azure_chatgpt": False, # 是否使用azure的chatgpt "azure_deployment_id": "", # azure 模型部署名称 + "use_baidu_wenxin": False, # 是否使用baidu文心一言,优先级次于azure # Bot触发配置 "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 @@ -50,6 +51,10 @@ available_setting = { "presence_penalty": 0, "request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 "timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 + # Baidu 文心一言参数 + "baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型 + "baidu_wenxin_api_key": "", # Baidu api key + "baidu_wenxin_secret_key": "", # Baidu secret key # 语音设置 "speech_recognition": False, # 是否开启语音识别 "group_speech_recognition": False, # 是否开启群组语音识别