Pārlūkot izejas kodu

fix: custom GPT model bug

master
lanvent pirms 1 gada
vecāks
revīzija
b476085110
1 mainītis faili ar 11 papildinājumiem un 5 dzēšanām
  1. +11
    -5
      bot/chatgpt/chat_gpt_bot.py

+ 11
- 5
bot/chatgpt/chat_gpt_bot.py Parādīt failu

@@ -66,12 +66,16 @@ class ChatGPTBot(Bot, OpenAIImage):
logger.debug("[CHATGPT] session query={}".format(session.messages))

api_key = context.get("openai_api_key")
self.args['model'] = context.get('gpt_model') or "gpt-3.5-turbo"
model = context.get("gpt_model")
new_args = None
if model:
new_args = self.args.copy()
new_args["model"] = model
# if context.get('stream'):
# # reply in stream
# return self.reply_text_stream(query, new_query, session_id)

reply_content = self.reply_text(session, api_key)
reply_content = self.reply_text(session, api_key, args=new_args)
logger.debug(
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
session.messages,
@@ -102,7 +106,7 @@ class ChatGPTBot(Bot, OpenAIImage):
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply

def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict:
def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
"""
call openai's ChatCompletion to get the answer
:param session: a conversation session
@@ -114,7 +118,9 @@ class ChatGPTBot(Bot, OpenAIImage):
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
if args is None:
args = self.args
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {
"total_tokens": response["usage"]["total_tokens"],
@@ -150,7 +156,7 @@ class ChatGPTBot(Bot, OpenAIImage):

if need_retry:
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, api_key, retry_count + 1)
return self.reply_text(session, api_key, args, retry_count + 1)
else:
return result



Notiek ielāde…
Atcelt
Saglabāt