From 1171b04e93f8d1e78d1357969df7d79ddd7843b0 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Fri, 25 Aug 2023 12:24:16 +0800 Subject: [PATCH] fix: wenxin token discard bug --- bot/baidu/baidu_wenxin.py | 1 - bot/baidu/baidu_wenxin_session.py | 54 ++++++------------------------- 2 files changed, 10 insertions(+), 45 deletions(-) diff --git a/bot/baidu/baidu_wenxin.py b/bot/baidu/baidu_wenxin.py index 315459e..54997f0 100644 --- a/bot/baidu/baidu_wenxin.py +++ b/bot/baidu/baidu_wenxin.py @@ -2,7 +2,6 @@ import requests, json from bot.bot import Bot -from bridge.reply import Reply, ReplyType from bot.session_manager import SessionManager from bridge.context import ContextType from bridge.reply import Reply, ReplyType diff --git a/bot/baidu/baidu_wenxin_session.py b/bot/baidu/baidu_wenxin_session.py index aad8f71..5ba2f17 100644 --- a/bot/baidu/baidu_wenxin_session.py +++ b/bot/baidu/baidu_wenxin_session.py @@ -9,6 +9,7 @@ from common.log import logger ] """ + class BaiduWenxinSession(Session): def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"): super().__init__(session_id, system_prompt) @@ -17,7 +18,6 @@ class BaiduWenxinSession(Session): # self.reset() def discard_exceeding(self, max_tokens, cur_tokens=None): - # pdb.set_trace() precise = True try: cur_tokens = self.calc_tokens() @@ -27,18 +27,9 @@ class BaiduWenxinSession(Session): raise e logger.debug("Exception when counting tokens precisely for query: {}".format(e)) while cur_tokens > max_tokens: - if len(self.messages) > 2: - self.messages.pop(1) - elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": - self.messages.pop(1) - if precise: - cur_tokens = self.calc_tokens() - else: - cur_tokens = cur_tokens - max_tokens - break - elif len(self.messages) == 2 and self.messages[1]["role"] == "user": - logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) - break + if len(self.messages) >= 2: + self.messages.pop(0) + self.messages.pop(0) else: logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) break @@ -52,36 +43,11 @@ class BaiduWenxinSession(Session): return num_tokens_from_messages(self.messages, self.model) -# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def num_tokens_from_messages(messages, model): """Returns the number of tokens used by a list of messages.""" - import tiktoken - - if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo"]: - return num_tokens_from_messages(messages, model="gpt-3.5-turbo") - elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]: - return num_tokens_from_messages(messages, model="gpt-4") - - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.debug("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - if model == "gpt-3.5-turbo": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - tokens_per_name = -1 # if there's a name, the role is omitted - elif model == "gpt-4": - tokens_per_message = 3 - tokens_per_name = 1 - else: - logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.") - return num_tokens_from_messages(messages, model="gpt-3.5-turbo") - num_tokens = 0 - for message in messages: - num_tokens += tokens_per_message - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name - num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> - return num_tokens + tokens = 0 + for msg in messages: + # 官方token计算规则暂不明确: "大约为 token数为 "中文字 + 其他语种单词数 x 1.3" + # 这里先直接根据字数粗略估算吧,暂不影响正常使用,仅在判断是否丢弃历史会话的时候会有偏差 + tokens += len(msg["content"]) + return tokens