|
|
@@ -3,6 +3,7 @@ |
|
|
|
from bot.bot import Bot |
|
|
|
from config import conf, load_config |
|
|
|
from common.log import logger |
|
|
|
from common.token_bucket import TokenBucket |
|
|
|
from common.expired_dict import ExpiredDict |
|
|
|
import openai |
|
|
|
import time |
|
|
@@ -21,6 +22,10 @@ class ChatGPTBot(Bot): |
|
|
|
proxy = conf().get('proxy') |
|
|
|
if proxy: |
|
|
|
openai.proxy = proxy |
|
|
|
if conf().get('rate_limit_chatgpt'): |
|
|
|
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20)) |
|
|
|
if conf().get('rate_limit_dalle'): |
|
|
|
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) |
|
|
|
|
|
|
|
def reply(self, query, context=None): |
|
|
|
# acquire reply content |
|
|
@@ -63,6 +68,8 @@ class ChatGPTBot(Bot): |
|
|
|
:return: {} |
|
|
|
''' |
|
|
|
try: |
|
|
|
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): |
|
|
|
return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"} |
|
|
|
response = openai.ChatCompletion.create( |
|
|
|
model= conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 |
|
|
|
messages=session, |
|
|
@@ -102,6 +109,8 @@ class ChatGPTBot(Bot): |
|
|
|
|
|
|
|
def create_img(self, query, retry_count=0): |
|
|
|
try: |
|
|
|
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): |
|
|
|
return "请求太快了,请休息一下再问我吧" |
|
|
|
logger.info("[OPEN_AI] image_query={}".format(query)) |
|
|
|
response = openai.Image.create( |
|
|
|
prompt=query, #图片描述 |
|
|
@@ -118,7 +127,7 @@ class ChatGPTBot(Bot): |
|
|
|
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) |
|
|
|
return self.create_img(query, retry_count+1) |
|
|
|
else: |
|
|
|
return "提问太快啦,请休息一下再问我吧" |
|
|
|
return "请求太快啦,请休息一下再问我吧" |
|
|
|
except Exception as e: |
|
|
|
logger.exception(e) |
|
|
|
return None |
|
|
|