You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

coze4upload.py 16KB

3 weeks ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import requests
  2. import json
  3. import plugins
  4. from bridge.reply import Reply, ReplyType
  5. from bridge.context import ContextType
  6. from channel.chat_message import ChatMessage
  7. from plugins import *
  8. from common.log import logger
  9. from common.expired_dict import ExpiredDict
  10. import os
  11. import base64
  12. from pathlib import Path
  13. from PIL import Image
  14. import oss2
  15. from lib import itchat
  16. from lib.itchat.content import *
  17. import re
  18. # C:\Users\vsoni\source\repos\chatgpt-on-wechat\channel\wechat\wechat_channel.py
  19. @plugins.register(
  20. name="coze4upload",
  21. desire_priority=-1,
  22. desc="A plugin for upload",
  23. version="0.0.01",
  24. author="",
  25. )
  26. class coze4upload(Plugin):
  27. def __init__(self):
  28. super().__init__()
  29. try:
  30. curdir = os.path.dirname(__file__)
  31. config_path = os.path.join(curdir, "config.json")
  32. if os.path.exists(config_path):
  33. with open(config_path, "r", encoding="utf-8") as f:
  34. self.config = json.load(f)
  35. else:
  36. # 使用父类的方法来加载配置
  37. self.config = super().load_config()
  38. if not self.config:
  39. raise Exception("config.json not found")
  40. # 设置事件处理函数
  41. self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
  42. self.params_cache = ExpiredDict(300)
  43. # 从配置中提取所需的设置
  44. self.keys = self.config.get("keys", {})
  45. self.url_sum = self.config.get("url_sum", {})
  46. self.search_sum = self.config.get("search_sum", {})
  47. self.file_sum = self.config.get("file_sum", {})
  48. self.image_sum = self.config.get("image_sum", {})
  49. self.note = self.config.get("note", {})
  50. self.sum4all_key = self.keys.get("sum4all_key", "")
  51. self.search1api_key = self.keys.get("search1api_key", "")
  52. self.gemini_key = self.keys.get("gemini_key", "")
  53. self.bibigpt_key = self.keys.get("bibigpt_key", "")
  54. self.outputLanguage = self.keys.get("outputLanguage", "zh-CN")
  55. self.opensum_key = self.keys.get("opensum_key", "")
  56. self.open_ai_api_key = self.keys.get("open_ai_api_key", "")
  57. self.model = self.keys.get("model", "gpt-3.5-turbo")
  58. self.open_ai_api_base = self.keys.get("open_ai_api_base", "https://api.openai.com/v1")
  59. self.xunfei_app_id = self.keys.get("xunfei_app_id", "")
  60. self.xunfei_api_key = self.keys.get("xunfei_api_key", "")
  61. self.xunfei_api_secret = self.keys.get("xunfei_api_secret", "")
  62. self.perplexity_key = self.keys.get("perplexity_key", "")
  63. self.flomo_key = self.keys.get("flomo_key", "")
  64. # 之前提示
  65. self.previous_prompt=''
  66. self.file_sum_enabled = self.file_sum.get("enabled", False)
  67. self.file_sum_service = self.file_sum.get("service", "")
  68. self.max_file_size = self.file_sum.get("max_file_size", 15000)
  69. self.file_sum_group = self.file_sum.get("group", True)
  70. self.file_sum_qa_prefix = self.file_sum.get("qa_prefix", "问")
  71. self.file_sum_prompt = self.file_sum.get("prompt", "")
  72. self.image_sum_enabled = self.image_sum.get("enabled", False)
  73. self.image_sum_service = self.image_sum.get("service", "")
  74. self.image_sum_group = self.image_sum.get("group", True)
  75. self.image_sum_qa_prefix = self.image_sum.get("qa_prefix", "问")
  76. self.image_sum_prompt = self.image_sum.get("prompt", "")
  77. # 初始化成功日志
  78. logger.info("[file4upload] inited.")
  79. except Exception as e:
  80. # 初始化失败日志
  81. logger.warn(f"file4upload init failed: {e}")
  82. # def on_handle_context(self, e_context: EventContext):
  83. # context = e_context["context"]
  84. # # logger.info(f'{e_context.__dict__}')
  85. # # logger.info('---------------------------------')
  86. # # logger.info(f'{ e_context["context"]}')
  87. # logger.info('---------------------------------')
  88. # logger.info(f'{e_context["context"]["msg"]}')
  89. # if context.type not in [ContextType.TEXT, ContextType.SHARING,ContextType.FILE,ContextType.IMAGE]:
  90. # return
  91. # msg: ChatMessage = e_context["context"]["msg"]
  92. # user_id = msg.from_user_id
  93. # content = context.content
  94. # isgroup = e_context["context"].get("isgroup", False)
  95. # print(msg.actual_user_nickname)
  96. # itchat.send(f'@{msg.actual_user_nickname}立刻为你服务', toUserName=user_id)
  97. def on_handle_context(self, e_context: EventContext):
  98. context = e_context["context"]
  99. # logger.info(f'{e_context.__dict__}')
  100. # logger.info('---------------------------------')
  101. # logger.info(f'{ e_context["context"]}')
  102. # logger.info('---------------------------------')
  103. # logger.info(f'{e_context["context"]["msg"]}')
  104. if context.type not in [ContextType.TEXT, ContextType.SHARING,ContextType.FILE,ContextType.IMAGE]:
  105. return
  106. msg: ChatMessage = e_context["context"]["msg"]
  107. user_id = msg.from_user_id
  108. content = context.content
  109. isgroup = e_context["context"].get("isgroup", False)
  110. # itchat.send(f'@{msg.actual_user_nickname}立刻为你服务', toUserName=msg.actual_user_nickname)
  111. if isgroup and not self.file_sum_group:
  112. # 群聊中忽略处理文件
  113. logger.info("群聊消息,文件处理功能已禁用")
  114. return
  115. logger.info("on_handle_context: 处理上下文开始")
  116. context.get("msg").prepare()
  117. api_key='sk-5z2L4zy9T1w90j6e3T90ANZdyN2zLWClRwFnBzWgzdrG4onx'
  118. logger.info(f'当前缓存:self.params_cache:{self.params_cache}')
  119. if user_id not in self.params_cache:
  120. self.params_cache[user_id] = {}
  121. logger.info(f'初始化缓存:{self.params_cache}')
  122. if context.type == ContextType.TEXT and user_id in self.params_cache:
  123. self.params_cache[user_id]['previous_prompt']=msg.content
  124. # print(f'{msg.__dict__}')
  125. if context.type == ContextType.IMAGE:
  126. logger.info('处理图片')
  127. file_path = context.content
  128. logger.info(f"on_handle_context: 获取到图片路径 {file_path},{user_id in self.params_cache}")
  129. if user_id in self.params_cache:
  130. if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']:
  131. reply = Reply()
  132. reply.type = ReplyType.TEXT
  133. reply.content = f"您刚刚上传了一张图片,请问我有什么可以帮您的呢?"
  134. e_context["reply"] = reply
  135. e_context.action = EventAction.BREAK
  136. # else:
  137. print(f'准备抽取文字')
  138. file_content=extract_content_by_llm(file_path,api_key)
  139. if file_content is None:
  140. logger.info('不能抽取文字,使用图片oss请求LLM')
  141. access_key_id = 'LTAI5tRTG6pLhTpKACJYoPR5'
  142. access_key_secret = 'E7dMzeeMxq4VQvLg7Tq7uKf3XWpYfN'
  143. # OSS区域对应的Endpoint
  144. endpoint = 'http://oss-cn-shanghai.aliyuncs.com' # 根据你的区域选择
  145. # Bucket名称
  146. bucket_name = 'cow-agent'
  147. local_file_path=file_path
  148. oss_file_name=f'cow/{os.path.basename(file_path)}'
  149. logger.info(f'oss_file_name:{oss_file_name}\n local_file_path :{local_file_path}')
  150. file_content = upload_oss(access_key_id, access_key_secret, endpoint, bucket_name, local_file_path, oss_file_name)
  151. logger.info(f'写入图片缓存oss 地址{file_content}')
  152. self.params_cache[user_id]['last_content']=file_content
  153. # else:
  154. # logger.warn(f'还没有建立会话')
  155. logger.info('删除图片')
  156. os.remove(file_path)
  157. if context.type == ContextType.FILE:
  158. logger.info('处理图片')
  159. file_path = context.content
  160. logger.info(f"on_handle_context: 获取到文件路径 {file_path}")
  161. if user_id in self.params_cache:
  162. if 'previous_prompt' not in self.params_cache[user_id] and not e_context['context']['isgroup']:
  163. reply = Reply()
  164. reply.type = ReplyType.TEXT
  165. reply.content = f"您刚刚上传了一份文件,请问我有什么可以帮您的呢?"
  166. e_context["reply"] = reply
  167. e_context.action = EventAction.BREAK
  168. # else:
  169. print(f'准备抽取文字')
  170. file_content=extract_content_by_llm(file_path,api_key)
  171. if file_content is None:
  172. reply = Reply()
  173. reply.type = ReplyType.TEXT
  174. reply.content = f"不能处理这份文件"
  175. e_context["reply"] = reply
  176. e_context.action = EventAction.BREAK
  177. return
  178. else:
  179. self.params_cache[user_id]['last_content']=file_content
  180. logger.info('删除图片')
  181. os.remove(file_path)
  182. # logger.info('previous_prompt' in self.params_cache[user_id])
  183. # logger.info('last_content' in self.params_cache[user_id])
  184. is_previous_prompt='previous_prompt' in self.params_cache[user_id]
  185. is_last_content='last_content' in self.params_cache[user_id]
  186. logger.info(f"存在提示词 previous_prompt:{is_previous_prompt}")
  187. logger.info(f'存在内容 last_content:{is_last_content}' )
  188. if 'previous_prompt' in self.params_cache[user_id] and 'last_content' in self.params_cache[user_id] and contains_keywords(self.params_cache[user_id]['previous_prompt']):
  189. #先回应
  190. logger.info('先回应')
  191. # reply2 = Reply()
  192. # reply2.type = ReplyType.TEXT
  193. # reply2.content = f"已经收到,立刻为你服务"
  194. # msg:ChatMessage = e_context['context']['msg']
  195. # e_context['reply'] = reply2
  196. # e_context.action = EventAction.BREAK # 事件结束
  197. # reply = Reply()
  198. # reply.type = ReplyType.TEXT
  199. # reply.content = f"已经收到,立刻为你服务"
  200. # e_context["reply"] = reply
  201. # e_context.action = EventAction.BREAK
  202. receiver=user_id
  203. print(receiver)
  204. # itchat_content= '' if e_context['context']['isgroup'] else '[小蕴]'+"已经收到,立刻为你服务"
  205. # if e_context['context']['isgroup']:
  206. # itchat_content =f'@{msg.actual_user_nickname}已经收到,立刻为你服务'
  207. # else:
  208. # itchat_content = '[小蕴]'+"已经收到,立刻为你服务"
  209. text=self.params_cache[user_id]['previous_prompt']
  210. logger.info(f'{text},{contains_keywords(text)}')
  211. itchat_content= f'@{msg.actual_user_nickname}' if e_context['context']['isgroup'] else '[小蕴]'
  212. itchat_content+="已经收到,立刻为您服务"
  213. flag=contains_keywords(text)
  214. if flag==True:
  215. print('发送'+itchat_content)
  216. itchat.send(itchat_content, toUserName=receiver)
  217. e_context.action = EventAction.BREAK
  218. if 'previous_prompt' in self.params_cache[user_id] and 'last_content' in self.params_cache[user_id]:
  219. if contains_keywords(self.params_cache[user_id]['previous_prompt']):
  220. e_context["context"].type = ContextType.TEXT
  221. e_context["context"].content ="<content>"+self.params_cache[user_id]['last_content']+"</content>"+'\n\t'+"<ask>"+self.params_cache[user_id]['previous_prompt']+"</ask>"
  222. logger.info(f'conze4upload 插件处理上传文件或图片')
  223. e_context.action = EventAction.CONTINUE
  224. # 清空清空缓存
  225. self.params_cache.clear()
  226. logger.info(f'清空缓存后:{self.params_cache}')
  227. else:
  228. if not e_context['context']['isgroup']:
  229. reply = Reply()
  230. reply.type = ReplyType.TEXT
  231. # reply.content = f"{remove_markdown(reply_content)}\n\n💬5min内输入{self.file_sum_qa_prefix}+问题,可继续追问"
  232. reply.content = f"您刚刚上传了,请问我有什么可以帮您的呢?"
  233. e_context["reply"] = reply
  234. e_context.action = EventAction.BREAK
  235. return
  236. ## e_context.action = EventAction.BREAK
  237. def remove_markdown(text):
  238. # 替换Markdown的粗体标记
  239. text = text.replace("**", "")
  240. # 替换Markdown的标题标记
  241. text = text.replace("### ", "").replace("## ", "").replace("# ", "")
  242. return text
  243. def extract_content_by_llm(file_path: str, api_key: str) -> str:
  244. logger.info(f'大模型开始抽取文字')
  245. try:
  246. headers = {
  247. 'Authorization': f'Bearer {api_key}'
  248. }
  249. data = {
  250. 'purpose': 'file-extract',
  251. }
  252. file_name=os.path.basename(file_path)
  253. files = {
  254. 'file': (file_name, open(Path(file_path), 'rb')),
  255. }
  256. # print(files)
  257. api_url='https://api.moonshot.cn/v1/files'
  258. response = requests.post(api_url, headers=headers, files=files, data=data)
  259. response_data = response.json()
  260. file_id = response_data.get('id')
  261. response=requests.get(url=f"https://api.moonshot.cn/v1/files/{file_id}/content", headers=headers)
  262. print(response.text)
  263. response_data = response.json()
  264. content = response_data.get('content')
  265. return content
  266. except requests.exceptions.RequestException as e:
  267. logger.error(f"Error calling LLM API: {e}")
  268. return None
  269. def upload_oss(access_key_id, access_key_secret, endpoint, bucket_name, local_file_path, oss_file_name, expiration_days=7):
  270. """
  271. 上传文件到阿里云OSS并设置生命周期规则,同时返回文件的公共访问地址。
  272. :param access_key_id: 阿里云AccessKey ID
  273. :param access_key_secret: 阿里云AccessKey Secret
  274. :param endpoint: OSS区域对应的Endpoint
  275. :param bucket_name: OSS中的Bucket名称
  276. :param local_file_path: 本地文件路径
  277. :param oss_file_name: OSS中的文件存储路径
  278. :param expiration_days: 文件保存天数,默认7天后删除
  279. :return: 文件的公共访问地址
  280. """
  281. # 创建Bucket实例
  282. auth = oss2.Auth(access_key_id, access_key_secret)
  283. bucket = oss2.Bucket(auth, endpoint, bucket_name)
  284. ### 1. 设置生命周期规则 ###
  285. rule_id = f'delete_after_{expiration_days}_days' # 规则ID
  286. prefix = oss_file_name.split('/')[0] + '/' # 设置规则应用的前缀为文件所在目录
  287. # 定义生命周期规则
  288. rule = oss2.models.LifecycleRule(rule_id, prefix, status=oss2.models.LifecycleRule.ENABLED,
  289. expiration=oss2.models.LifecycleExpiration(days=expiration_days))
  290. # 设置Bucket的生命周期
  291. lifecycle = oss2.models.BucketLifecycle([rule])
  292. bucket.put_bucket_lifecycle(lifecycle)
  293. print(f"已设置生命周期规则:文件将在{expiration_days}天后自动删除")
  294. ### 2. 上传文件到OSS ###
  295. bucket.put_object_from_file(oss_file_name, local_file_path)
  296. ### 3. 构建公共访问URL ###
  297. file_url = f"http://{bucket_name}.{endpoint.replace('http://', '')}/{oss_file_name}"
  298. print(f"文件上传成功,公共访问地址:{file_url}")
  299. return file_url
  300. def contains_keywords_by_re(text):
  301. # 匹配<ask>标签中的内容
  302. # match = re.search(r'<ask>(.*?)</ask>', text)
  303. match = re.search(r'(.*?)', text)
  304. if match:
  305. content = match.group(1)
  306. # 检查关键词
  307. keywords = ['分析', '总结', '报告', '描述']
  308. for keyword in keywords:
  309. if keyword in content:
  310. return True
  311. return False
  312. def contains_keywords(text):
  313. keywords = ["分析", "总结", "报告", "描述"]
  314. return any(keyword in text for keyword in keywords)