diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py index e6c319b..dd35832 100644 --- a/bot/chatgpt/chat_gpt_session.py +++ b/bot/chatgpt/chat_gpt_session.py @@ -57,16 +57,17 @@ def num_tokens_from_messages(messages, model): """Returns the number of tokens used by a list of messages.""" import tiktoken + if model == "gpt-3.5-turbo" or model == "gpt-35-turbo": + return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") + elif model == "gpt-4": + return num_tokens_from_messages(messages, model="gpt-4-0314") + try: encoding = tiktoken.encoding_for_model(model) except KeyError: logger.debug("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") - if model == "gpt-3.5-turbo" or model == "gpt-35-turbo": - return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") - elif model == "gpt-4": - return num_tokens_from_messages(messages, model="gpt-4-0314") - elif model == "gpt-3.5-turbo-0301": + if model == "gpt-3.5-turbo-0301": tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n tokens_per_name = -1 # if there's a name, the role is omitted elif model == "gpt-4-0314":