Browse Source

feat: image input and session optimize

master
zhayujie 1 year ago
parent
commit
4e675b84fb
6 changed files with 134 additions and 16 deletions
  1. +117
    -7
      bot/linkai/link_ai_bot.py
  2. +2
    -2
      bot/session_manager.py
  3. +6
    -5
      channel/chat_channel.py
  4. +3
    -0
      common/memory.py
  5. +6
    -1
      common/utils.py
  6. +0
    -1
      plugins/linkai/summary.py

+ 117
- 7
bot/linkai/link_ai_bot.py View File

@@ -13,6 +13,9 @@ from bridge.reply import Reply, ReplyType
from common.log import logger from common.log import logger
from config import conf, pconf from config import conf, pconf
import threading import threading
from common import memory, utils
import base64



class LinkAIBot(Bot): class LinkAIBot(Bot):
# authentication failed # authentication failed
@@ -21,7 +24,7 @@ class LinkAIBot(Bot):


def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = {} self.args = {}


def reply(self, query, context: Context = None) -> Reply: def reply(self, query, context: Context = None) -> Reply:
@@ -61,17 +64,25 @@ class LinkAIBot(Bot):
linkai_api_key = conf().get("linkai_api_key") linkai_api_key = conf().get("linkai_api_key")


session_id = context["session_id"] session_id = context["session_id"]
session_message = self.sessions.session_msg_query(query, session_id)
logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")

# image process
img_cache = memory.USER_IMAGE_CACHE.get(session_id)
if img_cache:
messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
if messages:
session_message = messages


session = self.sessions.session_query(query, session_id)
model = conf().get("model") model = conf().get("model")
# remove system message # remove system message
if session.messages[0].get("role") == "system":
if session_message[0].get("role") == "system":
if app_code or model == "wenxin": if app_code or model == "wenxin":
session.messages.pop(0)
session_message.pop(0)


body = { body = {
"app_code": app_code, "app_code": app_code,
"messages": session.messages,
"messages": session_message,
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei "model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
"temperature": conf().get("temperature"), "temperature": conf().get("temperature"),
"top_p": conf().get("top_p", 1), "top_p": conf().get("top_p", 1),
@@ -94,7 +105,7 @@ class LinkAIBot(Bot):
reply_content = response["choices"][0]["message"]["content"] reply_content = response["choices"][0]["message"]["content"]
total_tokens = response["usage"]["total_tokens"] total_tokens = response["usage"]["total_tokens"]
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}") logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
self.sessions.session_reply(reply_content, session_id, total_tokens)
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
agent_suffix = self._fetch_agent_suffix(response) agent_suffix = self._fetch_agent_suffix(response)
if agent_suffix: if agent_suffix:
@@ -130,6 +141,54 @@ class LinkAIBot(Bot):
logger.warn(f"[LINKAI] do retry, times={retry_count}") logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1) return self._chat(query, context, retry_count + 1)


def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
try:
enable_image_input = False
app_info = self._fetch_app_info(app_code)
if not app_info:
logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
return None
plugins = app_info.get("data").get("plugins")
for plugin in plugins:
if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
enable_image_input = True
if not enable_image_input:
return
msg = img_cache.get("msg")
path = img_cache.get("path")
msg.prepare()
logger.info(f"[LinkAI] query with images, path={path}")
messages = self._build_vision_msg(query, path)
memory.USER_IMAGE_CACHE[session_id] = None
return messages
except Exception as e:
logger.exception(e)


def _build_vision_msg(self, query: str, path: str):
try:
suffix = utils.get_path_suffix(path)
with open(path, "rb") as file:
base64_str = base64.b64encode(file.read()).decode('utf-8')
messages = [{
"role": "user",
"content": [
{
"type": "text",
"text": query
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/{suffix};base64,{base64_str}"
}
}
]
}]
return messages
except Exception as e:
logger.exception(e)

def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict: def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
if retry_count >= 2: if retry_count >= 2:
# exit from retry 2 times # exit from retry 2 times
@@ -195,6 +254,16 @@ class LinkAIBot(Bot):
logger.warn(f"[LINKAI] do retry, times={retry_count}") logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self.reply_text(session, app_code, retry_count + 1) return self.reply_text(session, app_code, retry_count + 1)


def _fetch_app_info(self, app_code: str):
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
# do http request
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
params = {"app_code": app_code}
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
if res.status_code == 200:
return res.json()
else:
logger.warning(f"[LinkAI] find app info exception, res={res}")


def create_img(self, query, retry_count=0, api_key=None): def create_img(self, query, retry_count=0, api_key=None):
try: try:
@@ -239,6 +308,7 @@ class LinkAIBot(Bot):
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)



def _fetch_agent_suffix(self, response): def _fetch_agent_suffix(self, response):
try: try:
plugin_list = [] plugin_list = []
@@ -275,4 +345,44 @@ class LinkAIBot(Bot):
reply = Reply(ReplyType.IMAGE_URL, url) reply = Reply(ReplyType.IMAGE_URL, url)
channel.send(reply, context) channel.send(reply, context)
except Exception as e: except Exception as e:
logger.error(e)
logger.error(e)


class LinkAISessionManager(SessionManager):
def session_msg_query(self, query, session_id):
session = self.build_session(session_id)
messages = session.messages + [{"role": "user", "content": query}]
return messages

def session_reply(self, reply, session_id, total_tokens=None, query=None):
session = self.build_session(session_id)
if query:
session.add_query(query)
session.add_reply(reply)
try:
max_tokens = conf().get("conversation_max_tokens", 2500)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.info(f"[LinkAI] chat history discard, before tokens={total_tokens}, now tokens={tokens_cnt}")
except Exception as e:
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
return session


class LinkAISession(ChatGPTSession):
def calc_tokens(self):
try:
cur_tokens = super().calc_tokens()
except Exception as e:
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
cur_tokens = len(str(self.messages))
return cur_tokens

def discard_exceeding(self, max_tokens, cur_tokens=None):
cur_tokens = self.calc_tokens()
if cur_tokens > max_tokens:
for i in range(0, len(self.messages)):
if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
self.messages.pop(i)
self.messages.pop(i - 1)
return self.calc_tokens()
return cur_tokens

+ 2
- 2
bot/session_manager.py View File

@@ -69,7 +69,7 @@ class SessionManager(object):
total_tokens = session.discard_exceeding(max_tokens, None) total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens)) logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e: except Exception as e:
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
return session return session


def session_reply(self, reply, session_id, total_tokens=None): def session_reply(self, reply, session_id, total_tokens=None):
@@ -80,7 +80,7 @@ class SessionManager(object):
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
except Exception as e: except Exception as e:
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
return session return session


def clear_session(self, session_id): def clear_session(self, session_id):


+ 6
- 5
channel/chat_channel.py View File

@@ -9,8 +9,7 @@ from bridge.context import *
from bridge.reply import * from bridge.reply import *
from channel.channel import Channel from channel.channel import Channel
from common.dequeue import Dequeue from common.dequeue import Dequeue
from common.log import logger
from config import conf
from common import memory
from plugins import * from plugins import *


try: try:
@@ -205,14 +204,16 @@ class ChatChannel(Channel):
else: else:
return return
elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑 elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
cmsg = context["msg"]
cmsg.prepare()
memory.USER_IMAGE_CACHE[context["session_id"]] = {
"path": context.content,
"msg": context.get("msg")
}
elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑 elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
pass pass
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑 elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
pass pass
else: else:
logger.error("[WX] unknown context type: {}".format(context.type))
logger.warning("[WX] unknown context type: {}".format(context.type))
return return
return reply return reply




+ 3
- 0
common/memory.py View File

@@ -0,0 +1,3 @@
from common.expired_dict import ExpiredDict

USER_IMAGE_CACHE = ExpiredDict(60 * 3)

+ 6
- 1
common/utils.py View File

@@ -1,6 +1,6 @@
import io import io
import os import os
from urllib.parse import urlparse
from PIL import Image from PIL import Image




@@ -49,3 +49,8 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
result.append(encoded[start:end].decode("utf-8")) result.append(encoded[start:end].decode("utf-8"))
start = end start = end
return result return result


def get_path_suffix(path):
path = urlparse(path).path
return os.path.splitext(path)[-1].lstrip('.')

+ 0
- 1
plugins/linkai/summary.py View File

@@ -91,5 +91,4 @@ class LinkSummary:
for support_url in support_list: for support_url in support_list:
if url.strip().startswith(support_url): if url.strip().startswith(support_url):
return True return True
logger.debug(f"[LinkSum] unsupported url, no need to process, url={url}")
return False return False

Loading…
Cancel
Save