You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
2.9KB

  1. # encoding:utf-8
  2. from bot.bot import Bot
  3. from config import conf
  4. from common.log import logger
  5. import openai
  6. # OpenAI对话模型API (可用)
  7. class OpenAIBot(Bot):
  8. def __init__(self):
  9. openai.api_key = conf().get('open_ai_api_key')
  10. def reply(self, query, context=None):
  11. if not context or not context.get('type') or context.get('type') == 'TEXT':
  12. return self.reply_text(query)
  13. elif context.get('type', None) == 'IMAGE_CREATE':
  14. return self.create_img(query)
  15. def reply_text(self, query):
  16. logger.info("[OPEN_AI] query={}".format(query))
  17. try:
  18. response = openai.Completion.create(
  19. model="text-davinci-003", # 对话模型的名称
  20. prompt=query,
  21. temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
  22. max_tokens=1200, # 回复最大的字符数
  23. top_p=1,
  24. frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
  25. presence_penalty=0.6, # [-2,2]之间,该值越大则更倾向于产生不同的内容
  26. stop=["#"]
  27. )
  28. res_content = response.choices[0]["text"].strip()
  29. except Exception as e:
  30. logger.exception(e)
  31. return None
  32. logger.info("[OPEN_AI] reply={}".format(res_content))
  33. return res_content
  34. def create_img(self, query):
  35. try:
  36. logger.info("[OPEN_AI] image_query={}".format(query))
  37. response = openai.Image.create(
  38. prompt=query, #图片描述
  39. n=1, #每次生成图片的数量
  40. size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
  41. )
  42. image_url = response['data'][0]['url']
  43. logger.info("[OPEN_AI] image_url={}".format(image_url))
  44. except Exception as e:
  45. logger.exception(e)
  46. return None
  47. return image_url
  48. def edit_img(self, query, src_img):
  49. try:
  50. response = openai.Image.create_edit(
  51. image=open(src_img, 'rb'),
  52. mask=open('cat-mask.png', 'rb'),
  53. prompt=query,
  54. n=1,
  55. size='512x512'
  56. )
  57. image_url = response['data'][0]['url']
  58. logger.info("[OPEN_AI] image_url={}".format(image_url))
  59. except Exception as e:
  60. logger.exception(e)
  61. return None
  62. return image_url
  63. def migration_img(self, query, src_img):
  64. try:
  65. response = openai.Image.create_variation(
  66. image=open(src_img, 'rb'),
  67. n=1,
  68. size="512x512"
  69. )
  70. image_url = response['data'][0]['url']
  71. logger.info("[OPEN_AI] image_url={}".format(image_url))
  72. except Exception as e:
  73. logger.exception(e)
  74. return None
  75. return image_url