You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
4.2KB

  1. from bot.session_manager import Session
  2. from common.log import logger
  3. from common import const
  4. """
  5. e.g. [
  6. {"role": "system", "content": "You are a helpful assistant."},
  7. {"role": "user", "content": "Who won the world series in 2020?"},
  8. {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
  9. {"role": "user", "content": "Where was it played?"}
  10. ]
  11. """
  12. class ChatGPTSession(Session):
  13. def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
  14. super().__init__(session_id, system_prompt)
  15. self.model = model
  16. self.reset()
  17. def discard_exceeding(self, max_tokens, cur_tokens=None):
  18. precise = True
  19. try:
  20. cur_tokens = self.calc_tokens()
  21. except Exception as e:
  22. precise = False
  23. if cur_tokens is None:
  24. raise e
  25. logger.debug("Exception when counting tokens precisely for query: {}".format(e))
  26. while cur_tokens > max_tokens:
  27. if len(self.messages) > 2:
  28. self.messages.pop(1)
  29. elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
  30. self.messages.pop(1)
  31. if precise:
  32. cur_tokens = self.calc_tokens()
  33. else:
  34. cur_tokens = cur_tokens - max_tokens
  35. break
  36. elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
  37. logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
  38. break
  39. else:
  40. logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
  41. break
  42. if precise:
  43. cur_tokens = self.calc_tokens()
  44. else:
  45. cur_tokens = cur_tokens - max_tokens
  46. return cur_tokens
  47. def calc_tokens(self):
  48. return num_tokens_from_messages(self.messages, self.model)
  49. # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  50. def num_tokens_from_messages(messages, model):
  51. """Returns the number of tokens used by a list of messages."""
  52. if model in ["wenxin", "xunfei", const.GEMINI]:
  53. return num_tokens_by_character(messages)
  54. import tiktoken
  55. if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot"]:
  56. return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
  57. elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
  58. "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
  59. "gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]:
  60. return num_tokens_from_messages(messages, model="gpt-4")
  61. elif model.startswith("claude-3"):
  62. return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
  63. try:
  64. encoding = tiktoken.encoding_for_model(model)
  65. except KeyError:
  66. logger.debug("Warning: model not found. Using cl100k_base encoding.")
  67. encoding = tiktoken.get_encoding("cl100k_base")
  68. if model == "gpt-3.5-turbo":
  69. tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
  70. tokens_per_name = -1 # if there's a name, the role is omitted
  71. elif model == "gpt-4":
  72. tokens_per_message = 3
  73. tokens_per_name = 1
  74. else:
  75. logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
  76. return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
  77. num_tokens = 0
  78. for message in messages:
  79. num_tokens += tokens_per_message
  80. for key, value in message.items():
  81. num_tokens += len(encoding.encode(value))
  82. if key == "name":
  83. num_tokens += tokens_per_name
  84. num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
  85. return num_tokens
  86. def num_tokens_by_character(messages):
  87. """Returns the number of tokens used by a list of messages."""
  88. tokens = 0
  89. for msg in messages:
  90. tokens += len(msg["content"])
  91. return tokens