Browse Source

feat: merge chat related sessions

master
lanvent 1 year ago
parent
commit
d62b7d1a99
3 changed files with 21 additions and 40 deletions
  1. +0
    -13
      bot/chatgpt/chat_gpt_session.py
  2. +14
    -24
      bot/openai/open_ai_session.py
  3. +7
    -3
      bot/session_manager.py

+ 0
- 13
bot/chatgpt/chat_gpt_session.py View File

@@ -11,22 +11,9 @@ from common.log import logger
class ChatGPTSession(Session): class ChatGPTSession(Session):
def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"): def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"):
super().__init__(session_id, system_prompt) super().__init__(session_id, system_prompt)
self.messages = []
self.model = model self.model = model
self.reset() self.reset()
def reset(self):
system_item = {'role': 'system', 'content': self.system_prompt}
self.messages = [system_item]

def add_query(self, query):
user_item = {'role': 'user', 'content': query}
self.messages.append(user_item)

def add_reply(self, reply):
assistant_item = {'role': 'assistant', 'content': reply}
self.messages.append(assistant_item)
def discard_exceeding(self, max_tokens, cur_tokens= None): def discard_exceeding(self, max_tokens, cur_tokens= None):
precise = True precise = True
try: try:


+ 14
- 24
bot/openai/open_ai_session.py View File

@@ -3,36 +3,26 @@ from common.log import logger
class OpenAISession(Session): class OpenAISession(Session):
def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"): def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"):
super().__init__(session_id, system_prompt) super().__init__(session_id, system_prompt)
self.conversation = []
self.model = model self.model = model
self.reset() self.reset()
def reset(self):
pass

def add_query(self, query):
question = {'type': 'question', 'content': query}
self.conversation.append(question)


def add_reply(self, reply):
answer = {'type': 'answer', 'content': reply}
self.conversation.append(answer)
def __str__(self): def __str__(self):
# 构造对话模型的输入
''' '''
e.g. Q: xxx e.g. Q: xxx
A: xxx A: xxx
Q: xxx Q: xxx
''' '''
prompt = self.system_prompt
if prompt:
prompt += "<|endoftext|>\n\n\n"
for item in self.conversation:
if item['type'] == 'question':
prompt = ""
for item in self.messages:
if item['role'] == 'system':
prompt += item['content'] + "<|endoftext|>\n\n\n"
elif item['role'] == 'user':
prompt += "Q: " + item['content'] + "\n" prompt += "Q: " + item['content'] + "\n"
elif item['type'] == 'answer':
elif item['role'] == 'assistant':
prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n" prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n"


if len(self.conversation) > 0 and self.conversation[-1]['type'] == 'question':
if len(self.messages) > 0 and self.messages[-1]['role'] == 'user':
prompt += "A: " prompt += "A: "
return prompt return prompt


@@ -46,20 +36,20 @@ class OpenAISession(Session):
raise e raise e
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens: while cur_tokens > max_tokens:
if len(self.conversation) > 1:
self.conversation.pop(0)
elif len(self.conversation) == 1 and self.conversation[0]["type"] == "answer":
self.conversation.pop(0)
if len(self.messages) > 1:
self.messages.pop(0)
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
self.messages.pop(0)
if precise: if precise:
cur_tokens = num_tokens_from_string(str(self), self.model) cur_tokens = num_tokens_from_string(str(self), self.model)
else: else:
cur_tokens = len(str(self)) cur_tokens = len(str(self))
break break
elif len(self.conversation) == 1 and self.conversation[0]["type"] == "question":
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
break break
else: else:
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.conversation)))
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
break break
if precise: if precise:
cur_tokens = num_tokens_from_string(str(self), self.model) cur_tokens = num_tokens_from_string(str(self), self.model)


+ 7
- 3
bot/session_manager.py View File

@@ -5,6 +5,7 @@ from config import conf
class Session(object): class Session(object):
def __init__(self, session_id, system_prompt=None): def __init__(self, session_id, system_prompt=None):
self.session_id = session_id self.session_id = session_id
self.messages = []
if system_prompt is None: if system_prompt is None:
self.system_prompt = conf().get("character_desc", "") self.system_prompt = conf().get("character_desc", "")
else: else:
@@ -12,17 +13,20 @@ class Session(object):


# 重置会话 # 重置会话
def reset(self): def reset(self):
raise NotImplementedError
system_item = {'role': 'system', 'content': self.system_prompt}
self.messages = [system_item]


def set_system_prompt(self, system_prompt): def set_system_prompt(self, system_prompt):
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.reset() self.reset()


def add_query(self, query): def add_query(self, query):
raise NotImplementedError
user_item = {'role': 'user', 'content': query}
self.messages.append(user_item)


def add_reply(self, reply): def add_reply(self, reply):
raise NotImplementedError
assistant_item = {'role': 'assistant', 'content': reply}
self.messages.append(assistant_item)
def discard_exceeding(self, max_tokens=None, cur_tokens=None): def discard_exceeding(self, max_tokens=None, cur_tokens=None):
raise NotImplementedError raise NotImplementedError


Loading…
Cancel
Save