diff --git a/bot/bot_factory.py b/bot/bot_factory.py index 2046da7..b5936c4 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -2,6 +2,7 @@ channel factory """ from common import const +from common.log import logger def create_bot(bot_type): @@ -43,7 +44,9 @@ def create_bot(bot_type): elif bot_type == const.CLAUDEAI: from bot.claude.claude_ai_bot import ClaudeAIBot return ClaudeAIBot() - + elif bot_type == const.CLAUDEAPI: + from bot.claudeapi.claude_api_bot import ClaudeAPIBot + return ClaudeAPIBot() elif bot_type == const.QWEN: from bot.ali.ali_qwen_bot import AliQwenBot return AliQwenBot() diff --git a/bot/claudeapi/claude_api_bot.py b/bot/claudeapi/claude_api_bot.py new file mode 100644 index 0000000..e9cf46d --- /dev/null +++ b/bot/claudeapi/claude_api_bot.py @@ -0,0 +1,125 @@ +# encoding:utf-8 + +import time + +import openai +import openai.error +import anthropic + +from bot.bot import Bot +from bot.openai.open_ai_image import OpenAIImage +from bot.claudeapi.claude_api_session import ClaudeAPISession +from bot.session_manager import SessionManager +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf + +user_session = dict() + + +# OpenAI对话模型API (可用) +class ClaudeAPIBot(Bot, OpenAIImage): + def __init__(self): + super().__init__() + self.claudeClient = anthropic.Anthropic( + api_key=conf().get("claude_api_key") + ) + openai.api_key = conf().get("open_ai_api_key") + if conf().get("open_ai_api_base"): + openai.api_base = conf().get("open_ai_api_base") + proxy = conf().get("proxy") + if proxy: + openai.proxy = proxy + + self.sessions = SessionManager(ClaudeAPISession, model=conf().get("model") or "text-davinci-003") + + def reply(self, query, context=None): + # acquire reply content + if context and context.type: + if context.type == ContextType.TEXT: + logger.info("[CLAUDE_API] query={}".format(query)) + session_id = context["session_id"] + reply = None + if query == "#清除记忆": + self.sessions.clear_session(session_id) + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": + self.sessions.clear_all_session() + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + else: + session = self.sessions.session_query(query, session_id) + result = self.reply_text(session) + logger.info(result) + total_tokens, completion_tokens, reply_content = ( + result["total_tokens"], + result["completion_tokens"], + result["content"], + ) + logger.debug( + "[CLAUDE_API] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens) + ) + + if total_tokens == 0: + reply = Reply(ReplyType.ERROR, reply_content) + else: + self.sessions.session_reply(reply_content, session_id, total_tokens) + reply = Reply(ReplyType.TEXT, reply_content) + return reply + elif context.type == ContextType.IMAGE_CREATE: + ok, retstring = self.create_img(query, 0) + reply = None + if ok: + reply = Reply(ReplyType.IMAGE_URL, retstring) + else: + reply = Reply(ReplyType.ERROR, retstring) + return reply + + def reply_text(self, session: ClaudeAPISession, retry_count=0): + try: + logger.info("[CLAUDE_API] sendMessage={}".format(str(session))) + response = self.claudeClient.messages.create( + model=conf().get("model"), + max_tokens=1024, + # system=conf().get("system"), + messages=[ + {"role": "user", "content": "{}".format(str(session))} + ] + ) + # response = openai.Completion.create(prompt=str(session), **self.args) + res_content = response.content[0].text.strip().replace("<|endoftext|>", "") + total_tokens = response.usage.input_tokens+response.usage.output_tokens + completion_tokens = response.usage.output_tokens + logger.info("[CLAUDE_API] reply={}".format(res_content)) + return { + "total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "content": res_content, + } + except Exception as e: + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if isinstance(e, openai.error.RateLimitError): + logger.warn("[CLAUDE_API] RateLimitError: {}".format(e)) + result["content"] = "提问太快啦,请休息一下再问我吧" + if need_retry: + time.sleep(20) + elif isinstance(e, openai.error.Timeout): + logger.warn("[CLAUDE_API] Timeout: {}".format(e)) + result["content"] = "我没有收到你的消息" + if need_retry: + time.sleep(5) + elif isinstance(e, openai.error.APIConnectionError): + logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e)) + need_retry = False + result["content"] = "我连接不到你的网络" + else: + logger.warn("[CLAUDE_API] Exception: {}".format(e)) + need_retry = False + self.sessions.clear_session(session.session_id) + + if need_retry: + logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1)) + return self.reply_text(session, retry_count + 1) + else: + return result diff --git a/bot/claudeapi/claude_api_session.py b/bot/claudeapi/claude_api_session.py new file mode 100644 index 0000000..a5e9b54 --- /dev/null +++ b/bot/claudeapi/claude_api_session.py @@ -0,0 +1,74 @@ +from bot.session_manager import Session +from common.log import logger + + +class ClaudeAPISession(Session): + def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): + super().__init__(session_id, system_prompt) + self.model = model + self.reset() + + def __str__(self): + # 构造对话模型的输入 + """ + e.g. Q: xxx + A: xxx + Q: xxx + """ + prompt = "" + for item in self.messages: + if item["role"] == "system": + prompt += item["content"] + "<|endoftext|>\n\n\n" + elif item["role"] == "user": + prompt += "Q: " + item["content"] + "\n" + elif item["role"] == "assistant": + prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" + + if len(self.messages) > 0 and self.messages[-1]["role"] == "user": + prompt += "A: " + return prompt + + def discard_exceeding(self, max_tokens, cur_tokens=None): + precise = True + + try: + cur_tokens = self.calc_tokens() + except Exception as e: + precise = False + if cur_tokens is None: + raise e + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + while cur_tokens > max_tokens: + if len(self.messages) > 1: + self.messages.pop(0) + elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": + self.messages.pop(0) + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = len(str(self)) + break + elif len(self.messages) == 1 and self.messages[0]["role"] == "user": + logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) + break + else: + logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) + break + if precise: + cur_tokens = self.calc_tokens() + else: + cur_tokens = len(str(self)) + return cur_tokens + def calc_tokens(self): + return num_tokens_from_string(str(self), self.model) + + +# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +def num_tokens_from_string(string: str, model: str) -> int: + """Returns the number of tokens in a text string.""" + num_tokens = len(string) + return num_tokens + + + + diff --git a/bridge/bridge.py b/bridge/bridge.py index 88e6b18..2c76844 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -18,6 +18,7 @@ class Bridge(object): "text_to_voice": conf().get("text_to_voice", "google"), "translate": conf().get("translate", "baidu"), } + # 这边取配置的模型 model_type = conf().get("model") or const.GPT35 if model_type in ["text-davinci-003"]: self.btype["chat"] = const.OPEN_AI @@ -33,6 +34,8 @@ class Bridge(object): self.btype["chat"] = const.GEMINI if model_type in [const.ZHIPU_AI]: self.btype["chat"] = const.ZHIPU_AI + if model_type in [const.CLAUDE3]: + self.btype["chat"] = const.CLAUDEAPI if conf().get("use_linkai") and conf().get("linkai_api_key"): self.btype["chat"] = const.LINKAI @@ -40,12 +43,12 @@ class Bridge(object): self.btype["voice_to_text"] = const.LINKAI if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: self.btype["text_to_voice"] = const.LINKAI - if model_type in ["claude"]: self.btype["chat"] = const.CLAUDEAI + self.bots = {} self.chat_bots = {} - + # 模型对应的接口 def get_bot(self, typename): if self.bots.get(typename) is None: logger.info("create bot {} for {}".format(self.btype[typename], typename)) diff --git a/common/const.py b/common/const.py index aeb9dcc..ce11398 100644 --- a/common/const.py +++ b/common/const.py @@ -6,12 +6,14 @@ XUNFEI = "xunfei" CHATGPTONAZURE = "chatGPTOnAzure" LINKAI = "linkai" CLAUDEAI = "claude" +CLAUDEAPI= "claudeAPI" QWEN = "qwen" GEMINI = "gemini" ZHIPU_AI = "glm-4" # model +CLAUDE3="claude-3-opus-20240229" GPT35 = "gpt-3.5-turbo" GPT4 = "gpt-4" GPT4_TURBO_PREVIEW = "gpt-4-0125-preview" @@ -20,7 +22,7 @@ WHISPER_1 = "whisper-1" TTS_1 = "tts-1" TTS_1_HD = "tts-1-hd" -MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", +MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo", "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI] # channel diff --git a/config-template.json b/config-template.json index bdaadde..f3b253d 100644 --- a/config-template.json +++ b/config-template.json @@ -2,6 +2,7 @@ "channel_type": "wx", "model": "", "open_ai_api_key": "YOUR API KEY", + "claude_api_key": "YOUR API KEY", "text_to_image": "dall-e-2", "voice_to_text": "openai", "text_to_voice": "openai", diff --git a/config.py b/config.py index 154c633..29a8f54 100644 --- a/config.py +++ b/config.py @@ -67,6 +67,8 @@ available_setting = { # claude 配置 "claude_api_cookie": "", "claude_uuid": "", + # claude api key + "claude_api_key":"", # 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html "qwen_access_key_id": "", "qwen_access_key_secret": "", diff --git a/requirements.txt b/requirements.txt index c032e08..f49bdfa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ chardet>=5.1.0 Pillow pre-commit web.py +anthropic \ No newline at end of file