|
@@ -13,10 +13,13 @@ from common.expired_dict import ExpiredDict |
|
|
import openai |
|
|
import openai |
|
|
import openai.error |
|
|
import openai.error |
|
|
import time |
|
|
import time |
|
|
|
|
|
import redis |
|
|
|
|
|
|
|
|
# OpenAI对话模型API (可用) |
|
|
# OpenAI对话模型API (可用) |
|
|
class ChatGPTBot(Bot,OpenAIImage): |
|
|
class ChatGPTBot(Bot,OpenAIImage): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
|
|
|
# set the default api_key |
|
|
openai.api_key = conf().get('open_ai_api_key') |
|
|
openai.api_key = conf().get('open_ai_api_key') |
|
|
if conf().get('open_ai_api_base'): |
|
|
if conf().get('open_ai_api_base'): |
|
|
openai.api_base = conf().get('open_ai_api_base') |
|
|
openai.api_base = conf().get('open_ai_api_base') |
|
@@ -33,6 +36,7 @@ class ChatGPTBot(Bot,OpenAIImage): |
|
|
if context.type == ContextType.TEXT: |
|
|
if context.type == ContextType.TEXT: |
|
|
logger.info("[CHATGPT] query={}".format(query)) |
|
|
logger.info("[CHATGPT] query={}".format(query)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session_id = context['session_id'] |
|
|
session_id = context['session_id'] |
|
|
reply = None |
|
|
reply = None |
|
|
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆']) |
|
|
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆']) |
|
@@ -50,11 +54,13 @@ class ChatGPTBot(Bot,OpenAIImage): |
|
|
session = self.sessions.session_query(query, session_id) |
|
|
session = self.sessions.session_query(query, session_id) |
|
|
logger.debug("[CHATGPT] session query={}".format(session.messages)) |
|
|
logger.debug("[CHATGPT] session query={}".format(session.messages)) |
|
|
|
|
|
|
|
|
|
|
|
api_key = context.get('openai_api_key') |
|
|
|
|
|
|
|
|
# if context.get('stream'): |
|
|
# if context.get('stream'): |
|
|
# # reply in stream |
|
|
# # reply in stream |
|
|
# return self.reply_text_stream(query, new_query, session_id) |
|
|
# return self.reply_text_stream(query, new_query, session_id) |
|
|
|
|
|
|
|
|
reply_content = self.reply_text(session, session_id, 0) |
|
|
|
|
|
|
|
|
reply_content = self.reply_text(session, session_id, api_key, 0) |
|
|
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) |
|
|
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) |
|
|
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: |
|
|
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: |
|
|
reply = Reply(ReplyType.ERROR, reply_content['content']) |
|
|
reply = Reply(ReplyType.ERROR, reply_content['content']) |
|
@@ -90,7 +96,7 @@ class ChatGPTBot(Bot,OpenAIImage): |
|
|
"timeout": 120, #重试超时时间,在这个时间内,将会自动重试 |
|
|
"timeout": 120, #重试超时时间,在这个时间内,将会自动重试 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict: |
|
|
|
|
|
|
|
|
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict: |
|
|
''' |
|
|
''' |
|
|
call openai's ChatCompletion to get the answer |
|
|
call openai's ChatCompletion to get the answer |
|
|
:param session: a conversation session |
|
|
:param session: a conversation session |
|
@@ -101,8 +107,9 @@ class ChatGPTBot(Bot,OpenAIImage): |
|
|
try: |
|
|
try: |
|
|
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): |
|
|
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): |
|
|
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") |
|
|
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") |
|
|
|
|
|
# if api_key == None, the default openai.api_key will be used |
|
|
response = openai.ChatCompletion.create( |
|
|
response = openai.ChatCompletion.create( |
|
|
messages=session.messages, **self.compose_args() |
|
|
|
|
|
|
|
|
api_key=api_key, messages=session.messages, **self.compose_args() |
|
|
) |
|
|
) |
|
|
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) |
|
|
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) |
|
|
return {"total_tokens": response["usage"]["total_tokens"], |
|
|
return {"total_tokens": response["usage"]["total_tokens"], |
|
@@ -118,21 +125,21 @@ class ChatGPTBot(Bot,OpenAIImage): |
|
|
time.sleep(5) |
|
|
time.sleep(5) |
|
|
elif isinstance(e, openai.error.Timeout): |
|
|
elif isinstance(e, openai.error.Timeout): |
|
|
logger.warn("[CHATGPT] Timeout: {}".format(e)) |
|
|
logger.warn("[CHATGPT] Timeout: {}".format(e)) |
|
|
result['content'] = "我没有收到你的消息" |
|
|
|
|
|
|
|
|
result['content'] = "服务器出现问题" |
|
|
if need_retry: |
|
|
if need_retry: |
|
|
time.sleep(5) |
|
|
time.sleep(5) |
|
|
elif isinstance(e, openai.error.APIConnectionError): |
|
|
elif isinstance(e, openai.error.APIConnectionError): |
|
|
logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) |
|
|
logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) |
|
|
need_retry = False |
|
|
need_retry = False |
|
|
result['content'] = "我连接不到你的网络" |
|
|
|
|
|
|
|
|
result['content'] = "网络连接出现问题" |
|
|
else: |
|
|
else: |
|
|
logger.warn("[CHATGPT] Exception: {}".format(e)) |
|
|
logger.warn("[CHATGPT] Exception: {}".format(e)) |
|
|
need_retry = False |
|
|
need_retry = False |
|
|
self.sessions.clear_session(session_id) |
|
|
self.sessions.clear_session(session_id) |
|
|
|
|
|
|
|
|
|
|
|
result['content'] = str(e) |
|
|
if need_retry: |
|
|
if need_retry: |
|
|
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) |
|
|
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) |
|
|
return self.reply_text(session, session_id, retry_count+1) |
|
|
|
|
|
|
|
|
return self.reply_text(session, session_id, api_key, retry_count+1) |
|
|
else: |
|
|
else: |
|
|
return result |
|
|
return result |
|
|
|
|
|
|
|
|