You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

open_ai_bot.py 5.9KB

2 yıl önce
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # encoding:utf-8
  2. from bot.bot import Bot
  3. from bot.openai.open_ai_image import OpenAIImage
  4. from bridge.context import ContextType
  5. from bridge.reply import Reply, ReplyType
  6. from config import conf
  7. from common.log import logger
  8. import openai
  9. import time
  10. user_session = dict()
  11. # OpenAI对话模型API (可用)
  12. class OpenAIBot(Bot, OpenAIImage):
  13. def __init__(self):
  14. openai.api_key = conf().get('open_ai_api_key')
  15. if conf().get('open_ai_api_base'):
  16. openai.api_base = conf().get('open_ai_api_base')
  17. proxy = conf().get('proxy')
  18. if proxy:
  19. openai.proxy = proxy
  20. def reply(self, query, context=None):
  21. # acquire reply content
  22. if context and context.type:
  23. if context.type == ContextType.TEXT:
  24. logger.info("[OPEN_AI] query={}".format(query))
  25. from_user_id = context['session_id']
  26. reply = None
  27. if query == '#清除记忆':
  28. Session.clear_session(from_user_id)
  29. reply = Reply(ReplyType.INFO, '记忆已清除')
  30. elif query == '#清除所有':
  31. Session.clear_all_session()
  32. reply = Reply(ReplyType.INFO, '所有人记忆已清除')
  33. else:
  34. new_query = Session.build_session_query(query, from_user_id)
  35. logger.debug("[OPEN_AI] session query={}".format(new_query))
  36. reply_content = self.reply_text(new_query, from_user_id, 0)
  37. logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
  38. if reply_content and query:
  39. Session.save_session(query, reply_content, from_user_id)
  40. reply = Reply(ReplyType.TEXT, reply_content)
  41. return reply
  42. elif context.type == ContextType.IMAGE_CREATE:
  43. ok, retstring = self.create_img(query, 0)
  44. reply = None
  45. if ok:
  46. reply = Reply(ReplyType.IMAGE_URL, retstring)
  47. else:
  48. reply = Reply(ReplyType.ERROR, retstring)
  49. return reply
  50. def reply_text(self, query, user_id, retry_count=0):
  51. try:
  52. response = openai.Completion.create(
  53. model= conf().get("model") or "text-davinci-003", # 对话模型的名称
  54. prompt=query,
  55. temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
  56. max_tokens=1200, # 回复最大的字符数
  57. top_p=1,
  58. frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
  59. presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
  60. stop=["\n\n\n"]
  61. )
  62. res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
  63. logger.info("[OPEN_AI] reply={}".format(res_content))
  64. return res_content
  65. except openai.error.RateLimitError as e:
  66. # rate limit exception
  67. logger.warn(e)
  68. if retry_count < 1:
  69. time.sleep(5)
  70. logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
  71. return self.reply_text(query, user_id, retry_count+1)
  72. else:
  73. return "提问太快啦,请休息一下再问我吧"
  74. except Exception as e:
  75. # unknown exception
  76. logger.exception(e)
  77. Session.clear_session(user_id)
  78. return "请再问我一次吧"
  79. class Session(object):
  80. @staticmethod
  81. def build_session_query(query, user_id):
  82. '''
  83. build query with conversation history
  84. e.g. Q: xxx
  85. A: xxx
  86. Q: xxx
  87. :param query: query content
  88. :param user_id: from user id
  89. :return: query content with conversaction
  90. '''
  91. prompt = conf().get("character_desc", "")
  92. if prompt:
  93. prompt += "<|endoftext|>\n\n\n"
  94. session = user_session.get(user_id, None)
  95. if session:
  96. for conversation in session:
  97. prompt += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|endoftext|>\n"
  98. prompt += "Q: " + query + "\nA: "
  99. return prompt
  100. else:
  101. return prompt + "Q: " + query + "\nA: "
  102. @staticmethod
  103. def save_session(query, answer, user_id):
  104. max_tokens = conf().get("conversation_max_tokens")
  105. if not max_tokens:
  106. # default 3000
  107. max_tokens = 1000
  108. conversation = dict()
  109. conversation["question"] = query
  110. conversation["answer"] = answer
  111. session = user_session.get(user_id)
  112. logger.debug(conversation)
  113. logger.debug(session)
  114. if session:
  115. # append conversation
  116. session.append(conversation)
  117. else:
  118. # create session
  119. queue = list()
  120. queue.append(conversation)
  121. user_session[user_id] = queue
  122. # discard exceed limit conversation
  123. Session.discard_exceed_conversation(user_session[user_id], max_tokens)
  124. @staticmethod
  125. def discard_exceed_conversation(session, max_tokens):
  126. count = 0
  127. count_list = list()
  128. for i in range(len(session)-1, -1, -1):
  129. # count tokens of conversation list
  130. history_conv = session[i]
  131. count += len(history_conv["question"]) + len(history_conv["answer"])
  132. count_list.append(count)
  133. for c in count_list:
  134. if c > max_tokens:
  135. # pop first conversation
  136. session.pop(0)
  137. @staticmethod
  138. def clear_session(user_id):
  139. user_session[user_id] = []
  140. @staticmethod
  141. def clear_all_session():
  142. user_session.clear()