您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

open_ai_bot.py 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # encoding:utf-8
  2. from bot.bot import Bot
  3. from bot.openai.open_ai_image import OpenAIImage
  4. from bot.openai.open_ai_session import OpenAISession
  5. from bot.session_manager import SessionManager
  6. from bridge.context import ContextType
  7. from bridge.reply import Reply, ReplyType
  8. from config import conf
  9. from common.log import logger
  10. import openai
  11. import openai.error
  12. import time
  13. user_session = dict()
  14. # OpenAI对话模型API (可用)
  15. class OpenAIBot(Bot, OpenAIImage):
  16. def __init__(self):
  17. super().__init__()
  18. openai.api_key = conf().get('open_ai_api_key')
  19. if conf().get('open_ai_api_base'):
  20. openai.api_base = conf().get('open_ai_api_base')
  21. proxy = conf().get('proxy')
  22. if proxy:
  23. openai.proxy = proxy
  24. self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
  25. self.args = {
  26. "model": conf().get("model") or "text-davinci-003", # 对话模型的名称
  27. "temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
  28. "max_tokens":1200, # 回复最大的字符数
  29. "top_p":1,
  30. "frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  31. "presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  32. "request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
  33. "timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
  34. "stop":["\n\n\n"]
  35. }
  36. def reply(self, query, context=None):
  37. # acquire reply content
  38. if context and context.type:
  39. if context.type == ContextType.TEXT:
  40. logger.info("[OPEN_AI] query={}".format(query))
  41. session_id = context['session_id']
  42. reply = None
  43. if query == '#清除记忆':
  44. self.sessions.clear_session(session_id)
  45. reply = Reply(ReplyType.INFO, '记忆已清除')
  46. elif query == '#清除所有':
  47. self.sessions.clear_all_session()
  48. reply = Reply(ReplyType.INFO, '所有人记忆已清除')
  49. else:
  50. session = self.sessions.session_query(query, session_id)
  51. result = self.reply_text(session)
  52. total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content']
  53. logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens))
  54. if total_tokens == 0 :
  55. reply = Reply(ReplyType.ERROR, reply_content)
  56. else:
  57. self.sessions.session_reply(reply_content, session_id, total_tokens)
  58. reply = Reply(ReplyType.TEXT, reply_content)
  59. return reply
  60. elif context.type == ContextType.IMAGE_CREATE:
  61. ok, retstring = self.create_img(query, 0)
  62. reply = None
  63. if ok:
  64. reply = Reply(ReplyType.IMAGE_URL, retstring)
  65. else:
  66. reply = Reply(ReplyType.ERROR, retstring)
  67. return reply
  68. def reply_text(self, session:OpenAISession, retry_count=0):
  69. try:
  70. response = openai.Completion.create(
  71. prompt=str(session), **self.args
  72. )
  73. res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
  74. total_tokens = response["usage"]["total_tokens"]
  75. completion_tokens = response["usage"]["completion_tokens"]
  76. logger.info("[OPEN_AI] reply={}".format(res_content))
  77. return {"total_tokens": total_tokens,
  78. "completion_tokens": completion_tokens,
  79. "content": res_content}
  80. except Exception as e:
  81. need_retry = retry_count < 2
  82. result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
  83. if isinstance(e, openai.error.RateLimitError):
  84. logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
  85. result['content'] = "提问太快啦,请休息一下再问我吧"
  86. if need_retry:
  87. time.sleep(5)
  88. elif isinstance(e, openai.error.Timeout):
  89. logger.warn("[OPEN_AI] Timeout: {}".format(e))
  90. result['content'] = "我没有收到你的消息"
  91. if need_retry:
  92. time.sleep(5)
  93. elif isinstance(e, openai.error.APIConnectionError):
  94. logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
  95. need_retry = False
  96. result['content'] = "我连接不到你的网络"
  97. else:
  98. logger.warn("[OPEN_AI] Exception: {}".format(e))
  99. need_retry = False
  100. self.sessions.clear_session(session.session_id)
  101. if need_retry:
  102. logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
  103. return self.reply_text(session, retry_count+1)
  104. else:
  105. return result