浏览代码

fix: handle RateLimitError #50 #51 #54

master
ubuntu 1年前
父节点
当前提交
c7d1e77ae6
共有 2 个文件被更改,包括 31 次插入50 次删除
  1. +29
    -48
      bot/openai/open_ai_bot.py
  2. +2
    -2
      channel/wechat/wechat_channel.py

+ 29
- 48
bot/openai/open_ai_bot.py 查看文件

@@ -4,7 +4,7 @@ from bot.bot import Bot
from config import conf
from common.log import logger
import openai
from datetime import date
import time

user_session = dict()

@@ -26,16 +26,16 @@ class OpenAIBot(Bot):
new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query))

reply_content = self.reply_text(new_query, from_user_id)
logger.debug("[OPEN_AI] new_query={}, user={}".format(new_query, from_user_id))
reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query:
Session.save_session(query, reply_content, from_user_id)
return reply_content

elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query)
return self.create_img(query, 0)

def reply_text(self, query, user_id):
def reply_text(self, query, user_id, retry_count=0):
try:
response = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称
@@ -48,14 +48,25 @@ class OpenAIBot(Bot):
stop=["#"]
)
res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>")
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
except openai.error.RateLimitError as e:
# rate limit exception
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e:
# unknown exception
logger.exception(e)
Session.clear_session(user_id)
return None
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
return "请再问我一次吧"


def create_img(self, query):
def create_img(self, query, retry_count=0):
try:
logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create(
@@ -65,48 +76,18 @@ class OpenAIBot(Bot):
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
return image_url
except openai.error.RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e:
logger.exception(e)
return None
return image_url

def edit_img(self, query, src_img):
try:
response = openai.Image.create_edit(
image=open(src_img, 'rb'),
mask=open('cat-mask.png', 'rb'),
prompt=query,
n=1,
size='512x512'
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def migration_img(self, query, src_img):

try:
response = openai.Image.create_variation(
image=open(src_img, 'rb'),
n=1,
size="512x512"
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def append_question_mark(self, query):
end_symbols = [".", "。", "?", "?", "!", "!"]
for symbol in end_symbols:
if query.endswith(symbol):
return query
return query + "?"


class Session(object):


+ 2
- 2
channel/wechat/wechat_channel.py 查看文件

@@ -109,7 +109,7 @@ class WechatChannel(Channel):
return
context = dict()
context['from_user_id'] = reply_user_id
reply_text = super().build_reply_content(query, context).strip()
reply_text = super().build_reply_content(query, context)
if reply_text:
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e:
@@ -144,8 +144,8 @@ class WechatChannel(Channel):
context = dict()
context['from_user_id'] = msg['ActualUserName']
reply_text = super().build_reply_content(query, context)
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
if reply_text:
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName'])




正在加载...
取消
保存