|
|
@@ -4,7 +4,7 @@ from bot.bot import Bot |
|
|
|
from config import conf |
|
|
|
from common.log import logger |
|
|
|
import openai |
|
|
|
from datetime import date |
|
|
|
import time |
|
|
|
|
|
|
|
user_session = dict() |
|
|
|
|
|
|
@@ -26,16 +26,16 @@ class OpenAIBot(Bot): |
|
|
|
new_query = Session.build_session_query(query, from_user_id) |
|
|
|
logger.debug("[OPEN_AI] session query={}".format(new_query)) |
|
|
|
|
|
|
|
reply_content = self.reply_text(new_query, from_user_id) |
|
|
|
logger.debug("[OPEN_AI] new_query={}, user={}".format(new_query, from_user_id)) |
|
|
|
reply_content = self.reply_text(new_query, from_user_id, 0) |
|
|
|
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) |
|
|
|
if reply_content and query: |
|
|
|
Session.save_session(query, reply_content, from_user_id) |
|
|
|
return reply_content |
|
|
|
|
|
|
|
elif context.get('type', None) == 'IMAGE_CREATE': |
|
|
|
return self.create_img(query) |
|
|
|
return self.create_img(query, 0) |
|
|
|
|
|
|
|
def reply_text(self, query, user_id): |
|
|
|
def reply_text(self, query, user_id, retry_count=0): |
|
|
|
try: |
|
|
|
response = openai.Completion.create( |
|
|
|
model="text-davinci-003", # 对话模型的名称 |
|
|
@@ -48,14 +48,25 @@ class OpenAIBot(Bot): |
|
|
|
stop=["#"] |
|
|
|
) |
|
|
|
res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>") |
|
|
|
logger.info("[OPEN_AI] reply={}".format(res_content)) |
|
|
|
return res_content |
|
|
|
except openai.error.RateLimitError as e: |
|
|
|
# rate limit exception |
|
|
|
logger.warn(e) |
|
|
|
if retry_count < 1: |
|
|
|
time.sleep(5) |
|
|
|
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1)) |
|
|
|
return self.reply_text(query, user_id, retry_count+1) |
|
|
|
else: |
|
|
|
return "提问太快啦,请休息一下再问我吧" |
|
|
|
except Exception as e: |
|
|
|
# unknown exception |
|
|
|
logger.exception(e) |
|
|
|
Session.clear_session(user_id) |
|
|
|
return None |
|
|
|
logger.info("[OPEN_AI] reply={}".format(res_content)) |
|
|
|
return res_content |
|
|
|
return "请再问我一次吧" |
|
|
|
|
|
|
|
|
|
|
|
def create_img(self, query): |
|
|
|
def create_img(self, query, retry_count=0): |
|
|
|
try: |
|
|
|
logger.info("[OPEN_AI] image_query={}".format(query)) |
|
|
|
response = openai.Image.create( |
|
|
@@ -65,48 +76,18 @@ class OpenAIBot(Bot): |
|
|
|
) |
|
|
|
image_url = response['data'][0]['url'] |
|
|
|
logger.info("[OPEN_AI] image_url={}".format(image_url)) |
|
|
|
return image_url |
|
|
|
except openai.error.RateLimitError as e: |
|
|
|
logger.warn(e) |
|
|
|
if retry_count < 1: |
|
|
|
time.sleep(5) |
|
|
|
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) |
|
|
|
return self.reply_text(query, retry_count+1) |
|
|
|
else: |
|
|
|
return "提问太快啦,请休息一下再问我吧" |
|
|
|
except Exception as e: |
|
|
|
logger.exception(e) |
|
|
|
return None |
|
|
|
return image_url |
|
|
|
|
|
|
|
def edit_img(self, query, src_img): |
|
|
|
try: |
|
|
|
response = openai.Image.create_edit( |
|
|
|
image=open(src_img, 'rb'), |
|
|
|
mask=open('cat-mask.png', 'rb'), |
|
|
|
prompt=query, |
|
|
|
n=1, |
|
|
|
size='512x512' |
|
|
|
) |
|
|
|
image_url = response['data'][0]['url'] |
|
|
|
logger.info("[OPEN_AI] image_url={}".format(image_url)) |
|
|
|
except Exception as e: |
|
|
|
logger.exception(e) |
|
|
|
return None |
|
|
|
return image_url |
|
|
|
|
|
|
|
def migration_img(self, query, src_img): |
|
|
|
|
|
|
|
try: |
|
|
|
response = openai.Image.create_variation( |
|
|
|
image=open(src_img, 'rb'), |
|
|
|
n=1, |
|
|
|
size="512x512" |
|
|
|
) |
|
|
|
image_url = response['data'][0]['url'] |
|
|
|
logger.info("[OPEN_AI] image_url={}".format(image_url)) |
|
|
|
except Exception as e: |
|
|
|
logger.exception(e) |
|
|
|
return None |
|
|
|
return image_url |
|
|
|
|
|
|
|
def append_question_mark(self, query): |
|
|
|
end_symbols = [".", "。", "?", "?", "!", "!"] |
|
|
|
for symbol in end_symbols: |
|
|
|
if query.endswith(symbol): |
|
|
|
return query |
|
|
|
return query + "?" |
|
|
|
|
|
|
|
|
|
|
|
class Session(object): |
|
|
|