From b476085110b19319d455bef9115f6f715b81ecd2 Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 30 May 2023 23:42:06 +0800 Subject: [PATCH] fix: custom GPT model bug --- bot/chatgpt/chat_gpt_bot.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 9c54faf..60fc3f8 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -66,12 +66,16 @@ class ChatGPTBot(Bot, OpenAIImage): logger.debug("[CHATGPT] session query={}".format(session.messages)) api_key = context.get("openai_api_key") - self.args['model'] = context.get('gpt_model') or "gpt-3.5-turbo" + model = context.get("gpt_model") + new_args = None + if model: + new_args = self.args.copy() + new_args["model"] = model # if context.get('stream'): # # reply in stream # return self.reply_text_stream(query, new_query, session_id) - reply_content = self.reply_text(session, api_key) + reply_content = self.reply_text(session, api_key, args=new_args) logger.debug( "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( session.messages, @@ -102,7 +106,7 @@ class ChatGPTBot(Bot, OpenAIImage): reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) return reply - def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict: + def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict: """ call openai's ChatCompletion to get the answer :param session: a conversation session @@ -114,7 +118,9 @@ class ChatGPTBot(Bot, OpenAIImage): if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") # if api_key == None, the default openai.api_key will be used - response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args) + if args is None: + args = self.args + response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) return { "total_tokens": response["usage"]["total_tokens"], @@ -150,7 +156,7 @@ class ChatGPTBot(Bot, OpenAIImage): if need_retry: logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) - return self.reply_text(session, api_key, retry_count + 1) + return self.reply_text(session, api_key, args, retry_count + 1) else: return result