@@ -1,6 +1,7 @@ | |||
# encoding:utf-8 | |||
from bot.bot import Bot | |||
from bot.openai.open_ai_image import OpenAIImage | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf, load_config | |||
@@ -12,7 +13,7 @@ import time | |||
# OpenAI对话模型API (可用) | |||
class ChatGPTBot(Bot): | |||
class ChatGPTBot(Bot,OpenAIImage): | |||
def __init__(self): | |||
openai.api_key = conf().get('open_ai_api_key') | |||
if conf().get('open_ai_api_base'): | |||
@@ -23,8 +24,6 @@ class ChatGPTBot(Bot): | |||
openai.proxy = proxy | |||
if conf().get('rate_limit_chatgpt'): | |||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20)) | |||
if conf().get('rate_limit_dalle'): | |||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) | |||
def reply(self, query, context=None): | |||
# acquire reply content | |||
@@ -128,30 +127,6 @@ class ChatGPTBot(Bot): | |||
self.sessions.clear_session(session_id) | |||
return {"completion_tokens": 0, "content": "请再问我一次吧"} | |||
def create_img(self, query, retry_count=0): | |||
try: | |||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): | |||
return False, "请求太快了,请休息一下再问我吧" | |||
logger.info("[OPEN_AI] image_query={}".format(query)) | |||
response = openai.Image.create( | |||
prompt=query, #图片描述 | |||
n=1, #每次生成图片的数量 | |||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 | |||
) | |||
image_url = response['data'][0]['url'] | |||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | |||
return True, image_url | |||
except openai.error.RateLimitError as e: | |||
logger.warn(e) | |||
if retry_count < 1: | |||
time.sleep(5) | |||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | |||
return self.create_img(query, retry_count+1) | |||
else: | |||
return False, "提问太快啦,请休息一下再问我吧" | |||
except Exception as e: | |||
logger.exception(e) | |||
return False, str(e) | |||
class AzureChatGPTBot(ChatGPTBot): | |||
@@ -1,6 +1,7 @@ | |||
# encoding:utf-8 | |||
from bot.bot import Bot | |||
from bot.openai.open_ai_image import OpenAIImage | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf | |||
@@ -11,7 +12,7 @@ import time | |||
user_session = dict() | |||
# OpenAI对话模型API (可用) | |||
class OpenAIBot(Bot): | |||
class OpenAIBot(Bot, OpenAIImage): | |||
def __init__(self): | |||
openai.api_key = conf().get('open_ai_api_key') | |||
if conf().get('open_ai_api_base'): | |||
@@ -45,7 +46,13 @@ class OpenAIBot(Bot): | |||
reply = Reply(ReplyType.TEXT, reply_content) | |||
return reply | |||
elif context.type == ContextType.IMAGE_CREATE: | |||
return self.create_img(query, 0) | |||
ok, retstring = self.create_img(query, 0) | |||
reply = None | |||
if ok: | |||
reply = Reply(ReplyType.IMAGE_URL, retstring) | |||
else: | |||
reply = Reply(ReplyType.ERROR, retstring) | |||
return reply | |||
def reply_text(self, query, user_id, retry_count=0): | |||
try: | |||
@@ -77,31 +84,6 @@ class OpenAIBot(Bot): | |||
Session.clear_session(user_id) | |||
return "请再问我一次吧" | |||
def create_img(self, query, retry_count=0): | |||
try: | |||
logger.info("[OPEN_AI] image_query={}".format(query)) | |||
response = openai.Image.create( | |||
prompt=query, #图片描述 | |||
n=1, #每次生成图片的数量 | |||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 | |||
) | |||
image_url = response['data'][0]['url'] | |||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | |||
return image_url | |||
except openai.error.RateLimitError as e: | |||
logger.warn(e) | |||
if retry_count < 1: | |||
time.sleep(5) | |||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | |||
return self.reply_text(query, retry_count+1) | |||
else: | |||
return "提问太快啦,请休息一下再问我吧" | |||
except Exception as e: | |||
logger.exception(e) | |||
return None | |||
class Session(object): | |||
@staticmethod | |||
def build_session_query(query, user_id): | |||
@@ -0,0 +1,37 @@ | |||
import time | |||
import openai | |||
from common.token_bucket import TokenBucket | |||
from common.log import logger | |||
from config import conf | |||
# OPENAI提供的画图接口 | |||
class OpenAIImage(object): | |||
def __init__(self): | |||
openai.api_key = conf().get('open_ai_api_key') | |||
if conf().get('rate_limit_dalle'): | |||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) | |||
def create_img(self, query, retry_count=0): | |||
try: | |||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): | |||
return False, "请求太快了,请休息一下再问我吧" | |||
logger.info("[OPEN_AI] image_query={}".format(query)) | |||
response = openai.Image.create( | |||
prompt=query, #图片描述 | |||
n=1, #每次生成图片的数量 | |||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 | |||
) | |||
image_url = response['data'][0]['url'] | |||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | |||
return True, image_url | |||
except openai.error.RateLimitError as e: | |||
logger.warn(e) | |||
if retry_count < 1: | |||
time.sleep(5) | |||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | |||
return self.create_img(query, retry_count+1) | |||
else: | |||
return False, "提问太快啦,请休息一下再问我吧" | |||
except Exception as e: | |||
logger.exception(e) | |||
return False, str(e) |