@@ -50,7 +50,9 @@ def create_bot(bot_type): | |||||
elif bot_type == const.QWEN: | elif bot_type == const.QWEN: | ||||
from bot.ali.ali_qwen_bot import AliQwenBot | from bot.ali.ali_qwen_bot import AliQwenBot | ||||
return AliQwenBot() | return AliQwenBot() | ||||
elif bot_type == const.QWEN_DASHSCOPE: | |||||
from bot.dashscope.dashscope_bot import DashscopeBot | |||||
return DashscopeBot() | |||||
elif bot_type == const.GEMINI: | elif bot_type == const.GEMINI: | ||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot | from bot.gemini.google_gemini_bot import GoogleGeminiBot | ||||
return GoogleGeminiBot() | return GoogleGeminiBot() | ||||
@@ -0,0 +1,117 @@ | |||||
# encoding:utf-8 | |||||
from bot.bot import Bot | |||||
from bot.session_manager import SessionManager | |||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
from common.log import logger | |||||
from config import conf, load_config | |||||
from .dashscope_session import DashscopeSession | |||||
import os | |||||
import dashscope | |||||
from http import HTTPStatus | |||||
dashscope_models = { | |||||
"qwen-turbo": dashscope.Generation.Models.qwen_turbo, | |||||
"qwen-plus": dashscope.Generation.Models.qwen_plus, | |||||
"qwen-max": dashscope.Generation.Models.qwen_max, | |||||
"qwen-bailian-v1": dashscope.Generation.Models.bailian_v1 | |||||
} | |||||
# ZhipuAI对话模型API | |||||
class DashscopeBot(Bot): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.sessions = SessionManager(DashscopeSession, model=conf().get("model") or "qwen-plus") | |||||
self.model_name = conf().get("model") or "qwen-plus" | |||||
self.api_key = conf().get("dashscope_api_key") | |||||
os.environ["DASHSCOPE_API_KEY"] = self.api_key | |||||
self.client = dashscope.Generation | |||||
def reply(self, query, context=None): | |||||
# acquire reply content | |||||
if context.type == ContextType.TEXT: | |||||
logger.info("[DASHSCOPE] query={}".format(query)) | |||||
session_id = context["session_id"] | |||||
reply = None | |||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) | |||||
if query in clear_memory_commands: | |||||
self.sessions.clear_session(session_id) | |||||
reply = Reply(ReplyType.INFO, "记忆已清除") | |||||
elif query == "#清除所有": | |||||
self.sessions.clear_all_session() | |||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除") | |||||
elif query == "#更新配置": | |||||
load_config() | |||||
reply = Reply(ReplyType.INFO, "配置已更新") | |||||
if reply: | |||||
return reply | |||||
session = self.sessions.session_query(query, session_id) | |||||
logger.debug("[DASHSCOPE] session query={}".format(session.messages)) | |||||
reply_content = self.reply_text(session) | |||||
logger.debug( | |||||
"[DASHSCOPE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( | |||||
session.messages, | |||||
session_id, | |||||
reply_content["content"], | |||||
reply_content["completion_tokens"], | |||||
) | |||||
) | |||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: | |||||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | |||||
elif reply_content["completion_tokens"] > 0: | |||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) | |||||
reply = Reply(ReplyType.TEXT, reply_content["content"]) | |||||
else: | |||||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | |||||
logger.debug("[DASHSCOPE] reply {} used 0 tokens.".format(reply_content)) | |||||
return reply | |||||
else: | |||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) | |||||
return reply | |||||
def reply_text(self, session: DashscopeSession, retry_count=0) -> dict: | |||||
""" | |||||
call openai's ChatCompletion to get the answer | |||||
:param session: a conversation session | |||||
:param session_id: session id | |||||
:param retry_count: retry count | |||||
:return: {} | |||||
""" | |||||
try: | |||||
dashscope.api_key = self.api_key | |||||
response = self.client.call( | |||||
dashscope_models[self.model_name], | |||||
messages=session.messages, | |||||
result_format="message" | |||||
) | |||||
if response.status_code == HTTPStatus.OK: | |||||
content = response.output.choices[0]["message"]["content"] | |||||
return { | |||||
"total_tokens": response.usage["total_tokens"], | |||||
"completion_tokens": response.usage["output_tokens"], | |||||
"content": content, | |||||
} | |||||
else: | |||||
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( | |||||
response.request_id, response.status_code, | |||||
response.code, response.message | |||||
)) | |||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | |||||
need_retry = retry_count < 2 | |||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | |||||
if need_retry: | |||||
return self.reply_text(session, retry_count + 1) | |||||
else: | |||||
return result | |||||
except Exception as e: | |||||
logger.exception(e) | |||||
need_retry = retry_count < 2 | |||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | |||||
if need_retry: | |||||
return self.reply_text(session, retry_count + 1) | |||||
else: | |||||
return result |
@@ -0,0 +1,51 @@ | |||||
from bot.session_manager import Session | |||||
from common.log import logger | |||||
class DashscopeSession(Session): | |||||
def __init__(self, session_id, system_prompt=None, model="qwen-turbo"): | |||||
super().__init__(session_id) | |||||
self.reset() | |||||
def discard_exceeding(self, max_tokens, cur_tokens=None): | |||||
precise = True | |||||
try: | |||||
cur_tokens = self.calc_tokens() | |||||
except Exception as e: | |||||
precise = False | |||||
if cur_tokens is None: | |||||
raise e | |||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||||
while cur_tokens > max_tokens: | |||||
if len(self.messages) > 2: | |||||
self.messages.pop(1) | |||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": | |||||
self.messages.pop(1) | |||||
if precise: | |||||
cur_tokens = self.calc_tokens() | |||||
else: | |||||
cur_tokens = cur_tokens - max_tokens | |||||
break | |||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user": | |||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||||
break | |||||
else: | |||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, | |||||
len(self.messages))) | |||||
break | |||||
if precise: | |||||
cur_tokens = self.calc_tokens() | |||||
else: | |||||
cur_tokens = cur_tokens - max_tokens | |||||
return cur_tokens | |||||
def calc_tokens(self): | |||||
return num_tokens_from_messages(self.messages) | |||||
def num_tokens_from_messages(messages): | |||||
# 只是大概,具体计算规则:https://help.aliyun.com/zh/dashscope/developer-reference/token-api?spm=a2c4g.11186623.0.0.4d8b12b0BkP3K9 | |||||
tokens = 0 | |||||
for msg in messages: | |||||
tokens += len(msg["content"]) | |||||
return tokens |
@@ -30,6 +30,8 @@ class Bridge(object): | |||||
self.btype["chat"] = const.XUNFEI | self.btype["chat"] = const.XUNFEI | ||||
if model_type in [const.QWEN]: | if model_type in [const.QWEN]: | ||||
self.btype["chat"] = const.QWEN | self.btype["chat"] = const.QWEN | ||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]: | |||||
self.btype["chat"] = const.QWEN_DASHSCOPE | |||||
if model_type in [const.GEMINI]: | if model_type in [const.GEMINI]: | ||||
self.btype["chat"] = const.GEMINI | self.btype["chat"] = const.GEMINI | ||||
if model_type in [const.ZHIPU_AI]: | if model_type in [const.ZHIPU_AI]: | ||||
@@ -8,6 +8,12 @@ LINKAI = "linkai" | |||||
CLAUDEAI = "claude" | CLAUDEAI = "claude" | ||||
CLAUDEAPI= "claudeAPI" | CLAUDEAPI= "claudeAPI" | ||||
QWEN = "qwen" | QWEN = "qwen" | ||||
QWEN_DASHSCOPE = "dashscope" | |||||
QWEN_TURBO = "qwen-turbo" | |||||
QWEN_PLUS = "qwen-plus" | |||||
QWEN_MAX = "qwen-max" | |||||
GEMINI = "gemini" | GEMINI = "gemini" | ||||
ZHIPU_AI = "glm-4" | ZHIPU_AI = "glm-4" | ||||
MOONSHOT = "moonshot" | MOONSHOT = "moonshot" | ||||
@@ -24,7 +30,8 @@ TTS_1 = "tts-1" | |||||
TTS_1_HD = "tts-1-hd" | TTS_1_HD = "tts-1-hd" | ||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo", | MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo", | ||||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MOONSHOT] | |||||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MOONSHOT, | |||||
QWEN_TURBO, QWEN_PLUS, QWEN_MAX] | |||||
# channel | # channel | ||||
FEISHU = "feishu" | FEISHU = "feishu" | ||||
@@ -75,6 +75,8 @@ available_setting = { | |||||
"qwen_agent_key": "", | "qwen_agent_key": "", | ||||
"qwen_app_id": "", | "qwen_app_id": "", | ||||
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串 | "qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串 | ||||
# 阿里灵积模型api key | |||||
"dashscope_api_key": "", | |||||
# Google Gemini Api Key | # Google Gemini Api Key | ||||
"gemini_api_key": "", | "gemini_api_key": "", | ||||
# wework的通用配置 | # wework的通用配置 | ||||
@@ -43,3 +43,6 @@ dingtalk_stream | |||||
# zhipuai | # zhipuai | ||||
zhipuai>=2.0.1 | zhipuai>=2.0.1 | ||||
# tongyi qwen new sdk | |||||
dashscope |