You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

176 satır
6.6KB

  1. # encoding:utf-8
  2. from bot.bot import Bot
  3. from bridge.context import ContextType
  4. from bridge.reply import Reply, ReplyType
  5. from config import conf
  6. from common.log import logger
  7. import openai
  8. import time
  9. user_session = dict()
  10. # OpenAI对话模型API (可用)
  11. class OpenAIBot(Bot):
  12. def __init__(self):
  13. openai.api_key = conf().get('open_ai_api_key')
  14. if conf().get('open_ai_api_base'):
  15. openai.api_base = conf().get('open_ai_api_base')
  16. proxy = conf().get('proxy')
  17. if proxy:
  18. openai.proxy = proxy
  19. def reply(self, query, context=None):
  20. # acquire reply content
  21. if context and context.type:
  22. if context.type == ContextType.TEXT:
  23. logger.info("[OPEN_AI] query={}".format(query))
  24. from_user_id = context['session_id']
  25. reply = None
  26. if query == '#清除记忆':
  27. Session.clear_session(from_user_id)
  28. reply = Reply(ReplyType.INFO, '记忆已清除')
  29. elif query == '#清除所有':
  30. Session.clear_all_session()
  31. reply = Reply(ReplyType.INFO, '所有人记忆已清除')
  32. else:
  33. new_query = Session.build_session_query(query, from_user_id)
  34. logger.debug("[OPEN_AI] session query={}".format(new_query))
  35. reply_content = self.reply_text(new_query, from_user_id, 0)
  36. logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
  37. if reply_content and query:
  38. Session.save_session(query, reply_content, from_user_id)
  39. reply = Reply(ReplyType.TEXT, reply_content)
  40. return reply
  41. elif context.type == ContextType.IMAGE_CREATE:
  42. return self.create_img(query, 0)
  43. def reply_text(self, query, user_id, retry_count=0):
  44. try:
  45. response = openai.Completion.create(
  46. model= conf().get("model") or "text-davinci-003", # 对话模型的名称
  47. prompt=query,
  48. temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
  49. max_tokens=1200, # 回复最大的字符数
  50. top_p=1,
  51. frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
  52. presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
  53. stop=["\n\n\n"]
  54. )
  55. res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
  56. logger.info("[OPEN_AI] reply={}".format(res_content))
  57. return res_content
  58. except openai.error.RateLimitError as e:
  59. # rate limit exception
  60. logger.warn(e)
  61. if retry_count < 1:
  62. time.sleep(5)
  63. logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
  64. return self.reply_text(query, user_id, retry_count+1)
  65. else:
  66. return "提问太快啦,请休息一下再问我吧"
  67. except Exception as e:
  68. # unknown exception
  69. logger.exception(e)
  70. Session.clear_session(user_id)
  71. return "请再问我一次吧"
  72. def create_img(self, query, retry_count=0):
  73. try:
  74. logger.info("[OPEN_AI] image_query={}".format(query))
  75. response = openai.Image.create(
  76. prompt=query, #图片描述
  77. n=1, #每次生成图片的数量
  78. size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
  79. )
  80. image_url = response['data'][0]['url']
  81. logger.info("[OPEN_AI] image_url={}".format(image_url))
  82. return image_url
  83. except openai.error.RateLimitError as e:
  84. logger.warn(e)
  85. if retry_count < 1:
  86. time.sleep(5)
  87. logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
  88. return self.reply_text(query, retry_count+1)
  89. else:
  90. return "提问太快啦,请休息一下再问我吧"
  91. except Exception as e:
  92. logger.exception(e)
  93. return None
  94. class Session(object):
  95. @staticmethod
  96. def build_session_query(query, user_id):
  97. '''
  98. build query with conversation history
  99. e.g. Q: xxx
  100. A: xxx
  101. Q: xxx
  102. :param query: query content
  103. :param user_id: from user id
  104. :return: query content with conversaction
  105. '''
  106. prompt = conf().get("character_desc", "")
  107. if prompt:
  108. prompt += "<|endoftext|>\n\n\n"
  109. session = user_session.get(user_id, None)
  110. if session:
  111. for conversation in session:
  112. prompt += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|endoftext|>\n"
  113. prompt += "Q: " + query + "\nA: "
  114. return prompt
  115. else:
  116. return prompt + "Q: " + query + "\nA: "
  117. @staticmethod
  118. def save_session(query, answer, user_id):
  119. max_tokens = conf().get("conversation_max_tokens")
  120. if not max_tokens:
  121. # default 3000
  122. max_tokens = 1000
  123. conversation = dict()
  124. conversation["question"] = query
  125. conversation["answer"] = answer
  126. session = user_session.get(user_id)
  127. logger.debug(conversation)
  128. logger.debug(session)
  129. if session:
  130. # append conversation
  131. session.append(conversation)
  132. else:
  133. # create session
  134. queue = list()
  135. queue.append(conversation)
  136. user_session[user_id] = queue
  137. # discard exceed limit conversation
  138. Session.discard_exceed_conversation(user_session[user_id], max_tokens)
  139. @staticmethod
  140. def discard_exceed_conversation(session, max_tokens):
  141. count = 0
  142. count_list = list()
  143. for i in range(len(session)-1, -1, -1):
  144. # count tokens of conversation list
  145. history_conv = session[i]
  146. count += len(history_conv["question"]) + len(history_conv["answer"])
  147. count_list.append(count)
  148. for c in count_list:
  149. if c > max_tokens:
  150. # pop first conversation
  151. session.pop(0)
  152. @staticmethod
  153. def clear_session(user_id):
  154. user_session[user_id] = []
  155. @staticmethod
  156. def clear_all_session():
  157. user_session.clear()