Bläddra i källkod

Merge pull request #1553 from zhayujie/feat-11-27

feat: add image chat and fix session discard
master
zhayujie GitHub 1 år sedan
förälder
incheckning
293c659053
Ingen känd nyckel hittad för denna signaturen i databasen GPG-nyckel ID: 4AEE18F83AFDEB23
9 ändrade filer med 137 tillägg och 21 borttagningar
  1. +117
    -7
      bot/linkai/link_ai_bot.py
  2. +2
    -2
      bot/session_manager.py
  3. +6
    -5
      channel/chat_channel.py
  4. +3
    -0
      common/memory.py
  5. +6
    -1
      common/utils.py
  6. +1
    -1
      plugins/linkai/README.md
  7. +1
    -1
      plugins/linkai/config.json.template
  8. +1
    -3
      plugins/linkai/linkai.py
  9. +0
    -1
      plugins/linkai/summary.py

+ 117
- 7
bot/linkai/link_ai_bot.py Visa fil

@@ -13,6 +13,9 @@ from bridge.reply import Reply, ReplyType
from common.log import logger from common.log import logger
from config import conf, pconf from config import conf, pconf
import threading import threading
from common import memory, utils
import base64



class LinkAIBot(Bot): class LinkAIBot(Bot):
# authentication failed # authentication failed
@@ -21,7 +24,7 @@ class LinkAIBot(Bot):


def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = {} self.args = {}


def reply(self, query, context: Context = None) -> Reply: def reply(self, query, context: Context = None) -> Reply:
@@ -61,17 +64,25 @@ class LinkAIBot(Bot):
linkai_api_key = conf().get("linkai_api_key") linkai_api_key = conf().get("linkai_api_key")


session_id = context["session_id"] session_id = context["session_id"]
session_message = self.sessions.session_msg_query(query, session_id)
logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")

# image process
img_cache = memory.USER_IMAGE_CACHE.get(session_id)
if img_cache:
messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
if messages:
session_message = messages


session = self.sessions.session_query(query, session_id)
model = conf().get("model") model = conf().get("model")
# remove system message # remove system message
if session.messages[0].get("role") == "system":
if session_message[0].get("role") == "system":
if app_code or model == "wenxin": if app_code or model == "wenxin":
session.messages.pop(0)
session_message.pop(0)


body = { body = {
"app_code": app_code, "app_code": app_code,
"messages": session.messages,
"messages": session_message,
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei "model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
"temperature": conf().get("temperature"), "temperature": conf().get("temperature"),
"top_p": conf().get("top_p", 1), "top_p": conf().get("top_p", 1),
@@ -94,7 +105,7 @@ class LinkAIBot(Bot):
reply_content = response["choices"][0]["message"]["content"] reply_content = response["choices"][0]["message"]["content"]
total_tokens = response["usage"]["total_tokens"] total_tokens = response["usage"]["total_tokens"]
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}") logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
self.sessions.session_reply(reply_content, session_id, total_tokens)
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
agent_suffix = self._fetch_agent_suffix(response) agent_suffix = self._fetch_agent_suffix(response)
if agent_suffix: if agent_suffix:
@@ -130,6 +141,54 @@ class LinkAIBot(Bot):
logger.warn(f"[LINKAI] do retry, times={retry_count}") logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1) return self._chat(query, context, retry_count + 1)


def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
try:
enable_image_input = False
app_info = self._fetch_app_info(app_code)
if not app_info:
logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
return None
plugins = app_info.get("data").get("plugins")
for plugin in plugins:
if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
enable_image_input = True
if not enable_image_input:
return
msg = img_cache.get("msg")
path = img_cache.get("path")
msg.prepare()
logger.info(f"[LinkAI] query with images, path={path}")
messages = self._build_vision_msg(query, path)
memory.USER_IMAGE_CACHE[session_id] = None
return messages
except Exception as e:
logger.exception(e)


def _build_vision_msg(self, query: str, path: str):
try:
suffix = utils.get_path_suffix(path)
with open(path, "rb") as file:
base64_str = base64.b64encode(file.read()).decode('utf-8')
messages = [{
"role": "user",
"content": [
{
"type": "text",
"text": query
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/{suffix};base64,{base64_str}"
}
}
]
}]
return messages
except Exception as e:
logger.exception(e)

def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict: def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
if retry_count >= 2: if retry_count >= 2:
# exit from retry 2 times # exit from retry 2 times
@@ -195,6 +254,16 @@ class LinkAIBot(Bot):
logger.warn(f"[LINKAI] do retry, times={retry_count}") logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self.reply_text(session, app_code, retry_count + 1) return self.reply_text(session, app_code, retry_count + 1)


def _fetch_app_info(self, app_code: str):
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
# do http request
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
params = {"app_code": app_code}
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
if res.status_code == 200:
return res.json()
else:
logger.warning(f"[LinkAI] find app info exception, res={res}")


def create_img(self, query, retry_count=0, api_key=None): def create_img(self, query, retry_count=0, api_key=None):
try: try:
@@ -239,6 +308,7 @@ class LinkAIBot(Bot):
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)



def _fetch_agent_suffix(self, response): def _fetch_agent_suffix(self, response):
try: try:
plugin_list = [] plugin_list = []
@@ -275,4 +345,44 @@ class LinkAIBot(Bot):
reply = Reply(ReplyType.IMAGE_URL, url) reply = Reply(ReplyType.IMAGE_URL, url)
channel.send(reply, context) channel.send(reply, context)
except Exception as e: except Exception as e:
logger.error(e)
logger.error(e)


class LinkAISessionManager(SessionManager):
def session_msg_query(self, query, session_id):
session = self.build_session(session_id)
messages = session.messages + [{"role": "user", "content": query}]
return messages

def session_reply(self, reply, session_id, total_tokens=None, query=None):
session = self.build_session(session_id)
if query:
session.add_query(query)
session.add_reply(reply)
try:
max_tokens = conf().get("conversation_max_tokens", 2500)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.info(f"[LinkAI] chat history discard, before tokens={total_tokens}, now tokens={tokens_cnt}")
except Exception as e:
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
return session


class LinkAISession(ChatGPTSession):
def calc_tokens(self):
try:
cur_tokens = super().calc_tokens()
except Exception as e:
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
cur_tokens = len(str(self.messages))
return cur_tokens

def discard_exceeding(self, max_tokens, cur_tokens=None):
cur_tokens = self.calc_tokens()
if cur_tokens > max_tokens:
for i in range(0, len(self.messages)):
if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
self.messages.pop(i)
self.messages.pop(i - 1)
return self.calc_tokens()
return cur_tokens

+ 2
- 2
bot/session_manager.py Visa fil

@@ -69,7 +69,7 @@ class SessionManager(object):
total_tokens = session.discard_exceeding(max_tokens, None) total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens)) logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e: except Exception as e:
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
return session return session


def session_reply(self, reply, session_id, total_tokens=None): def session_reply(self, reply, session_id, total_tokens=None):
@@ -80,7 +80,7 @@ class SessionManager(object):
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
except Exception as e: except Exception as e:
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
return session return session


def clear_session(self, session_id): def clear_session(self, session_id):


+ 6
- 5
channel/chat_channel.py Visa fil

@@ -9,8 +9,7 @@ from bridge.context import *
from bridge.reply import * from bridge.reply import *
from channel.channel import Channel from channel.channel import Channel
from common.dequeue import Dequeue from common.dequeue import Dequeue
from common.log import logger
from config import conf
from common import memory
from plugins import * from plugins import *


try: try:
@@ -205,14 +204,16 @@ class ChatChannel(Channel):
else: else:
return return
elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑 elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
cmsg = context["msg"]
cmsg.prepare()
memory.USER_IMAGE_CACHE[context["session_id"]] = {
"path": context.content,
"msg": context.get("msg")
}
elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑 elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
pass pass
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑 elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
pass pass
else: else:
logger.error("[WX] unknown context type: {}".format(context.type))
logger.warning("[WX] unknown context type: {}".format(context.type))
return return
return reply return reply




+ 3
- 0
common/memory.py Visa fil

@@ -0,0 +1,3 @@
from common.expired_dict import ExpiredDict

USER_IMAGE_CACHE = ExpiredDict(60 * 3)

+ 6
- 1
common/utils.py Visa fil

@@ -1,6 +1,6 @@
import io import io
import os import os
from urllib.parse import urlparse
from PIL import Image from PIL import Image




@@ -49,3 +49,8 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
result.append(encoded[start:end].decode("utf-8")) result.append(encoded[start:end].decode("utf-8"))
start = end start = end
return result return result


def get_path_suffix(path):
path = urlparse(path).path
return os.path.splitext(path)[-1].lstrip('.')

+ 1
- 1
plugins/linkai/README.md Visa fil

@@ -26,7 +26,7 @@
"enabled": true, # 文档总结和对话功能开关 "enabled": true, # 文档总结和对话功能开关
"group_enabled": true, # 是否支持群聊开启 "group_enabled": true, # 是否支持群聊开启
"max_file_size": 5000, # 文件的大小限制,单位KB,默认为5M,超过该大小直接忽略 "max_file_size": 5000, # 文件的大小限制,单位KB,默认为5M,超过该大小直接忽略
"type": ["FILE", "SHARING", "IMAGE"] # 支持总结的类型,分别表示 文件、分享链接、图片
"type": ["FILE", "SHARING", "IMAGE"] # 支持总结的类型,分别表示 文件、分享链接、图片,其中文件和链接默认打开,图片默认关闭
} }
} }
``` ```


+ 1
- 1
plugins/linkai/config.json.template Visa fil

@@ -15,6 +15,6 @@
"enabled": true, "enabled": true,
"group_enabled": true, "group_enabled": true,
"max_file_size": 5000, "max_file_size": 5000,
"type": ["FILE", "SHARING", "IMAGE"]
"type": ["FILE", "SHARING"]
} }
} }

+ 1
- 3
plugins/linkai/linkai.py Visa fil

@@ -192,9 +192,7 @@ class LinkAI(Plugin):
return False return False
if context.kwargs.get("isgroup") and not self.sum_config.get("group_enabled"): if context.kwargs.get("isgroup") and not self.sum_config.get("group_enabled"):
return False return False
support_type = self.sum_config.get("type")
if not support_type:
return True
support_type = self.sum_config.get("type") or ["FILE", "SHARING"]
if context.type.name not in support_type: if context.type.name not in support_type:
return False return False
return True return True


+ 0
- 1
plugins/linkai/summary.py Visa fil

@@ -91,5 +91,4 @@ class LinkSummary:
for support_url in support_list: for support_url in support_list:
if url.strip().startswith(support_url): if url.strip().startswith(support_url):
return True return True
logger.debug(f"[LinkSum] unsupported url, no need to process, url={url}")
return False return False

Laddar…
Avbryt
Spara