import requests
import json
import plugins
from bridge.reply import Reply, ReplyType
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from plugins import *
from common.log import logger
from common.expired_dict import ExpiredDict
import os
import base64
from pathlib import Path
from PIL import Image
import oss2
from lib import itchat
from lib.itchat.content import *
import re
from bot.session_manager import Session
from bot.session_manager import SessionManager
from bot.chatgpt.chat_gpt_session import ChatGPTSession
@plugins.register(
name="healthai",
desire_priority=-1,
desc="A plugin for upload",
version="0.0.01",
author="",
)
class healthai(Plugin):
def __init__(self):
super().__init__()
try:
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
self.config = json.load(f)
else:
# 使用父类的方法来加载配置
self.config = super().load_config()
if not self.config:
raise Exception("config.json not found")
# 设置事件处理函数
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
self.params_cache = ExpiredDict(300)
# 从配置中提取所需的设置
self.oss = self.config.get("oss", {})
self.oss_access_key_id=self.oss.get("access_key_id","LTAI5tRTG6pLhTpKACJYoPR5")
self.oss_access_key_secret=self.oss.get("access_key_secret","E7dMzeeMxq4VQvLg7Tq7uKf3XWpYfN")
self.oss_endpoint=self.oss.get("endpoint","http://oss-cn-shanghai.aliyuncs.com")
self.oss_bucket_name=self.oss.get("bucket_name","cow-agent")
# 之前提示
self.previous_prompt=''
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
# 初始化成功日志
logger.info("[file4upload] inited.")
except Exception as e:
# 初始化失败日志
logger.warn(f"file4upload init failed: {e}")
def on_handle_context(self, e_context: EventContext):
context = e_context["context"]
if context.type not in [ContextType.TEXT, ContextType.SHARING,ContextType.FILE,ContextType.IMAGE]:
return
msg: ChatMessage = e_context["context"]["msg"]
user_id = msg.from_user_id
content = context.content
isgroup = e_context["context"].get("isgroup", False)
context.get("msg").prepare()
logger.info(f'当前缓存:self.params_cache:{self.params_cache}')
print(f'输入内容:{content}')
print(f'类型:{context.type}')
if user_id not in self.params_cache:
self.params_cache[user_id] = {}
logger.info(f'初始化缓存:{self.params_cache}')
if context.type == ContextType.TEXT and user_id in self.params_cache:
self.params_cache[user_id]['previous_prompt']=content
logger.info(f'上次提示缓存:{self.params_cache}')
# if context.type == ContextType.TEXT and user_id in self.params_cache and contains_keywords(content):
# self.params_cache[user_id]['previous_prompt']=content
# logger.info(f'上次提示缓存:{self.params_cache}')
# session_id = context["session_id"]
# session = self.sessions.session_query(content, session_id)
# print(f'session 消息{session.messages}')
# if 'last_content' not in self.params_cache[user_id]:
# reply = Reply()
# reply.type = ReplyType.TEXT
# reply.content = f"请上传相关报告或图片"
# e_context["reply"] = reply
# e_context.action = EventAction.BREAK_PASS
session_id = context["session_id"]
print(f'会话id:{session_id}')
# friends=itchat.get_friends(update=True)[1:]
# # logger.info(f'好友列表{friends}')
# # 提取所有好友的 NickName
# nicknames = [friend['NickName'] for friend in friends]
# print(nicknames)
# 打印所有 NickName
# for nickname in nicknames:
# print(nickname)
session = self.sessions.build_session(session_id)
print(f'session 消息{session.messages}')
# if context.type == ContextType.TEXT and user_id in self.params_cache and contains_keywords(content):
# self.params_cache[user_id]['previous_prompt']=content
# logger.info(f'上次提示缓存:{self.params_cache}')
# session_id = context["session_id"]
# session = self.sessions.session_query(content, session_id)
# print(f'session 消息{session.messages}')
# if 'last_content' not in self.params_cache[user_id]:
# reply = Reply()
# reply.type = ReplyType.TEXT
# reply.content = f"请上传相关报告或图片"
# e_context["reply"] = reply
# e_context.action = EventAction.BREAK_PASS
if context.type in [ContextType.IMAGE]:
logger.info('处理上传')
file_path = context.content
logger.info(f"on_handle_context: 获取到图片路径 {file_path},{user_id in self.params_cache}")
if user_id in self.params_cache:
if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']:
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = f"您刚刚上传图片,请问我有什么可以帮您的呢?"
e_context["reply"] = reply
e_context.action = EventAction.BREAK
file_content = upload_oss(self.oss_access_key_id, self.oss_access_key_secret, self.oss_endpoint, self.oss_bucket_name, file_path, f'cow/{os.path.basename(file_path)}')
# 确保 'last_content' 键存在,并且是一个列表
if 'last_content' not in self.params_cache[user_id]:
self.params_cache[user_id]['last_content'] = []
# 添加文件内容到 'urls' 列表
self.params_cache[user_id]['last_content'].append(file_content)
logger.info('删除图片')
os.remove(file_path)
if context.type == ContextType.FILE:
logger.info('处理图片')
file_path = context.content
logger.info(f"on_handle_context: 获取到文件路径 {file_path}")
if user_id in self.params_cache:
if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']:
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = f"您刚刚上传了一份文件,请问我有什么可以帮您的呢?"
e_context["reply"] = reply
e_context.action = EventAction.BREAK
# else:
print(f'准备抽取文字')
file_content=extract_content_by_llm(file_path,"sk-5z2L4zy9T1w90j6e3T90ANZdyN2zLWClRwFnBzWgzdrG4onx")
if file_content is None:
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = f"不能处理这份文件"
e_context["reply"] = reply
e_context.action = EventAction.BREAK
return
else:
self.params_cache[user_id]['last_content']=file_content
logger.info('删除文件')
os.remove(file_path)
# 先回应
if 'previous_prompt' in self.params_cache[user_id] and 'last_content' in self.params_cache[user_id] and contains_keywords(self.params_cache[user_id]['previous_prompt']):
logger.info('先回应')
receiver=user_id
print(receiver)
text=self.params_cache[user_id]['previous_prompt']
logger.info(f'{text},{contains_keywords(text)}')
itchat_content= f'@{msg.actual_user_nickname}' if e_context['context']['isgroup'] else '[小蕴]'
itchat_content+="已经收到,立刻为您服务"
flag=contains_keywords(text)
if flag==True:
print('发送'+itchat_content)
itchat.send(itchat_content, toUserName=receiver)
e_context.action = EventAction.BREAK
# 图片和提示次齐全
if 'previous_prompt' in self.params_cache[user_id] and 'last_content' in self.params_cache[user_id]:
if contains_keywords(self.params_cache[user_id]['previous_prompt']):
e_context["context"].type = ContextType.TEXT
last_content=self.params_cache[user_id]['last_content']
prompt=self.params_cache[user_id]['previous_prompt']
# if isinstance(last_content, list):
# e_context["context"].content =self.generate_openai_messages_content(last_content,prompt)
# elif isinstance(last_content, str):
# e_context["context"].content =""+last_content+""+'\n\t'+""+prompt+""
# else:
# return "urls is neither a list nor a string"
e_context["context"].content =self.generate_openai_messages_content(last_content,prompt)
logger.info(f'插件处理上传文件或图片')
e_context.action = EventAction.CONTINUE
# 清空清空缓存
self.params_cache.clear()
logger.info(f'清空缓存后:{self.params_cache}')
else:
if not e_context['context']['isgroup']:
reply = Reply()
reply.type = ReplyType.TEXT
# reply.content = f"{remove_markdown(reply_content)}\n\n💬5min内输入{self.file_sum_qa_prefix}+问题,可继续追问"
reply.content = f"您刚刚上传了,请问我有什么可以帮您的呢?"
e_context["reply"] = reply
e_context.action = EventAction.BREAK
return
def on_handle_context2(self, e_context: EventContext):
context = e_context["context"]
# 检查 context 类型
if context.type not in {ContextType.TEXT, ContextType.SHARING, ContextType.FILE, ContextType.IMAGE}:
return
msg: ChatMessage = context["msg"]
user_id = msg.from_user_id
content = context.content
is_group = context.get("isgroup", False)
# 准备消息
context.get("msg").prepare()
logger.info(f'当前缓存:self.params_cache:{self.params_cache}')
# 初始化用户缓存
user_cache = self.params_cache.setdefault(user_id, {})
if not user_cache:
logger.info(f'初始化缓存:{self.params_cache}')
previous_prompt = user_cache.get('previous_prompt')
last_content = user_cache.get('last_content')
# 更新 previous_prompt
if context.type == ContextType.TEXT and previous_prompt and contains_keywords(previous_prompt):
user_cache['previous_prompt'] = msg.content
# 处理 previous_prompt 和 last_content
if previous_prompt and last_content and contains_keywords(previous_prompt):
logger.info('先回应')
receiver = user_id
itchat_content = f'@{msg.actual_user_nickname}' if is_group else '[小蕴]'
itchat_content += "已经收到,立刻为您服务"
if contains_keywords(previous_prompt):
logger.info(f'发送消息: {itchat_content}')
itchat.send(itchat_content, toUserName=receiver)
e_context.action = EventAction.BREAK
# 清空缓存
self.params_cache.clear()
logger.info(f'清空缓存后:{self.params_cache}')
else:
if not is_group:
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = "您刚刚上传了,请问我有什么可以帮您的呢?"
e_context["reply"] = reply
e_context.action = EventAction.BREAK
if context.type in [ContextType.FILE,ContextType.IMAGE]:
logger.info('处理上传')
file_path = context.content
logger.info(f"on_handle_context: 获取到图片路径 {file_path},{user_id in self.params_cache}")
if user_id in self.params_cache:
if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']:
reply = Reply()
reply.type = ReplyType.TEXT
if context.type==ContextType.FILE:
reply.content = f"您刚刚上传文件,请问我有什么可以帮您的呢?"
else:
reply.content = f"您刚刚上传图片,请问我有什么可以帮您的呢?"
e_context["reply"] = reply
e_context.action = EventAction.BREAK
file_content = upload_oss(self.oss_access_key_id, self.oss_access_key_secret, self.oss_endpoint, self.oss_bucket_name, file_path, f'cow/{os.path.basename(file_path)}')
# 确保 'urls' 键存在,并且是一个列表
if 'urls' not in self.params_cache[user_id]:
self.params_cache[user_id]['urls'] = []
# 添加文件内容到 'urls' 列表
self.params_cache[user_id]['urls'].append(file_content)
logger.info('删除图片')
os.remove(file_path)
def generate_openai_messages_content(self, last_content,prompt):
content = []
if isinstance(last_content, list):
# 遍历每个 URL,生成对应的消息结构
for url in last_content:
if url.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
# 对于图片,生成 "image_url" 类型的消息
content.append({
"type": "image_url",
"image_url": {
"url": url
}
})
else:
# 对于其他文件,生成 "file_url" 或类似的处理方式
content.append({
"type": "file_url",
"file_url": {
"url": url
}
})
else:
prompt=""+last_content+""+'\n\t'+""+prompt+""
# 遍历每个 URL,生成对应的消息结构
# for url in urls:
# if url.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
# # 对于图片,生成 "image_url" 类型的消息
# content.append({
# "type": "image_url",
# "image_url": {
# "url": url
# }
# })
# else:
# # 对于其他文件,生成 "file_url" 或类似的处理方式
# content.append({
# "type": "file_url",
# "file_url": {
# "url": url
# }
# })
# 添加额外的文本说明
content.append({
"type": "text",
"text": prompt
})
return json.dumps(content, ensure_ascii=False)
def remove_markdown(text):
# 替换Markdown的粗体标记
text = text.replace("**", "")
# 替换Markdown的标题标记
text = text.replace("### ", "").replace("## ", "").replace("# ", "")
return text
def extract_content_by_llm(file_path: str, api_key: str) -> str:
logger.info(f'大模型开始抽取文字')
try:
headers = {
'Authorization': f'Bearer {api_key}'
}
data = {
'purpose': 'file-extract',
}
file_name=os.path.basename(file_path)
files = {
'file': (file_name, open(Path(file_path), 'rb')),
}
# print(files)
api_url='https://api.moonshot.cn/v1/files'
response = requests.post(api_url, headers=headers, files=files, data=data)
response_data = response.json()
file_id = response_data.get('id')
response=requests.get(url=f"https://api.moonshot.cn/v1/files/{file_id}/content", headers=headers)
print(response.text)
response_data = response.json()
content = response_data.get('content')
return content
except requests.exceptions.RequestException as e:
logger.error(f"Error calling LLM API: {e}")
return None
def upload_oss(access_key_id, access_key_secret, endpoint, bucket_name, local_file_path, oss_file_name, expiration_days=7):
"""
上传文件到阿里云OSS并设置生命周期规则,同时返回文件的公共访问地址。
:param access_key_id: 阿里云AccessKey ID
:param access_key_secret: 阿里云AccessKey Secret
:param endpoint: OSS区域对应的Endpoint
:param bucket_name: OSS中的Bucket名称
:param local_file_path: 本地文件路径
:param oss_file_name: OSS中的文件存储路径
:param expiration_days: 文件保存天数,默认7天后删除
:return: 文件的公共访问地址
"""
# 创建Bucket实例
auth = oss2.Auth(access_key_id, access_key_secret)
bucket = oss2.Bucket(auth, endpoint, bucket_name)
### 1. 设置生命周期规则 ###
rule_id = f'delete_after_{expiration_days}_days' # 规则ID
prefix = oss_file_name.split('/')[0] + '/' # 设置规则应用的前缀为文件所在目录
# 定义生命周期规则
rule = oss2.models.LifecycleRule(rule_id, prefix, status=oss2.models.LifecycleRule.ENABLED,
expiration=oss2.models.LifecycleExpiration(days=expiration_days))
# 设置Bucket的生命周期
lifecycle = oss2.models.BucketLifecycle([rule])
bucket.put_bucket_lifecycle(lifecycle)
print(f"已设置生命周期规则:文件将在{expiration_days}天后自动删除")
### 2. 上传文件到OSS ###
bucket.put_object_from_file(oss_file_name, local_file_path)
### 3. 构建公共访问URL ###
file_url = f"http://{bucket_name}.{endpoint.replace('http://', '')}/{oss_file_name}"
print(f"文件上传成功,公共访问地址:{file_url}")
return file_url
def contains_keywords_by_re(text):
# 匹配标签中的内容
# match = re.search(r'(.*?)', text)
match = re.search(r'(.*?)', text)
if match:
content = match.group(1)
# 检查关键词
keywords = ['分析', '总结', '报告', '描述']
for keyword in keywords:
if keyword in content:
return True
return False
def contains_keywords(text):
keywords = ["分析", "总结", "报告", "描述","说说","讲述","讲讲","讲一下","图片"]
return any(keyword in text for keyword in keywords)