Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

152 lines
6.9KB

  1. # encoding:utf-8
  2. import time
  3. import openai
  4. import openai.error
  5. from bot.bot import Bot
  6. from bot.minimax.minimax_session import MinimaxSession
  7. from bot.session_manager import SessionManager
  8. from bridge.context import Context, ContextType
  9. from bridge.reply import Reply, ReplyType
  10. from common.log import logger
  11. from config import conf, load_config
  12. from bot.chatgpt.chat_gpt_session import ChatGPTSession
  13. import requests
  14. from common import const
  15. # ZhipuAI对话模型API
  16. class MinimaxBot(Bot):
  17. def __init__(self):
  18. super().__init__()
  19. self.args = {
  20. "model": conf().get("model") or "abab6.5", # 对话模型的名称
  21. "temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
  22. "top_p": conf().get("top_p", 0.95), # 使用默认值
  23. }
  24. self.api_key = conf().get("Minimax_api_key")
  25. self.group_id = conf().get("Minimax_group_id")
  26. self.base_url = conf().get("Minimax_base_url", f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={self.group_id}")
  27. # tokens_to_generate/bot_setting/reply_constraints可自行修改
  28. self.request_body = {
  29. "model": self.args["model"],
  30. "tokens_to_generate": 2048,
  31. "reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"},
  32. "messages": [],
  33. "bot_setting": [
  34. {
  35. "bot_name": "MM智能助理",
  36. "content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。",
  37. }
  38. ],
  39. }
  40. self.sessions = SessionManager(MinimaxSession, model=const.MiniMax)
  41. def reply(self, query, context: Context = None) -> Reply:
  42. # acquire reply content
  43. logger.info("[Minimax_AI] query={}".format(query))
  44. if context.type == ContextType.TEXT:
  45. session_id = context["session_id"]
  46. reply = None
  47. clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
  48. if query in clear_memory_commands:
  49. self.sessions.clear_session(session_id)
  50. reply = Reply(ReplyType.INFO, "记忆已清除")
  51. elif query == "#清除所有":
  52. self.sessions.clear_all_session()
  53. reply = Reply(ReplyType.INFO, "所有人记忆已清除")
  54. elif query == "#更新配置":
  55. load_config()
  56. reply = Reply(ReplyType.INFO, "配置已更新")
  57. if reply:
  58. return reply
  59. session = self.sessions.session_query(query, session_id)
  60. logger.debug("[Minimax_AI] session query={}".format(session))
  61. model = context.get("Minimax_model")
  62. new_args = self.args.copy()
  63. if model:
  64. new_args["model"] = model
  65. # if context.get('stream'):
  66. # # reply in stream
  67. # return self.reply_text_stream(query, new_query, session_id)
  68. reply_content = self.reply_text(session, args=new_args)
  69. logger.debug(
  70. "[Minimax_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
  71. session.messages,
  72. session_id,
  73. reply_content["content"],
  74. reply_content["completion_tokens"],
  75. )
  76. )
  77. if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
  78. reply = Reply(ReplyType.ERROR, reply_content["content"])
  79. elif reply_content["completion_tokens"] > 0:
  80. self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
  81. reply = Reply(ReplyType.TEXT, reply_content["content"])
  82. else:
  83. reply = Reply(ReplyType.ERROR, reply_content["content"])
  84. logger.debug("[Minimax_AI] reply {} used 0 tokens.".format(reply_content))
  85. return reply
  86. else:
  87. reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
  88. return reply
  89. def reply_text(self, session: MinimaxSession, args=None, retry_count=0) -> dict:
  90. """
  91. call openai's ChatCompletion to get the answer
  92. :param session: a conversation session
  93. :param session_id: session id
  94. :param retry_count: retry count
  95. :return: {}
  96. """
  97. try:
  98. headers = {"Content-Type": "application/json", "Authorization": "Bearer " + self.api_key}
  99. self.request_body["messages"].extend(session.messages)
  100. logger.info("[Minimax_AI] request_body={}".format(self.request_body))
  101. # logger.info("[Minimax_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
  102. res = requests.post(self.base_url, headers=headers, json=self.request_body)
  103. # self.request_body["messages"].extend(response.json()["choices"][0]["messages"])
  104. if res.status_code == 200:
  105. response = res.json()
  106. return {
  107. "total_tokens": response["usage"]["total_tokens"],
  108. "completion_tokens": response["usage"]["total_tokens"],
  109. "content": response["reply"],
  110. }
  111. else:
  112. response = res.json()
  113. error = response.get("error")
  114. logger.error(f"[Minimax_AI] chat failed, status_code={res.status_code}, " f"msg={error.get('message')}, type={error.get('type')}")
  115. result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
  116. need_retry = False
  117. if res.status_code >= 500:
  118. # server error, need retry
  119. logger.warn(f"[Minimax_AI] do retry, times={retry_count}")
  120. need_retry = retry_count < 2
  121. elif res.status_code == 401:
  122. result["content"] = "授权失败,请检查API Key是否正确"
  123. elif res.status_code == 429:
  124. result["content"] = "请求过于频繁,请稍后再试"
  125. need_retry = retry_count < 2
  126. else:
  127. need_retry = False
  128. if need_retry:
  129. time.sleep(3)
  130. return self.reply_text(session, args, retry_count + 1)
  131. else:
  132. return result
  133. except Exception as e:
  134. logger.exception(e)
  135. need_retry = retry_count < 2
  136. result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
  137. if need_retry:
  138. return self.reply_text(session, args, retry_count + 1)
  139. else:
  140. return result