diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 279ca80..cbfc736 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -26,6 +26,9 @@ class ChatGPTBot(Bot): if query == '#清除记忆': Session.clear_session(from_user_id) return '记忆已清除' + elif query == '#清除所有': + Session.clear_all_session() + return '所有人记忆已清除' new_query = Session.build_session_query(query, from_user_id) logger.debug("[OPEN_AI] session query={}".format(new_query)) @@ -132,13 +135,40 @@ class Session(object): @staticmethod def save_session(query, answer, user_id): + max_tokens = conf().get("conversation_max_tokens") + if not max_tokens: + # default 3000 + max_tokens = 1000 + session = user_session.get(user_id) if session: # append conversation gpt_item = {'role': 'assistant', 'content': answer} session.append(gpt_item) + # discard exceed limit conversation + Session.discard_exceed_conversation(user_session[user_id], max_tokens) + + @staticmethod + def discard_exceed_conversation(session, max_tokens): + count = 0 + count_list = list() + for i in range(len(session)-1, -1, -1): + # count tokens of conversation list + history_conv = session[i] + tokens=history_conv.split() + count += len(tokens) + count_list.append(count) + + for c in count_list: + if c > max_tokens: + # pop first conversation + session.pop(0) + @staticmethod def clear_session(user_id): user_session[user_id] = [] + @staticmethod + def clear_all_session(): + user_session.clear() \ No newline at end of file