diff --git a/bot/linkai/link_ai_bot.py b/bot/linkai/link_ai_bot.py index 1dc5df2..22f5172 100644 --- a/bot/linkai/link_ai_bot.py +++ b/bot/linkai/link_ai_bot.py @@ -13,6 +13,9 @@ from bridge.reply import Reply, ReplyType from common.log import logger from config import conf, pconf import threading +from common import memory, utils +import base64 + class LinkAIBot(Bot): # authentication failed @@ -21,7 +24,7 @@ class LinkAIBot(Bot): def __init__(self): super().__init__() - self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") + self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo") self.args = {} def reply(self, query, context: Context = None) -> Reply: @@ -61,17 +64,25 @@ class LinkAIBot(Bot): linkai_api_key = conf().get("linkai_api_key") session_id = context["session_id"] + session_message = self.sessions.session_msg_query(query, session_id) + logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}") + + # image process + img_cache = memory.USER_IMAGE_CACHE.get(session_id) + if img_cache: + messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache) + if messages: + session_message = messages - session = self.sessions.session_query(query, session_id) model = conf().get("model") # remove system message - if session.messages[0].get("role") == "system": + if session_message[0].get("role") == "system": if app_code or model == "wenxin": - session.messages.pop(0) + session_message.pop(0) body = { "app_code": app_code, - "messages": session.messages, + "messages": session_message, "model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei "temperature": conf().get("temperature"), "top_p": conf().get("top_p", 1), @@ -94,7 +105,7 @@ class LinkAIBot(Bot): reply_content = response["choices"][0]["message"]["content"] total_tokens = response["usage"]["total_tokens"] logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}") - self.sessions.session_reply(reply_content, session_id, total_tokens) + self.sessions.session_reply(reply_content, session_id, total_tokens, query=query) agent_suffix = self._fetch_agent_suffix(response) if agent_suffix: @@ -130,6 +141,54 @@ class LinkAIBot(Bot): logger.warn(f"[LINKAI] do retry, times={retry_count}") return self._chat(query, context, retry_count + 1) + def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict): + try: + enable_image_input = False + app_info = self._fetch_app_info(app_code) + if not app_info: + logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}") + return None + plugins = app_info.get("data").get("plugins") + for plugin in plugins: + if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"): + enable_image_input = True + if not enable_image_input: + return + msg = img_cache.get("msg") + path = img_cache.get("path") + msg.prepare() + logger.info(f"[LinkAI] query with images, path={path}") + messages = self._build_vision_msg(query, path) + memory.USER_IMAGE_CACHE[session_id] = None + return messages + except Exception as e: + logger.exception(e) + + + def _build_vision_msg(self, query: str, path: str): + try: + suffix = utils.get_path_suffix(path) + with open(path, "rb") as file: + base64_str = base64.b64encode(file.read()).decode('utf-8') + messages = [{ + "role": "user", + "content": [ + { + "type": "text", + "text": query + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/{suffix};base64,{base64_str}" + } + } + ] + }] + return messages + except Exception as e: + logger.exception(e) + def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict: if retry_count >= 2: # exit from retry 2 times @@ -195,6 +254,16 @@ class LinkAIBot(Bot): logger.warn(f"[LINKAI] do retry, times={retry_count}") return self.reply_text(session, app_code, retry_count + 1) + def _fetch_app_info(self, app_code: str): + headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} + # do http request + base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + params = {"app_code": app_code} + res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10)) + if res.status_code == 200: + return res.json() + else: + logger.warning(f"[LinkAI] find app info exception, res={res}") def create_img(self, query, retry_count=0, api_key=None): try: @@ -239,6 +308,7 @@ class LinkAIBot(Bot): except Exception as e: logger.exception(e) + def _fetch_agent_suffix(self, response): try: plugin_list = [] @@ -275,4 +345,44 @@ class LinkAIBot(Bot): reply = Reply(ReplyType.IMAGE_URL, url) channel.send(reply, context) except Exception as e: - logger.error(e) \ No newline at end of file + logger.error(e) + + +class LinkAISessionManager(SessionManager): + def session_msg_query(self, query, session_id): + session = self.build_session(session_id) + messages = session.messages + [{"role": "user", "content": query}] + return messages + + def session_reply(self, reply, session_id, total_tokens=None, query=None): + session = self.build_session(session_id) + if query: + session.add_query(query) + session.add_reply(reply) + try: + max_tokens = conf().get("conversation_max_tokens", 2500) + tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) + logger.info(f"[LinkAI] chat history discard, before tokens={total_tokens}, now tokens={tokens_cnt}") + except Exception as e: + logger.warning("Exception when counting tokens precisely for session: {}".format(str(e))) + return session + + +class LinkAISession(ChatGPTSession): + def calc_tokens(self): + try: + cur_tokens = super().calc_tokens() + except Exception as e: + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + cur_tokens = len(str(self.messages)) + return cur_tokens + + def discard_exceeding(self, max_tokens, cur_tokens=None): + cur_tokens = self.calc_tokens() + if cur_tokens > max_tokens: + for i in range(0, len(self.messages)): + if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user": + self.messages.pop(i) + self.messages.pop(i - 1) + return self.calc_tokens() + return cur_tokens diff --git a/bot/session_manager.py b/bot/session_manager.py index 8d70886..a6e89f9 100644 --- a/bot/session_manager.py +++ b/bot/session_manager.py @@ -69,7 +69,7 @@ class SessionManager(object): total_tokens = session.discard_exceeding(max_tokens, None) logger.debug("prompt tokens used={}".format(total_tokens)) except Exception as e: - logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e))) + logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e))) return session def session_reply(self, reply, session_id, total_tokens=None): @@ -80,7 +80,7 @@ class SessionManager(object): tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) except Exception as e: - logger.debug("Exception when counting tokens precisely for session: {}".format(str(e))) + logger.warning("Exception when counting tokens precisely for session: {}".format(str(e))) return session def clear_session(self, session_id): diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 8ed5f4f..ab574c6 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -9,8 +9,7 @@ from bridge.context import * from bridge.reply import * from channel.channel import Channel from common.dequeue import Dequeue -from common.log import logger -from config import conf +from common import memory from plugins import * try: @@ -205,14 +204,16 @@ class ChatChannel(Channel): else: return elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑 - cmsg = context["msg"] - cmsg.prepare() + memory.USER_IMAGE_CACHE[context["session_id"]] = { + "path": context.content, + "msg": context.get("msg") + } elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑 pass elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑 pass else: - logger.error("[WX] unknown context type: {}".format(context.type)) + logger.warning("[WX] unknown context type: {}".format(context.type)) return return reply diff --git a/common/memory.py b/common/memory.py new file mode 100644 index 0000000..026bed2 --- /dev/null +++ b/common/memory.py @@ -0,0 +1,3 @@ +from common.expired_dict import ExpiredDict + +USER_IMAGE_CACHE = ExpiredDict(60 * 3) \ No newline at end of file diff --git a/common/utils.py b/common/utils.py index 966a7cf..dd69c9d 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,6 +1,6 @@ import io import os - +from urllib.parse import urlparse from PIL import Image @@ -49,3 +49,8 @@ def split_string_by_utf8_length(string, max_length, max_split=0): result.append(encoded[start:end].decode("utf-8")) start = end return result + + +def get_path_suffix(path): + path = urlparse(path).path + return os.path.splitext(path)[-1].lstrip('.') diff --git a/plugins/linkai/summary.py b/plugins/linkai/summary.py index c945896..5711fd9 100644 --- a/plugins/linkai/summary.py +++ b/plugins/linkai/summary.py @@ -91,5 +91,4 @@ class LinkSummary: for support_url in support_list: if url.strip().startswith(support_url): return True - logger.debug(f"[LinkSum] unsupported url, no need to process, url={url}") return False