@@ -62,7 +62,7 @@ def num_tokens_from_messages(messages, model): | |||
import tiktoken | |||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]: | |||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot"]: | |||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo") | |||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", | |||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview", | |||
@@ -8,8 +8,8 @@ import anthropic | |||
from bot.bot import Bot | |||
from bot.openai.open_ai_image import OpenAIImage | |||
from bot.claudeapi.claude_api_session import ClaudeAPISession | |||
from bot.chatgpt.chat_gpt_session import ChatGPTSession | |||
from bot.gemini.google_gemini_bot import GoogleGeminiBot | |||
from bot.session_manager import SessionManager | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
@@ -78,15 +78,12 @@ class ClaudeAPIBot(Bot, OpenAIImage): | |||
def reply_text(self, session: ChatGPTSession, retry_count=0): | |||
try: | |||
if session.messages[0].get("role") == "system": | |||
system = session.messages[0].get("content") | |||
session.messages.pop(0) | |||
actual_model = self._model_mapping(conf().get("model")) | |||
response = self.claudeClient.messages.create( | |||
model=actual_model, | |||
max_tokens=1024, | |||
# system=conf().get("system"), | |||
messages=session.messages | |||
messages=GoogleGeminiBot.filter_messages(session.messages) | |||
) | |||
# response = openai.Completion.create(prompt=str(session), **self.args) | |||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "") | |||
@@ -1,74 +0,0 @@ | |||
from bot.session_manager import Session | |||
from common.log import logger | |||
class ClaudeAPISession(Session): | |||
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): | |||
super().__init__(session_id, system_prompt) | |||
self.model = model | |||
self.reset() | |||
def __str__(self): | |||
# 构造对话模型的输入 | |||
""" | |||
e.g. Q: xxx | |||
A: xxx | |||
Q: xxx | |||
""" | |||
prompt = "" | |||
for item in self.messages: | |||
if item["role"] == "system": | |||
prompt += item["content"] + "<|endoftext|>\n\n\n" | |||
elif item["role"] == "user": | |||
prompt += "Q: " + item["content"] + "\n" | |||
elif item["role"] == "assistant": | |||
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" | |||
if len(self.messages) > 0 and self.messages[-1]["role"] == "user": | |||
prompt += "A: " | |||
return prompt | |||
def discard_exceeding(self, max_tokens, cur_tokens=None): | |||
precise = True | |||
try: | |||
cur_tokens = self.calc_tokens() | |||
except Exception as e: | |||
precise = False | |||
if cur_tokens is None: | |||
raise e | |||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||
while cur_tokens > max_tokens: | |||
if len(self.messages) > 1: | |||
self.messages.pop(0) | |||
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": | |||
self.messages.pop(0) | |||
if precise: | |||
cur_tokens = self.calc_tokens() | |||
else: | |||
cur_tokens = len(str(self)) | |||
break | |||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user": | |||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||
break | |||
else: | |||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||
break | |||
if precise: | |||
cur_tokens = self.calc_tokens() | |||
else: | |||
cur_tokens = len(str(self)) | |||
return cur_tokens | |||
def calc_tokens(self): | |||
return num_tokens_from_string(str(self), self.model) | |||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||
def num_tokens_from_string(string: str, model: str) -> int: | |||
"""Returns the number of tokens in a text string.""" | |||
num_tokens = len(string) | |||
return num_tokens | |||
@@ -33,7 +33,7 @@ class GoogleGeminiBot(Bot): | |||
logger.info(f"[Gemini] query={query}") | |||
session_id = context["session_id"] | |||
session = self.sessions.session_query(query, session_id) | |||
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages)) | |||
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages)) | |||
genai.configure(api_key=self.api_key) | |||
model = genai.GenerativeModel('gemini-pro') | |||
response = model.generate_content(gemini_messages) | |||
@@ -61,7 +61,8 @@ class GoogleGeminiBot(Bot): | |||
}) | |||
return res | |||
def _filter_messages(self, messages: list): | |||
@staticmethod | |||
def filter_messages(messages: list): | |||
res = [] | |||
turn = "user" | |||
if not messages: | |||
@@ -7,6 +7,7 @@ import requests | |||
import config | |||
from bot.bot import Bot | |||
from bot.chatgpt.chat_gpt_session import ChatGPTSession | |||
from bot.gemini.google_gemini_bot import GoogleGeminiBot | |||
from bot.session_manager import SessionManager | |||
from bridge.context import Context, ContextType | |||
from bridge.reply import Reply, ReplyType | |||
@@ -10,10 +10,11 @@ CLAUDEAPI= "claudeAPI" | |||
QWEN = "qwen" | |||
GEMINI = "gemini" | |||
ZHIPU_AI = "glm-4" | |||
MOONSHOT = "moonshot" | |||
# model | |||
CLAUDE3="claude-3-opus-20240229" | |||
CLAUDE3 = "claude-3-opus-20240229" | |||
GPT35 = "gpt-3.5-turbo" | |||
GPT4 = "gpt-4" | |||
GPT4_TURBO_PREVIEW = "gpt-4-0125-preview" | |||
@@ -23,7 +24,7 @@ TTS_1 = "tts-1" | |||
TTS_1_HD = "tts-1-hd" | |||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo", | |||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI] | |||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MOONSHOT] | |||
# channel | |||
FEISHU = "feishu" | |||
@@ -339,7 +339,7 @@ class Godcmd(Plugin): | |||
ok, result = True, "配置已重载" | |||
elif cmd == "resetall": | |||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, | |||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI]: | |||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT]: | |||
channel.cancel_all_session() | |||
bot.sessions.clear_all_session() | |||
ok, result = True, "重置所有会话成功" | |||