Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

573 lines
26KB

  1. import requests
  2. import json
  3. import plugins
  4. from bridge.reply import Reply, ReplyType
  5. from bridge.context import ContextType
  6. from channel.chat_message import ChatMessage
  7. from plugins import *
  8. from common.log import logger
  9. from common.expired_dict import ExpiredDict
  10. import os
  11. import base64
  12. from pathlib import Path
  13. from PIL import Image
  14. import oss2
  15. EXTENSION_TO_TYPE = {
  16. 'pdf': 'pdf',
  17. 'doc': 'docx', 'docx': 'docx',
  18. 'md': 'md',
  19. 'txt': 'txt',
  20. 'xls': 'excel', 'xlsx': 'excel',
  21. 'csv': 'csv',
  22. 'html': 'html', 'htm': 'html',
  23. 'ppt': 'ppt', 'pptx': 'ppt'
  24. }
  25. @plugins.register(
  26. name="kimi4upload",
  27. desire_priority=-1,
  28. desc="A plugin for upload",
  29. version="0.0.01",
  30. author="",
  31. )
  32. class file4upload(Plugin):
  33. def __init__(self):
  34. super().__init__()
  35. try:
  36. curdir = os.path.dirname(__file__)
  37. config_path = os.path.join(curdir, "config.json")
  38. if os.path.exists(config_path):
  39. with open(config_path, "r", encoding="utf-8") as f:
  40. self.config = json.load(f)
  41. else:
  42. # 使用父类的方法来加载配置
  43. self.config = super().load_config()
  44. if not self.config:
  45. raise Exception("config.json not found")
  46. # 设置事件处理函数
  47. self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
  48. self.params_cache = ExpiredDict(300)
  49. # 从配置中提取所需的设置
  50. self.keys = self.config.get("keys", {})
  51. self.url_sum = self.config.get("url_sum", {})
  52. self.search_sum = self.config.get("search_sum", {})
  53. self.file_sum = self.config.get("file_sum", {})
  54. self.image_sum = self.config.get("image_sum", {})
  55. self.note = self.config.get("note", {})
  56. self.sum4all_key = self.keys.get("sum4all_key", "")
  57. self.search1api_key = self.keys.get("search1api_key", "")
  58. self.gemini_key = self.keys.get("gemini_key", "")
  59. self.bibigpt_key = self.keys.get("bibigpt_key", "")
  60. self.outputLanguage = self.keys.get("outputLanguage", "zh-CN")
  61. self.opensum_key = self.keys.get("opensum_key", "")
  62. self.open_ai_api_key = self.keys.get("open_ai_api_key", "")
  63. self.model = self.keys.get("model", "gpt-3.5-turbo")
  64. self.open_ai_api_base = self.keys.get("open_ai_api_base", "https://api.openai.com/v1")
  65. self.xunfei_app_id = self.keys.get("xunfei_app_id", "")
  66. self.xunfei_api_key = self.keys.get("xunfei_api_key", "")
  67. self.xunfei_api_secret = self.keys.get("xunfei_api_secret", "")
  68. self.perplexity_key = self.keys.get("perplexity_key", "")
  69. self.flomo_key = self.keys.get("flomo_key", "")
  70. # 之前提示
  71. self.previous_prompt=''
  72. self.file_sum_enabled = self.file_sum.get("enabled", False)
  73. self.file_sum_service = self.file_sum.get("service", "")
  74. self.max_file_size = self.file_sum.get("max_file_size", 15000)
  75. self.file_sum_group = self.file_sum.get("group", True)
  76. self.file_sum_qa_prefix = self.file_sum.get("qa_prefix", "问")
  77. self.file_sum_prompt = self.file_sum.get("prompt", "")
  78. self.image_sum_enabled = self.image_sum.get("enabled", False)
  79. self.image_sum_service = self.image_sum.get("service", "")
  80. self.image_sum_group = self.image_sum.get("group", True)
  81. self.image_sum_qa_prefix = self.image_sum.get("qa_prefix", "问")
  82. self.image_sum_prompt = self.image_sum.get("prompt", "")
  83. # 初始化成功日志
  84. logger.info("[file4upload] inited.")
  85. except Exception as e:
  86. # 初始化失败日志
  87. logger.warn(f"file4upload init failed: {e}")
  88. def on_handle_context(self, e_context: EventContext):
  89. context = e_context["context"]
  90. if context.type not in [ContextType.TEXT, ContextType.SHARING,ContextType.FILE,ContextType.IMAGE]:
  91. return
  92. msg: ChatMessage = e_context["context"]["msg"]
  93. user_id = msg.from_user_id
  94. content = context.content
  95. isgroup = e_context["context"].get("isgroup", False)
  96. # logger.info(f"user_id:{user_id},content:{content},isgroup:{isgroup}")
  97. # logger.info(f'上下文参数缓存键字典:{self.params_cache.keys}')
  98. # logger.info(f'user_id in self.params_cache: {user_id in self.params_cache}')
  99. # 上次提示
  100. if context.type == ContextType.TEXT:
  101. self.previous_prompt=msg.content
  102. if isgroup and not self.file_sum_group:
  103. # 群聊中忽略处理文件
  104. logger.info("群聊消息,文件处理功能已禁用")
  105. return
  106. logger.info("on_handle_context: 处理上下文开始")
  107. context.get("msg").prepare()
  108. # file_path = context.content
  109. # logger.info(f"on_handle_context: 获取到文件路径 {file_path}")
  110. api_key='sk-5z2L4zy9T1w90j6e3T90ANZdyN2zLWClRwFnBzWgzdrG4onx'
  111. if context.type == ContextType.IMAGE:
  112. file_path = context.content
  113. logger.info(f"on_handle_context: 获取到文件路径 {file_path}")
  114. print(f'处理首次上次的图片,准备抽取文字')
  115. file_content=self.extract_content_by_llm(file_path,api_key)
  116. self.params_cache[user_id] = {}
  117. if file_content is not None:
  118. logger.info('图片中抽取文字,使用使用图片的文字请求LLM')
  119. messages = [{
  120. "role": "system",
  121. "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视,黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。",
  122. },{
  123. "role": "system",
  124. "content": file_content,
  125. }]
  126. self.params_cache[user_id]['last_word_messages']=messages
  127. self.params_cache[user_id]['last_image_oss']=None
  128. else:
  129. logger.info('不能抽取文字,使用图片oss请求LLM')
  130. # logger.info(f"on_handle_context: 获取到图片路径 {file_path}")
  131. # base64_image=self.encode_image_to_base64(file_path)
  132. # self.params_cache[user_id]['last_image_oss']=base64_image
  133. # self.params_cache[user_id]['last_word_messages']=None
  134. access_key_id = 'LTAI5tRTG6pLhTpKACJYoPR5'
  135. access_key_secret = 'E7dMzeeMxq4VQvLg7Tq7uKf3XWpYfN'
  136. # OSS区域对应的Endpoint
  137. endpoint = 'http://oss-cn-shanghai.aliyuncs.com' # 根据你的区域选择
  138. # Bucket名称
  139. bucket_name = 'cow-agent'
  140. local_file_path=file_path
  141. oss_file_name=f'cow/{os.path.basename(file_path)}'
  142. logger.info(f'oss_file_name:{oss_file_name}\n local_file_path :{local_file_path}')
  143. file_url = upload_oss(access_key_id, access_key_secret, endpoint, bucket_name, local_file_path, oss_file_name)
  144. logger.info(f'写入图片缓存oss 地址{file_url}')
  145. self.params_cache[user_id]['last_image_oss']=file_url
  146. self.params_cache[user_id]['last_word_messages']=None
  147. if self.previous_prompt == '':
  148. reply = Reply()
  149. reply.type = ReplyType.TEXT
  150. # reply.content = f"{remove_markdown(reply_content)}\n\n💬5min内输入{self.file_sum_qa_prefix}+问题,可继续追问"
  151. reply.content = f"您刚刚上传了一张图片,请问我有什么可以帮您的呢?"
  152. e_context["reply"] = reply
  153. e_context.action = EventAction.BREAK
  154. return
  155. else:
  156. e_context.action = EventAction.CONTINUE
  157. if context.type == ContextType.FILE:
  158. file_path = context.content
  159. logger.info(f"on_handle_context: 获取到文件路径 {file_path}")
  160. print(f'处理首次上次的文件')
  161. file_content=self.extract_content_by_llm(file_path,api_key)
  162. if file_content is not None:
  163. self.params_cache[user_id] = {}
  164. messages = [{
  165. "role": "system",
  166. "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视,黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。",
  167. },{
  168. "role": "system",
  169. "content": file_content,
  170. }]
  171. self.params_cache[user_id]['last_word_messages']=messages
  172. self.params_cache[user_id]['last_image_oss']=None
  173. if self.previous_prompt == '':
  174. reply = Reply()
  175. reply.type = ReplyType.TEXT
  176. # reply.content = f"{remove_markdown(reply_content)}\n\n💬5min内输入{self.file_sum_qa_prefix}+问题,可继续追问"
  177. reply.content = f"您刚刚上传了一个文件,请问我有什么可以帮您的呢?如可以问“请总结分析这份报告文件,同时,提供治疗和健康建议。”"
  178. e_context["reply"] = reply
  179. e_context.action = EventAction.BREAK
  180. return
  181. else:
  182. e_context.action = EventAction.CONTINUE
  183. # if user_id in self.params_cache and (self.params_cache[user_id]['last_word_messages']!=None):
  184. if user_id in self.params_cache:
  185. if 'last_word_messages' in self.params_cache[user_id] and self.params_cache[user_id]['last_word_messages'] is not None:
  186. print(f'缓存处理已经上传的文件')
  187. # last_word_messages=self.params_cache[user_id]['last_word_messages']
  188. # cache_messages=last_word_messages[:2]
  189. cache_messages=self.params_cache[user_id]['last_word_messages']
  190. messages = [
  191. *cache_messages,
  192. {
  193. "role": "user",
  194. "content": self.previous_prompt ,#msg.content,
  195. },
  196. ]
  197. self.handle_file_upload(messages, e_context)
  198. # if user_id in self.params_cache and ('last_image_oss' in self.params_cache[user_id] or self.params_cache[user_id]['last_image_oss']!=None):
  199. if user_id in self.params_cache:
  200. if 'last_image_oss' in self.params_cache[user_id] and self.params_cache[user_id]['last_image_oss'] is not None:
  201. print(f'缓存处理已经oss图片的文件')
  202. file_url=self.params_cache[user_id]['last_image_oss']
  203. messages = [{
  204. "role": "system",
  205. "content": "你是一个能能描述任何图片的智能助手",
  206. },
  207. {
  208. "role": "user",
  209. "content": f'{file_url}\n{self.previous_prompt}',
  210. }]
  211. # messages=[
  212. # {
  213. # "role": "user",
  214. # "content": [
  215. # {
  216. # "type": "image_url",
  217. # "image_url": {
  218. # "url": f"{file_url}"
  219. # }
  220. # },
  221. # {
  222. # "type": "text",
  223. # "text": f"{self.previous_prompt}"
  224. # }
  225. # ]
  226. # }
  227. # ]
  228. self.handle_images_oos(messages, e_context)
  229. def handle_file_upload(self, messages, e_context):
  230. logger.info("handle_file: 向LLM发送内容总结请求")
  231. msg: ChatMessage = e_context["context"]["msg"]
  232. user_id = msg.from_user_id
  233. user_params = self.params_cache.get(user_id, {})
  234. prompt = user_params.get('prompt', self.file_sum_prompt)
  235. self.params_cache[user_id] = {}
  236. try:
  237. api_key = "sk-5z2L4zy9T1w90j6e3T90ANZdyN2zLWClRwFnBzWgzdrG4onx"
  238. # base_url = "https://api.moonshot.cn/v1",
  239. api_url = "https://api.moonshot.cn/v1/chat/completions"
  240. headers = {
  241. 'Content-Type': 'application/json',
  242. 'Authorization': f'Bearer {api_key}'
  243. }
  244. data={
  245. "model": "moonshot-v1-128k",
  246. "messages":messages,
  247. # "temperature": 0.3
  248. }
  249. response = requests.post(url=api_url, headers=headers, data=json.dumps(data))
  250. logger.info(f'handle_file_upload: 请求文件内容{json.dumps(messages, ensure_ascii=False)}')
  251. response.raise_for_status()
  252. response_data = response.json()
  253. if "choices" in response_data and len(response_data["choices"]) > 0:
  254. first_choice = response_data["choices"][0]
  255. if "message" in first_choice and "content" in first_choice["message"]:
  256. response_content = first_choice["message"]["content"].strip() # 获取响应内容
  257. reply_content = response_content.replace("\\n", "\n") # 替换 \\n 为 \n
  258. # self.params_cache[user_id]['last_word_messages']=messages
  259. # if self.params_cache[user_id]['last_word_messages']!=None:
  260. # self.params_cache[user_id]['last_word_messages']=messages
  261. self.previous_prompt =''
  262. else:
  263. logger.error("Content not found in the response")
  264. reply_content = "Content not found in the LLM API response"
  265. else:
  266. logger.error("No choices available in the response")
  267. reply_content = "No choices available in the LLM API response"
  268. except requests.exceptions.RequestException as e:
  269. logger.error(f"Error calling LLM API: {e}")
  270. reply_content = f"An error occurred while calling LLM API"
  271. reply = Reply()
  272. reply.type = ReplyType.TEXT
  273. reply.content = f"{remove_markdown(reply_content)}"
  274. e_context["reply"] = reply
  275. e_context.action = EventAction.BREAK_PASS
  276. def handle_images_base64(self, messages, e_context):
  277. logger.info("handle_file: 向LLM发送内容总结请求")
  278. msg: ChatMessage = e_context["context"]["msg"]
  279. user_id = msg.from_user_id
  280. user_params = self.params_cache.get(user_id, {})
  281. prompt = user_params.get('prompt', self.file_sum_prompt)
  282. try:
  283. # api_key = "sk-5z2L4zy9T1w90j6e3T90ANZdyN2zLWClRwFnBzWgzdrG4onx"
  284. # # base_url = "https://api.moonshot.cn/v1",
  285. # api_url = "https://api.moonshot.cn/v1/chat/completions"
  286. # api_key = "sk-5dyg7PMUNeoSqHH807453eB06f434c34Ba6fB4764aC8358c"
  287. # api_url = "http://106.15.182.218:3001/v1/chat/completions"
  288. # headers = {
  289. # 'Content-Type': 'application/json',
  290. # 'Authorization': f'Bearer {api_key}'
  291. # }
  292. # data={
  293. # "model": "moonshot-v1-128k",
  294. # "messages":messages,
  295. # # "temperature": 0.3
  296. # }
  297. # response = requests.post(url=api_url, headers=headers, json=data)
  298. base64_image=self.encode_image_to_base64('tmp/240926-164856.png')
  299. api_key = self.open_ai_api_key
  300. api_base = f"{self.open_ai_api_base}/chat/completions"
  301. logger.info(api_base)
  302. payload = {
  303. "model": "moonshot-v1-128k",
  304. "messages": [
  305. {
  306. "role": "user",
  307. "content": [
  308. {
  309. "type": "text",
  310. "text": self.previous_prompt
  311. },
  312. {
  313. "type": "image_url",
  314. "image_url": {
  315. "url": f"data:image/jpeg;base64,{base64_image}"
  316. }
  317. }
  318. ]
  319. }
  320. ],
  321. "max_tokens": 3000
  322. }
  323. # payload = {
  324. # "model": "moonshot-v1-128k",
  325. # "messages": messages,
  326. # "max_tokens": 3000
  327. # }
  328. headers = {
  329. "Content-Type": "application/json",
  330. "Authorization": f"Bearer {api_key}"
  331. }
  332. logger.info('开始')
  333. response = requests.post(api_base, headers=headers, json=payload)
  334. # logger.info(f'handle_file_upload: 请求文件内容{json.dumps(messages, ensure_ascii=False)}')
  335. response.raise_for_status()
  336. response_data = response.json()
  337. if "choices" in response_data and len(response_data["choices"]) > 0:
  338. first_choice = response_data["choices"][0]
  339. if "message" in first_choice and "content" in first_choice["message"]:
  340. response_content = first_choice["message"]["content"].strip() # 获取响应内容
  341. reply_content = response_content.replace("\\n", "\n") # 替换 \\n 为 \n
  342. # self.params_cache[user_id]['last_word_messages']=messages
  343. # if self.params_cache[user_id]['last_word_messages']!=None:
  344. # self.params_cache[user_id]['last_word_messages']=messages
  345. self.previous_prompt =''
  346. else:
  347. logger.error("Content not found in the response")
  348. reply_content = "Content not found in the LLM API response"
  349. else:
  350. logger.error("No choices available in the response")
  351. reply_content = "No choices available in the LLM API response"
  352. except requests.exceptions.RequestException as e:
  353. logger.error(f"Error calling LLM API: {e}")
  354. reply_content = f"An error occurred while calling LLM API"
  355. reply = Reply()
  356. reply.type = ReplyType.TEXT
  357. reply.content = f"{remove_markdown(reply_content)}"
  358. e_context["reply"] = reply
  359. e_context.action = EventAction.BREAK_PASS
  360. def handle_images_oos(self, messages, e_context):
  361. logger.info("handle_file: 向LLM发送内容总结请求")
  362. msg: ChatMessage = e_context["context"]["msg"]
  363. user_id = msg.from_user_id
  364. user_params = self.params_cache.get(user_id, {})
  365. prompt = user_params.get('prompt', self.file_sum_prompt)
  366. self.params_cache[user_id] = {}
  367. try:
  368. api_key = self.open_ai_api_key
  369. api_base = f"{self.open_ai_api_base}/chat/completions"
  370. logger.info(api_base)
  371. payload = {
  372. "model": "7374349217580056592",
  373. "messages":messages,
  374. "max_tokens": 3000
  375. }
  376. # payload = {
  377. # "model": "moonshot-v1-128k",
  378. # "messages": messages,
  379. # "max_tokens": 3000
  380. # }
  381. headers = {
  382. "Content-Type": "application/json",
  383. "Authorization": f"Bearer {api_key}"
  384. }
  385. # logger.info('开始')
  386. response = requests.post(api_base, headers=headers, json=payload)
  387. logger.info(f'handle_file_upload: 请求文件内容{json.dumps(messages, ensure_ascii=False)}')
  388. response.raise_for_status()
  389. response_data = response.json()
  390. if "choices" in response_data and len(response_data["choices"]) > 0:
  391. first_choice = response_data["choices"][0]
  392. if "message" in first_choice and "content" in first_choice["message"]:
  393. response_content = first_choice["message"]["content"].strip() # 获取响应内容
  394. reply_content = response_content.replace("\\n", "\n") # 替换 \\n 为 \n
  395. # self.params_cache[user_id]['last_word_messages']=messages
  396. # if self.params_cache[user_id]['last_word_messages']!=None:
  397. # self.params_cache[user_id]['last_word_messages']=messages
  398. self.previous_prompt =''
  399. else:
  400. logger.error("Content not found in the response")
  401. reply_content = "Content not found in the LLM API response"
  402. else:
  403. logger.error("No choices available in the response")
  404. reply_content = "No choices available in the LLM API response"
  405. except requests.exceptions.RequestException as e:
  406. logger.error(f"Error calling LLM API: {e}")
  407. reply_content = f"An error occurred while calling LLM API"
  408. reply = Reply()
  409. reply.type = ReplyType.TEXT
  410. reply.content = f"{remove_markdown(reply_content)}"
  411. e_context["reply"] = reply
  412. e_context.action = EventAction.BREAK_PASS
  413. def extract_content_by_llm(self, file_path: str, api_key: str) -> str:
  414. logger.info(f'大模型开始抽取文字')
  415. try:
  416. headers = {
  417. 'Authorization': f'Bearer {api_key}'
  418. }
  419. data = {
  420. 'purpose': 'file-extract',
  421. }
  422. file_name=os.path.basename(file_path)
  423. files = {
  424. 'file': (file_name, open(Path(file_path), 'rb')),
  425. }
  426. print(files)
  427. api_url='https://api.moonshot.cn/v1/files'
  428. response = requests.post(api_url, headers=headers, files=files, data=data)
  429. # print(response.text)
  430. response_data = response.json()
  431. file_id = response_data.get('id')
  432. # print(f'文件id:{file_id}')
  433. response=requests.get(url=f"https://api.moonshot.cn/v1/files/{file_id}/content", headers=headers)
  434. print(response.text)
  435. response_data = response.json()
  436. content = response_data.get('content')
  437. return content
  438. except requests.exceptions.RequestException as e:
  439. logger.error(f"Error calling LLM API: {e}")
  440. return None
  441. def encode_image_to_base64(self, image_path):
  442. logger.info(f"开始处理图片: {image_path}")
  443. try:
  444. with Image.open(image_path) as img:
  445. logger.info(f"成功打开图片. 原始大小: {img.size}")
  446. if img.width > 1024:
  447. new_size = (1024, int(img.height*1024/img.width))
  448. img = img.resize(new_size)
  449. img.save(image_path) # 保存调整大小后的图片
  450. logger.info(f"调整图片大小至: {new_size}")
  451. with open(image_path, "rb") as image_file:
  452. img_byte_arr = image_file.read()
  453. logger.info(f"读取图片完成. 大小: {len(img_byte_arr)} 字节")
  454. encoded = base64.b64encode(img_byte_arr).decode('ascii')
  455. logger.info(f"Base64编码完成. 编码后长度: {len(encoded)}")
  456. return encoded
  457. except Exception as e:
  458. logger.error(f"图片编码过程中发生错误: {str(e)}", exc_info=True)
  459. raise
  460. def remove_markdown(text):
  461. # 替换Markdown的粗体标记
  462. text = text.replace("**", "")
  463. # 替换Markdown的标题标记
  464. text = text.replace("### ", "").replace("## ", "").replace("# ", "")
  465. return text
  466. def upload_oss(access_key_id, access_key_secret, endpoint, bucket_name, local_file_path, oss_file_name, expiration_days=7):
  467. """
  468. 上传文件到阿里云OSS并设置生命周期规则,同时返回文件的公共访问地址。
  469. :param access_key_id: 阿里云AccessKey ID
  470. :param access_key_secret: 阿里云AccessKey Secret
  471. :param endpoint: OSS区域对应的Endpoint
  472. :param bucket_name: OSS中的Bucket名称
  473. :param local_file_path: 本地文件路径
  474. :param oss_file_name: OSS中的文件存储路径
  475. :param expiration_days: 文件保存天数,默认7天后删除
  476. :return: 文件的公共访问地址
  477. """
  478. # 创建Bucket实例
  479. auth = oss2.Auth(access_key_id, access_key_secret)
  480. bucket = oss2.Bucket(auth, endpoint, bucket_name)
  481. ### 1. 设置生命周期规则 ###
  482. rule_id = f'delete_after_{expiration_days}_days' # 规则ID
  483. prefix = oss_file_name.split('/')[0] + '/' # 设置规则应用的前缀为文件所在目录
  484. # 定义生命周期规则
  485. rule = oss2.models.LifecycleRule(rule_id, prefix, status=oss2.models.LifecycleRule.ENABLED,
  486. expiration=oss2.models.LifecycleExpiration(days=expiration_days))
  487. # 设置Bucket的生命周期
  488. lifecycle = oss2.models.BucketLifecycle([rule])
  489. bucket.put_bucket_lifecycle(lifecycle)
  490. print(f"已设置生命周期规则:文件将在{expiration_days}天后自动删除")
  491. ### 2. 上传文件到OSS ###
  492. bucket.put_object_from_file(oss_file_name, local_file_path)
  493. ### 3. 构建公共访问URL ###
  494. file_url = f"http://{bucket_name}.{endpoint.replace('http://', '')}/{oss_file_name}"
  495. print(f"文件上传成功,公共访问地址:{file_url}")
  496. return file_url