@@ -58,7 +58,7 @@ class ChatGPTBot(Bot,OpenAIImage): | |||||
# # reply in stream | # # reply in stream | ||||
# return self.reply_text_stream(query, new_query, session_id) | # return self.reply_text_stream(query, new_query, session_id) | ||||
reply_content = self.reply_text(session, session_id, api_key, 0) | |||||
reply_content = self.reply_text(session, api_key) | |||||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) | logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) | ||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: | if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: | ||||
reply = Reply(ReplyType.ERROR, reply_content['content']) | reply = Reply(ReplyType.ERROR, reply_content['content']) | ||||
@@ -94,7 +94,7 @@ class ChatGPTBot(Bot,OpenAIImage): | |||||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 | "timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 | ||||
} | } | ||||
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict: | |||||
def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict: | |||||
''' | ''' | ||||
call openai's ChatCompletion to get the answer | call openai's ChatCompletion to get the answer | ||||
:param session: a conversation session | :param session: a conversation session | ||||
@@ -133,11 +133,11 @@ class ChatGPTBot(Bot,OpenAIImage): | |||||
else: | else: | ||||
logger.warn("[CHATGPT] Exception: {}".format(e)) | logger.warn("[CHATGPT] Exception: {}".format(e)) | ||||
need_retry = False | need_retry = False | ||||
self.sessions.clear_session(session_id) | |||||
self.sessions.clear_session(session.session_id) | |||||
if need_retry: | if need_retry: | ||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) | logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) | ||||
return self.reply_text(session, session_id, api_key, retry_count+1) | |||||
return self.reply_text(session, api_key, retry_count+1) | |||||
else: | else: | ||||
return result | return result | ||||
@@ -17,7 +17,7 @@ class ChatGPTSession(Session): | |||||
def discard_exceeding(self, max_tokens, cur_tokens= None): | def discard_exceeding(self, max_tokens, cur_tokens= None): | ||||
precise = True | precise = True | ||||
try: | try: | ||||
cur_tokens = num_tokens_from_messages(self.messages, self.model) | |||||
cur_tokens = self.calc_tokens() | |||||
except Exception as e: | except Exception as e: | ||||
precise = False | precise = False | ||||
if cur_tokens is None: | if cur_tokens is None: | ||||
@@ -29,7 +29,7 @@ class ChatGPTSession(Session): | |||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": | elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": | ||||
self.messages.pop(1) | self.messages.pop(1) | ||||
if precise: | if precise: | ||||
cur_tokens = num_tokens_from_messages(self.messages, self.model) | |||||
cur_tokens = self.calc_tokens() | |||||
else: | else: | ||||
cur_tokens = cur_tokens - max_tokens | cur_tokens = cur_tokens - max_tokens | ||||
break | break | ||||
@@ -40,11 +40,14 @@ class ChatGPTSession(Session): | |||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) | logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) | ||||
break | break | ||||
if precise: | if precise: | ||||
cur_tokens = num_tokens_from_messages(self.messages, self.model) | |||||
cur_tokens = self.calc_tokens() | |||||
else: | else: | ||||
cur_tokens = cur_tokens - max_tokens | cur_tokens = cur_tokens - max_tokens | ||||
return cur_tokens | return cur_tokens | ||||
def calc_tokens(self): | |||||
return num_tokens_from_messages(self.messages, self.model) | |||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
def num_tokens_from_messages(messages, model): | def num_tokens_from_messages(messages, model): | ||||
@@ -42,11 +42,9 @@ class OpenAIBot(Bot, OpenAIImage): | |||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | reply = Reply(ReplyType.INFO, '所有人记忆已清除') | ||||
else: | else: | ||||
session = self.sessions.session_query(query, session_id) | session = self.sessions.session_query(query, session_id) | ||||
new_query = str(session) | |||||
logger.debug("[OPEN_AI] session query={}".format(new_query)) | |||||
total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0) | |||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens)) | |||||
result = self.reply_text(session) | |||||
total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content'] | |||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)) | |||||
if total_tokens == 0 : | if total_tokens == 0 : | ||||
reply = Reply(ReplyType.ERROR, reply_content) | reply = Reply(ReplyType.ERROR, reply_content) | ||||
@@ -63,11 +61,11 @@ class OpenAIBot(Bot, OpenAIImage): | |||||
reply = Reply(ReplyType.ERROR, retstring) | reply = Reply(ReplyType.ERROR, retstring) | ||||
return reply | return reply | ||||
def reply_text(self, query, session_id, retry_count=0): | |||||
def reply_text(self, session:OpenAISession, retry_count=0): | |||||
try: | try: | ||||
response = openai.Completion.create( | response = openai.Completion.create( | ||||
model= conf().get("model") or "text-davinci-003", # 对话模型的名称 | model= conf().get("model") or "text-davinci-003", # 对话模型的名称 | ||||
prompt=query, | |||||
prompt=str(session), | |||||
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性 | temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性 | ||||
max_tokens=1200, # 回复最大的字符数 | max_tokens=1200, # 回复最大的字符数 | ||||
top_p=1, | top_p=1, | ||||
@@ -79,31 +77,33 @@ class OpenAIBot(Bot, OpenAIImage): | |||||
total_tokens = response["usage"]["total_tokens"] | total_tokens = response["usage"]["total_tokens"] | ||||
completion_tokens = response["usage"]["completion_tokens"] | completion_tokens = response["usage"]["completion_tokens"] | ||||
logger.info("[OPEN_AI] reply={}".format(res_content)) | logger.info("[OPEN_AI] reply={}".format(res_content)) | ||||
return total_tokens, completion_tokens, res_content | |||||
return {"total_tokens": total_tokens, | |||||
"completion_tokens": completion_tokens, | |||||
"content": res_content} | |||||
except Exception as e: | except Exception as e: | ||||
need_retry = retry_count < 2 | need_retry = retry_count < 2 | ||||
result = [0,0,"我现在有点累了,等会再来吧"] | |||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | |||||
if isinstance(e, openai.error.RateLimitError): | if isinstance(e, openai.error.RateLimitError): | ||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) | logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) | ||||
result[2] = "提问太快啦,请休息一下再问我吧" | |||||
result['content'] = "提问太快啦,请休息一下再问我吧" | |||||
if need_retry: | if need_retry: | ||||
time.sleep(5) | time.sleep(5) | ||||
elif isinstance(e, openai.error.Timeout): | elif isinstance(e, openai.error.Timeout): | ||||
logger.warn("[OPEN_AI] Timeout: {}".format(e)) | logger.warn("[OPEN_AI] Timeout: {}".format(e)) | ||||
result[2] = "我没有收到你的消息" | |||||
result['content'] = "我没有收到你的消息" | |||||
if need_retry: | if need_retry: | ||||
time.sleep(5) | time.sleep(5) | ||||
elif isinstance(e, openai.error.APIConnectionError): | elif isinstance(e, openai.error.APIConnectionError): | ||||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) | logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) | ||||
need_retry = False | need_retry = False | ||||
result[2] = "我连接不到你的网络" | |||||
result['content'] = "我连接不到你的网络" | |||||
else: | else: | ||||
logger.warn("[OPEN_AI] Exception: {}".format(e)) | logger.warn("[OPEN_AI] Exception: {}".format(e)) | ||||
need_retry = False | need_retry = False | ||||
self.sessions.clear_session(session_id) | |||||
self.sessions.clear_session(session.session_id) | |||||
if need_retry: | if need_retry: | ||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1)) | logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1)) | ||||
return self.reply_text(query, session_id, retry_count+1) | |||||
return self.reply_text(session, retry_count+1) | |||||
else: | else: | ||||
return result | return result |
@@ -29,7 +29,7 @@ class OpenAISession(Session): | |||||
def discard_exceeding(self, max_tokens, cur_tokens= None): | def discard_exceeding(self, max_tokens, cur_tokens= None): | ||||
precise = True | precise = True | ||||
try: | try: | ||||
cur_tokens = num_tokens_from_string(str(self), self.model) | |||||
cur_tokens = self.calc_tokens() | |||||
except Exception as e: | except Exception as e: | ||||
precise = False | precise = False | ||||
if cur_tokens is None: | if cur_tokens is None: | ||||
@@ -41,7 +41,7 @@ class OpenAISession(Session): | |||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": | elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": | ||||
self.messages.pop(0) | self.messages.pop(0) | ||||
if precise: | if precise: | ||||
cur_tokens = num_tokens_from_string(str(self), self.model) | |||||
cur_tokens = self.calc_tokens() | |||||
else: | else: | ||||
cur_tokens = len(str(self)) | cur_tokens = len(str(self)) | ||||
break | break | ||||
@@ -52,11 +52,13 @@ class OpenAISession(Session): | |||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) | logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) | ||||
break | break | ||||
if precise: | if precise: | ||||
cur_tokens = num_tokens_from_string(str(self), self.model) | |||||
cur_tokens = self.calc_tokens() | |||||
else: | else: | ||||
cur_tokens = len(str(self)) | cur_tokens = len(str(self)) | ||||
return cur_tokens | return cur_tokens | ||||
def calc_tokens(self): | |||||
return num_tokens_from_string(str(self), self.model) | |||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
def num_tokens_from_string(string: str, model: str) -> int: | def num_tokens_from_string(string: str, model: str) -> int: | ||||
@@ -31,6 +31,8 @@ class Session(object): | |||||
def discard_exceeding(self, max_tokens=None, cur_tokens=None): | def discard_exceeding(self, max_tokens=None, cur_tokens=None): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def calc_tokens(self): | |||||
raise NotImplementedError | |||||
class SessionManager(object): | class SessionManager(object): | ||||