From 3314b05648dda09b12dbb4b914fd6c3b46a0fc34 Mon Sep 17 00:00:00 2001 From: lanvent Date: Thu, 27 Apr 2023 22:16:42 +0800 Subject: [PATCH] feat: add support for azure dalle --- bot/chatgpt/chat_gpt_bot.py | 25 +++++++++++++++++++++++++ bot/openai/open_ai_image.py | 3 ++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index b045311..f7c1d65 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -4,6 +4,7 @@ import time import openai import openai.error +import requests from bot.bot import Bot from bot.chatgpt.chat_gpt_session import ChatGPTSession @@ -155,3 +156,27 @@ class AzureChatGPTBot(ChatGPTBot): openai.api_type = "azure" openai.api_version = "2023-03-15-preview" self.args["deployment_id"] = conf().get("azure_deployment_id") + + def create_img(self, query, retry_count=0, api_key=None): + api_base = "https://a-wxf.openai.azure.com/" + api_version = "2022-08-03-preview" + url = "{}dalle/text-to-image?api-version={}".format(api_base, api_version) + api_key = api_key or openai.api_key + headers = {"api-key": api_key, "Content-Type": "application/json"} + try: + body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")} + submission = requests.post(url, headers=headers, json=body) + operation_location = submission.headers["Operation-Location"] + retry_after = submission.headers["Retry-after"] + status = "" + image_url = "" + while status != "Succeeded": + logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds") + time.sleep(int(retry_after)) + response = requests.get(operation_location, headers=headers) + status = response.json()["status"] + image_url = response.json()["result"]["contentUrl"] + return True, image_url + except Exception as e: + logger.error("create image error: {}".format(e)) + return False, "图片生成失败" diff --git a/bot/openai/open_ai_image.py b/bot/openai/open_ai_image.py index b188557..89449a2 100644 --- a/bot/openai/open_ai_image.py +++ b/bot/openai/open_ai_image.py @@ -15,12 +15,13 @@ class OpenAIImage(object): if conf().get("rate_limit_dalle"): self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50)) - def create_img(self, query, retry_count=0): + def create_img(self, query, retry_count=0, api_key=None): try: if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token(): return False, "请求太快了,请休息一下再问我吧" logger.info("[OPEN_AI] image_query={}".format(query)) response = openai.Image.create( + api_key=api_key, prompt=query, # 图片描述 n=1, # 每次生成图片的数量 size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024