From efbc9de9d16e53c3852a3d66c673787cf0dbb6ae Mon Sep 17 00:00:00 2001 From: zwssunny Date: Sat, 4 Mar 2023 23:44:57 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E8=B6=85=E9=95=BF=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot/chatgpt/chat_gpt_bot.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 279ca80..cbfc736 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -26,6 +26,9 @@ class ChatGPTBot(Bot): if query == '#清除记忆': Session.clear_session(from_user_id) return '记忆已清除' + elif query == '#清除所有': + Session.clear_all_session() + return '所有人记忆已清除' new_query = Session.build_session_query(query, from_user_id) logger.debug("[OPEN_AI] session query={}".format(new_query)) @@ -132,13 +135,40 @@ class Session(object): @staticmethod def save_session(query, answer, user_id): + max_tokens = conf().get("conversation_max_tokens") + if not max_tokens: + # default 3000 + max_tokens = 1000 + session = user_session.get(user_id) if session: # append conversation gpt_item = {'role': 'assistant', 'content': answer} session.append(gpt_item) + # discard exceed limit conversation + Session.discard_exceed_conversation(user_session[user_id], max_tokens) + + @staticmethod + def discard_exceed_conversation(session, max_tokens): + count = 0 + count_list = list() + for i in range(len(session)-1, -1, -1): + # count tokens of conversation list + history_conv = session[i] + tokens=history_conv.split() + count += len(tokens) + count_list.append(count) + + for c in count_list: + if c > max_tokens: + # pop first conversation + session.pop(0) + @staticmethod def clear_session(user_id): user_session[user_id] = [] + @staticmethod + def clear_all_session(): + user_session.clear() \ No newline at end of file From 3d4d1c734a88f808ac15d95a29520b1fa99e19f4 Mon Sep 17 00:00:00 2001 From: zwssunny Date: Sun, 5 Mar 2023 09:43:59 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E8=B6=85=E9=95=BF=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot/chatgpt/chat_gpt_bot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index cbfc736..00588d0 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -6,6 +6,7 @@ from common.log import logger from common.expired_dict import ExpiredDict import openai import time +import json if conf().get('expires_in_seconds'): user_session = ExpiredDict(conf().get('expires_in_seconds')) @@ -139,7 +140,7 @@ class Session(object): if not max_tokens: # default 3000 max_tokens = 1000 - + session = user_session.get(user_id) if session: # append conversation @@ -156,7 +157,7 @@ class Session(object): for i in range(len(session)-1, -1, -1): # count tokens of conversation list history_conv = session[i] - tokens=history_conv.split() + tokens=json.dumps(history_conv).split() count += len(tokens) count_list.append(count)