From 721b36c7f7c041b1220e0ab43c1e7f9dedcc1c28 Mon Sep 17 00:00:00 2001 From: lanvent Date: Sun, 26 Mar 2023 20:08:04 +0800 Subject: [PATCH] refactor: reuse openai image interface --- bot/chatgpt/chat_gpt_bot.py | 29 ++--------------------------- bot/openai/open_ai_bot.py | 36 +++++++++--------------------------- bot/openai/open_ai_image.py | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 54 deletions(-) create mode 100644 bot/openai/open_ai_image.py diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 0db31d2..657aa6b 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -1,6 +1,7 @@ # encoding:utf-8 from bot.bot import Bot +from bot.openai.open_ai_image import OpenAIImage from bridge.context import ContextType from bridge.reply import Reply, ReplyType from config import conf, load_config @@ -12,7 +13,7 @@ import time # OpenAI对话模型API (可用) -class ChatGPTBot(Bot): +class ChatGPTBot(Bot,OpenAIImage): def __init__(self): openai.api_key = conf().get('open_ai_api_key') if conf().get('open_ai_api_base'): @@ -23,8 +24,6 @@ class ChatGPTBot(Bot): openai.proxy = proxy if conf().get('rate_limit_chatgpt'): self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20)) - if conf().get('rate_limit_dalle'): - self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) def reply(self, query, context=None): # acquire reply content @@ -128,30 +127,6 @@ class ChatGPTBot(Bot): self.sessions.clear_session(session_id) return {"completion_tokens": 0, "content": "请再问我一次吧"} - def create_img(self, query, retry_count=0): - try: - if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): - return False, "请求太快了,请休息一下再问我吧" - logger.info("[OPEN_AI] image_query={}".format(query)) - response = openai.Image.create( - prompt=query, #图片描述 - n=1, #每次生成图片的数量 - size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 - ) - image_url = response['data'][0]['url'] - logger.info("[OPEN_AI] image_url={}".format(image_url)) - return True, image_url - except openai.error.RateLimitError as e: - logger.warn(e) - if retry_count < 1: - time.sleep(5) - logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) - return self.create_img(query, retry_count+1) - else: - return False, "提问太快啦,请休息一下再问我吧" - except Exception as e: - logger.exception(e) - return False, str(e) class AzureChatGPTBot(ChatGPTBot): diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index e38af15..553b20a 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -1,6 +1,7 @@ # encoding:utf-8 from bot.bot import Bot +from bot.openai.open_ai_image import OpenAIImage from bridge.context import ContextType from bridge.reply import Reply, ReplyType from config import conf @@ -11,7 +12,7 @@ import time user_session = dict() # OpenAI对话模型API (可用) -class OpenAIBot(Bot): +class OpenAIBot(Bot, OpenAIImage): def __init__(self): openai.api_key = conf().get('open_ai_api_key') if conf().get('open_ai_api_base'): @@ -45,7 +46,13 @@ class OpenAIBot(Bot): reply = Reply(ReplyType.TEXT, reply_content) return reply elif context.type == ContextType.IMAGE_CREATE: - return self.create_img(query, 0) + ok, retstring = self.create_img(query, 0) + reply = None + if ok: + reply = Reply(ReplyType.IMAGE_URL, retstring) + else: + reply = Reply(ReplyType.ERROR, retstring) + return reply def reply_text(self, query, user_id, retry_count=0): try: @@ -77,31 +84,6 @@ class OpenAIBot(Bot): Session.clear_session(user_id) return "请再问我一次吧" - - def create_img(self, query, retry_count=0): - try: - logger.info("[OPEN_AI] image_query={}".format(query)) - response = openai.Image.create( - prompt=query, #图片描述 - n=1, #每次生成图片的数量 - size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 - ) - image_url = response['data'][0]['url'] - logger.info("[OPEN_AI] image_url={}".format(image_url)) - return image_url - except openai.error.RateLimitError as e: - logger.warn(e) - if retry_count < 1: - time.sleep(5) - logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) - return self.reply_text(query, retry_count+1) - else: - return "提问太快啦,请休息一下再问我吧" - except Exception as e: - logger.exception(e) - return None - - class Session(object): @staticmethod def build_session_query(query, user_id): diff --git a/bot/openai/open_ai_image.py b/bot/openai/open_ai_image.py new file mode 100644 index 0000000..4fa02de --- /dev/null +++ b/bot/openai/open_ai_image.py @@ -0,0 +1,37 @@ +import time +import openai +from common.token_bucket import TokenBucket +from common.log import logger +from config import conf + +# OPENAI提供的画图接口 +class OpenAIImage(object): + def __init__(self): + openai.api_key = conf().get('open_ai_api_key') + if conf().get('rate_limit_dalle'): + self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) + + def create_img(self, query, retry_count=0): + try: + if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): + return False, "请求太快了,请休息一下再问我吧" + logger.info("[OPEN_AI] image_query={}".format(query)) + response = openai.Image.create( + prompt=query, #图片描述 + n=1, #每次生成图片的数量 + size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 + ) + image_url = response['data'][0]['url'] + logger.info("[OPEN_AI] image_url={}".format(image_url)) + return True, image_url + except openai.error.RateLimitError as e: + logger.warn(e) + if retry_count < 1: + time.sleep(5) + logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) + return self.create_img(query, retry_count+1) + else: + return False, "提问太快啦,请休息一下再问我吧" + except Exception as e: + logger.exception(e) + return False, str(e) \ No newline at end of file