@@ -68,7 +68,8 @@ def num_tokens_from_messages(messages, model): | |||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview", | "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview", | ||||
"gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]: | "gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]: | ||||
return num_tokens_from_messages(messages, model="gpt-4") | return num_tokens_from_messages(messages, model="gpt-4") | ||||
elif model.startswith("claude-3"): | |||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo") | |||||
try: | try: | ||||
encoding = tiktoken.encoding_for_model(model) | encoding = tiktoken.encoding_for_model(model) | ||||
except KeyError: | except KeyError: | ||||
@@ -9,6 +9,7 @@ import anthropic | |||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bot.openai.open_ai_image import OpenAIImage | from bot.openai.open_ai_image import OpenAIImage | ||||
from bot.claudeapi.claude_api_session import ClaudeAPISession | from bot.claudeapi.claude_api_session import ClaudeAPISession | ||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession | |||||
from bot.session_manager import SessionManager | from bot.session_manager import SessionManager | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
@@ -32,7 +33,7 @@ class ClaudeAPIBot(Bot, OpenAIImage): | |||||
if proxy: | if proxy: | ||||
openai.proxy = proxy | openai.proxy = proxy | ||||
self.sessions = SessionManager(ClaudeAPISession, model=conf().get("model") or "text-davinci-003") | |||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "text-davinci-003") | |||||
def reply(self, query, context=None): | def reply(self, query, context=None): | ||||
# acquire reply content | # acquire reply content | ||||
@@ -75,16 +76,17 @@ class ClaudeAPIBot(Bot, OpenAIImage): | |||||
reply = Reply(ReplyType.ERROR, retstring) | reply = Reply(ReplyType.ERROR, retstring) | ||||
return reply | return reply | ||||
def reply_text(self, session: ClaudeAPISession, retry_count=0): | |||||
def reply_text(self, session: ChatGPTSession, retry_count=0): | |||||
try: | try: | ||||
logger.info("[CLAUDE_API] sendMessage={}".format(str(session))) | |||||
if session.messages[0].get("role") == "system": | |||||
system = session.messages[0].get("content") | |||||
session.messages.pop(0) | |||||
actual_model = self._model_mapping(conf().get("model")) | |||||
response = self.claudeClient.messages.create( | response = self.claudeClient.messages.create( | ||||
model=conf().get("model"), | |||||
model=actual_model, | |||||
max_tokens=1024, | max_tokens=1024, | ||||
# system=conf().get("system"), | # system=conf().get("system"), | ||||
messages=[ | |||||
{"role": "user", "content": "{}".format(str(session))} | |||||
] | |||||
messages=session.messages | |||||
) | ) | ||||
# response = openai.Completion.create(prompt=str(session), **self.args) | # response = openai.Completion.create(prompt=str(session), **self.args) | ||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "") | res_content = response.content[0].text.strip().replace("<|endoftext|>", "") | ||||
@@ -123,3 +125,12 @@ class ClaudeAPIBot(Bot, OpenAIImage): | |||||
return self.reply_text(session, retry_count + 1) | return self.reply_text(session, retry_count + 1) | ||||
else: | else: | ||||
return result | return result | ||||
def _model_mapping(self, model) -> str: | |||||
if model == "claude-3-opus": | |||||
return "claude-3-opus-20240229" | |||||
elif model == "claude-3-sonnet": | |||||
return "claude-3-sonnet-20240229" | |||||
elif model == "claude-3-haiku": | |||||
return "claude-3-haiku-20240307" | |||||
return model |
@@ -130,9 +130,12 @@ class LinkAIBot(Bot): | |||||
response = res.json() | response = res.json() | ||||
reply_content = response["choices"][0]["message"]["content"] | reply_content = response["choices"][0]["message"]["content"] | ||||
total_tokens = response["usage"]["total_tokens"] | total_tokens = response["usage"]["total_tokens"] | ||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}") | |||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query) | |||||
res_code = response.get('code') | |||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}, res_code={res_code}") | |||||
if res_code == 429: | |||||
logger.warn(f"[LINKAI] 用户访问超出限流配置,sender_id={body.get('sender_id')}") | |||||
else: | |||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query) | |||||
agent_suffix = self._fetch_agent_suffix(response) | agent_suffix = self._fetch_agent_suffix(response) | ||||
if agent_suffix: | if agent_suffix: | ||||
reply_content += agent_suffix | reply_content += agent_suffix | ||||
@@ -161,7 +164,10 @@ class LinkAIBot(Bot): | |||||
logger.warn(f"[LINKAI] do retry, times={retry_count}") | logger.warn(f"[LINKAI] do retry, times={retry_count}") | ||||
return self._chat(query, context, retry_count + 1) | return self._chat(query, context, retry_count + 1) | ||||
return Reply(ReplyType.TEXT, "提问太快啦,请休息一下再问我吧") | |||||
error_reply = "提问太快啦,请休息一下再问我吧" | |||||
if res.status_code == 409: | |||||
error_reply = "这个问题我还没有学会,请问我其它问题吧" | |||||
return Reply(ReplyType.TEXT, error_reply) | |||||
except Exception as e: | except Exception as e: | ||||
logger.exception(e) | logger.exception(e) | ||||
@@ -34,7 +34,7 @@ class Bridge(object): | |||||
self.btype["chat"] = const.GEMINI | self.btype["chat"] = const.GEMINI | ||||
if model_type in [const.ZHIPU_AI]: | if model_type in [const.ZHIPU_AI]: | ||||
self.btype["chat"] = const.ZHIPU_AI | self.btype["chat"] = const.ZHIPU_AI | ||||
if model_type in [const.CLAUDE3]: | |||||
if model_type and model_type.startswith("claude-3"): | |||||
self.btype["chat"] = const.CLAUDEAPI | self.btype["chat"] = const.CLAUDEAPI | ||||
if conf().get("use_linkai") and conf().get("linkai_api_key"): | if conf().get("use_linkai") and conf().get("linkai_api_key"): | ||||
@@ -26,6 +26,8 @@ websocket-client==1.2.0 | |||||
# claude bot | # claude bot | ||||
curl_cffi | curl_cffi | ||||
# claude API | |||||
anthropic | |||||
# tongyi qwen | # tongyi qwen | ||||
broadscope_bailian | broadscope_bailian | ||||
@@ -6,5 +6,4 @@ requests>=2.28.2 | |||||
chardet>=5.1.0 | chardet>=5.1.0 | ||||
Pillow | Pillow | ||||
pre-commit | pre-commit | ||||
web.py | |||||
anthropic | |||||
web.py |