Quellcode durchsuchen

feat: add support for azure dalle

master
lanvent vor 1 Jahr
Ursprung
Commit
3314b05648
2 geänderte Dateien mit 27 neuen und 1 gelöschten Zeilen
  1. +25
    -0
      bot/chatgpt/chat_gpt_bot.py
  2. +2
    -1
      bot/openai/open_ai_image.py

+ 25
- 0
bot/chatgpt/chat_gpt_bot.py Datei anzeigen

@@ -4,6 +4,7 @@ import time

import openai
import openai.error
import requests

from bot.bot import Bot
from bot.chatgpt.chat_gpt_session import ChatGPTSession
@@ -155,3 +156,27 @@ class AzureChatGPTBot(ChatGPTBot):
openai.api_type = "azure"
openai.api_version = "2023-03-15-preview"
self.args["deployment_id"] = conf().get("azure_deployment_id")

def create_img(self, query, retry_count=0, api_key=None):
api_base = "https://a-wxf.openai.azure.com/"
api_version = "2022-08-03-preview"
url = "{}dalle/text-to-image?api-version={}".format(api_base, api_version)
api_key = api_key or openai.api_key
headers = {"api-key": api_key, "Content-Type": "application/json"}
try:
body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
submission = requests.post(url, headers=headers, json=body)
operation_location = submission.headers["Operation-Location"]
retry_after = submission.headers["Retry-after"]
status = ""
image_url = ""
while status != "Succeeded":
logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
time.sleep(int(retry_after))
response = requests.get(operation_location, headers=headers)
status = response.json()["status"]
image_url = response.json()["result"]["contentUrl"]
return True, image_url
except Exception as e:
logger.error("create image error: {}".format(e))
return False, "图片生成失败"

+ 2
- 1
bot/openai/open_ai_image.py Datei anzeigen

@@ -15,12 +15,13 @@ class OpenAIImage(object):
if conf().get("rate_limit_dalle"):
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))

def create_img(self, query, retry_count=0):
def create_img(self, query, retry_count=0, api_key=None):
try:
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
return False, "请求太快了,请休息一下再问我吧"
logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create(
api_key=api_key,
prompt=query, # 图片描述
n=1, # 每次生成图片的数量
size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024


Laden…
Abbrechen
Speichern