From 2989249e4b05c8dd033d368836f46a51342f838d Mon Sep 17 00:00:00 2001 From: lanvent Date: Thu, 13 Apr 2023 20:06:33 +0800 Subject: [PATCH] chore: add calc_tokens method on session --- bot/chatgpt/chat_gpt_bot.py | 8 ++++---- bot/chatgpt/chat_gpt_session.py | 9 ++++++--- bot/openai/open_ai_bot.py | 28 ++++++++++++++-------------- bot/openai/open_ai_session.py | 8 +++++--- bot/session_manager.py | 2 ++ 5 files changed, 31 insertions(+), 24 deletions(-) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index b392659..9e99809 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -58,7 +58,7 @@ class ChatGPTBot(Bot,OpenAIImage): # # reply in stream # return self.reply_text_stream(query, new_query, session_id) - reply_content = self.reply_text(session, session_id, api_key, 0) + reply_content = self.reply_text(session, api_key) logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: reply = Reply(ReplyType.ERROR, reply_content['content']) @@ -94,7 +94,7 @@ class ChatGPTBot(Bot,OpenAIImage): "timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 } - def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict: + def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict: ''' call openai's ChatCompletion to get the answer :param session: a conversation session @@ -133,11 +133,11 @@ class ChatGPTBot(Bot,OpenAIImage): else: logger.warn("[CHATGPT] Exception: {}".format(e)) need_retry = False - self.sessions.clear_session(session_id) + self.sessions.clear_session(session.session_id) if need_retry: logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) - return self.reply_text(session, session_id, api_key, retry_count+1) + return self.reply_text(session, api_key, retry_count+1) else: return result diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py index 90fe064..ed986a7 100644 --- a/bot/chatgpt/chat_gpt_session.py +++ b/bot/chatgpt/chat_gpt_session.py @@ -17,7 +17,7 @@ class ChatGPTSession(Session): def discard_exceeding(self, max_tokens, cur_tokens= None): precise = True try: - cur_tokens = num_tokens_from_messages(self.messages, self.model) + cur_tokens = self.calc_tokens() except Exception as e: precise = False if cur_tokens is None: @@ -29,7 +29,7 @@ class ChatGPTSession(Session): elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": self.messages.pop(1) if precise: - cur_tokens = num_tokens_from_messages(self.messages, self.model) + cur_tokens = self.calc_tokens() else: cur_tokens = cur_tokens - max_tokens break @@ -40,11 +40,14 @@ class ChatGPTSession(Session): logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) break if precise: - cur_tokens = num_tokens_from_messages(self.messages, self.model) + cur_tokens = self.calc_tokens() else: cur_tokens = cur_tokens - max_tokens return cur_tokens + def calc_tokens(self): + return num_tokens_from_messages(self.messages, self.model) + # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def num_tokens_from_messages(messages, model): diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index 422c7ea..4d88b99 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -42,11 +42,9 @@ class OpenAIBot(Bot, OpenAIImage): reply = Reply(ReplyType.INFO, '所有人记忆已清除') else: session = self.sessions.session_query(query, session_id) - new_query = str(session) - logger.debug("[OPEN_AI] session query={}".format(new_query)) - - total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0) - logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens)) + result = self.reply_text(session) + total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content'] + logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)) if total_tokens == 0 : reply = Reply(ReplyType.ERROR, reply_content) @@ -63,11 +61,11 @@ class OpenAIBot(Bot, OpenAIImage): reply = Reply(ReplyType.ERROR, retstring) return reply - def reply_text(self, query, session_id, retry_count=0): + def reply_text(self, session:OpenAISession, retry_count=0): try: response = openai.Completion.create( model= conf().get("model") or "text-davinci-003", # 对话模型的名称 - prompt=query, + prompt=str(session), temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性 max_tokens=1200, # 回复最大的字符数 top_p=1, @@ -79,31 +77,33 @@ class OpenAIBot(Bot, OpenAIImage): total_tokens = response["usage"]["total_tokens"] completion_tokens = response["usage"]["completion_tokens"] logger.info("[OPEN_AI] reply={}".format(res_content)) - return total_tokens, completion_tokens, res_content + return {"total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "content": res_content} except Exception as e: need_retry = retry_count < 2 - result = [0,0,"我现在有点累了,等会再来吧"] + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} if isinstance(e, openai.error.RateLimitError): logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) - result[2] = "提问太快啦,请休息一下再问我吧" + result['content'] = "提问太快啦,请休息一下再问我吧" if need_retry: time.sleep(5) elif isinstance(e, openai.error.Timeout): logger.warn("[OPEN_AI] Timeout: {}".format(e)) - result[2] = "我没有收到你的消息" + result['content'] = "我没有收到你的消息" if need_retry: time.sleep(5) elif isinstance(e, openai.error.APIConnectionError): logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) need_retry = False - result[2] = "我连接不到你的网络" + result['content'] = "我连接不到你的网络" else: logger.warn("[OPEN_AI] Exception: {}".format(e)) need_retry = False - self.sessions.clear_session(session_id) + self.sessions.clear_session(session.session_id) if need_retry: logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1)) - return self.reply_text(query, session_id, retry_count+1) + return self.reply_text(session, retry_count+1) else: return result \ No newline at end of file diff --git a/bot/openai/open_ai_session.py b/bot/openai/open_ai_session.py index 597611c..28dd7ec 100644 --- a/bot/openai/open_ai_session.py +++ b/bot/openai/open_ai_session.py @@ -29,7 +29,7 @@ class OpenAISession(Session): def discard_exceeding(self, max_tokens, cur_tokens= None): precise = True try: - cur_tokens = num_tokens_from_string(str(self), self.model) + cur_tokens = self.calc_tokens() except Exception as e: precise = False if cur_tokens is None: @@ -41,7 +41,7 @@ class OpenAISession(Session): elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": self.messages.pop(0) if precise: - cur_tokens = num_tokens_from_string(str(self), self.model) + cur_tokens = self.calc_tokens() else: cur_tokens = len(str(self)) break @@ -52,11 +52,13 @@ class OpenAISession(Session): logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) break if precise: - cur_tokens = num_tokens_from_string(str(self), self.model) + cur_tokens = self.calc_tokens() else: cur_tokens = len(str(self)) return cur_tokens + def calc_tokens(self): + return num_tokens_from_string(str(self), self.model) # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def num_tokens_from_string(string: str, model: str) -> int: diff --git a/bot/session_manager.py b/bot/session_manager.py index 0e20cd7..cb05b68 100644 --- a/bot/session_manager.py +++ b/bot/session_manager.py @@ -31,6 +31,8 @@ class Session(object): def discard_exceeding(self, max_tokens=None, cur_tokens=None): raise NotImplementedError + def calc_tokens(self): + raise NotImplementedError class SessionManager(object):