Browse Source

fix: handle RateLimitError #50 #51 #54

develop
ubuntu 1 year ago
parent
commit
c7d1e77ae6
2 changed files with 31 additions and 50 deletions
  1. +29
    -48
      bot/openai/open_ai_bot.py
  2. +2
    -2
      channel/wechat/wechat_channel.py

+ 29
- 48
bot/openai/open_ai_bot.py View File

@@ -4,7 +4,7 @@ from bot.bot import Bot
from config import conf from config import conf
from common.log import logger from common.log import logger
import openai import openai
from datetime import date
import time


user_session = dict() user_session = dict()


@@ -26,16 +26,16 @@ class OpenAIBot(Bot):
new_query = Session.build_session_query(query, from_user_id) new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query)) logger.debug("[OPEN_AI] session query={}".format(new_query))


reply_content = self.reply_text(new_query, from_user_id)
logger.debug("[OPEN_AI] new_query={}, user={}".format(new_query, from_user_id))
reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query: if reply_content and query:
Session.save_session(query, reply_content, from_user_id) Session.save_session(query, reply_content, from_user_id)
return reply_content return reply_content


elif context.get('type', None) == 'IMAGE_CREATE': elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query)
return self.create_img(query, 0)


def reply_text(self, query, user_id):
def reply_text(self, query, user_id, retry_count=0):
try: try:
response = openai.Completion.create( response = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称 model="text-davinci-003", # 对话模型的名称
@@ -48,14 +48,25 @@ class OpenAIBot(Bot):
stop=["#"] stop=["#"]
) )
res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>") res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>")
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
except openai.error.RateLimitError as e:
# rate limit exception
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e: except Exception as e:
# unknown exception
logger.exception(e) logger.exception(e)
Session.clear_session(user_id) Session.clear_session(user_id)
return None
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
return "请再问我一次吧"



def create_img(self, query):
def create_img(self, query, retry_count=0):
try: try:
logger.info("[OPEN_AI] image_query={}".format(query)) logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create( response = openai.Image.create(
@@ -65,48 +76,18 @@ class OpenAIBot(Bot):
) )
image_url = response['data'][0]['url'] image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url)) logger.info("[OPEN_AI] image_url={}".format(image_url))
return image_url
except openai.error.RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
return None return None
return image_url

def edit_img(self, query, src_img):
try:
response = openai.Image.create_edit(
image=open(src_img, 'rb'),
mask=open('cat-mask.png', 'rb'),
prompt=query,
n=1,
size='512x512'
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def migration_img(self, query, src_img):

try:
response = openai.Image.create_variation(
image=open(src_img, 'rb'),
n=1,
size="512x512"
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def append_question_mark(self, query):
end_symbols = [".", "。", "?", "?", "!", "!"]
for symbol in end_symbols:
if query.endswith(symbol):
return query
return query + "?"




class Session(object): class Session(object):


+ 2
- 2
channel/wechat/wechat_channel.py View File

@@ -109,7 +109,7 @@ class WechatChannel(Channel):
return return
context = dict() context = dict()
context['from_user_id'] = reply_user_id context['from_user_id'] = reply_user_id
reply_text = super().build_reply_content(query, context).strip()
reply_text = super().build_reply_content(query, context)
if reply_text: if reply_text:
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e: except Exception as e:
@@ -144,8 +144,8 @@ class WechatChannel(Channel):
context = dict() context = dict()
context['from_user_id'] = msg['ActualUserName'] context['from_user_id'] = msg['ActualUserName']
reply_text = super().build_reply_content(query, context) reply_text = super().build_reply_content(query, context)
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
if reply_text: if reply_text:
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName']) self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName'])






Loading…
Cancel
Save