You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

coze4upload.txt 12KB

2 weeks ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import requests
  2. import json
  3. import plugins
  4. from bridge.reply import Reply, ReplyType
  5. from bridge.context import ContextType
  6. from channel.chat_message import ChatMessage
  7. from plugins import *
  8. from common.log import logger
  9. from common.expired_dict import ExpiredDict
  10. import os
  11. import base64
  12. from pathlib import Path
  13. from PIL import Image
  14. import oss2
  15. @plugins.register(
  16. name="coze4upload",
  17. desire_priority=-1,
  18. desc="A plugin for upload",
  19. version="0.0.01",
  20. author="",
  21. )
  22. class coze4upload(Plugin):
  23. def __init__(self):
  24. super().__init__()
  25. try:
  26. curdir = os.path.dirname(__file__)
  27. config_path = os.path.join(curdir, "config.json")
  28. if os.path.exists(config_path):
  29. with open(config_path, "r", encoding="utf-8") as f:
  30. self.config = json.load(f)
  31. else:
  32. # 使用父类的方法来加载配置
  33. self.config = super().load_config()
  34. if not self.config:
  35. raise Exception("config.json not found")
  36. # 设置事件处理函数
  37. self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
  38. self.params_cache = ExpiredDict(300)
  39. # 从配置中提取所需的设置
  40. self.keys = self.config.get("keys", {})
  41. self.url_sum = self.config.get("url_sum", {})
  42. self.search_sum = self.config.get("search_sum", {})
  43. self.file_sum = self.config.get("file_sum", {})
  44. self.image_sum = self.config.get("image_sum", {})
  45. self.note = self.config.get("note", {})
  46. self.sum4all_key = self.keys.get("sum4all_key", "")
  47. self.search1api_key = self.keys.get("search1api_key", "")
  48. self.gemini_key = self.keys.get("gemini_key", "")
  49. self.bibigpt_key = self.keys.get("bibigpt_key", "")
  50. self.outputLanguage = self.keys.get("outputLanguage", "zh-CN")
  51. self.opensum_key = self.keys.get("opensum_key", "")
  52. self.open_ai_api_key = self.keys.get("open_ai_api_key", "")
  53. self.model = self.keys.get("model", "gpt-3.5-turbo")
  54. self.open_ai_api_base = self.keys.get("open_ai_api_base", "https://api.openai.com/v1")
  55. self.xunfei_app_id = self.keys.get("xunfei_app_id", "")
  56. self.xunfei_api_key = self.keys.get("xunfei_api_key", "")
  57. self.xunfei_api_secret = self.keys.get("xunfei_api_secret", "")
  58. self.perplexity_key = self.keys.get("perplexity_key", "")
  59. self.flomo_key = self.keys.get("flomo_key", "")
  60. # 之前提示
  61. self.previous_prompt=''
  62. self.file_sum_enabled = self.file_sum.get("enabled", False)
  63. self.file_sum_service = self.file_sum.get("service", "")
  64. self.max_file_size = self.file_sum.get("max_file_size", 15000)
  65. self.file_sum_group = self.file_sum.get("group", True)
  66. self.file_sum_qa_prefix = self.file_sum.get("qa_prefix", "问")
  67. self.file_sum_prompt = self.file_sum.get("prompt", "")
  68. self.image_sum_enabled = self.image_sum.get("enabled", False)
  69. self.image_sum_service = self.image_sum.get("service", "")
  70. self.image_sum_group = self.image_sum.get("group", True)
  71. self.image_sum_qa_prefix = self.image_sum.get("qa_prefix", "问")
  72. self.image_sum_prompt = self.image_sum.get("prompt", "")
  73. # 初始化成功日志
  74. logger.info("[file4upload] inited.")
  75. except Exception as e:
  76. # 初始化失败日志
  77. logger.warn(f"file4upload init failed: {e}")
  78. def on_handle_context(self, e_context: EventContext):
  79. context = e_context["context"]
  80. if context.type not in [ContextType.TEXT, ContextType.SHARING,ContextType.FILE,ContextType.IMAGE]:
  81. return
  82. msg: ChatMessage = e_context["context"]["msg"]
  83. user_id = msg.from_user_id
  84. content = context.content
  85. isgroup = e_context["context"].get("isgroup", False)
  86. if isgroup and not self.file_sum_group:
  87. # 群聊中忽略处理文件
  88. logger.info("群聊消息,文件处理功能已禁用")
  89. return
  90. logger.info("on_handle_context: 处理上下文开始")
  91. context.get("msg").prepare()
  92. api_key='sk-5z2L4zy9T1w90j6e3T90ANZdyN2zLWClRwFnBzWgzdrG4onx'
  93. logger.info(f'self.params_cache:{self.params_cache}')
  94. if user_id not in self.params_cache:
  95. self.params_cache[user_id] = {}
  96. if context.type == ContextType.TEXT and user_id in self.params_cache:
  97. self.params_cache[user_id]['previous_prompt']=msg.content
  98. # print(f'{msg.__dict__}')
  99. if context.type == ContextType.IMAGE:
  100. logger.info('处理图片')
  101. file_path = context.content
  102. logger.info(f"on_handle_context: 获取到图片路径 {file_path}")
  103. if user_id in self.params_cache:
  104. if 'previous_prompt' not in self.params_cache[user_id]:
  105. reply = Reply()
  106. reply.type = ReplyType.TEXT
  107. # reply.content = f"{remove_markdown(reply_content)}\n\n💬5min内输入{self.file_sum_qa_prefix}+问题,可继续追问"
  108. reply.content = f"您刚刚上传了一张图片,请问我有什么可以帮您的呢?"
  109. e_context["reply"] = reply
  110. e_context.action = EventAction.BREAK
  111. else:
  112. print(f'准备抽取文字')
  113. file_content=extract_content_by_llm(file_path,api_key)
  114. if file_content is None:
  115. logger.info('不能抽取文字,使用图片oss请求LLM')
  116. access_key_id = 'LTAI5tRTG6pLhTpKACJYoPR5'
  117. access_key_secret = 'E7dMzeeMxq4VQvLg7Tq7uKf3XWpYfN'
  118. # OSS区域对应的Endpoint
  119. endpoint = 'http://oss-cn-shanghai.aliyuncs.com' # 根据你的区域选择
  120. # Bucket名称
  121. bucket_name = 'cow-agent'
  122. local_file_path=file_path
  123. oss_file_name=f'cow/{os.path.basename(file_path)}'
  124. logger.info(f'oss_file_name:{oss_file_name}\n local_file_path :{local_file_path}')
  125. file_content = upload_oss(access_key_id, access_key_secret, endpoint, bucket_name, local_file_path, oss_file_name)
  126. logger.info(f'写入图片缓存oss 地址{file_content}')
  127. self.params_cache[user_id]['last_content']=file_content
  128. else:
  129. logger.warn(f'还没有建立会话')
  130. logger.info('删除图片')
  131. os.remove(file_path)
  132. if context.type== ContextType.FILE:
  133. logger.info('处理图片')
  134. file_path = context.content
  135. logger.info(f"on_handle_context: 获取到文件路径 {file_path}")
  136. if user_id in self.params_cache:
  137. if 'previous_prompt' not in self.params_cache[user_id]:
  138. reply = Reply()
  139. reply.type = ReplyType.TEXT
  140. # reply.content = f"{remove_markdown(reply_content)}\n\n💬5min内输入{self.file_sum_qa_prefix}+问题,可继续追问"
  141. reply.content = f"您刚刚上传了一份文件,请问我有什么可以帮您的呢?"
  142. e_context["reply"] = reply
  143. e_context.action = EventAction.BREAK
  144. else:
  145. print(f'准备抽取文字')
  146. file_content=extract_content_by_llm(file_path,api_key)
  147. if file_content is None:
  148. reply = Reply()
  149. reply.type = ReplyType.TEXT
  150. # reply.content = f"{remove_markdown(reply_content)}\n\n💬5min内输入{self.file_sum_qa_prefix}+问题,可继续追问"
  151. reply.content = f"不能处理这份文件"
  152. e_context["reply"] = reply
  153. e_context.action = EventAction.BREAK
  154. return
  155. else:
  156. self.params_cache[user_id]['last_content']=file_content
  157. logger.info('删除图片')
  158. os.remove(file_path)
  159. logger.info('previous_prompt' in self.params_cache[user_id])
  160. logger.info('last_content' in self.params_cache[user_id])
  161. if 'previous_prompt' in self.params_cache[user_id] and 'last_content' in self.params_cache[user_id] :
  162. e_context["context"].type = ContextType.TEXT
  163. e_context["context"].content = self.params_cache[user_id]['last_content']+'\n\t'+self.params_cache[user_id]['previous_prompt']
  164. logger.info(f'conze4upload 插件处理上传文件或图片')
  165. e_context.action = EventAction.CONTINUE
  166. # 清空清空缓存
  167. self.params_cache.clear()
  168. logger.info(f'清空缓存后:{self.params_cache}')
  169. # e_context.action = EventAction.BREAK
  170. def remove_markdown(text):
  171. # 替换Markdown的粗体标记
  172. text = text.replace("**", "")
  173. # 替换Markdown的标题标记
  174. text = text.replace("### ", "").replace("## ", "").replace("# ", "")
  175. return text
  176. def extract_content_by_llm(file_path: str, api_key: str) -> str:
  177. logger.info(f'大模型开始抽取文字')
  178. try:
  179. headers = {
  180. 'Authorization': f'Bearer {api_key}'
  181. }
  182. data = {
  183. 'purpose': 'file-extract',
  184. }
  185. file_name=os.path.basename(file_path)
  186. files = {
  187. 'file': (file_name, open(Path(file_path), 'rb')),
  188. }
  189. # print(files)
  190. api_url='https://api.moonshot.cn/v1/files'
  191. response = requests.post(api_url, headers=headers, files=files, data=data)
  192. response_data = response.json()
  193. file_id = response_data.get('id')
  194. response=requests.get(url=f"https://api.moonshot.cn/v1/files/{file_id}/content", headers=headers)
  195. print(response.text)
  196. response_data = response.json()
  197. content = response_data.get('content')
  198. return content
  199. except requests.exceptions.RequestException as e:
  200. logger.error(f"Error calling LLM API: {e}")
  201. return None
  202. def upload_oss(access_key_id, access_key_secret, endpoint, bucket_name, local_file_path, oss_file_name, expiration_days=7):
  203. """
  204. 上传文件到阿里云OSS并设置生命周期规则,同时返回文件的公共访问地址。
  205. :param access_key_id: 阿里云AccessKey ID
  206. :param access_key_secret: 阿里云AccessKey Secret
  207. :param endpoint: OSS区域对应的Endpoint
  208. :param bucket_name: OSS中的Bucket名称
  209. :param local_file_path: 本地文件路径
  210. :param oss_file_name: OSS中的文件存储路径
  211. :param expiration_days: 文件保存天数,默认7天后删除
  212. :return: 文件的公共访问地址
  213. """
  214. # 创建Bucket实例
  215. auth = oss2.Auth(access_key_id, access_key_secret)
  216. bucket = oss2.Bucket(auth, endpoint, bucket_name)
  217. ### 1. 设置生命周期规则 ###
  218. rule_id = f'delete_after_{expiration_days}_days' # 规则ID
  219. prefix = oss_file_name.split('/')[0] + '/' # 设置规则应用的前缀为文件所在目录
  220. # 定义生命周期规则
  221. rule = oss2.models.LifecycleRule(rule_id, prefix, status=oss2.models.LifecycleRule.ENABLED,
  222. expiration=oss2.models.LifecycleExpiration(days=expiration_days))
  223. # 设置Bucket的生命周期
  224. lifecycle = oss2.models.BucketLifecycle([rule])
  225. bucket.put_bucket_lifecycle(lifecycle)
  226. print(f"已设置生命周期规则:文件将在{expiration_days}天后自动删除")
  227. ### 2. 上传文件到OSS ###
  228. bucket.put_object_from_file(oss_file_name, local_file_path)
  229. ### 3. 构建公共访问URL ###
  230. file_url = f"http://{bucket_name}.{endpoint.replace('http://', '')}/{oss_file_name}"
  231. print(f"文件上传成功,公共访问地址:{file_url}")
  232. return file_url