瀏覽代碼

Merge pull request #56 from zhayujie/fix-ratelimit

fix: handle RateLimitError #50 #54
master
zhayujie GitHub 2 年之前
父節點
當前提交
80f57267ec
沒有發現已知的金鑰在資料庫的簽署中 GPG 金鑰 ID: 4AEE18F83AFDEB23
共有 2 個檔案被更改,包括 31 行新增50 行删除
  1. +29
    -48
      bot/openai/open_ai_bot.py
  2. +2
    -2
      channel/wechat/wechat_channel.py

+ 29
- 48
bot/openai/open_ai_bot.py 查看文件

@@ -4,7 +4,7 @@ from bot.bot import Bot
from config import conf from config import conf
from common.log import logger from common.log import logger
import openai import openai
from datetime import date
import time


user_session = dict() user_session = dict()


@@ -26,16 +26,16 @@ class OpenAIBot(Bot):
new_query = Session.build_session_query(query, from_user_id) new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query)) logger.debug("[OPEN_AI] session query={}".format(new_query))


reply_content = self.reply_text(new_query, from_user_id)
logger.debug("[OPEN_AI] new_query={}, user={}".format(new_query, from_user_id))
reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query: if reply_content and query:
Session.save_session(query, reply_content, from_user_id) Session.save_session(query, reply_content, from_user_id)
return reply_content return reply_content


elif context.get('type', None) == 'IMAGE_CREATE': elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query)
return self.create_img(query, 0)


def reply_text(self, query, user_id):
def reply_text(self, query, user_id, retry_count=0):
try: try:
response = openai.Completion.create( response = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称 model="text-davinci-003", # 对话模型的名称
@@ -48,14 +48,25 @@ class OpenAIBot(Bot):
stop=["#"] stop=["#"]
) )
res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>") res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>")
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
except openai.error.RateLimitError as e:
# rate limit exception
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e: except Exception as e:
# unknown exception
logger.exception(e) logger.exception(e)
Session.clear_session(user_id) Session.clear_session(user_id)
return None
logger.info("[OPEN_AI] reply={}".format(res_content))
return res_content
return "请再问我一次吧"



def create_img(self, query):
def create_img(self, query, retry_count=0):
try: try:
logger.info("[OPEN_AI] image_query={}".format(query)) logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create( response = openai.Image.create(
@@ -65,48 +76,18 @@ class OpenAIBot(Bot):
) )
image_url = response['data'][0]['url'] image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url)) logger.info("[OPEN_AI] image_url={}".format(image_url))
return image_url
except openai.error.RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
return None return None
return image_url

def edit_img(self, query, src_img):
try:
response = openai.Image.create_edit(
image=open(src_img, 'rb'),
mask=open('cat-mask.png', 'rb'),
prompt=query,
n=1,
size='512x512'
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def migration_img(self, query, src_img):

try:
response = openai.Image.create_variation(
image=open(src_img, 'rb'),
n=1,
size="512x512"
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
except Exception as e:
logger.exception(e)
return None
return image_url

def append_question_mark(self, query):
end_symbols = [".", "。", "?", "?", "!", "!"]
for symbol in end_symbols:
if query.endswith(symbol):
return query
return query + "?"




class Session(object): class Session(object):


+ 2
- 2
channel/wechat/wechat_channel.py 查看文件

@@ -109,7 +109,7 @@ class WechatChannel(Channel):
return return
context = dict() context = dict()
context['from_user_id'] = reply_user_id context['from_user_id'] = reply_user_id
reply_text = super().build_reply_content(query, context).strip()
reply_text = super().build_reply_content(query, context)
if reply_text: if reply_text:
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e: except Exception as e:
@@ -144,8 +144,8 @@ class WechatChannel(Channel):
context = dict() context = dict()
context['from_user_id'] = msg['ActualUserName'] context['from_user_id'] = msg['ActualUserName']
reply_text = super().build_reply_content(query, context) reply_text = super().build_reply_content(query, context)
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
if reply_text: if reply_text:
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName']) self.send(conf().get("group_chat_reply_prefix", "") + reply_text, msg['User']['UserName'])






Loading…
取消
儲存