您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

456 行
19KB

  1. import requests
  2. import json
  3. import plugins
  4. from bridge.reply import Reply, ReplyType
  5. from bridge.context import ContextType
  6. from channel.chat_message import ChatMessage
  7. from plugins import *
  8. from common.log import logger
  9. from common.expired_dict import ExpiredDict
  10. import os
  11. import base64
  12. from pathlib import Path
  13. from PIL import Image
  14. import oss2
  15. from lib import itchat
  16. from lib.itchat.content import *
  17. import re
  18. from bot.session_manager import Session
  19. from bot.session_manager import SessionManager
  20. from bot.chatgpt.chat_gpt_session import ChatGPTSession
  21. @plugins.register(
  22. name="healthai",
  23. desire_priority=-1,
  24. desc="A plugin for upload",
  25. version="0.0.01",
  26. author="",
  27. )
  28. class healthai(Plugin):
  29. def __init__(self):
  30. super().__init__()
  31. try:
  32. curdir = os.path.dirname(__file__)
  33. config_path = os.path.join(curdir, "config.json")
  34. if os.path.exists(config_path):
  35. with open(config_path, "r", encoding="utf-8") as f:
  36. self.config = json.load(f)
  37. else:
  38. # 使用父类的方法来加载配置
  39. self.config = super().load_config()
  40. if not self.config:
  41. raise Exception("config.json not found")
  42. # 设置事件处理函数
  43. self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
  44. self.params_cache = ExpiredDict(300)
  45. # 从配置中提取所需的设置
  46. self.oss = self.config.get("oss", {})
  47. self.oss_access_key_id=self.oss.get("access_key_id","LTAI5tRTG6pLhTpKACJYoPR5")
  48. self.oss_access_key_secret=self.oss.get("access_key_secret","E7dMzeeMxq4VQvLg7Tq7uKf3XWpYfN")
  49. self.oss_endpoint=self.oss.get("endpoint","http://oss-cn-shanghai.aliyuncs.com")
  50. self.oss_bucket_name=self.oss.get("bucket_name","cow-agent")
  51. # 之前提示
  52. self.previous_prompt=''
  53. self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
  54. # 初始化成功日志
  55. logger.info("[file4upload] inited.")
  56. except Exception as e:
  57. # 初始化失败日志
  58. logger.warn(f"file4upload init failed: {e}")
  59. def on_handle_context(self, e_context: EventContext):
  60. context = e_context["context"]
  61. if context.type not in [ContextType.TEXT, ContextType.SHARING,ContextType.FILE,ContextType.IMAGE]:
  62. return
  63. msg: ChatMessage = e_context["context"]["msg"]
  64. user_id = msg.from_user_id
  65. content = context.content
  66. isgroup = e_context["context"].get("isgroup", False)
  67. context.get("msg").prepare()
  68. logger.info(f'当前缓存:self.params_cache:{self.params_cache}')
  69. print(f'输入内容:{content}')
  70. print(f'类型:{context.type}')
  71. if user_id not in self.params_cache:
  72. self.params_cache[user_id] = {}
  73. logger.info(f'初始化缓存:{self.params_cache}')
  74. if context.type == ContextType.TEXT and user_id in self.params_cache:
  75. self.params_cache[user_id]['previous_prompt']=content
  76. logger.info(f'上次提示缓存:{self.params_cache}')
  77. # if context.type == ContextType.TEXT and user_id in self.params_cache and contains_keywords(content):
  78. # self.params_cache[user_id]['previous_prompt']=content
  79. # logger.info(f'上次提示缓存:{self.params_cache}')
  80. # session_id = context["session_id"]
  81. # session = self.sessions.session_query(content, session_id)
  82. # print(f'session 消息{session.messages}')
  83. # if 'last_content' not in self.params_cache[user_id]:
  84. # reply = Reply()
  85. # reply.type = ReplyType.TEXT
  86. # reply.content = f"请上传相关报告或图片"
  87. # e_context["reply"] = reply
  88. # e_context.action = EventAction.BREAK_PASS
  89. session_id = context["session_id"]
  90. print(f'会话id:{session_id}')
  91. # friends=itchat.get_friends(update=True)[1:]
  92. # # logger.info(f'好友列表{friends}')
  93. # # 提取所有好友的 NickName
  94. # nicknames = [friend['NickName'] for friend in friends]
  95. # print(nicknames)
  96. # 打印所有 NickName
  97. # for nickname in nicknames:
  98. # print(nickname)
  99. session = self.sessions.build_session(session_id)
  100. print(f'session 消息{session.messages}')
  101. # if context.type == ContextType.TEXT and user_id in self.params_cache and contains_keywords(content):
  102. # self.params_cache[user_id]['previous_prompt']=content
  103. # logger.info(f'上次提示缓存:{self.params_cache}')
  104. # session_id = context["session_id"]
  105. # session = self.sessions.session_query(content, session_id)
  106. # print(f'session 消息{session.messages}')
  107. # if 'last_content' not in self.params_cache[user_id]:
  108. # reply = Reply()
  109. # reply.type = ReplyType.TEXT
  110. # reply.content = f"请上传相关报告或图片"
  111. # e_context["reply"] = reply
  112. # e_context.action = EventAction.BREAK_PASS
  113. if context.type in [ContextType.IMAGE]:
  114. logger.info('处理上传')
  115. file_path = context.content
  116. logger.info(f"on_handle_context: 获取到图片路径 {file_path},{user_id in self.params_cache}")
  117. if user_id in self.params_cache:
  118. if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']:
  119. reply = Reply()
  120. reply.type = ReplyType.TEXT
  121. reply.content = f"您刚刚上传图片,请问我有什么可以帮您的呢?"
  122. e_context["reply"] = reply
  123. e_context.action = EventAction.BREAK
  124. file_content = upload_oss(self.oss_access_key_id, self.oss_access_key_secret, self.oss_endpoint, self.oss_bucket_name, file_path, f'cow/{os.path.basename(file_path)}')
  125. # 确保 'last_content' 键存在,并且是一个列表
  126. if 'last_content' not in self.params_cache[user_id]:
  127. self.params_cache[user_id]['last_content'] = []
  128. # 添加文件内容到 'urls' 列表
  129. self.params_cache[user_id]['last_content'].append(file_content)
  130. logger.info('删除图片')
  131. os.remove(file_path)
  132. if context.type == ContextType.FILE:
  133. logger.info('处理图片')
  134. file_path = context.content
  135. logger.info(f"on_handle_context: 获取到文件路径 {file_path}")
  136. if user_id in self.params_cache:
  137. if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']:
  138. reply = Reply()
  139. reply.type = ReplyType.TEXT
  140. reply.content = f"您刚刚上传了一份文件,请问我有什么可以帮您的呢?"
  141. e_context["reply"] = reply
  142. e_context.action = EventAction.BREAK
  143. # else:
  144. print(f'准备抽取文字')
  145. file_content=extract_content_by_llm(file_path,"sk-5z2L4zy9T1w90j6e3T90ANZdyN2zLWClRwFnBzWgzdrG4onx")
  146. if file_content is None:
  147. reply = Reply()
  148. reply.type = ReplyType.TEXT
  149. reply.content = f"不能处理这份文件"
  150. e_context["reply"] = reply
  151. e_context.action = EventAction.BREAK
  152. return
  153. else:
  154. self.params_cache[user_id]['last_content']=file_content
  155. logger.info('删除文件')
  156. os.remove(file_path)
  157. # 先回应
  158. if 'previous_prompt' in self.params_cache[user_id] and 'last_content' in self.params_cache[user_id] and contains_keywords(self.params_cache[user_id]['previous_prompt']):
  159. logger.info('先回应')
  160. receiver=user_id
  161. print(receiver)
  162. text=self.params_cache[user_id]['previous_prompt']
  163. logger.info(f'{text},{contains_keywords(text)}')
  164. itchat_content= f'@{msg.actual_user_nickname}' if e_context['context']['isgroup'] else '[小蕴]'
  165. itchat_content+="已经收到,立刻为您服务"
  166. flag=contains_keywords(text)
  167. if flag==True:
  168. print('发送'+itchat_content)
  169. itchat.send(itchat_content, toUserName=receiver)
  170. e_context.action = EventAction.BREAK
  171. # 图片和提示次齐全
  172. if 'previous_prompt' in self.params_cache[user_id] and 'last_content' in self.params_cache[user_id]:
  173. if contains_keywords(self.params_cache[user_id]['previous_prompt']):
  174. e_context["context"].type = ContextType.TEXT
  175. last_content=self.params_cache[user_id]['last_content']
  176. prompt=self.params_cache[user_id]['previous_prompt']
  177. # if isinstance(last_content, list):
  178. # e_context["context"].content =self.generate_openai_messages_content(last_content,prompt)
  179. # elif isinstance(last_content, str):
  180. # e_context["context"].content ="<content>"+last_content+"</content>"+'\n\t'+"<ask>"+prompt+"</ask>"
  181. # else:
  182. # return "urls is neither a list nor a string"
  183. e_context["context"].content =self.generate_openai_messages_content(last_content,prompt)
  184. logger.info(f'插件处理上传文件或图片')
  185. e_context.action = EventAction.CONTINUE
  186. # 清空清空缓存
  187. self.params_cache.clear()
  188. logger.info(f'清空缓存后:{self.params_cache}')
  189. else:
  190. if not e_context['context']['isgroup']:
  191. reply = Reply()
  192. reply.type = ReplyType.TEXT
  193. # reply.content = f"{remove_markdown(reply_content)}\n\n💬5min内输入{self.file_sum_qa_prefix}+问题,可继续追问"
  194. reply.content = f"您刚刚上传了,请问我有什么可以帮您的呢?"
  195. e_context["reply"] = reply
  196. e_context.action = EventAction.BREAK
  197. return
  198. def on_handle_context2(self, e_context: EventContext):
  199. context = e_context["context"]
  200. # 检查 context 类型
  201. if context.type not in {ContextType.TEXT, ContextType.SHARING, ContextType.FILE, ContextType.IMAGE}:
  202. return
  203. msg: ChatMessage = context["msg"]
  204. user_id = msg.from_user_id
  205. content = context.content
  206. is_group = context.get("isgroup", False)
  207. # 准备消息
  208. context.get("msg").prepare()
  209. logger.info(f'当前缓存:self.params_cache:{self.params_cache}')
  210. # 初始化用户缓存
  211. user_cache = self.params_cache.setdefault(user_id, {})
  212. if not user_cache:
  213. logger.info(f'初始化缓存:{self.params_cache}')
  214. previous_prompt = user_cache.get('previous_prompt')
  215. last_content = user_cache.get('last_content')
  216. # 更新 previous_prompt
  217. if context.type == ContextType.TEXT and previous_prompt and contains_keywords(previous_prompt):
  218. user_cache['previous_prompt'] = msg.content
  219. # 处理 previous_prompt 和 last_content
  220. if previous_prompt and last_content and contains_keywords(previous_prompt):
  221. logger.info('先回应')
  222. receiver = user_id
  223. itchat_content = f'@{msg.actual_user_nickname}' if is_group else '[小蕴]'
  224. itchat_content += "已经收到,立刻为您服务"
  225. if contains_keywords(previous_prompt):
  226. logger.info(f'发送消息: {itchat_content}')
  227. itchat.send(itchat_content, toUserName=receiver)
  228. e_context.action = EventAction.BREAK
  229. # 清空缓存
  230. self.params_cache.clear()
  231. logger.info(f'清空缓存后:{self.params_cache}')
  232. else:
  233. if not is_group:
  234. reply = Reply()
  235. reply.type = ReplyType.TEXT
  236. reply.content = "您刚刚上传了,请问我有什么可以帮您的呢?"
  237. e_context["reply"] = reply
  238. e_context.action = EventAction.BREAK
  239. if context.type in [ContextType.FILE,ContextType.IMAGE]:
  240. logger.info('处理上传')
  241. file_path = context.content
  242. logger.info(f"on_handle_context: 获取到图片路径 {file_path},{user_id in self.params_cache}")
  243. if user_id in self.params_cache:
  244. if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']:
  245. reply = Reply()
  246. reply.type = ReplyType.TEXT
  247. if context.type==ContextType.FILE:
  248. reply.content = f"您刚刚上传文件,请问我有什么可以帮您的呢?"
  249. else:
  250. reply.content = f"您刚刚上传图片,请问我有什么可以帮您的呢?"
  251. e_context["reply"] = reply
  252. e_context.action = EventAction.BREAK
  253. file_content = upload_oss(self.oss_access_key_id, self.oss_access_key_secret, self.oss_endpoint, self.oss_bucket_name, file_path, f'cow/{os.path.basename(file_path)}')
  254. # 确保 'urls' 键存在,并且是一个列表
  255. if 'urls' not in self.params_cache[user_id]:
  256. self.params_cache[user_id]['urls'] = []
  257. # 添加文件内容到 'urls' 列表
  258. self.params_cache[user_id]['urls'].append(file_content)
  259. logger.info('删除图片')
  260. os.remove(file_path)
  261. def generate_openai_messages_content(self, last_content,prompt):
  262. content = []
  263. if isinstance(last_content, list):
  264. # 遍历每个 URL,生成对应的消息结构
  265. for url in last_content:
  266. if url.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
  267. # 对于图片,生成 "image_url" 类型的消息
  268. content.append({
  269. "type": "image_url",
  270. "image_url": {
  271. "url": url
  272. }
  273. })
  274. else:
  275. # 对于其他文件,生成 "file_url" 或类似的处理方式
  276. content.append({
  277. "type": "file_url",
  278. "file_url": {
  279. "url": url
  280. }
  281. })
  282. else:
  283. prompt="<content>"+last_content+"</content>"+'\n\t'+"<ask>"+prompt+"</ask>"
  284. # 遍历每个 URL,生成对应的消息结构
  285. # for url in urls:
  286. # if url.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
  287. # # 对于图片,生成 "image_url" 类型的消息
  288. # content.append({
  289. # "type": "image_url",
  290. # "image_url": {
  291. # "url": url
  292. # }
  293. # })
  294. # else:
  295. # # 对于其他文件,生成 "file_url" 或类似的处理方式
  296. # content.append({
  297. # "type": "file_url",
  298. # "file_url": {
  299. # "url": url
  300. # }
  301. # })
  302. # 添加额外的文本说明
  303. content.append({
  304. "type": "text",
  305. "text": prompt
  306. })
  307. return json.dumps(content, ensure_ascii=False)
  308. def remove_markdown(text):
  309. # 替换Markdown的粗体标记
  310. text = text.replace("**", "")
  311. # 替换Markdown的标题标记
  312. text = text.replace("### ", "").replace("## ", "").replace("# ", "")
  313. return text
  314. def extract_content_by_llm(file_path: str, api_key: str) -> str:
  315. logger.info(f'大模型开始抽取文字')
  316. try:
  317. headers = {
  318. 'Authorization': f'Bearer {api_key}'
  319. }
  320. data = {
  321. 'purpose': 'file-extract',
  322. }
  323. file_name=os.path.basename(file_path)
  324. files = {
  325. 'file': (file_name, open(Path(file_path), 'rb')),
  326. }
  327. # print(files)
  328. api_url='https://api.moonshot.cn/v1/files'
  329. response = requests.post(api_url, headers=headers, files=files, data=data)
  330. response_data = response.json()
  331. file_id = response_data.get('id')
  332. response=requests.get(url=f"https://api.moonshot.cn/v1/files/{file_id}/content", headers=headers)
  333. print(response.text)
  334. response_data = response.json()
  335. content = response_data.get('content')
  336. return content
  337. except requests.exceptions.RequestException as e:
  338. logger.error(f"Error calling LLM API: {e}")
  339. return None
  340. def upload_oss(access_key_id, access_key_secret, endpoint, bucket_name, local_file_path, oss_file_name, expiration_days=7):
  341. """
  342. 上传文件到阿里云OSS并设置生命周期规则,同时返回文件的公共访问地址。
  343. :param access_key_id: 阿里云AccessKey ID
  344. :param access_key_secret: 阿里云AccessKey Secret
  345. :param endpoint: OSS区域对应的Endpoint
  346. :param bucket_name: OSS中的Bucket名称
  347. :param local_file_path: 本地文件路径
  348. :param oss_file_name: OSS中的文件存储路径
  349. :param expiration_days: 文件保存天数,默认7天后删除
  350. :return: 文件的公共访问地址
  351. """
  352. # 创建Bucket实例
  353. auth = oss2.Auth(access_key_id, access_key_secret)
  354. bucket = oss2.Bucket(auth, endpoint, bucket_name)
  355. ### 1. 设置生命周期规则 ###
  356. rule_id = f'delete_after_{expiration_days}_days' # 规则ID
  357. prefix = oss_file_name.split('/')[0] + '/' # 设置规则应用的前缀为文件所在目录
  358. # 定义生命周期规则
  359. rule = oss2.models.LifecycleRule(rule_id, prefix, status=oss2.models.LifecycleRule.ENABLED,
  360. expiration=oss2.models.LifecycleExpiration(days=expiration_days))
  361. # 设置Bucket的生命周期
  362. lifecycle = oss2.models.BucketLifecycle([rule])
  363. bucket.put_bucket_lifecycle(lifecycle)
  364. print(f"已设置生命周期规则:文件将在{expiration_days}天后自动删除")
  365. ### 2. 上传文件到OSS ###
  366. bucket.put_object_from_file(oss_file_name, local_file_path)
  367. ### 3. 构建公共访问URL ###
  368. file_url = f"http://{bucket_name}.{endpoint.replace('http://', '')}/{oss_file_name}"
  369. print(f"文件上传成功,公共访问地址:{file_url}")
  370. return file_url
  371. def contains_keywords_by_re(text):
  372. # 匹配<ask>标签中的内容
  373. # match = re.search(r'<ask>(.*?)</ask>', text)
  374. match = re.search(r'(.*?)', text)
  375. if match:
  376. content = match.group(1)
  377. # 检查关键词
  378. keywords = ['分析', '总结', '报告', '描述']
  379. for keyword in keywords:
  380. if keyword in content:
  381. return True
  382. return False
  383. def contains_keywords(text):
  384. keywords = ["分析", "总结", "报告", "描述","说说","讲述","讲讲","讲一下","图片"]
  385. return any(keyword in text for keyword in keywords)