import requests import json import plugins from bridge.reply import Reply, ReplyType from bridge.context import ContextType from channel.chat_message import ChatMessage from plugins import * from common.log import logger from common.expired_dict import ExpiredDict import os import base64 from pathlib import Path from PIL import Image import oss2 from lib import itchat from lib.itchat.content import * import re from bot.session_manager import Session from bot.session_manager import SessionManager from bot.chatgpt.chat_gpt_session import ChatGPTSession @plugins.register( name="healthai", desire_priority=-1, desc="A plugin for upload", version="0.0.01", author="", ) class healthai(Plugin): def __init__(self): super().__init__() try: curdir = os.path.dirname(__file__) config_path = os.path.join(curdir, "config.json") if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as f: self.config = json.load(f) else: # 使用父类的方法来加载配置 self.config = super().load_config() if not self.config: raise Exception("config.json not found") # 设置事件处理函数 self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.params_cache = ExpiredDict(300) # 从配置中提取所需的设置 self.oss = self.config.get("oss", {}) self.oss_access_key_id=self.oss.get("access_key_id","LTAI5tRTG6pLhTpKACJYoPR5") self.oss_access_key_secret=self.oss.get("access_key_secret","E7dMzeeMxq4VQvLg7Tq7uKf3XWpYfN") self.oss_endpoint=self.oss.get("endpoint","http://oss-cn-shanghai.aliyuncs.com") self.oss_bucket_name=self.oss.get("bucket_name","cow-agent") # 之前提示 self.previous_prompt='' self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") # 初始化成功日志 logger.info("[file4upload] inited.") except Exception as e: # 初始化失败日志 logger.warn(f"file4upload init failed: {e}") def on_handle_context(self, e_context: EventContext): context = e_context["context"] if context.type not in [ContextType.TEXT, ContextType.SHARING,ContextType.FILE,ContextType.IMAGE]: return msg: ChatMessage = e_context["context"]["msg"] user_id = msg.from_user_id content = context.content isgroup = e_context["context"].get("isgroup", False) context.get("msg").prepare() logger.info(f'当前缓存:self.params_cache:{self.params_cache}') print(f'输入内容:{content}') print(f'类型:{context.type}') if user_id not in self.params_cache: self.params_cache[user_id] = {} logger.info(f'初始化缓存:{self.params_cache}') if context.type == ContextType.TEXT and user_id in self.params_cache: self.params_cache[user_id]['previous_prompt']=content logger.info(f'上次提示缓存:{self.params_cache}') # if context.type == ContextType.TEXT and user_id in self.params_cache and contains_keywords(content): # self.params_cache[user_id]['previous_prompt']=content # logger.info(f'上次提示缓存:{self.params_cache}') # session_id = context["session_id"] # session = self.sessions.session_query(content, session_id) # print(f'session 消息{session.messages}') # if 'last_content' not in self.params_cache[user_id]: # reply = Reply() # reply.type = ReplyType.TEXT # reply.content = f"请上传相关报告或图片" # e_context["reply"] = reply # e_context.action = EventAction.BREAK_PASS session_id = context["session_id"] print(f'会话id:{session_id}') # friends=itchat.get_friends(update=True)[1:] # # logger.info(f'好友列表{friends}') # # 提取所有好友的 NickName # nicknames = [friend['NickName'] for friend in friends] # print(nicknames) # 打印所有 NickName # for nickname in nicknames: # print(nickname) session = self.sessions.build_session(session_id) print(f'session 消息{session.messages}') # if context.type == ContextType.TEXT and user_id in self.params_cache and contains_keywords(content): # self.params_cache[user_id]['previous_prompt']=content # logger.info(f'上次提示缓存:{self.params_cache}') # session_id = context["session_id"] # session = self.sessions.session_query(content, session_id) # print(f'session 消息{session.messages}') # if 'last_content' not in self.params_cache[user_id]: # reply = Reply() # reply.type = ReplyType.TEXT # reply.content = f"请上传相关报告或图片" # e_context["reply"] = reply # e_context.action = EventAction.BREAK_PASS if context.type in [ContextType.IMAGE]: logger.info('处理上传') file_path = context.content logger.info(f"on_handle_context: 获取到图片路径 {file_path},{user_id in self.params_cache}") if user_id in self.params_cache: if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']: reply = Reply() reply.type = ReplyType.TEXT reply.content = f"您刚刚上传图片,请问我有什么可以帮您的呢?" e_context["reply"] = reply e_context.action = EventAction.BREAK file_content = upload_oss(self.oss_access_key_id, self.oss_access_key_secret, self.oss_endpoint, self.oss_bucket_name, file_path, f'cow/{os.path.basename(file_path)}') # 确保 'last_content' 键存在,并且是一个列表 if 'last_content' not in self.params_cache[user_id]: self.params_cache[user_id]['last_content'] = [] # 添加文件内容到 'urls' 列表 self.params_cache[user_id]['last_content'].append(file_content) logger.info('删除图片') os.remove(file_path) if context.type == ContextType.FILE: logger.info('处理图片') file_path = context.content logger.info(f"on_handle_context: 获取到文件路径 {file_path}") if user_id in self.params_cache: if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']: reply = Reply() reply.type = ReplyType.TEXT reply.content = f"您刚刚上传了一份文件,请问我有什么可以帮您的呢?" e_context["reply"] = reply e_context.action = EventAction.BREAK # else: print(f'准备抽取文字') file_content=extract_content_by_llm(file_path,"sk-5z2L4zy9T1w90j6e3T90ANZdyN2zLWClRwFnBzWgzdrG4onx") if file_content is None: reply = Reply() reply.type = ReplyType.TEXT reply.content = f"不能处理这份文件" e_context["reply"] = reply e_context.action = EventAction.BREAK return else: self.params_cache[user_id]['last_content']=file_content logger.info('删除文件') os.remove(file_path) # 先回应 if 'previous_prompt' in self.params_cache[user_id] and 'last_content' in self.params_cache[user_id] and contains_keywords(self.params_cache[user_id]['previous_prompt']): logger.info('先回应') receiver=user_id print(receiver) text=self.params_cache[user_id]['previous_prompt'] logger.info(f'{text},{contains_keywords(text)}') itchat_content= f'@{msg.actual_user_nickname}' if e_context['context']['isgroup'] else '[小蕴]' itchat_content+="已经收到,立刻为您服务" flag=contains_keywords(text) if flag==True: print('发送'+itchat_content) itchat.send(itchat_content, toUserName=receiver) e_context.action = EventAction.BREAK # 图片和提示次齐全 if 'previous_prompt' in self.params_cache[user_id] and 'last_content' in self.params_cache[user_id]: if contains_keywords(self.params_cache[user_id]['previous_prompt']): e_context["context"].type = ContextType.TEXT last_content=self.params_cache[user_id]['last_content'] prompt=self.params_cache[user_id]['previous_prompt'] # if isinstance(last_content, list): # e_context["context"].content =self.generate_openai_messages_content(last_content,prompt) # elif isinstance(last_content, str): # e_context["context"].content =""+last_content+""+'\n\t'+""+prompt+"" # else: # return "urls is neither a list nor a string" e_context["context"].content =self.generate_openai_messages_content(last_content,prompt) logger.info(f'插件处理上传文件或图片') e_context.action = EventAction.CONTINUE # 清空清空缓存 self.params_cache.clear() logger.info(f'清空缓存后:{self.params_cache}') else: if not e_context['context']['isgroup']: reply = Reply() reply.type = ReplyType.TEXT # reply.content = f"{remove_markdown(reply_content)}\n\n💬5min内输入{self.file_sum_qa_prefix}+问题,可继续追问" reply.content = f"您刚刚上传了,请问我有什么可以帮您的呢?" e_context["reply"] = reply e_context.action = EventAction.BREAK return def on_handle_context2(self, e_context: EventContext): context = e_context["context"] # 检查 context 类型 if context.type not in {ContextType.TEXT, ContextType.SHARING, ContextType.FILE, ContextType.IMAGE}: return msg: ChatMessage = context["msg"] user_id = msg.from_user_id content = context.content is_group = context.get("isgroup", False) # 准备消息 context.get("msg").prepare() logger.info(f'当前缓存:self.params_cache:{self.params_cache}') # 初始化用户缓存 user_cache = self.params_cache.setdefault(user_id, {}) if not user_cache: logger.info(f'初始化缓存:{self.params_cache}') previous_prompt = user_cache.get('previous_prompt') last_content = user_cache.get('last_content') # 更新 previous_prompt if context.type == ContextType.TEXT and previous_prompt and contains_keywords(previous_prompt): user_cache['previous_prompt'] = msg.content # 处理 previous_prompt 和 last_content if previous_prompt and last_content and contains_keywords(previous_prompt): logger.info('先回应') receiver = user_id itchat_content = f'@{msg.actual_user_nickname}' if is_group else '[小蕴]' itchat_content += "已经收到,立刻为您服务" if contains_keywords(previous_prompt): logger.info(f'发送消息: {itchat_content}') itchat.send(itchat_content, toUserName=receiver) e_context.action = EventAction.BREAK # 清空缓存 self.params_cache.clear() logger.info(f'清空缓存后:{self.params_cache}') else: if not is_group: reply = Reply() reply.type = ReplyType.TEXT reply.content = "您刚刚上传了,请问我有什么可以帮您的呢?" e_context["reply"] = reply e_context.action = EventAction.BREAK if context.type in [ContextType.FILE,ContextType.IMAGE]: logger.info('处理上传') file_path = context.content logger.info(f"on_handle_context: 获取到图片路径 {file_path},{user_id in self.params_cache}") if user_id in self.params_cache: if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']: reply = Reply() reply.type = ReplyType.TEXT if context.type==ContextType.FILE: reply.content = f"您刚刚上传文件,请问我有什么可以帮您的呢?" else: reply.content = f"您刚刚上传图片,请问我有什么可以帮您的呢?" e_context["reply"] = reply e_context.action = EventAction.BREAK file_content = upload_oss(self.oss_access_key_id, self.oss_access_key_secret, self.oss_endpoint, self.oss_bucket_name, file_path, f'cow/{os.path.basename(file_path)}') # 确保 'urls' 键存在,并且是一个列表 if 'urls' not in self.params_cache[user_id]: self.params_cache[user_id]['urls'] = [] # 添加文件内容到 'urls' 列表 self.params_cache[user_id]['urls'].append(file_content) logger.info('删除图片') os.remove(file_path) def generate_openai_messages_content(self, last_content,prompt): content = [] if isinstance(last_content, list): # 遍历每个 URL,生成对应的消息结构 for url in last_content: if url.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): # 对于图片,生成 "image_url" 类型的消息 content.append({ "type": "image_url", "image_url": { "url": url } }) else: # 对于其他文件,生成 "file_url" 或类似的处理方式 content.append({ "type": "file_url", "file_url": { "url": url } }) else: prompt=""+last_content+""+'\n\t'+""+prompt+"" # 遍历每个 URL,生成对应的消息结构 # for url in urls: # if url.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): # # 对于图片,生成 "image_url" 类型的消息 # content.append({ # "type": "image_url", # "image_url": { # "url": url # } # }) # else: # # 对于其他文件,生成 "file_url" 或类似的处理方式 # content.append({ # "type": "file_url", # "file_url": { # "url": url # } # }) # 添加额外的文本说明 content.append({ "type": "text", "text": prompt }) return json.dumps(content, ensure_ascii=False) def remove_markdown(text): # 替换Markdown的粗体标记 text = text.replace("**", "") # 替换Markdown的标题标记 text = text.replace("### ", "").replace("## ", "").replace("# ", "") return text def extract_content_by_llm(file_path: str, api_key: str) -> str: logger.info(f'大模型开始抽取文字') try: headers = { 'Authorization': f'Bearer {api_key}' } data = { 'purpose': 'file-extract', } file_name=os.path.basename(file_path) files = { 'file': (file_name, open(Path(file_path), 'rb')), } # print(files) api_url='https://api.moonshot.cn/v1/files' response = requests.post(api_url, headers=headers, files=files, data=data) response_data = response.json() file_id = response_data.get('id') response=requests.get(url=f"https://api.moonshot.cn/v1/files/{file_id}/content", headers=headers) print(response.text) response_data = response.json() content = response_data.get('content') return content except requests.exceptions.RequestException as e: logger.error(f"Error calling LLM API: {e}") return None def upload_oss(access_key_id, access_key_secret, endpoint, bucket_name, local_file_path, oss_file_name, expiration_days=7): """ 上传文件到阿里云OSS并设置生命周期规则,同时返回文件的公共访问地址。 :param access_key_id: 阿里云AccessKey ID :param access_key_secret: 阿里云AccessKey Secret :param endpoint: OSS区域对应的Endpoint :param bucket_name: OSS中的Bucket名称 :param local_file_path: 本地文件路径 :param oss_file_name: OSS中的文件存储路径 :param expiration_days: 文件保存天数,默认7天后删除 :return: 文件的公共访问地址 """ # 创建Bucket实例 auth = oss2.Auth(access_key_id, access_key_secret) bucket = oss2.Bucket(auth, endpoint, bucket_name) ### 1. 设置生命周期规则 ### rule_id = f'delete_after_{expiration_days}_days' # 规则ID prefix = oss_file_name.split('/')[0] + '/' # 设置规则应用的前缀为文件所在目录 # 定义生命周期规则 rule = oss2.models.LifecycleRule(rule_id, prefix, status=oss2.models.LifecycleRule.ENABLED, expiration=oss2.models.LifecycleExpiration(days=expiration_days)) # 设置Bucket的生命周期 lifecycle = oss2.models.BucketLifecycle([rule]) bucket.put_bucket_lifecycle(lifecycle) print(f"已设置生命周期规则:文件将在{expiration_days}天后自动删除") ### 2. 上传文件到OSS ### bucket.put_object_from_file(oss_file_name, local_file_path) ### 3. 构建公共访问URL ### file_url = f"http://{bucket_name}.{endpoint.replace('http://', '')}/{oss_file_name}" print(f"文件上传成功,公共访问地址:{file_url}") return file_url def contains_keywords_by_re(text): # 匹配标签中的内容 # match = re.search(r'(.*?)', text) match = re.search(r'(.*?)', text) if match: content = match.group(1) # 检查关键词 keywords = ['分析', '总结', '报告', '描述'] for keyword in keywords: if keyword in content: return True return False def contains_keywords(text): keywords = ["分析", "总结", "报告", "描述","说说","讲述","讲讲","讲一下","图片"] return any(keyword in text for keyword in keywords)