Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

92 lines
3.9KB

  1. from bot.session_manager import Session
  2. from common.log import logger
  3. '''
  4. e.g. [
  5. {"role": "system", "content": "You are a helpful assistant."},
  6. {"role": "user", "content": "Who won the world series in 2020?"},
  7. {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
  8. {"role": "user", "content": "Where was it played?"}
  9. ]
  10. '''
  11. class ChatGPTSession(Session):
  12. def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"):
  13. super().__init__(session_id, system_prompt)
  14. self.messages = []
  15. self.model = model
  16. self.reset()
  17. def reset(self):
  18. system_item = {'role': 'system', 'content': self.system_prompt}
  19. self.messages = [system_item]
  20. def add_query(self, query):
  21. user_item = {'role': 'user', 'content': query}
  22. self.messages.append(user_item)
  23. def add_reply(self, reply):
  24. assistant_item = {'role': 'assistant', 'content': reply}
  25. self.messages.append(assistant_item)
  26. def discard_exceeding(self, max_tokens, cur_tokens= None):
  27. precise = True
  28. try:
  29. cur_tokens = num_tokens_from_messages(self.messages, self.model)
  30. except Exception as e:
  31. precise = False
  32. if cur_tokens is None:
  33. raise e
  34. logger.debug("Exception when counting tokens precisely for query: {}".format(e))
  35. while cur_tokens > max_tokens:
  36. if len(self.messages) > 2:
  37. self.messages.pop(1)
  38. elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
  39. self.messages.pop(1)
  40. if precise:
  41. cur_tokens = num_tokens_from_messages(self.messages, self.model)
  42. else:
  43. cur_tokens = cur_tokens - max_tokens
  44. break
  45. elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
  46. logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
  47. break
  48. else:
  49. logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
  50. break
  51. if precise:
  52. cur_tokens = num_tokens_from_messages(self.messages, self.model)
  53. else:
  54. cur_tokens = cur_tokens - max_tokens
  55. return cur_tokens
  56. # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  57. def num_tokens_from_messages(messages, model):
  58. """Returns the number of tokens used by a list of messages."""
  59. import tiktoken
  60. try:
  61. encoding = tiktoken.encoding_for_model(model)
  62. except KeyError:
  63. logger.debug("Warning: model not found. Using cl100k_base encoding.")
  64. encoding = tiktoken.get_encoding("cl100k_base")
  65. if model == "gpt-3.5-turbo":
  66. return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
  67. elif model == "gpt-4":
  68. return num_tokens_from_messages(messages, model="gpt-4-0314")
  69. elif model == "gpt-3.5-turbo-0301":
  70. tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
  71. tokens_per_name = -1 # if there's a name, the role is omitted
  72. elif model == "gpt-4-0314":
  73. tokens_per_message = 3
  74. tokens_per_name = 1
  75. else:
  76. logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
  77. return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
  78. num_tokens = 0
  79. for message in messages:
  80. num_tokens += tokens_per_message
  81. for key, value in message.items():
  82. num_tokens += len(encoding.encode(value))
  83. if key == "name":
  84. num_tokens += tokens_per_name
  85. num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
  86. return num_tokens