From d62b7d1a9954f1ea461ad1e14d79cdd88461f2e3 Mon Sep 17 00:00:00 2001 From: lanvent Date: Wed, 29 Mar 2023 12:25:31 +0800 Subject: [PATCH] feat: merge chat related sessions --- bot/chatgpt/chat_gpt_session.py | 13 ----------- bot/openai/open_ai_session.py | 38 ++++++++++++--------------------- bot/session_manager.py | 10 ++++++--- 3 files changed, 21 insertions(+), 40 deletions(-) diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py index faf06b3..afbe6a5 100644 --- a/bot/chatgpt/chat_gpt_session.py +++ b/bot/chatgpt/chat_gpt_session.py @@ -11,22 +11,9 @@ from common.log import logger class ChatGPTSession(Session): def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"): super().__init__(session_id, system_prompt) - self.messages = [] self.model = model self.reset() - def reset(self): - system_item = {'role': 'system', 'content': self.system_prompt} - self.messages = [system_item] - - def add_query(self, query): - user_item = {'role': 'user', 'content': query} - self.messages.append(user_item) - - def add_reply(self, reply): - assistant_item = {'role': 'assistant', 'content': reply} - self.messages.append(assistant_item) - def discard_exceeding(self, max_tokens, cur_tokens= None): precise = True try: diff --git a/bot/openai/open_ai_session.py b/bot/openai/open_ai_session.py index 9eb6b32..597611c 100644 --- a/bot/openai/open_ai_session.py +++ b/bot/openai/open_ai_session.py @@ -3,36 +3,26 @@ from common.log import logger class OpenAISession(Session): def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"): super().__init__(session_id, system_prompt) - self.conversation = [] self.model = model self.reset() - - def reset(self): - pass - - def add_query(self, query): - question = {'type': 'question', 'content': query} - self.conversation.append(question) - def add_reply(self, reply): - answer = {'type': 'answer', 'content': reply} - self.conversation.append(answer) def __str__(self): + # 构造对话模型的输入 ''' e.g. Q: xxx A: xxx Q: xxx ''' - prompt = self.system_prompt - if prompt: - prompt += "<|endoftext|>\n\n\n" - for item in self.conversation: - if item['type'] == 'question': + prompt = "" + for item in self.messages: + if item['role'] == 'system': + prompt += item['content'] + "<|endoftext|>\n\n\n" + elif item['role'] == 'user': prompt += "Q: " + item['content'] + "\n" - elif item['type'] == 'answer': + elif item['role'] == 'assistant': prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n" - if len(self.conversation) > 0 and self.conversation[-1]['type'] == 'question': + if len(self.messages) > 0 and self.messages[-1]['role'] == 'user': prompt += "A: " return prompt @@ -46,20 +36,20 @@ class OpenAISession(Session): raise e logger.debug("Exception when counting tokens precisely for query: {}".format(e)) while cur_tokens > max_tokens: - if len(self.conversation) > 1: - self.conversation.pop(0) - elif len(self.conversation) == 1 and self.conversation[0]["type"] == "answer": - self.conversation.pop(0) + if len(self.messages) > 1: + self.messages.pop(0) + elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": + self.messages.pop(0) if precise: cur_tokens = num_tokens_from_string(str(self), self.model) else: cur_tokens = len(str(self)) break - elif len(self.conversation) == 1 and self.conversation[0]["type"] == "question": + elif len(self.messages) == 1 and self.messages[0]["role"] == "user": logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) break else: - logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.conversation))) + logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) break if precise: cur_tokens = num_tokens_from_string(str(self), self.model) diff --git a/bot/session_manager.py b/bot/session_manager.py index 3bde7f4..1114730 100644 --- a/bot/session_manager.py +++ b/bot/session_manager.py @@ -5,6 +5,7 @@ from config import conf class Session(object): def __init__(self, session_id, system_prompt=None): self.session_id = session_id + self.messages = [] if system_prompt is None: self.system_prompt = conf().get("character_desc", "") else: @@ -12,17 +13,20 @@ class Session(object): # 重置会话 def reset(self): - raise NotImplementedError + system_item = {'role': 'system', 'content': self.system_prompt} + self.messages = [system_item] def set_system_prompt(self, system_prompt): self.system_prompt = system_prompt self.reset() def add_query(self, query): - raise NotImplementedError + user_item = {'role': 'user', 'content': query} + self.messages.append(user_item) def add_reply(self, reply): - raise NotImplementedError + assistant_item = {'role': 'assistant', 'content': reply} + self.messages.append(assistant_item) def discard_exceeding(self, max_tokens=None, cur_tokens=None): raise NotImplementedError