@@ -1,6 +1,8 @@ | |||||
.DS_Store | .DS_Store | ||||
.idea | .idea | ||||
.vscode | .vscode | ||||
.venv | |||||
.vs | |||||
.wechaty/ | .wechaty/ | ||||
__pycache__/ | __pycache__/ | ||||
venv* | venv* | ||||
@@ -22,6 +24,8 @@ plugins/**/ | |||||
!plugins/tool | !plugins/tool | ||||
!plugins/banwords | !plugins/banwords | ||||
!plugins/banwords/**/ | !plugins/banwords/**/ | ||||
plugins/banwords/__pycache__ | |||||
plugins/banwords/lib/__pycache__ | |||||
!plugins/hello | !plugins/hello | ||||
!plugins/role | !plugins/role | ||||
!plugins/keyword | !plugins/keyword |
@@ -0,0 +1,104 @@ | |||||
# encoding:utf-8 | |||||
import requests, json | |||||
from bot.bot import Bot | |||||
from bridge.reply import Reply, ReplyType | |||||
from bot.session_manager import SessionManager | |||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
from common.log import logger | |||||
from config import conf | |||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession | |||||
BAIDU_API_KEY = conf().get("baidu_wenxin_api_key") | |||||
BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key") | |||||
class BaiduWenxinBot(Bot): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("baidu_wenxin_model") or "eb-instant") | |||||
def reply(self, query, context=None): | |||||
# acquire reply content | |||||
if context and context.type: | |||||
if context.type == ContextType.TEXT: | |||||
logger.info("[BAIDU] query={}".format(query)) | |||||
session_id = context["session_id"] | |||||
reply = None | |||||
if query == "#清除记忆": | |||||
self.sessions.clear_session(session_id) | |||||
reply = Reply(ReplyType.INFO, "记忆已清除") | |||||
elif query == "#清除所有": | |||||
self.sessions.clear_all_session() | |||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除") | |||||
else: | |||||
session = self.sessions.session_query(query, session_id) | |||||
result = self.reply_text(session) | |||||
total_tokens, completion_tokens, reply_content = ( | |||||
result["total_tokens"], | |||||
result["completion_tokens"], | |||||
result["content"], | |||||
) | |||||
logger.debug( | |||||
"[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens) | |||||
) | |||||
if total_tokens == 0: | |||||
reply = Reply(ReplyType.ERROR, reply_content) | |||||
else: | |||||
self.sessions.session_reply(reply_content, session_id, total_tokens) | |||||
reply = Reply(ReplyType.TEXT, reply_content) | |||||
return reply | |||||
elif context.type == ContextType.IMAGE_CREATE: | |||||
ok, retstring = self.create_img(query, 0) | |||||
reply = None | |||||
if ok: | |||||
reply = Reply(ReplyType.IMAGE_URL, retstring) | |||||
else: | |||||
reply = Reply(ReplyType.ERROR, retstring) | |||||
return reply | |||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0): | |||||
try: | |||||
logger.info("[BAIDU] model={}".format(session.model)) | |||||
access_token = self.get_access_token() | |||||
if access_token == 'None': | |||||
logger.warn("[BAIDU] access token 获取失败") | |||||
return { | |||||
"total_tokens": 0, | |||||
"completion_tokens": 0, | |||||
"content": 0, | |||||
} | |||||
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token | |||||
headers = { | |||||
'Content-Type': 'application/json' | |||||
} | |||||
payload = {'messages': session.messages} | |||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload)) | |||||
response_text = json.loads(response.text) | |||||
res_content = response_text["result"] | |||||
total_tokens = response_text["usage"]["total_tokens"] | |||||
completion_tokens = response_text["usage"]["completion_tokens"] | |||||
logger.info("[BAIDU] reply={}".format(res_content)) | |||||
return { | |||||
"total_tokens": total_tokens, | |||||
"completion_tokens": completion_tokens, | |||||
"content": res_content, | |||||
} | |||||
except Exception as e: | |||||
need_retry = retry_count < 2 | |||||
logger.warn("[BAIDU] Exception: {}".format(e)) | |||||
need_retry = False | |||||
self.sessions.clear_session(session.session_id) | |||||
result = {"completion_tokens": 0, "content": "出错了: {}".format(e)} | |||||
return result | |||||
def get_access_token(self): | |||||
""" | |||||
使用 AK,SK 生成鉴权签名(Access Token) | |||||
:return: access_token,或是None(如果错误) | |||||
""" | |||||
url = "https://aip.baidubce.com/oauth/2.0/token" | |||||
params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY} | |||||
return str(requests.post(url, params=params).json().get("access_token")) |
@@ -0,0 +1,87 @@ | |||||
from bot.session_manager import Session | |||||
from common.log import logger | |||||
""" | |||||
e.g. [ | |||||
{"role": "user", "content": "Who won the world series in 2020?"}, | |||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, | |||||
{"role": "user", "content": "Where was it played?"} | |||||
] | |||||
""" | |||||
class BaiduWenxinSession(Session): | |||||
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"): | |||||
super().__init__(session_id, system_prompt) | |||||
self.model = model | |||||
# 百度文心不支持system prompt | |||||
# self.reset() | |||||
def discard_exceeding(self, max_tokens, cur_tokens=None): | |||||
# pdb.set_trace() | |||||
precise = True | |||||
try: | |||||
cur_tokens = self.calc_tokens() | |||||
except Exception as e: | |||||
precise = False | |||||
if cur_tokens is None: | |||||
raise e | |||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||||
while cur_tokens > max_tokens: | |||||
if len(self.messages) > 2: | |||||
self.messages.pop(1) | |||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": | |||||
self.messages.pop(1) | |||||
if precise: | |||||
cur_tokens = self.calc_tokens() | |||||
else: | |||||
cur_tokens = cur_tokens - max_tokens | |||||
break | |||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user": | |||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||||
break | |||||
else: | |||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||||
break | |||||
if precise: | |||||
cur_tokens = self.calc_tokens() | |||||
else: | |||||
cur_tokens = cur_tokens - max_tokens | |||||
return cur_tokens | |||||
def calc_tokens(self): | |||||
return num_tokens_from_messages(self.messages, self.model) | |||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||||
def num_tokens_from_messages(messages, model): | |||||
"""Returns the number of tokens used by a list of messages.""" | |||||
import tiktoken | |||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo"]: | |||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo") | |||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]: | |||||
return num_tokens_from_messages(messages, model="gpt-4") | |||||
try: | |||||
encoding = tiktoken.encoding_for_model(model) | |||||
except KeyError: | |||||
logger.debug("Warning: model not found. Using cl100k_base encoding.") | |||||
encoding = tiktoken.get_encoding("cl100k_base") | |||||
if model == "gpt-3.5-turbo": | |||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |||||
tokens_per_name = -1 # if there's a name, the role is omitted | |||||
elif model == "gpt-4": | |||||
tokens_per_message = 3 | |||||
tokens_per_name = 1 | |||||
else: | |||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.") | |||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo") | |||||
num_tokens = 0 | |||||
for message in messages: | |||||
num_tokens += tokens_per_message | |||||
for key, value in message.items(): | |||||
num_tokens += len(encoding.encode(value)) | |||||
if key == "name": | |||||
num_tokens += tokens_per_name | |||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> | |||||
return num_tokens |
@@ -11,10 +11,13 @@ def create_bot(bot_type): | |||||
:return: bot instance | :return: bot instance | ||||
""" | """ | ||||
if bot_type == const.BAIDU: | if bot_type == const.BAIDU: | ||||
# Baidu Unit对话接口 | |||||
from bot.baidu.baidu_unit_bot import BaiduUnitBot | |||||
# 替换Baidu Unit为Baidu文心千帆对话接口 | |||||
# from bot.baidu.baidu_unit_bot import BaiduUnitBot | |||||
# return BaiduUnitBot() | |||||
return BaiduUnitBot() | |||||
from bot.baidu.baidu_wenxin import BaiduWenxinBot | |||||
return BaiduWenxinBot() | |||||
elif bot_type == const.CHATGPT: | elif bot_type == const.CHATGPT: | ||||
# ChatGPT 网页端web接口 | # ChatGPT 网页端web接口 | ||||
@@ -23,6 +23,8 @@ class Bridge(object): | |||||
self.btype["chat"] = const.OPEN_AI | self.btype["chat"] = const.OPEN_AI | ||||
if conf().get("use_azure_chatgpt", False): | if conf().get("use_azure_chatgpt", False): | ||||
self.btype["chat"] = const.CHATGPTONAZURE | self.btype["chat"] = const.CHATGPTONAZURE | ||||
if conf().get("use_baidu_wenxin", False): | |||||
self.btype["chat"] = const.BAIDU | |||||
if conf().get("use_linkai") and conf().get("linkai_api_key"): | if conf().get("use_linkai") and conf().get("linkai_api_key"): | ||||
self.btype["chat"] = const.LINKAI | self.btype["chat"] = const.LINKAI | ||||
self.bots = {} | self.bots = {} | ||||
@@ -19,6 +19,7 @@ available_setting = { | |||||
"model": "gpt-3.5-turbo", | "model": "gpt-3.5-turbo", | ||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt | "use_azure_chatgpt": False, # 是否使用azure的chatgpt | ||||
"azure_deployment_id": "", # azure 模型部署名称 | "azure_deployment_id": "", # azure 模型部署名称 | ||||
"use_baidu_wenxin": False, # 是否使用baidu文心一言,优先级次于azure | |||||
"azure_api_version": "", # azure api版本 | "azure_api_version": "", # azure api版本 | ||||
# Bot触发配置 | # Bot触发配置 | ||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | ||||
@@ -51,6 +52,10 @@ available_setting = { | |||||
"presence_penalty": 0, | "presence_penalty": 0, | ||||
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | "request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | ||||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 | "timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 | ||||
# Baidu 文心一言参数 | |||||
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型 | |||||
"baidu_wenxin_api_key": "", # Baidu api key | |||||
"baidu_wenxin_secret_key": "", # Baidu secret key | |||||
# 语音设置 | # 语音设置 | ||||
"speech_recognition": False, # 是否开启语音识别 | "speech_recognition": False, # 是否开启语音识别 | ||||
"group_speech_recognition": False, # 是否开启群组语音识别 | "group_speech_recognition": False, # 是否开启群组语音识别 | ||||