@@ -2,6 +2,7 @@ | |||
channel factory | |||
""" | |||
from common import const | |||
from common.log import logger | |||
def create_bot(bot_type): | |||
@@ -43,7 +44,9 @@ def create_bot(bot_type): | |||
elif bot_type == const.CLAUDEAI: | |||
from bot.claude.claude_ai_bot import ClaudeAIBot | |||
return ClaudeAIBot() | |||
elif bot_type == const.CLAUDEAPI: | |||
from bot.claudeapi.claude_api_bot import ClaudeAPIBot | |||
return ClaudeAPIBot() | |||
elif bot_type == const.QWEN: | |||
from bot.ali.ali_qwen_bot import AliQwenBot | |||
return AliQwenBot() | |||
@@ -0,0 +1,125 @@ | |||
# encoding:utf-8 | |||
import time | |||
import openai | |||
import openai.error | |||
import anthropic | |||
from bot.bot import Bot | |||
from bot.openai.open_ai_image import OpenAIImage | |||
from bot.claudeapi.claude_api_session import ClaudeAPISession | |||
from bot.session_manager import SessionManager | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from config import conf | |||
user_session = dict() | |||
# OpenAI对话模型API (可用) | |||
class ClaudeAPIBot(Bot, OpenAIImage): | |||
def __init__(self): | |||
super().__init__() | |||
self.claudeClient = anthropic.Anthropic( | |||
api_key=conf().get("claude_api_key") | |||
) | |||
openai.api_key = conf().get("open_ai_api_key") | |||
if conf().get("open_ai_api_base"): | |||
openai.api_base = conf().get("open_ai_api_base") | |||
proxy = conf().get("proxy") | |||
if proxy: | |||
openai.proxy = proxy | |||
self.sessions = SessionManager(ClaudeAPISession, model=conf().get("model") or "text-davinci-003") | |||
def reply(self, query, context=None): | |||
# acquire reply content | |||
if context and context.type: | |||
if context.type == ContextType.TEXT: | |||
logger.info("[CLAUDE_API] query={}".format(query)) | |||
session_id = context["session_id"] | |||
reply = None | |||
if query == "#清除记忆": | |||
self.sessions.clear_session(session_id) | |||
reply = Reply(ReplyType.INFO, "记忆已清除") | |||
elif query == "#清除所有": | |||
self.sessions.clear_all_session() | |||
reply = Reply(ReplyType.INFO, "所有人记忆已清除") | |||
else: | |||
session = self.sessions.session_query(query, session_id) | |||
result = self.reply_text(session) | |||
logger.info(result) | |||
total_tokens, completion_tokens, reply_content = ( | |||
result["total_tokens"], | |||
result["completion_tokens"], | |||
result["content"], | |||
) | |||
logger.debug( | |||
"[CLAUDE_API] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens) | |||
) | |||
if total_tokens == 0: | |||
reply = Reply(ReplyType.ERROR, reply_content) | |||
else: | |||
self.sessions.session_reply(reply_content, session_id, total_tokens) | |||
reply = Reply(ReplyType.TEXT, reply_content) | |||
return reply | |||
elif context.type == ContextType.IMAGE_CREATE: | |||
ok, retstring = self.create_img(query, 0) | |||
reply = None | |||
if ok: | |||
reply = Reply(ReplyType.IMAGE_URL, retstring) | |||
else: | |||
reply = Reply(ReplyType.ERROR, retstring) | |||
return reply | |||
def reply_text(self, session: ClaudeAPISession, retry_count=0): | |||
try: | |||
logger.info("[CLAUDE_API] sendMessage={}".format(str(session))) | |||
response = self.claudeClient.messages.create( | |||
model=conf().get("model"), | |||
max_tokens=1024, | |||
# system=conf().get("system"), | |||
messages=[ | |||
{"role": "user", "content": "{}".format(str(session))} | |||
] | |||
) | |||
# response = openai.Completion.create(prompt=str(session), **self.args) | |||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "") | |||
total_tokens = response.usage.input_tokens+response.usage.output_tokens | |||
completion_tokens = response.usage.output_tokens | |||
logger.info("[CLAUDE_API] reply={}".format(res_content)) | |||
return { | |||
"total_tokens": total_tokens, | |||
"completion_tokens": completion_tokens, | |||
"content": res_content, | |||
} | |||
except Exception as e: | |||
need_retry = retry_count < 2 | |||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | |||
if isinstance(e, openai.error.RateLimitError): | |||
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e)) | |||
result["content"] = "提问太快啦,请休息一下再问我吧" | |||
if need_retry: | |||
time.sleep(20) | |||
elif isinstance(e, openai.error.Timeout): | |||
logger.warn("[CLAUDE_API] Timeout: {}".format(e)) | |||
result["content"] = "我没有收到你的消息" | |||
if need_retry: | |||
time.sleep(5) | |||
elif isinstance(e, openai.error.APIConnectionError): | |||
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e)) | |||
need_retry = False | |||
result["content"] = "我连接不到你的网络" | |||
else: | |||
logger.warn("[CLAUDE_API] Exception: {}".format(e)) | |||
need_retry = False | |||
self.sessions.clear_session(session.session_id) | |||
if need_retry: | |||
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1)) | |||
return self.reply_text(session, retry_count + 1) | |||
else: | |||
return result |
@@ -0,0 +1,74 @@ | |||
from bot.session_manager import Session | |||
from common.log import logger | |||
class ClaudeAPISession(Session): | |||
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): | |||
super().__init__(session_id, system_prompt) | |||
self.model = model | |||
self.reset() | |||
def __str__(self): | |||
# 构造对话模型的输入 | |||
""" | |||
e.g. Q: xxx | |||
A: xxx | |||
Q: xxx | |||
""" | |||
prompt = "" | |||
for item in self.messages: | |||
if item["role"] == "system": | |||
prompt += item["content"] + "<|endoftext|>\n\n\n" | |||
elif item["role"] == "user": | |||
prompt += "Q: " + item["content"] + "\n" | |||
elif item["role"] == "assistant": | |||
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" | |||
if len(self.messages) > 0 and self.messages[-1]["role"] == "user": | |||
prompt += "A: " | |||
return prompt | |||
def discard_exceeding(self, max_tokens, cur_tokens=None): | |||
precise = True | |||
try: | |||
cur_tokens = self.calc_tokens() | |||
except Exception as e: | |||
precise = False | |||
if cur_tokens is None: | |||
raise e | |||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||
while cur_tokens > max_tokens: | |||
if len(self.messages) > 1: | |||
self.messages.pop(0) | |||
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": | |||
self.messages.pop(0) | |||
if precise: | |||
cur_tokens = self.calc_tokens() | |||
else: | |||
cur_tokens = len(str(self)) | |||
break | |||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user": | |||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||
break | |||
else: | |||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||
break | |||
if precise: | |||
cur_tokens = self.calc_tokens() | |||
else: | |||
cur_tokens = len(str(self)) | |||
return cur_tokens | |||
def calc_tokens(self): | |||
return num_tokens_from_string(str(self), self.model) | |||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||
def num_tokens_from_string(string: str, model: str) -> int: | |||
"""Returns the number of tokens in a text string.""" | |||
num_tokens = len(string) | |||
return num_tokens | |||
@@ -18,6 +18,7 @@ class Bridge(object): | |||
"text_to_voice": conf().get("text_to_voice", "google"), | |||
"translate": conf().get("translate", "baidu"), | |||
} | |||
# 这边取配置的模型 | |||
model_type = conf().get("model") or const.GPT35 | |||
if model_type in ["text-davinci-003"]: | |||
self.btype["chat"] = const.OPEN_AI | |||
@@ -33,6 +34,8 @@ class Bridge(object): | |||
self.btype["chat"] = const.GEMINI | |||
if model_type in [const.ZHIPU_AI]: | |||
self.btype["chat"] = const.ZHIPU_AI | |||
if model_type in [const.CLAUDE3]: | |||
self.btype["chat"] = const.CLAUDEAPI | |||
if conf().get("use_linkai") and conf().get("linkai_api_key"): | |||
self.btype["chat"] = const.LINKAI | |||
@@ -40,12 +43,12 @@ class Bridge(object): | |||
self.btype["voice_to_text"] = const.LINKAI | |||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: | |||
self.btype["text_to_voice"] = const.LINKAI | |||
if model_type in ["claude"]: | |||
self.btype["chat"] = const.CLAUDEAI | |||
self.bots = {} | |||
self.chat_bots = {} | |||
# 模型对应的接口 | |||
def get_bot(self, typename): | |||
if self.bots.get(typename) is None: | |||
logger.info("create bot {} for {}".format(self.btype[typename], typename)) | |||
@@ -6,12 +6,14 @@ XUNFEI = "xunfei" | |||
CHATGPTONAZURE = "chatGPTOnAzure" | |||
LINKAI = "linkai" | |||
CLAUDEAI = "claude" | |||
CLAUDEAPI= "claudeAPI" | |||
QWEN = "qwen" | |||
GEMINI = "gemini" | |||
ZHIPU_AI = "glm-4" | |||
# model | |||
CLAUDE3="claude-3-opus-20240229" | |||
GPT35 = "gpt-3.5-turbo" | |||
GPT4 = "gpt-4" | |||
GPT4_TURBO_PREVIEW = "gpt-4-0125-preview" | |||
@@ -20,7 +22,7 @@ WHISPER_1 = "whisper-1" | |||
TTS_1 = "tts-1" | |||
TTS_1_HD = "tts-1-hd" | |||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", | |||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo", | |||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI] | |||
# channel | |||
@@ -2,6 +2,7 @@ | |||
"channel_type": "wx", | |||
"model": "", | |||
"open_ai_api_key": "YOUR API KEY", | |||
"claude_api_key": "YOUR API KEY", | |||
"text_to_image": "dall-e-2", | |||
"voice_to_text": "openai", | |||
"text_to_voice": "openai", | |||
@@ -67,6 +67,8 @@ available_setting = { | |||
# claude 配置 | |||
"claude_api_cookie": "", | |||
"claude_uuid": "", | |||
# claude api key | |||
"claude_api_key":"", | |||
# 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html | |||
"qwen_access_key_id": "", | |||
"qwen_access_key_secret": "", | |||
@@ -7,3 +7,4 @@ chardet>=5.1.0 | |||
Pillow | |||
pre-commit | |||
web.py | |||
anthropic |