@@ -62,7 +62,7 @@ def num_tokens_from_messages(messages, model): | |||||
import tiktoken | import tiktoken | ||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]: | |||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot"]: | |||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo") | return num_tokens_from_messages(messages, model="gpt-3.5-turbo") | ||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", | elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", | ||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview", | "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview", | ||||
@@ -8,8 +8,8 @@ import anthropic | |||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bot.openai.open_ai_image import OpenAIImage | from bot.openai.open_ai_image import OpenAIImage | ||||
from bot.claudeapi.claude_api_session import ClaudeAPISession | |||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession | from bot.chatgpt.chat_gpt_session import ChatGPTSession | ||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot | |||||
from bot.session_manager import SessionManager | from bot.session_manager import SessionManager | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
@@ -78,15 +78,12 @@ class ClaudeAPIBot(Bot, OpenAIImage): | |||||
def reply_text(self, session: ChatGPTSession, retry_count=0): | def reply_text(self, session: ChatGPTSession, retry_count=0): | ||||
try: | try: | ||||
if session.messages[0].get("role") == "system": | |||||
system = session.messages[0].get("content") | |||||
session.messages.pop(0) | |||||
actual_model = self._model_mapping(conf().get("model")) | actual_model = self._model_mapping(conf().get("model")) | ||||
response = self.claudeClient.messages.create( | response = self.claudeClient.messages.create( | ||||
model=actual_model, | model=actual_model, | ||||
max_tokens=1024, | max_tokens=1024, | ||||
# system=conf().get("system"), | # system=conf().get("system"), | ||||
messages=session.messages | |||||
messages=GoogleGeminiBot.filter_messages(session.messages) | |||||
) | ) | ||||
# response = openai.Completion.create(prompt=str(session), **self.args) | # response = openai.Completion.create(prompt=str(session), **self.args) | ||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "") | res_content = response.content[0].text.strip().replace("<|endoftext|>", "") | ||||
@@ -1,74 +0,0 @@ | |||||
from bot.session_manager import Session | |||||
from common.log import logger | |||||
class ClaudeAPISession(Session): | |||||
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): | |||||
super().__init__(session_id, system_prompt) | |||||
self.model = model | |||||
self.reset() | |||||
def __str__(self): | |||||
# 构造对话模型的输入 | |||||
""" | |||||
e.g. Q: xxx | |||||
A: xxx | |||||
Q: xxx | |||||
""" | |||||
prompt = "" | |||||
for item in self.messages: | |||||
if item["role"] == "system": | |||||
prompt += item["content"] + "<|endoftext|>\n\n\n" | |||||
elif item["role"] == "user": | |||||
prompt += "Q: " + item["content"] + "\n" | |||||
elif item["role"] == "assistant": | |||||
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" | |||||
if len(self.messages) > 0 and self.messages[-1]["role"] == "user": | |||||
prompt += "A: " | |||||
return prompt | |||||
def discard_exceeding(self, max_tokens, cur_tokens=None): | |||||
precise = True | |||||
try: | |||||
cur_tokens = self.calc_tokens() | |||||
except Exception as e: | |||||
precise = False | |||||
if cur_tokens is None: | |||||
raise e | |||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||||
while cur_tokens > max_tokens: | |||||
if len(self.messages) > 1: | |||||
self.messages.pop(0) | |||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": | |||||
self.messages.pop(0) | |||||
if precise: | |||||
cur_tokens = self.calc_tokens() | |||||
else: | |||||
cur_tokens = len(str(self)) | |||||
break | |||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user": | |||||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||||
break | |||||
else: | |||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||||
break | |||||
if precise: | |||||
cur_tokens = self.calc_tokens() | |||||
else: | |||||
cur_tokens = len(str(self)) | |||||
return cur_tokens | |||||
def calc_tokens(self): | |||||
return num_tokens_from_string(str(self), self.model) | |||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||||
def num_tokens_from_string(string: str, model: str) -> int: | |||||
"""Returns the number of tokens in a text string.""" | |||||
num_tokens = len(string) | |||||
return num_tokens | |||||
@@ -33,7 +33,7 @@ class GoogleGeminiBot(Bot): | |||||
logger.info(f"[Gemini] query={query}") | logger.info(f"[Gemini] query={query}") | ||||
session_id = context["session_id"] | session_id = context["session_id"] | ||||
session = self.sessions.session_query(query, session_id) | session = self.sessions.session_query(query, session_id) | ||||
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages)) | |||||
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages)) | |||||
genai.configure(api_key=self.api_key) | genai.configure(api_key=self.api_key) | ||||
model = genai.GenerativeModel('gemini-pro') | model = genai.GenerativeModel('gemini-pro') | ||||
response = model.generate_content(gemini_messages) | response = model.generate_content(gemini_messages) | ||||
@@ -61,7 +61,8 @@ class GoogleGeminiBot(Bot): | |||||
}) | }) | ||||
return res | return res | ||||
def _filter_messages(self, messages: list): | |||||
@staticmethod | |||||
def filter_messages(messages: list): | |||||
res = [] | res = [] | ||||
turn = "user" | turn = "user" | ||||
if not messages: | if not messages: | ||||
@@ -7,6 +7,7 @@ import requests | |||||
import config | import config | ||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession | from bot.chatgpt.chat_gpt_session import ChatGPTSession | ||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot | |||||
from bot.session_manager import SessionManager | from bot.session_manager import SessionManager | ||||
from bridge.context import Context, ContextType | from bridge.context import Context, ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
@@ -10,10 +10,11 @@ CLAUDEAPI= "claudeAPI" | |||||
QWEN = "qwen" | QWEN = "qwen" | ||||
GEMINI = "gemini" | GEMINI = "gemini" | ||||
ZHIPU_AI = "glm-4" | ZHIPU_AI = "glm-4" | ||||
MOONSHOT = "moonshot" | |||||
# model | # model | ||||
CLAUDE3="claude-3-opus-20240229" | |||||
CLAUDE3 = "claude-3-opus-20240229" | |||||
GPT35 = "gpt-3.5-turbo" | GPT35 = "gpt-3.5-turbo" | ||||
GPT4 = "gpt-4" | GPT4 = "gpt-4" | ||||
GPT4_TURBO_PREVIEW = "gpt-4-0125-preview" | GPT4_TURBO_PREVIEW = "gpt-4-0125-preview" | ||||
@@ -23,7 +24,7 @@ TTS_1 = "tts-1" | |||||
TTS_1_HD = "tts-1-hd" | TTS_1_HD = "tts-1-hd" | ||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo", | MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo", | ||||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI] | |||||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MOONSHOT] | |||||
# channel | # channel | ||||
FEISHU = "feishu" | FEISHU = "feishu" | ||||
@@ -339,7 +339,7 @@ class Godcmd(Plugin): | |||||
ok, result = True, "配置已重载" | ok, result = True, "配置已重载" | ||||
elif cmd == "resetall": | elif cmd == "resetall": | ||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, | if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, | ||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI]: | |||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT]: | |||||
channel.cancel_all_session() | channel.cancel_all_session() | ||||
bot.sessions.clear_all_session() | bot.sessions.clear_all_session() | ||||
ok, result = True, "重置所有会话成功" | ok, result = True, "重置所有会话成功" | ||||