You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

144 lines
6.2KB

  1. # encoding:utf-8
  2. import time
  3. import openai
  4. import openai.error
  5. from bot.bot import Bot
  6. from bot.session_manager import SessionManager
  7. from bridge.context import ContextType
  8. from bridge.reply import Reply, ReplyType
  9. from common.log import logger
  10. from config import conf, load_config
  11. from .moonshot_session import MoonshotSession
  12. import requests
  13. # ZhipuAI对话模型API
  14. class MoonshotBot(Bot):
  15. def __init__(self):
  16. super().__init__()
  17. self.sessions = SessionManager(MoonshotSession, model=conf().get("model") or "moonshot-v1-128k")
  18. self.args = {
  19. "model": conf().get("model") or "moonshot-v1-128k", # 对话模型的名称
  20. "temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
  21. "top_p": conf().get("top_p", 1.0), # 使用默认值
  22. }
  23. self.api_key = conf().get("moonshot_api_key")
  24. self.base_url = conf().get("moonshot_base_url", "https://api.moonshot.cn/v1/chat/completions")
  25. def reply(self, query, context=None):
  26. # acquire reply content
  27. if context.type == ContextType.TEXT:
  28. logger.info("[MOONSHOT_AI] query={}".format(query))
  29. session_id = context["session_id"]
  30. reply = None
  31. clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
  32. if query in clear_memory_commands:
  33. self.sessions.clear_session(session_id)
  34. reply = Reply(ReplyType.INFO, "记忆已清除")
  35. elif query == "#清除所有":
  36. self.sessions.clear_all_session()
  37. reply = Reply(ReplyType.INFO, "所有人记忆已清除")
  38. elif query == "#更新配置":
  39. load_config()
  40. reply = Reply(ReplyType.INFO, "配置已更新")
  41. if reply:
  42. return reply
  43. session = self.sessions.session_query(query, session_id)
  44. logger.debug("[MOONSHOT_AI] session query={}".format(session.messages))
  45. model = context.get("moonshot_model")
  46. new_args = self.args.copy()
  47. if model:
  48. new_args["model"] = model
  49. # if context.get('stream'):
  50. # # reply in stream
  51. # return self.reply_text_stream(query, new_query, session_id)
  52. reply_content = self.reply_text(session, args=new_args)
  53. logger.debug(
  54. "[MOONSHOT_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
  55. session.messages,
  56. session_id,
  57. reply_content["content"],
  58. reply_content["completion_tokens"],
  59. )
  60. )
  61. if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
  62. reply = Reply(ReplyType.ERROR, reply_content["content"])
  63. elif reply_content["completion_tokens"] > 0:
  64. self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
  65. reply = Reply(ReplyType.TEXT, reply_content["content"])
  66. else:
  67. reply = Reply(ReplyType.ERROR, reply_content["content"])
  68. logger.debug("[MOONSHOT_AI] reply {} used 0 tokens.".format(reply_content))
  69. return reply
  70. else:
  71. reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
  72. return reply
  73. def reply_text(self, session: MoonshotSession, args=None, retry_count=0) -> dict:
  74. """
  75. call openai's ChatCompletion to get the answer
  76. :param session: a conversation session
  77. :param session_id: session id
  78. :param retry_count: retry count
  79. :return: {}
  80. """
  81. try:
  82. headers = {
  83. "Content-Type": "application/json",
  84. "Authorization": "Bearer " + self.api_key
  85. }
  86. body = args
  87. body["messages"] = session.messages
  88. # logger.debug("[MOONSHOT_AI] response={}".format(response))
  89. # logger.info("[MOONSHOT_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
  90. res = requests.post(
  91. self.base_url,
  92. headers=headers,
  93. json=body
  94. )
  95. if res.status_code == 200:
  96. response = res.json()
  97. return {
  98. "total_tokens": response["usage"]["total_tokens"],
  99. "completion_tokens": response["usage"]["completion_tokens"],
  100. "content": response["choices"][0]["message"]["content"]
  101. }
  102. else:
  103. response = res.json()
  104. error = response.get("error")
  105. logger.error(f"[MOONSHOT_AI] chat failed, status_code={res.status_code}, "
  106. f"msg={error.get('message')}, type={error.get('type')}")
  107. result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
  108. need_retry = False
  109. if res.status_code >= 500:
  110. # server error, need retry
  111. logger.warn(f"[MOONSHOT_AI] do retry, times={retry_count}")
  112. need_retry = retry_count < 2
  113. elif res.status_code == 401:
  114. result["content"] = "授权失败,请检查API Key是否正确"
  115. elif res.status_code == 429:
  116. result["content"] = "请求过于频繁,请稍后再试"
  117. need_retry = retry_count < 2
  118. else:
  119. need_retry = False
  120. if need_retry:
  121. time.sleep(3)
  122. return self.reply_text(session, args, retry_count + 1)
  123. else:
  124. return result
  125. except Exception as e:
  126. logger.exception(e)
  127. need_retry = retry_count < 2
  128. result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
  129. if need_retry:
  130. return self.reply_text(session, args, retry_count + 1)
  131. else:
  132. return result