浏览代码

fix: handle RateLimitError #50 #51 #54

master
ubuntu 2 年前
父节点
当前提交
c7d1e77ae6
共有 2 个文件被更改,包括 31 次插入50 次删除
  1. +29
    -48
      bot/openai/open_ai_bot.py
  2. +2
    -2
      channel/wechat/wechat_channel.py

+ 29
- 48
bot/openai/open_ai_bot.py 查看文件

@@ -4,7 +4,7 @@ from bot.bot import Bot
from config import conf from config import conf
from common.log import logger from common.log import logger
import openai import openai
from datetime import date
import time


user_session = dict() user_session = dict()


@@ -26,16 +26,16 @@ class OpenAIBot(Bot):
new_query = Session.build_session_query(query, from_user_id) new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query)) logger.debug("[OPEN_AI] session query={}".format(new_query))


reply_content = self.reply_text(new_query, from_user_id)
logger.debug("[OPEN_AI] new_query={}, user={}".format(new_query, from_user_id))
reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query: if reply_content and query:
Session.save_session(query, reply_content, from_user_id) Session.save_session(query, reply_content, from_user_id)
return reply_content return reply_content


elif context.get('type', None) == 'IMAGE_CREATE': elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query)
return self.create_img(query, 0)


def reply_text(self, query, user_id):
def reply_text(self, query, user_id, retry_count=0):
try: try:
response = openai.Completion.create( response = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称 model="text-davinci-003", # 对话模型的名称
@@ -48,14 +48,25 @@ class OpenAIBot(Bot):
stop=["#"] stop=["#"]
) )
res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>") res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>")
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
except openai.error.RateLimitError as e:
# rate limit exception
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e: except Exception as e:
# unknown exception
logger.exception(e) logger.exception(e)
Session.clear_session(user_id) Session.clear_session(user_id)
return None
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
return "请再问我一次吧"



def create_img(self, query):
def create_img(self, query, retry_count=0):
try: try:
logger.info("[OPEN_AI] image_query={}".format(query)) logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create( response = openai.Image.create(
@@ -65,48 +76,18 @@ class OpenAIBot(Bot):
) )
image_url = response['data'][0]['url'] image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url)) logger.info("[OPEN_AI] image_url={}".format(image_url))
return image_url
except openai.error.RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
return None return None
return image_url

def edit_img(self, query, src_img):
try:
response = openai.Image.create_edit(
image=open(src_img, 'rb'),
mask=open('cat-mask.png', 'rb'),
prompt=query,
n=1,
size='512x512'
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def migration_img(self, query, src_img):

try:
response = openai.Image.create_variation(
image=open(src_img, 'rb'),
n=1,
size="512x512"
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def append_question_mark(self, query):
end_symbols = [".", "。", "?", "?", "!", "!"]
for symbol in end_symbols:
if query.endswith(symbol):
return query
return query + "?"




class Session(object): class Session(object):


+ 2
- 2
channel/wechat/wechat_channel.py 查看文件

@@ -109,7 +109,7 @@ class WechatChannel(Channel):
return return
context = dict() context = dict()
context['from_user_id'] = reply_user_id context['from_user_id'] = reply_user_id
reply_text = super().build_reply_content(query, context).strip()
reply_text = super().build_reply_content(query, context)
if reply_text: if reply_text:
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e: except Exception as e:
@@ -144,8 +144,8 @@ class WechatChannel(Channel):
context = dict() context = dict()
context['from_user_id'] = msg['ActualUserName'] context['from_user_id'] = msg['ActualUserName']
reply_text = super().build_reply_content(query, context) reply_text = super().build_reply_content(query, context)
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
if reply_text: if reply_text:
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName']) self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName'])






正在加载...
取消
保存