From 5346dfdd8b52a2291f59d8250eb0f17dedd30b93 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Mon, 5 Feb 2024 12:21:50 +0800 Subject: [PATCH] feat: code tidying up --- bot/zhipu/chat_glm_bot.py | 155 ---------------------------------- bot/zhipu/chat_glm_session.py | 48 ----------- 2 files changed, 203 deletions(-) delete mode 100644 bot/zhipu/chat_glm_bot.py delete mode 100644 bot/zhipu/chat_glm_session.py diff --git a/bot/zhipu/chat_glm_bot.py b/bot/zhipu/chat_glm_bot.py deleted file mode 100644 index e2127dd..0000000 --- a/bot/zhipu/chat_glm_bot.py +++ /dev/null @@ -1,155 +0,0 @@ -# encoding:utf-8 - -import time - -import openai -import openai.error -import requests - -from bot.bot import Bot -from bot.zhipu.chat_glm_session import ChatGLMSession -from bot.openai.open_ai_image import OpenAIImage -from bot.session_manager import SessionManager -from bridge.context import ContextType -from bridge.reply import Reply, ReplyType -from common.log import logger -# from common.token_bucket import TokenBucket -from config import conf, load_config -from zhipuai import ZhipuAI - - -# ZhipuAI对话模型API -class ChatGLMBot(Bot): - def __init__(self): - super().__init__() - # set the default api_key - self.api_key = conf().get("zhipu_ai_api_key") - if conf().get("zhipu_ai_api_base"): - openai.api_base = conf().get("zhipu_ai_api_base") - # if conf().get("rate_limit_chatgpt"): - # self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) - - self.sessions = SessionManager(ChatGLMSession, model=conf().get("model") or "chatglm") - self.args = { - "model": "glm-4", # 对话模型的名称 - "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 - # "max_tokens":4096, # 回复最大的字符数 - "top_p": conf().get("top_p", 0.7), - # "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - # "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - # "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 - # "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 - } - self.client = ZhipuAI(api_key=self.api_key) - - def reply(self, query, context=None): - # acquire reply content - if context.type == ContextType.TEXT: - logger.info("[CHATGLM] query={}".format(query)) - - session_id = context["session_id"] - reply = None - clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) - if query in clear_memory_commands: - self.sessions.clear_session(session_id) - reply = Reply(ReplyType.INFO, "记忆已清除") - elif query == "#清除所有": - self.sessions.clear_all_session() - reply = Reply(ReplyType.INFO, "所有人记忆已清除") - elif query == "#更新配置": - load_config() - reply = Reply(ReplyType.INFO, "配置已更新") - if reply: - return reply - session = self.sessions.session_query(query, session_id) - logger.debug("[CHATGLM] session query={}".format(session.messages)) - - api_key = context.get("openai_api_key") or openai.api_key - model = context.get("gpt_model") - new_args = None - if model: - new_args = self.args.copy() - new_args["model"] = model - # if context.get('stream'): - # # reply in stream - # return self.reply_text_stream(query, new_query, session_id) - - reply_content = self.reply_text(session, api_key, args=new_args) - logger.debug( - "[CHATGLM] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( - session.messages, - session_id, - reply_content["content"], - reply_content["completion_tokens"], - ) - ) - if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: - reply = Reply(ReplyType.ERROR, reply_content["content"]) - elif reply_content["completion_tokens"] > 0: - self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) - reply = Reply(ReplyType.TEXT, reply_content["content"]) - else: - reply = Reply(ReplyType.ERROR, reply_content["content"]) - logger.debug("[CHATGLM] reply {} used 0 tokens.".format(reply_content)) - return reply - else: - reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) - return reply - - def reply_text(self, session: ChatGLMSession, api_key=None, args=None, retry_count=0) -> dict: - """ - call openai's ChatCompletion to get the answer - :param session: a conversation session - :param session_id: session id - :param retry_count: retry count - :return: {} - """ - try: - # if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): - # raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") - # if api_key == None, the default openai.api_key will be used - if args is None: - args = self.args - # response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args) - response = self.client.chat.completions.create(messages=session.messages, **args) - # logger.debug("[CHATGLM] response={}".format(response)) - # logger.info("[CHATGLM] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) - return { - "total_tokens": response.usage.total_tokens, - "completion_tokens": response.usage.completion_tokens, - "content": response.choices[0].message.content, - } - except Exception as e: - need_retry = retry_count < 2 - result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} - if isinstance(e, openai.error.RateLimitError): - logger.warn("[CHATGLM] RateLimitError: {}".format(e)) - result["content"] = "提问太快啦,请休息一下再问我吧" - if need_retry: - time.sleep(20) - elif isinstance(e, openai.error.Timeout): - logger.warn("[CHATGLM] Timeout: {}".format(e)) - result["content"] = "我没有收到你的消息" - if need_retry: - time.sleep(5) - elif isinstance(e, openai.error.APIError): - logger.warn("[CHATGLM] Bad Gateway: {}".format(e)) - result["content"] = "请再问我一次" - if need_retry: - time.sleep(10) - elif isinstance(e, openai.error.APIConnectionError): - logger.warn("[CHATGLM] APIConnectionError: {}".format(e)) - result["content"] = "我连接不到你的网络" - if need_retry: - time.sleep(5) - else: - logger.exception("[CHATGLM] Exception: {}".format(e), e) - need_retry = False - self.sessions.clear_session(session.session_id) - - if need_retry: - logger.warn("[CHATGLM] 第{}次重试".format(retry_count + 1)) - return self.reply_text(session, api_key, args, retry_count + 1) - else: - return result - diff --git a/bot/zhipu/chat_glm_session.py b/bot/zhipu/chat_glm_session.py deleted file mode 100644 index ab3d62b..0000000 --- a/bot/zhipu/chat_glm_session.py +++ /dev/null @@ -1,48 +0,0 @@ -from bot.session_manager import Session -from common.log import logger - -class ChatGLMSession(Session): - def __init__(self, session_id, system_prompt=None, model="glm-4"): - super().__init__(session_id, system_prompt) - self.model = model - self.reset() - - def discard_exceeding(self, max_tokens, cur_tokens=None): - precise = True - try: - cur_tokens = self.calc_tokens() - except Exception as e: - precise = False - if cur_tokens is None: - raise e - logger.debug("Exception when counting tokens precisely for query: {}".format(e)) - while cur_tokens > max_tokens: - if len(self.messages) > 2: - self.messages.pop(1) - elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": - self.messages.pop(1) - if precise: - cur_tokens = self.calc_tokens() - else: - cur_tokens = cur_tokens - max_tokens - break - elif len(self.messages) == 2 and self.messages[1]["role"] == "user": - logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) - break - else: - logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) - break - if precise: - cur_tokens = self.calc_tokens() - else: - cur_tokens = cur_tokens - max_tokens - return cur_tokens - - def calc_tokens(self): - return num_tokens_from_messages(self.messages, self.model) - -def num_tokens_from_messages(messages, model): - tokens = 0 - for msg in messages: - tokens += len(msg["content"]) - return tokens