diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 405499a..7b9012e 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -6,6 +6,7 @@ from common.log import logger from common.expired_dict import ExpiredDict import openai import time +import json if conf().get('expires_in_seconds'): user_session = ExpiredDict(conf().get('expires_in_seconds')) @@ -28,6 +29,9 @@ class ChatGPTBot(Bot): if query == '#清除记忆': Session.clear_session(from_user_id) return '记忆已清除' + elif query == '#清除所有': + Session.clear_all_session() + return '所有人记忆已清除' new_query = Session.build_session_query(query, from_user_id) logger.debug("[OPEN_AI] session query={}".format(new_query)) @@ -134,13 +138,40 @@ class Session(object): @staticmethod def save_session(query, answer, user_id): + max_tokens = conf().get("conversation_max_tokens") + if not max_tokens: + # default 3000 + max_tokens = 1000 + session = user_session.get(user_id) if session: # append conversation gpt_item = {'role': 'assistant', 'content': answer} session.append(gpt_item) + # discard exceed limit conversation + Session.discard_exceed_conversation(user_session[user_id], max_tokens) + + @staticmethod + def discard_exceed_conversation(session, max_tokens): + count = 0 + count_list = list() + for i in range(len(session)-1, -1, -1): + # count tokens of conversation list + history_conv = session[i] + tokens=json.dumps(history_conv).split() + count += len(tokens) + count_list.append(count) + + for c in count_list: + if c > max_tokens: + # pop first conversation + session.pop(0) + @staticmethod def clear_session(user_id): user_session[user_id] = [] + @staticmethod + def clear_all_session(): + user_session.clear() \ No newline at end of file