|
@@ -6,7 +6,6 @@ from common.log import logger |
|
|
from common.expired_dict import ExpiredDict |
|
|
from common.expired_dict import ExpiredDict |
|
|
import openai |
|
|
import openai |
|
|
import time |
|
|
import time |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
if conf().get('expires_in_seconds'): |
|
|
if conf().get('expires_in_seconds'): |
|
|
user_session = ExpiredDict(conf().get('expires_in_seconds')) |
|
|
user_session = ExpiredDict(conf().get('expires_in_seconds')) |
|
@@ -41,15 +40,22 @@ class ChatGPTBot(Bot): |
|
|
# return self.reply_text_stream(query, new_query, from_user_id) |
|
|
# return self.reply_text_stream(query, new_query, from_user_id) |
|
|
|
|
|
|
|
|
reply_content = self.reply_text(new_query, from_user_id, 0) |
|
|
reply_content = self.reply_text(new_query, from_user_id, 0) |
|
|
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) |
|
|
|
|
|
if reply_content: |
|
|
|
|
|
Session.save_session(query, reply_content, from_user_id) |
|
|
|
|
|
return reply_content |
|
|
|
|
|
|
|
|
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content["content"])) |
|
|
|
|
|
if reply_content["completion_tokens"] > 0: |
|
|
|
|
|
Session.save_session(reply_content["content"], from_user_id, reply_content["total_tokens"]) |
|
|
|
|
|
return reply_content["content"] |
|
|
|
|
|
|
|
|
elif context.get('type', None) == 'IMAGE_CREATE': |
|
|
elif context.get('type', None) == 'IMAGE_CREATE': |
|
|
return self.create_img(query, 0) |
|
|
return self.create_img(query, 0) |
|
|
|
|
|
|
|
|
def reply_text(self, query, user_id, retry_count=0): |
|
|
|
|
|
|
|
|
def reply_text(self, query, user_id, retry_count=0) ->dict: |
|
|
|
|
|
''' |
|
|
|
|
|
call openai's ChatCompletion to get the answer |
|
|
|
|
|
:param query: query content |
|
|
|
|
|
:param user_id: from user id |
|
|
|
|
|
:param retry_count: retry count |
|
|
|
|
|
:return: {} |
|
|
|
|
|
''' |
|
|
try: |
|
|
try: |
|
|
response = openai.ChatCompletion.create( |
|
|
response = openai.ChatCompletion.create( |
|
|
model="gpt-3.5-turbo", # 对话模型的名称 |
|
|
model="gpt-3.5-turbo", # 对话模型的名称 |
|
@@ -62,8 +68,9 @@ class ChatGPTBot(Bot): |
|
|
) |
|
|
) |
|
|
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '') |
|
|
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '') |
|
|
logger.info(response.choices[0]['message']['content']) |
|
|
logger.info(response.choices[0]['message']['content']) |
|
|
# log.info("[OPEN_AI] reply={}".format(res_content)) |
|
|
|
|
|
return response.choices[0]['message']['content'] |
|
|
|
|
|
|
|
|
return {"total_tokens": response["usage"]["total_tokens"], |
|
|
|
|
|
"completion_tokens": response["usage"]["completion_tokens"], |
|
|
|
|
|
"content": response.choices[0]['message']['content']} |
|
|
except openai.error.RateLimitError as e: |
|
|
except openai.error.RateLimitError as e: |
|
|
# rate limit exception |
|
|
# rate limit exception |
|
|
logger.warn(e) |
|
|
logger.warn(e) |
|
@@ -72,21 +79,21 @@ class ChatGPTBot(Bot): |
|
|
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1)) |
|
|
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1)) |
|
|
return self.reply_text(query, user_id, retry_count+1) |
|
|
return self.reply_text(query, user_id, retry_count+1) |
|
|
else: |
|
|
else: |
|
|
return "提问太快啦,请休息一下再问我吧" |
|
|
|
|
|
|
|
|
return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"} |
|
|
except openai.error.APIConnectionError as e: |
|
|
except openai.error.APIConnectionError as e: |
|
|
# api connection exception |
|
|
# api connection exception |
|
|
logger.warn(e) |
|
|
logger.warn(e) |
|
|
logger.warn("[OPEN_AI] APIConnection failed") |
|
|
logger.warn("[OPEN_AI] APIConnection failed") |
|
|
return "我连接不到你的网络" |
|
|
|
|
|
|
|
|
return {"completion_tokens": 0, "content":"我连接不到你的网络"} |
|
|
except openai.error.Timeout as e: |
|
|
except openai.error.Timeout as e: |
|
|
logger.warn(e) |
|
|
logger.warn(e) |
|
|
logger.warn("[OPEN_AI] Timeout") |
|
|
logger.warn("[OPEN_AI] Timeout") |
|
|
return "我没有收到你的消息" |
|
|
|
|
|
|
|
|
return {"completion_tokens": 0, "content":"我没有收到你的消息"} |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
# unknown exception |
|
|
# unknown exception |
|
|
logger.exception(e) |
|
|
logger.exception(e) |
|
|
Session.clear_session(user_id) |
|
|
Session.clear_session(user_id) |
|
|
return "请再问我一次吧" |
|
|
|
|
|
|
|
|
return {"completion_tokens": 0, "content": "请再问我一次吧"} |
|
|
|
|
|
|
|
|
def create_img(self, query, retry_count=0): |
|
|
def create_img(self, query, retry_count=0): |
|
|
try: |
|
|
try: |
|
@@ -137,11 +144,12 @@ class Session(object): |
|
|
return session |
|
|
return session |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def save_session(query, answer, user_id): |
|
|
|
|
|
|
|
|
def save_session(answer, user_id, total_tokens): |
|
|
max_tokens = conf().get("conversation_max_tokens") |
|
|
max_tokens = conf().get("conversation_max_tokens") |
|
|
if not max_tokens: |
|
|
if not max_tokens: |
|
|
# default 3000 |
|
|
# default 3000 |
|
|
max_tokens = 1000 |
|
|
max_tokens = 1000 |
|
|
|
|
|
max_tokens=int(max_tokens) |
|
|
|
|
|
|
|
|
session = user_session.get(user_id) |
|
|
session = user_session.get(user_id) |
|
|
if session: |
|
|
if session: |
|
@@ -150,23 +158,19 @@ class Session(object): |
|
|
session.append(gpt_item) |
|
|
session.append(gpt_item) |
|
|
|
|
|
|
|
|
# discard exceed limit conversation |
|
|
# discard exceed limit conversation |
|
|
Session.discard_exceed_conversation(user_session[user_id], max_tokens) |
|
|
|
|
|
|
|
|
Session.discard_exceed_conversation(session, max_tokens, total_tokens) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def discard_exceed_conversation(session, max_tokens): |
|
|
|
|
|
count = 0 |
|
|
|
|
|
count_list = list() |
|
|
|
|
|
for i in range(len(session)-1, -1, -1): |
|
|
|
|
|
# count tokens of conversation list |
|
|
|
|
|
history_conv = session[i] |
|
|
|
|
|
tokens=json.dumps(history_conv).split() |
|
|
|
|
|
count += len(tokens) |
|
|
|
|
|
count_list.append(count) |
|
|
|
|
|
|
|
|
|
|
|
for c in count_list: |
|
|
|
|
|
if c > max_tokens: |
|
|
|
|
|
# pop first conversation |
|
|
|
|
|
session.pop(0) |
|
|
|
|
|
|
|
|
def discard_exceed_conversation(session, max_tokens, total_tokens): |
|
|
|
|
|
dec_tokens=int(total_tokens) |
|
|
|
|
|
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) |
|
|
|
|
|
while dec_tokens > max_tokens: |
|
|
|
|
|
# pop first conversation |
|
|
|
|
|
if len(session) > 0: |
|
|
|
|
|
session.pop(0) |
|
|
|
|
|
else: |
|
|
|
|
|
break |
|
|
|
|
|
dec_tokens=dec_tokens-max_tokens |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def clear_session(user_id): |
|
|
def clear_session(user_id): |
|
|