Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

minimax_session.py 2.8KB

5 månader sedan
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from bot.session_manager import Session
  2. from common.log import logger
  3. """
  4. e.g.
  5. [
  6. {"role": "system", "content": "You are a helpful assistant."},
  7. {"role": "user", "content": "Who won the world series in 2020?"},
  8. {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
  9. {"role": "user", "content": "Where was it played?"}
  10. ]
  11. """
  12. class MinimaxSession(Session):
  13. def __init__(self, session_id, system_prompt=None, model="minimax"):
  14. super().__init__(session_id, system_prompt)
  15. self.model = model
  16. # self.reset()
  17. def add_query(self, query):
  18. user_item = {"sender_type": "USER", "sender_name": self.session_id, "text": query}
  19. self.messages.append(user_item)
  20. def add_reply(self, reply):
  21. assistant_item = {"sender_type": "BOT", "sender_name": "MM智能助理", "text": reply}
  22. self.messages.append(assistant_item)
  23. def discard_exceeding(self, max_tokens, cur_tokens=None):
  24. precise = True
  25. try:
  26. cur_tokens = self.calc_tokens()
  27. except Exception as e:
  28. precise = False
  29. if cur_tokens is None:
  30. raise e
  31. logger.debug("Exception when counting tokens precisely for query: {}".format(e))
  32. while cur_tokens > max_tokens:
  33. if len(self.messages) > 2:
  34. self.messages.pop(1)
  35. elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "BOT":
  36. self.messages.pop(1)
  37. if precise:
  38. cur_tokens = self.calc_tokens()
  39. else:
  40. cur_tokens = cur_tokens - max_tokens
  41. break
  42. elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "USER":
  43. logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
  44. break
  45. else:
  46. logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
  47. break
  48. if precise:
  49. cur_tokens = self.calc_tokens()
  50. else:
  51. cur_tokens = cur_tokens - max_tokens
  52. return cur_tokens
  53. def calc_tokens(self):
  54. return num_tokens_from_messages(self.messages, self.model)
  55. def num_tokens_from_messages(messages, model):
  56. """Returns the number of tokens used by a list of messages."""
  57. # 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词"
  58. # 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html
  59. # 目前根据字符串长度粗略估计token数,不影响正常使用
  60. tokens = 0
  61. for msg in messages:
  62. tokens += len(msg["text"])
  63. return tokens