|
|
@@ -66,12 +66,16 @@ class ChatGPTBot(Bot, OpenAIImage): |
|
|
|
logger.debug("[CHATGPT] session query={}".format(session.messages)) |
|
|
|
|
|
|
|
api_key = context.get("openai_api_key") |
|
|
|
self.args['model'] = context.get('gpt_model') or "gpt-3.5-turbo" |
|
|
|
model = context.get("gpt_model") |
|
|
|
new_args = None |
|
|
|
if model: |
|
|
|
new_args = self.args.copy() |
|
|
|
new_args["model"] = model |
|
|
|
# if context.get('stream'): |
|
|
|
# # reply in stream |
|
|
|
# return self.reply_text_stream(query, new_query, session_id) |
|
|
|
|
|
|
|
reply_content = self.reply_text(session, api_key) |
|
|
|
reply_content = self.reply_text(session, api_key, args=new_args) |
|
|
|
logger.debug( |
|
|
|
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( |
|
|
|
session.messages, |
|
|
@@ -102,7 +106,7 @@ class ChatGPTBot(Bot, OpenAIImage): |
|
|
|
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) |
|
|
|
return reply |
|
|
|
|
|
|
|
def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict: |
|
|
|
def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict: |
|
|
|
""" |
|
|
|
call openai's ChatCompletion to get the answer |
|
|
|
:param session: a conversation session |
|
|
@@ -114,7 +118,9 @@ class ChatGPTBot(Bot, OpenAIImage): |
|
|
|
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): |
|
|
|
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") |
|
|
|
# if api_key == None, the default openai.api_key will be used |
|
|
|
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args) |
|
|
|
if args is None: |
|
|
|
args = self.args |
|
|
|
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args) |
|
|
|
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) |
|
|
|
return { |
|
|
|
"total_tokens": response["usage"]["total_tokens"], |
|
|
@@ -150,7 +156,7 @@ class ChatGPTBot(Bot, OpenAIImage): |
|
|
|
|
|
|
|
if need_retry: |
|
|
|
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) |
|
|
|
return self.reply_text(session, api_key, retry_count + 1) |
|
|
|
return self.reply_text(session, api_key, args, retry_count + 1) |
|
|
|
else: |
|
|
|
return result |
|
|
|
|
|
|
|