瀏覽代碼

feat: support character description of ali qwen model

master
Han Fangyuan 11 月之前
父節點
當前提交
bfacdb9c3b
共有 2 個文件被更改,包括 74 次插入3 次删除
  1. +62
    -0
      bot/tongyi/ali_qwen_session.py
  2. +12
    -3
      bot/tongyi/tongyi_qwen_bot.py

+ 62
- 0
bot/tongyi/ali_qwen_session.py 查看文件

@@ -0,0 +1,62 @@
from bot.session_manager import Session
from common.log import logger

"""
e.g.
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"}
]
"""

class AliQwenSession(Session):
def __init__(self, session_id, system_prompt=None, model="qianwen"):
super().__init__(session_id, system_prompt)
self.model = model
self.reset()

def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True
try:
cur_tokens = self.calc_tokens()
except Exception as e:
precise = False
if cur_tokens is None:
raise e
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 2:
self.messages.pop(1)
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
self.messages.pop(1)
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = cur_tokens - max_tokens
break
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = cur_tokens - max_tokens
return cur_tokens

def calc_tokens(self):
return num_tokens_from_messages(self.messages, self.model)

def num_tokens_from_messages(messages, model):
"""Returns the number of tokens used by a list of messages."""
# 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词"
# 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html
# 目前根据字符串长度粗略估计token数,不影响正常使用
tokens = 0
for msg in messages:
tokens += len(msg["content"])
return tokens

+ 12
- 3
bot/tongyi/tongyi_qwen_bot.py 查看文件

@@ -10,7 +10,7 @@ import broadscope_bailian
from broadscope_bailian import ChatQaMessage

from bot.bot import Bot
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
from bot.tongyi.ali_qwen_session import AliQwenSession
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
@@ -27,7 +27,7 @@ class TongyiQwenBot(Bot):
self.node_id = conf().get("qwen_node_id") or ""
self.api_key_client = broadscope_bailian.AccessTokenClient(access_key_id=self.access_key_id, access_key_secret=self.access_key_secret)
self.api_key_expired_time = self.set_api_key()
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "qwen")
self.sessions = SessionManager(AliQwenSession, model=conf().get("model") or "qwen")
self.temperature = conf().get("temperature", 0.2) # 值在[0,1]之间,越大表示回复越具有不确定性
self.top_p = conf().get("top_p", 1)

@@ -76,7 +76,7 @@ class TongyiQwenBot(Bot):
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply

def reply_text(self, session: BaiduWenxinSession, retry_count=0) -> dict:
def reply_text(self, session: AliQwenSession, retry_count=0) -> dict:
"""
call bailian's ChatCompletion to get the answer
:param session: a conversation session
@@ -140,6 +140,7 @@ class TongyiQwenBot(Bot):
history = []
user_content = ''
assistant_content = ''
system_content = ''
for message in messages:
role = message.get('role')
if role == 'user':
@@ -149,8 +150,16 @@ class TongyiQwenBot(Bot):
history.append(ChatQaMessage(user_content, assistant_content))
user_content = ''
assistant_content = ''
elif role =='system':
system_content += message.get('content')
if user_content == '':
raise Exception('no user message')
if system_content != '':
# NOTE 模拟系统消息,测试发现人格描述以"你需要扮演ChatGPT"开头能够起作用,而以"你是ChatGPT"开头模型会直接否认
system_qa = ChatQaMessage(system_content, '好的,我会严格按照你的设定回答问题')
history.insert(0, system_qa)
logger.debug("[TONGYI] converted qa messages: {}".format([item.to_dict() for item in history]))
logger.debug("[TONGYI] user content as prompt: {}".format(user_content))
return user_content, history

def get_completion_content(self, response, node_id):


Loading…
取消
儲存