|
|
@@ -6,6 +6,7 @@ from common.log import logger |
|
|
|
from common.expired_dict import ExpiredDict |
|
|
|
import openai |
|
|
|
import time |
|
|
|
import json |
|
|
|
|
|
|
|
if conf().get('expires_in_seconds'): |
|
|
|
user_session = ExpiredDict(conf().get('expires_in_seconds')) |
|
|
@@ -28,6 +29,9 @@ class ChatGPTBot(Bot): |
|
|
|
if query == '#清除记忆': |
|
|
|
Session.clear_session(from_user_id) |
|
|
|
return '记忆已清除' |
|
|
|
elif query == '#清除所有': |
|
|
|
Session.clear_all_session() |
|
|
|
return '所有人记忆已清除' |
|
|
|
|
|
|
|
new_query = Session.build_session_query(query, from_user_id) |
|
|
|
logger.debug("[OPEN_AI] session query={}".format(new_query)) |
|
|
@@ -134,13 +138,40 @@ class Session(object): |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def save_session(query, answer, user_id): |
|
|
|
max_tokens = conf().get("conversation_max_tokens") |
|
|
|
if not max_tokens: |
|
|
|
# default 3000 |
|
|
|
max_tokens = 1000 |
|
|
|
|
|
|
|
session = user_session.get(user_id) |
|
|
|
if session: |
|
|
|
# append conversation |
|
|
|
gpt_item = {'role': 'assistant', 'content': answer} |
|
|
|
session.append(gpt_item) |
|
|
|
|
|
|
|
# discard exceed limit conversation |
|
|
|
Session.discard_exceed_conversation(user_session[user_id], max_tokens) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def discard_exceed_conversation(session, max_tokens): |
|
|
|
count = 0 |
|
|
|
count_list = list() |
|
|
|
for i in range(len(session)-1, -1, -1): |
|
|
|
# count tokens of conversation list |
|
|
|
history_conv = session[i] |
|
|
|
tokens=json.dumps(history_conv).split() |
|
|
|
count += len(tokens) |
|
|
|
count_list.append(count) |
|
|
|
|
|
|
|
for c in count_list: |
|
|
|
if c > max_tokens: |
|
|
|
# pop first conversation |
|
|
|
session.pop(0) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def clear_session(user_id): |
|
|
|
user_session[user_id] = [] |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def clear_all_session(): |
|
|
|
user_session.clear() |