From 160a028705db610bd35103f52e7ab3b46da4d577 Mon Sep 17 00:00:00 2001 From: H Vs Date: Fri, 24 Jan 2025 15:01:11 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot/bot.py | 17 -- bot/bot_factory.py | 72 ------- bot/chatgpt/chat_gpt_bot.py | 323 -------------------------------- bot/chatgpt/chat_gpt_session.py | 104 ---------- bot/openai/open_ai_bot.py | 122 ------------ bot/openai/open_ai_image.py | 43 ----- bot/openai/open_ai_session.py | 73 -------- bot/session_manager.py | 105 ----------- 8 files changed, 859 deletions(-) delete mode 100644 bot/bot.py delete mode 100644 bot/bot_factory.py delete mode 100644 bot/chatgpt/chat_gpt_bot.py delete mode 100644 bot/chatgpt/chat_gpt_session.py delete mode 100644 bot/openai/open_ai_bot.py delete mode 100644 bot/openai/open_ai_image.py delete mode 100644 bot/openai/open_ai_session.py delete mode 100644 bot/session_manager.py diff --git a/bot/bot.py b/bot/bot.py deleted file mode 100644 index ca6e1aa..0000000 --- a/bot/bot.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Auto-replay chat robot abstract class -""" - - -from bridge.context import Context -from bridge.reply import Reply - - -class Bot(object): - def reply(self, query, context: Context = None) -> Reply: - """ - bot auto-reply content - :param req: received message - :return: reply content - """ - raise NotImplementedError diff --git a/bot/bot_factory.py b/bot/bot_factory.py deleted file mode 100644 index 50b4d3b..0000000 --- a/bot/bot_factory.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -channel factory -""" -from common import const - - -def create_bot(bot_type): - """ - create a bot_type instance - :param bot_type: bot type code - :return: bot instance - """ - # if bot_type == const.BAIDU: - # # 替换Baidu Unit为Baidu文心千帆对话接口 - # # from bot.baidu.baidu_unit_bot import BaiduUnitBot - # # return BaiduUnitBot() - # from bot.baidu.baidu_wenxin import BaiduWenxinBot - # return BaiduWenxinBot() - - if bot_type == const.CHATGPT: - # ChatGPT 网页端web接口 - from bot.chatgpt.chat_gpt_bot import ChatGPTBot - return ChatGPTBot() - - elif bot_type == const.OPEN_AI: - # OpenAI 官方对话模型API - from bot.openai.open_ai_bot import OpenAIBot - return OpenAIBot() - - # elif bot_type == const.CHATGPTONAZURE: - # # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/ - # from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot - # return AzureChatGPTBot() - - # elif bot_type == const.XUNFEI: - # from bot.xunfei.xunfei_spark_bot import XunFeiBot - # return XunFeiBot() - - # elif bot_type == const.LINKAI: - # from bot.linkai.link_ai_bot import LinkAIBot - # return LinkAIBot() - - # elif bot_type == const.CLAUDEAI: - # from bot.claude.claude_ai_bot import ClaudeAIBot - # return ClaudeAIBot() - # elif bot_type == const.CLAUDEAPI: - # from bot.claudeapi.claude_api_bot import ClaudeAPIBot - # return ClaudeAPIBot() - # elif bot_type == const.QWEN: - # from bot.ali.ali_qwen_bot import AliQwenBot - # return AliQwenBot() - # elif bot_type == const.QWEN_DASHSCOPE: - # from bot.dashscope.dashscope_bot import DashscopeBot - # return DashscopeBot() - # elif bot_type == const.GEMINI: - # from bot.gemini.google_gemini_bot import GoogleGeminiBot - # return GoogleGeminiBot() - - # elif bot_type == const.ZHIPU_AI: - # from bot.zhipuai.zhipuai_bot import ZHIPUAIBot - # return ZHIPUAIBot() - - # elif bot_type == const.MOONSHOT: - # from bot.moonshot.moonshot_bot import MoonshotBot - # return MoonshotBot() - - # elif bot_type == const.MiniMax: - # from bot.minimax.minimax_bot import MinimaxBot - # return MinimaxBot() - - - raise RuntimeError diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py deleted file mode 100644 index 2a8a091..0000000 --- a/bot/chatgpt/chat_gpt_bot.py +++ /dev/null @@ -1,323 +0,0 @@ -# encoding:utf-8 - -import time - -import openai -import openai.error -import requests -import json - -from bot.bot import Bot -from bot.chatgpt.chat_gpt_session import ChatGPTSession -from bot.openai.open_ai_image import OpenAIImage -from bot.session_manager import SessionManager -from bridge.context import ContextType -from bridge.reply import Reply, ReplyType -from common.log import logger -from common.token_bucket import TokenBucket -from config import conf, load_config -from channel.chat_message import ChatMessage - -from common import memory - - -# OpenAI对话模型API (可用) -class ChatGPTBot(Bot, OpenAIImage): - def __init__(self): - super().__init__() - # set the default api_key - openai.api_key = conf().get("open_ai_api_key") - if conf().get("open_ai_api_base"): - openai.api_base = conf().get("open_ai_api_base") - proxy = conf().get("proxy") - if proxy: - openai.proxy = proxy - if conf().get("rate_limit_chatgpt"): - self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) - - self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") - self.args = { - "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 - "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 - # "max_tokens":4096, # 回复最大的字符数 - "top_p": conf().get("top_p", 1), - "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 - "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 - } - - def reply(self, query, context=None): - # acquire reply content - if context.type == ContextType.TEXT: - # print(context.__dict__) - msg: ChatMessage = context.kwargs['msg'] - # print(msg.from_user_nickname) - logger.info("[CHATGPT] {} query={}".format(msg.from_user_nickname,query)) - - session_id = context["session_id"] - # print(f'会话id:{session_id}') - reply = None - clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) - if query in clear_memory_commands: - self.sessions.clear_session(session_id) - reply = Reply(ReplyType.INFO, "记忆已清除") - elif query == "#清除所有": - self.sessions.clear_all_session() - reply = Reply(ReplyType.INFO, "所有人记忆已清除") - elif query == "#更新配置": - load_config() - reply = Reply(ReplyType.INFO, "配置已更新") - if reply: - return reply - session = self.sessions.session_query(query, session_id) - logger.debug("[CHATGPT] session query={}".format(session.messages)) - - api_key = context.get("openai_api_key") - model = context.get("gpt_model") - new_args = None - if model: - new_args = self.args.copy() - new_args["model"] = model - # if context.get('stream'): - # # reply in stream - # return self.reply_text_stream(query, new_query, session_id) - - reply_content = self.reply_text(session, api_key, args=new_args) - logger.debug( - "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( - session.messages, - session_id, - reply_content["content"], - reply_content["completion_tokens"], - ) - ) - if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: - reply = Reply(ReplyType.ERROR, reply_content["content"]) - elif reply_content["completion_tokens"] > 0: - self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) - reply = Reply(ReplyType.TEXT, reply_content["content"]) - else: - reply = Reply(ReplyType.ERROR, reply_content["content"]) - logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content)) - return reply - - elif context.type == ContextType.IMAGE_CREATE: - ok, retstring = self.create_img(query, 0) - reply = None - if ok: - reply = Reply(ReplyType.IMAGE_URL, retstring) - else: - reply = Reply(ReplyType.ERROR, retstring) - return reply - else: - reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) - return reply - - def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict: - """ - call openai's ChatCompletion to get the answer - :param session: a conversation session - :param session_id: session id - :param retry_count: retry count - :return: {} - """ - try: - if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): - raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") - # if api_key == None, the default openai.api_key will be used - if args is None: - args = self.args - - # Define additional parameters - additional_params = { - "chatId": session.session_id, - "detail": True - } - - # Combine the additional params with the existing args (if any) - args.update(additional_params) - # msgs=session.messages - - - # cache_data = memory.USER_INTERACTIVE_CACHE.get(session.session_id) - - # # Determine messages to send based on cache data - # messages_to_send = msgs[-1] if cache_data and cache_data.get('interactive') else msgs - # print(msgs[-1]) - # print('----------------') - # # Send the response using OpenAI API - # response = openai.ChatCompletion.create(api_key=api_key, messages=messages_to_send, **args) - messages_to_send=session.messages - - cache_data = memory.USER_INTERACTIVE_CACHE.get(session.session_id) - if cache_data and cache_data.get('interactive'): - messages_to_send=[session.messages[-1]] - print(messages_to_send) - response = openai.ChatCompletion.create(api_key=api_key, messages=messages_to_send, **args) - # print("{}".format(session.__dict__)) - logger.info("[CHATGPT] 请求={}".format(messages_to_send)) - # print(f'会话id:{session.session_id}') - # logger.info("[CHATGPT] 响应={}".format(response)) - logger.info("[CHATGPT] 响应={}".format(json.dumps(response, separators=(',', ':'),ensure_ascii=False))) - # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) - content=response.choices[0]["message"]["content"] - description = '' - userSelectOptions = [] - if isinstance(content, list) and any(item.get("type") == "interactive" for item in content): - # print(content) - for item in content: - if item["type"] == "interactive" and item["interactive"]["type"] == "userSelect": - params = item["interactive"]["params"] - description = params.get("description") - userSelectOptions = params.get("userSelectOptions", []) - values_string = "\n".join(option["value"] for option in userSelectOptions) - if description is not None: - memory.USER_INTERACTIVE_CACHE[session.session_id] = { - "interactive":True - } - return { - "total_tokens": response["usage"]["total_tokens"], - "completion_tokens": response["usage"]["completion_tokens"], - "content": description + '------------------------------\n'+values_string, - } - - elif isinstance(content, list) and any(item.get("type") == "text" for item in content): - memory.USER_INTERACTIVE_CACHE[session.session_id] = { - "interactive":False - } - text='' - for item in content: - if item["type"] == "text": - text=item["text"]["content"] - - if text=='': - args.pop('chatId', None) # The second argument (None) is the default return value if the key doesn't exist - args.pop('detail', None) - response = openai.ChatCompletion.create(api_key=api_key, messages=messages_to_send, **args) - text=response.choices[0]["message"]["content"] - return { - "total_tokens": response["usage"]["total_tokens"], - "completion_tokens": response["usage"]["completion_tokens"], - "content": text, - } - - else: - memory.USER_INTERACTIVE_CACHE[session.session_id] = { - "interactive":False - } - return { - "total_tokens": response["usage"]["total_tokens"], - "completion_tokens": response["usage"]["completion_tokens"], - "content": content.lstrip("\n"), - } - except Exception as e: - need_retry = retry_count < 2 - result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} - if isinstance(e, openai.error.RateLimitError): - logger.warn("[CHATGPT] RateLimitError: {}".format(e)) - result["content"] = "提问太快啦,请休息一下再问我吧" - if need_retry: - time.sleep(20) - elif isinstance(e, openai.error.Timeout): - logger.warn("[CHATGPT] Timeout: {}".format(e)) - result["content"] = "我没有收到你的消息" - if need_retry: - time.sleep(5) - elif isinstance(e, openai.error.APIError): - logger.warn("[CHATGPT] Bad Gateway: {}".format(e)) - result["content"] = "请再问我一次" - if need_retry: - time.sleep(10) - elif isinstance(e, openai.error.APIConnectionError): - logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) - result["content"] = "我连接不到你的网络" - if need_retry: - time.sleep(5) - else: - logger.exception("[CHATGPT] Exception: {}".format(e)) - need_retry = False - self.sessions.clear_session(session.session_id) - - if need_retry: - logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) - return self.reply_text(session, api_key, args, retry_count + 1) - else: - return result - - -class AzureChatGPTBot(ChatGPTBot): - def __init__(self): - super().__init__() - openai.api_type = "azure" - openai.api_version = conf().get("azure_api_version", "2023-06-01-preview") - self.args["deployment_id"] = conf().get("azure_deployment_id") - - def create_img(self, query, retry_count=0, api_key=None): - text_to_image_model = conf().get("text_to_image") - if text_to_image_model == "dall-e-2": - api_version = "2023-06-01-preview" - endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base") - # 检查endpoint是否以/结尾 - if not endpoint.endswith("/"): - endpoint = endpoint + "/" - url = "{}openai/images/generations:submit?api-version={}".format(endpoint, api_version) - api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key") - headers = {"api-key": api_key, "Content-Type": "application/json"} - try: - body = {"prompt": query, "size": conf().get("image_create_size", "256x256"),"n": 1} - submission = requests.post(url, headers=headers, json=body) - operation_location = submission.headers['operation-location'] - status = "" - while (status != "succeeded"): - if retry_count > 3: - return False, "图片生成失败" - response = requests.get(operation_location, headers=headers) - status = response.json()['status'] - retry_count += 1 - image_url = response.json()['result']['data'][0]['url'] - return True, image_url - except Exception as e: - logger.error("create image error: {}".format(e)) - return False, "图片生成失败" - elif text_to_image_model == "dall-e-3": - api_version = conf().get("azure_api_version", "2024-02-15-preview") - endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base") - # 检查endpoint是否以/结尾 - if not endpoint.endswith("/"): - endpoint = endpoint + "/" - url = "{}openai/deployments/{}/images/generations?api-version={}".format(endpoint, conf().get("azure_openai_dalle_deployment_id","text_to_image"),api_version) - api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key") - headers = {"api-key": api_key, "Content-Type": "application/json"} - try: - body = {"prompt": query, "size": conf().get("image_create_size", "1024x1024"), "quality": conf().get("dalle3_image_quality", "standard")} - response = requests.post(url, headers=headers, json=body) - response.raise_for_status() # 检查请求是否成功 - data = response.json() - - # 检查响应中是否包含图像 URL - if 'data' in data and len(data['data']) > 0 and 'url' in data['data'][0]: - image_url = data['data'][0]['url'] - return True, image_url - else: - error_message = "响应中没有图像 URL" - logger.error(error_message) - return False, "图片生成失败" - - except requests.exceptions.RequestException as e: - # 捕获所有请求相关的异常 - try: - error_detail = response.json().get('error', {}).get('message', str(e)) - except ValueError: - error_detail = str(e) - error_message = f"{error_detail}" - logger.error(error_message) - return False, error_message - - except Exception as e: - # 捕获所有其他异常 - error_message = f"生成图像时发生错误: {e}" - logger.error(error_message) - return False, "图片生成失败" - else: - return False, "图片生成失败,未配置text_to_image参数" diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py deleted file mode 100644 index 4d50297..0000000 --- a/bot/chatgpt/chat_gpt_session.py +++ /dev/null @@ -1,104 +0,0 @@ -from bot.session_manager import Session -from common.log import logger -from common import const - -""" - e.g. [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Who won the world series in 2020?"}, - {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, - {"role": "user", "content": "Where was it played?"} - ] -""" - - -class ChatGPTSession(Session): - def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"): - super().__init__(session_id, system_prompt) - self.model = model - self.reset() - - def discard_exceeding(self, max_tokens, cur_tokens=None): - precise = True - try: - cur_tokens = self.calc_tokens() - except Exception as e: - precise = False - if cur_tokens is None: - raise e - logger.debug("Exception when counting tokens precisely for query: {}".format(e)) - while cur_tokens > max_tokens: - if len(self.messages) > 2: - self.messages.pop(1) - elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": - self.messages.pop(1) - if precise: - cur_tokens = self.calc_tokens() - else: - cur_tokens = cur_tokens - max_tokens - break - elif len(self.messages) == 2 and self.messages[1]["role"] == "user": - logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) - break - else: - logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) - break - if precise: - cur_tokens = self.calc_tokens() - else: - cur_tokens = cur_tokens - max_tokens - return cur_tokens - - def calc_tokens(self): - return num_tokens_from_messages(self.messages, self.model) - - -# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb -def num_tokens_from_messages(messages, model): - """Returns the number of tokens used by a list of messages.""" - - if model in ["wenxin", "xunfei", const.GEMINI]: - return num_tokens_by_character(messages) - - import tiktoken - - if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot", const.LINKAI_35]: - return num_tokens_from_messages(messages, model="gpt-3.5-turbo") - elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview", - "gpt-4-1106-preview",const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW, const.GPT4_TURBO_01_25, - const.GPT_4o, const.GPT_4O_0806, const.GPT_4o_MINI, const.LINKAI_4o, const.LINKAI_4_TURBO]: - return num_tokens_from_messages(messages, model="gpt-4") - elif model.startswith("claude-3"): - return num_tokens_from_messages(messages, model="gpt-3.5-turbo") - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.debug("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - if model == "gpt-3.5-turbo": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - tokens_per_name = -1 # if there's a name, the role is omitted - elif model == "gpt-4": - tokens_per_message = 3 - tokens_per_name = 1 - else: - logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.") - return num_tokens_from_messages(messages, model="gpt-3.5-turbo") - num_tokens = 0 - for message in messages: - num_tokens += tokens_per_message - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name - num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> - return num_tokens - - -def num_tokens_by_character(messages): - """Returns the number of tokens used by a list of messages.""" - tokens = 0 - for msg in messages: - tokens += len(msg["content"]) - return tokens diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py deleted file mode 100644 index 1605625..0000000 --- a/bot/openai/open_ai_bot.py +++ /dev/null @@ -1,122 +0,0 @@ -# encoding:utf-8 - -import time - -import openai -import openai.error - -from bot.bot import Bot -from bot.openai.open_ai_image import OpenAIImage -from bot.openai.open_ai_session import OpenAISession -from bot.session_manager import SessionManager -from bridge.context import ContextType -from bridge.reply import Reply, ReplyType -from common.log import logger -from config import conf - -user_session = dict() - - -# OpenAI对话模型API (可用) -class OpenAIBot(Bot, OpenAIImage): - def __init__(self): - super().__init__() - openai.api_key = conf().get("open_ai_api_key") - if conf().get("open_ai_api_base"): - openai.api_base = conf().get("open_ai_api_base") - proxy = conf().get("proxy") - if proxy: - openai.proxy = proxy - - self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003") - self.args = { - "model": conf().get("model") or "text-davinci-003", # 对话模型的名称 - "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 - "max_tokens": 1200, # 回复最大的字符数 - "top_p": 1, - "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 - "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 - "stop": ["\n\n\n"], - } - - def reply(self, query, context=None): - # acquire reply content - if context and context.type: - if context.type == ContextType.TEXT: - logger.info("[OPEN_AI] query={}".format(query)) - session_id = context["session_id"] - reply = None - if query == "#清除记忆": - self.sessions.clear_session(session_id) - reply = Reply(ReplyType.INFO, "记忆已清除") - elif query == "#清除所有": - self.sessions.clear_all_session() - reply = Reply(ReplyType.INFO, "所有人记忆已清除") - else: - session = self.sessions.session_query(query, session_id) - result = self.reply_text(session) - total_tokens, completion_tokens, reply_content = ( - result["total_tokens"], - result["completion_tokens"], - result["content"], - ) - logger.debug( - "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens) - ) - - if total_tokens == 0: - reply = Reply(ReplyType.ERROR, reply_content) - else: - self.sessions.session_reply(reply_content, session_id, total_tokens) - reply = Reply(ReplyType.TEXT, reply_content) - return reply - elif context.type == ContextType.IMAGE_CREATE: - ok, retstring = self.create_img(query, 0) - reply = None - if ok: - reply = Reply(ReplyType.IMAGE_URL, retstring) - else: - reply = Reply(ReplyType.ERROR, retstring) - return reply - - def reply_text(self, session: OpenAISession, retry_count=0): - try: - response = openai.Completion.create(prompt=str(session), **self.args) - res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "") - total_tokens = response["usage"]["total_tokens"] - completion_tokens = response["usage"]["completion_tokens"] - logger.info("[OPEN_AI] reply={}".format(res_content)) - return { - "total_tokens": total_tokens, - "completion_tokens": completion_tokens, - "content": res_content, - } - except Exception as e: - need_retry = retry_count < 2 - result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} - if isinstance(e, openai.error.RateLimitError): - logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) - result["content"] = "提问太快啦,请休息一下再问我吧" - if need_retry: - time.sleep(20) - elif isinstance(e, openai.error.Timeout): - logger.warn("[OPEN_AI] Timeout: {}".format(e)) - result["content"] = "我没有收到你的消息" - if need_retry: - time.sleep(5) - elif isinstance(e, openai.error.APIConnectionError): - logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) - need_retry = False - result["content"] = "我连接不到你的网络" - else: - logger.warn("[OPEN_AI] Exception: {}".format(e)) - need_retry = False - self.sessions.clear_session(session.session_id) - - if need_retry: - logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1)) - return self.reply_text(session, retry_count + 1) - else: - return result diff --git a/bot/openai/open_ai_image.py b/bot/openai/open_ai_image.py deleted file mode 100644 index 3ff56c1..0000000 --- a/bot/openai/open_ai_image.py +++ /dev/null @@ -1,43 +0,0 @@ -import time - -import openai -import openai.error - -from common.log import logger -from common.token_bucket import TokenBucket -from config import conf - - -# OPENAI提供的画图接口 -class OpenAIImage(object): - def __init__(self): - openai.api_key = conf().get("open_ai_api_key") - if conf().get("rate_limit_dalle"): - self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50)) - - def create_img(self, query, retry_count=0, api_key=None, api_base=None): - try: - if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token(): - return False, "请求太快了,请休息一下再问我吧" - logger.info("[OPEN_AI] image_query={}".format(query)) - response = openai.Image.create( - api_key=api_key, - prompt=query, # 图片描述 - n=1, # 每次生成图片的数量 - model=conf().get("text_to_image") or "dall-e-2", - # size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024 - ) - image_url = response["data"][0]["url"] - logger.info("[OPEN_AI] image_url={}".format(image_url)) - return True, image_url - except openai.error.RateLimitError as e: - logger.warn(e) - if retry_count < 1: - time.sleep(5) - logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1)) - return self.create_img(query, retry_count + 1) - else: - return False, "画图出现问题,请休息一下再问我吧" - except Exception as e: - logger.exception(e) - return False, "画图出现问题,请休息一下再问我吧" diff --git a/bot/openai/open_ai_session.py b/bot/openai/open_ai_session.py deleted file mode 100644 index 8f6aa4f..0000000 --- a/bot/openai/open_ai_session.py +++ /dev/null @@ -1,73 +0,0 @@ -from bot.session_manager import Session -from common.log import logger - - -class OpenAISession(Session): - def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): - super().__init__(session_id, system_prompt) - self.model = model - self.reset() - - def __str__(self): - # 构造对话模型的输入 - """ - e.g. Q: xxx - A: xxx - Q: xxx - """ - prompt = "" - for item in self.messages: - if item["role"] == "system": - prompt += item["content"] + "<|endoftext|>\n\n\n" - elif item["role"] == "user": - prompt += "Q: " + item["content"] + "\n" - elif item["role"] == "assistant": - prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" - - if len(self.messages) > 0 and self.messages[-1]["role"] == "user": - prompt += "A: " - return prompt - - def discard_exceeding(self, max_tokens, cur_tokens=None): - precise = True - try: - cur_tokens = self.calc_tokens() - except Exception as e: - precise = False - if cur_tokens is None: - raise e - logger.debug("Exception when counting tokens precisely for query: {}".format(e)) - while cur_tokens > max_tokens: - if len(self.messages) > 1: - self.messages.pop(0) - elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": - self.messages.pop(0) - if precise: - cur_tokens = self.calc_tokens() - else: - cur_tokens = len(str(self)) - break - elif len(self.messages) == 1 and self.messages[0]["role"] == "user": - logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) - break - else: - logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) - break - if precise: - cur_tokens = self.calc_tokens() - else: - cur_tokens = len(str(self)) - return cur_tokens - - def calc_tokens(self): - return num_tokens_from_string(str(self), self.model) - - -# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb -def num_tokens_from_string(string: str, model: str) -> int: - """Returns the number of tokens in a text string.""" - import tiktoken - - encoding = tiktoken.encoding_for_model(model) - num_tokens = len(encoding.encode(string, disallowed_special=())) - return num_tokens diff --git a/bot/session_manager.py b/bot/session_manager.py deleted file mode 100644 index 80cd459..0000000 --- a/bot/session_manager.py +++ /dev/null @@ -1,105 +0,0 @@ -from common.expired_dict import ExpiredDict -from common.log import logger -from config import conf -import json - - -class Session(object): - def __init__(self, session_id, system_prompt=None): - self.session_id = session_id - self.messages = [] - if system_prompt is None: - self.system_prompt = conf().get("character_desc", "") - else: - self.system_prompt = system_prompt - - # 重置会话 - def reset(self): - system_item = {"role": "system", "content": self.system_prompt} - self.messages = [system_item] - - def set_system_prompt(self, system_prompt): - self.system_prompt = system_prompt - self.reset() - - # def add_query(self, query): - # user_item = {"role": "user", "content": query} - # self.messages.append(user_item) - - def add_query(self, query): - try: - # 判断是否为 JSON 字符串,如果是则转换为 Python 字典 - json_data = json.loads(query) - if isinstance(json_data, dict) or isinstance(json_data, list): # 检查是否为字典格式 - user_item = {"role": "user", "content": json_data} - else: - user_item = {"role": "user", "content": query} - except json.JSONDecodeError: - # 如果不是 JSON 字符串,直接保存为字符串 - user_item = {"role": "user", "content": query} - self.messages.append(user_item) - - def add_reply(self, reply): - assistant_item = {"role": "assistant", "content": reply} - self.messages.append(assistant_item) - - def discard_exceeding(self, max_tokens=None, cur_tokens=None): - raise NotImplementedError - - def calc_tokens(self): - raise NotImplementedError - - -class SessionManager(object): - def __init__(self, sessioncls, **session_args): - if conf().get("expires_in_seconds"): - sessions = ExpiredDict(conf().get("expires_in_seconds")) - else: - sessions = dict() - self.sessions = sessions - self.sessioncls = sessioncls - self.session_args = session_args - - def build_session(self, session_id, system_prompt=None): - """ - 如果session_id不在sessions中,创建一个新的session并添加到sessions中 - 如果system_prompt不会空,会更新session的system_prompt并重置session - """ - if session_id is None: - return self.sessioncls(session_id, system_prompt, **self.session_args) - - if session_id not in self.sessions: - self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args) - elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session - self.sessions[session_id].set_system_prompt(system_prompt) - session = self.sessions[session_id] - return session - - def session_query(self, query, session_id): - session = self.build_session(session_id) - session.add_query(query) - try: - max_tokens = conf().get("conversation_max_tokens", 1000) - total_tokens = session.discard_exceeding(max_tokens, None) - logger.debug("prompt tokens used={}".format(total_tokens)) - except Exception as e: - logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e))) - return session - - def session_reply(self, reply, session_id, total_tokens=None): - session = self.build_session(session_id) - session.add_reply(reply) - try: - max_tokens = conf().get("conversation_max_tokens", 1000) - tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) - logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) - except Exception as e: - logger.warning("Exception when counting tokens precisely for session: {}".format(str(e))) - return session - - def clear_session(self, session_id): - if session_id in self.sessions: - del self.sessions[session_id] - - def clear_all_session(self): - self.sessions.clear()