Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

96 rindas
4.3KB

  1. # encoding:utf-8
  2. from bot.bot import Bot
  3. from bot.openai.open_ai_image import OpenAIImage
  4. from bot.openai.open_ai_session import OpenAISession
  5. from bot.session_manager import SessionManager
  6. from bridge.context import ContextType
  7. from bridge.reply import Reply, ReplyType
  8. from config import conf
  9. from common.log import logger
  10. import openai
  11. import time
  12. user_session = dict()
  13. # OpenAI对话模型API (可用)
  14. class OpenAIBot(Bot, OpenAIImage):
  15. def __init__(self):
  16. super().__init__()
  17. openai.api_key = conf().get('open_ai_api_key')
  18. if conf().get('open_ai_api_base'):
  19. openai.api_base = conf().get('open_ai_api_base')
  20. proxy = conf().get('proxy')
  21. if proxy:
  22. openai.proxy = proxy
  23. self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
  24. def reply(self, query, context=None):
  25. # acquire reply content
  26. if context and context.type:
  27. if context.type == ContextType.TEXT:
  28. logger.info("[OPEN_AI] query={}".format(query))
  29. session_id = context['session_id']
  30. reply = None
  31. if query == '#清除记忆':
  32. self.sessions.clear_session(session_id)
  33. reply = Reply(ReplyType.INFO, '记忆已清除')
  34. elif query == '#清除所有':
  35. self.sessions.clear_all_session()
  36. reply = Reply(ReplyType.INFO, '所有人记忆已清除')
  37. else:
  38. session = self.sessions.session_query(query, session_id)
  39. new_query = str(session)
  40. logger.debug("[OPEN_AI] session query={}".format(new_query))
  41. total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0)
  42. logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens))
  43. if total_tokens == 0 :
  44. reply = Reply(ReplyType.ERROR, reply_content)
  45. else:
  46. self.sessions.session_reply(reply_content, session_id, total_tokens)
  47. reply = Reply(ReplyType.TEXT, reply_content)
  48. return reply
  49. elif context.type == ContextType.IMAGE_CREATE:
  50. ok, retstring = self.create_img(query, 0)
  51. reply = None
  52. if ok:
  53. reply = Reply(ReplyType.IMAGE_URL, retstring)
  54. else:
  55. reply = Reply(ReplyType.ERROR, retstring)
  56. return reply
  57. def reply_text(self, query, user_id, retry_count=0):
  58. try:
  59. response = openai.Completion.create(
  60. model= conf().get("model") or "text-davinci-003", # 对话模型的名称
  61. prompt=query,
  62. temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
  63. max_tokens=1200, # 回复最大的字符数
  64. top_p=1,
  65. frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
  66. presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
  67. stop=["\n\n\n"]
  68. )
  69. res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
  70. total_tokens = response["usage"]["total_tokens"]
  71. completion_tokens = response["usage"]["completion_tokens"]
  72. logger.info("[OPEN_AI] reply={}".format(res_content))
  73. return total_tokens, completion_tokens, res_content
  74. except openai.error.RateLimitError as e:
  75. # rate limit exception
  76. logger.warn(e)
  77. if retry_count < 1:
  78. time.sleep(5)
  79. logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
  80. return self.reply_text(query, user_id, retry_count+1)
  81. else:
  82. return 0,0, "提问太快啦,请休息一下再问我吧"
  83. except Exception as e:
  84. # unknown exception
  85. logger.exception(e)
  86. self.sessions.clear_session(user_id)
  87. return 0,0, "请再问我一次吧"