@@ -1,6 +1,7 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bot.openai.open_ai_image import OpenAIImage | |||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from config import conf, load_config | from config import conf, load_config | ||||
@@ -12,7 +13,7 @@ import time | |||||
# OpenAI对话模型API (可用) | # OpenAI对话模型API (可用) | ||||
class ChatGPTBot(Bot): | |||||
class ChatGPTBot(Bot,OpenAIImage): | |||||
def __init__(self): | def __init__(self): | ||||
openai.api_key = conf().get('open_ai_api_key') | openai.api_key = conf().get('open_ai_api_key') | ||||
if conf().get('open_ai_api_base'): | if conf().get('open_ai_api_base'): | ||||
@@ -23,8 +24,6 @@ class ChatGPTBot(Bot): | |||||
openai.proxy = proxy | openai.proxy = proxy | ||||
if conf().get('rate_limit_chatgpt'): | if conf().get('rate_limit_chatgpt'): | ||||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20)) | self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20)) | ||||
if conf().get('rate_limit_dalle'): | |||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) | |||||
def reply(self, query, context=None): | def reply(self, query, context=None): | ||||
# acquire reply content | # acquire reply content | ||||
@@ -128,30 +127,6 @@ class ChatGPTBot(Bot): | |||||
self.sessions.clear_session(session_id) | self.sessions.clear_session(session_id) | ||||
return {"completion_tokens": 0, "content": "请再问我一次吧"} | return {"completion_tokens": 0, "content": "请再问我一次吧"} | ||||
def create_img(self, query, retry_count=0): | |||||
try: | |||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): | |||||
return False, "请求太快了,请休息一下再问我吧" | |||||
logger.info("[OPEN_AI] image_query={}".format(query)) | |||||
response = openai.Image.create( | |||||
prompt=query, #图片描述 | |||||
n=1, #每次生成图片的数量 | |||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 | |||||
) | |||||
image_url = response['data'][0]['url'] | |||||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | |||||
return True, image_url | |||||
except openai.error.RateLimitError as e: | |||||
logger.warn(e) | |||||
if retry_count < 1: | |||||
time.sleep(5) | |||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | |||||
return self.create_img(query, retry_count+1) | |||||
else: | |||||
return False, "提问太快啦,请休息一下再问我吧" | |||||
except Exception as e: | |||||
logger.exception(e) | |||||
return False, str(e) | |||||
class AzureChatGPTBot(ChatGPTBot): | class AzureChatGPTBot(ChatGPTBot): | ||||
@@ -1,6 +1,7 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bot.openai.open_ai_image import OpenAIImage | |||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from config import conf | from config import conf | ||||
@@ -11,7 +12,7 @@ import time | |||||
user_session = dict() | user_session = dict() | ||||
# OpenAI对话模型API (可用) | # OpenAI对话模型API (可用) | ||||
class OpenAIBot(Bot): | |||||
class OpenAIBot(Bot, OpenAIImage): | |||||
def __init__(self): | def __init__(self): | ||||
openai.api_key = conf().get('open_ai_api_key') | openai.api_key = conf().get('open_ai_api_key') | ||||
if conf().get('open_ai_api_base'): | if conf().get('open_ai_api_base'): | ||||
@@ -45,7 +46,13 @@ class OpenAIBot(Bot): | |||||
reply = Reply(ReplyType.TEXT, reply_content) | reply = Reply(ReplyType.TEXT, reply_content) | ||||
return reply | return reply | ||||
elif context.type == ContextType.IMAGE_CREATE: | elif context.type == ContextType.IMAGE_CREATE: | ||||
return self.create_img(query, 0) | |||||
ok, retstring = self.create_img(query, 0) | |||||
reply = None | |||||
if ok: | |||||
reply = Reply(ReplyType.IMAGE_URL, retstring) | |||||
else: | |||||
reply = Reply(ReplyType.ERROR, retstring) | |||||
return reply | |||||
def reply_text(self, query, user_id, retry_count=0): | def reply_text(self, query, user_id, retry_count=0): | ||||
try: | try: | ||||
@@ -77,31 +84,6 @@ class OpenAIBot(Bot): | |||||
Session.clear_session(user_id) | Session.clear_session(user_id) | ||||
return "请再问我一次吧" | return "请再问我一次吧" | ||||
def create_img(self, query, retry_count=0): | |||||
try: | |||||
logger.info("[OPEN_AI] image_query={}".format(query)) | |||||
response = openai.Image.create( | |||||
prompt=query, #图片描述 | |||||
n=1, #每次生成图片的数量 | |||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 | |||||
) | |||||
image_url = response['data'][0]['url'] | |||||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | |||||
return image_url | |||||
except openai.error.RateLimitError as e: | |||||
logger.warn(e) | |||||
if retry_count < 1: | |||||
time.sleep(5) | |||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | |||||
return self.reply_text(query, retry_count+1) | |||||
else: | |||||
return "提问太快啦,请休息一下再问我吧" | |||||
except Exception as e: | |||||
logger.exception(e) | |||||
return None | |||||
class Session(object): | class Session(object): | ||||
@staticmethod | @staticmethod | ||||
def build_session_query(query, user_id): | def build_session_query(query, user_id): | ||||
@@ -0,0 +1,37 @@ | |||||
import time | |||||
import openai | |||||
from common.token_bucket import TokenBucket | |||||
from common.log import logger | |||||
from config import conf | |||||
# OPENAI提供的画图接口 | |||||
class OpenAIImage(object): | |||||
def __init__(self): | |||||
openai.api_key = conf().get('open_ai_api_key') | |||||
if conf().get('rate_limit_dalle'): | |||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) | |||||
def create_img(self, query, retry_count=0): | |||||
try: | |||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): | |||||
return False, "请求太快了,请休息一下再问我吧" | |||||
logger.info("[OPEN_AI] image_query={}".format(query)) | |||||
response = openai.Image.create( | |||||
prompt=query, #图片描述 | |||||
n=1, #每次生成图片的数量 | |||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 | |||||
) | |||||
image_url = response['data'][0]['url'] | |||||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | |||||
return True, image_url | |||||
except openai.error.RateLimitError as e: | |||||
logger.warn(e) | |||||
if retry_count < 1: | |||||
time.sleep(5) | |||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | |||||
return self.create_img(query, retry_count+1) | |||||
else: | |||||
return False, "提问太快啦,请休息一下再问我吧" | |||||
except Exception as e: | |||||
logger.exception(e) | |||||
return False, str(e) |