|
|
@@ -171,24 +171,48 @@ class AzureChatGPTBot(ChatGPTBot): |
|
|
|
self.args["deployment_id"] = conf().get("azure_deployment_id") |
|
|
|
|
|
|
|
def create_img(self, query, retry_count=0, api_key=None): |
|
|
|
api_version = "2022-08-03-preview" |
|
|
|
url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version) |
|
|
|
api_key = api_key or openai.api_key |
|
|
|
headers = {"api-key": api_key, "Content-Type": "application/json"} |
|
|
|
try: |
|
|
|
body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")} |
|
|
|
submission = requests.post(url, headers=headers, json=body) |
|
|
|
operation_location = submission.headers["Operation-Location"] |
|
|
|
retry_after = submission.headers["Retry-after"] |
|
|
|
status = "" |
|
|
|
image_url = "" |
|
|
|
while status != "Succeeded": |
|
|
|
logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds") |
|
|
|
time.sleep(int(retry_after)) |
|
|
|
response = requests.get(operation_location, headers=headers) |
|
|
|
status = response.json()["status"] |
|
|
|
image_url = response.json()["result"]["contentUrl"] |
|
|
|
return True, image_url |
|
|
|
except Exception as e: |
|
|
|
logger.error("create image error: {}".format(e)) |
|
|
|
return False, "图片生成失败" |
|
|
|
text_to_image_model = conf().get("text_to_image") |
|
|
|
if text_to_image_model == "dall-e-2": |
|
|
|
api_version = "2023-06-01-preview" |
|
|
|
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base") |
|
|
|
# 检查endpoint是否以/结尾 |
|
|
|
if not endpoint.endswith("/"): |
|
|
|
endpoint = endpoint + "/" |
|
|
|
url = "{}openai/images/generations:submit?api-version={}".format(endpoint, api_version) |
|
|
|
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key") |
|
|
|
headers = {"api-key": api_key, "Content-Type": "application/json"} |
|
|
|
try: |
|
|
|
body = {"prompt": query, "size": conf().get("image_create_size", "256x256"),"n": 1} |
|
|
|
submission = requests.post(url, headers=headers, json=body) |
|
|
|
operation_location = submission.headers['operation-location'] |
|
|
|
status = "" |
|
|
|
while (status != "succeeded"): |
|
|
|
if retry_count > 3: |
|
|
|
return False, "图片生成失败" |
|
|
|
response = requests.get(operation_location, headers=headers) |
|
|
|
status = response.json()['status'] |
|
|
|
retry_count += 1 |
|
|
|
image_url = response.json()['result']['data'][0]['url'] |
|
|
|
return True, image_url |
|
|
|
except Exception as e: |
|
|
|
logger.error("create image error: {}".format(e)) |
|
|
|
return False, "图片生成失败" |
|
|
|
elif text_to_image_model == "dall-e-3": |
|
|
|
api_version = conf().get("azure_api_version", "2024-02-15-preview") |
|
|
|
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base") |
|
|
|
# 检查endpoint是否以/结尾 |
|
|
|
if not endpoint.endswith("/"): |
|
|
|
endpoint = endpoint + "/" |
|
|
|
url = "{}openai/deployments/{}/images/generations?api-version={}".format(endpoint, conf().get("azure_openai_dalle_deployment_id","text_to_image"),api_version) |
|
|
|
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key") |
|
|
|
headers = {"api-key": api_key, "Content-Type": "application/json"} |
|
|
|
try: |
|
|
|
body = {"prompt": query, "size": conf().get("image_create_size", "1024x1024"), "quality": conf().get("dalle3_image_quality", "standard")} |
|
|
|
submission = requests.post(url, headers=headers, json=body) |
|
|
|
image_url = submission.json()['data'][0]['url'] |
|
|
|
return True, image_url |
|
|
|
except Exception as e: |
|
|
|
logger.error("create image error: {}".format(e)) |
|
|
|
return False, "图片生成失败" |
|
|
|
else: |
|
|
|
return False, "图片生成失败,未配置text_to_image参数" |