You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

310 lines
12KB

  1. import io
  2. import os
  3. import random
  4. import tempfile
  5. import threading
  6. os.environ['ntwork_LOG'] = "ERROR"
  7. import ntwork
  8. import requests
  9. import uuid
  10. from bridge.context import *
  11. from bridge.reply import *
  12. from channel.chat_channel import ChatChannel
  13. from channel.wework.wework_message import *
  14. from channel.wework.wework_message import WeworkMessage
  15. from common.singleton import singleton
  16. from common.log import logger
  17. from common.time_check import time_checker
  18. from config import conf
  19. from channel.wework.run import wework
  20. from channel.wework import run
  21. from PIL import Image
  22. def get_wxid_by_name(room_members, group_wxid, name):
  23. if group_wxid in room_members:
  24. for member in room_members[group_wxid]['member_list']:
  25. if member['room_nickname'] == name or member['username'] == name:
  26. return member['user_id']
  27. return None # 如果没有找到对应的group_wxid或name,则返回None
  28. def download_and_compress_image(url, filename, quality=30):
  29. # 确定保存图片的目录
  30. directory = os.path.join(os.getcwd(), "tmp")
  31. # 如果目录不存在,则创建目录
  32. if not os.path.exists(directory):
  33. os.makedirs(directory)
  34. # 下载图片
  35. response = requests.get(url)
  36. image = Image.open(io.BytesIO(response.content))
  37. # 压缩图片
  38. image_path = os.path.join(directory, f"{filename}.jpg")
  39. image.save(image_path, "JPEG", quality=quality)
  40. return image_path
  41. def download_video(url, filename):
  42. # 确定保存视频的目录
  43. directory = os.path.join(os.getcwd(), "tmp")
  44. # 如果目录不存在,则创建目录
  45. if not os.path.exists(directory):
  46. os.makedirs(directory)
  47. # 下载视频
  48. response = requests.get(url, stream=True)
  49. total_size = 0
  50. video_path = os.path.join(directory, f"{filename}.mp4")
  51. with open(video_path, 'wb') as f:
  52. for block in response.iter_content(1024):
  53. total_size += len(block)
  54. # 如果视频的总大小超过30MB (30 * 1024 * 1024 bytes),则停止下载并返回
  55. if total_size > 30 * 1024 * 1024:
  56. logger.info("[WX] Video is larger than 30MB, skipping...")
  57. return None
  58. f.write(block)
  59. return video_path
  60. def create_message(wework_instance, message, is_group):
  61. logger.debug(f"正在为{'群聊' if is_group else '单聊'}创建 WeworkMessage")
  62. cmsg = WeworkMessage(message, wework=wework_instance, is_group=is_group)
  63. logger.debug(f"cmsg:{cmsg}")
  64. return cmsg
  65. def handle_message(cmsg, is_group):
  66. logger.debug(f"准备用 WeworkChannel 处理{'群聊' if is_group else '单聊'}消息")
  67. if is_group:
  68. WeworkChannel().handle_group(cmsg)
  69. else:
  70. WeworkChannel().handle_single(cmsg)
  71. logger.debug(f"已用 WeworkChannel 处理完{'群聊' if is_group else '单聊'}消息")
  72. def _check(func):
  73. def wrapper(self, cmsg: ChatMessage):
  74. msgId = cmsg.msg_id
  75. create_time = cmsg.create_time # 消息时间戳
  76. if create_time is None:
  77. return func(self, cmsg)
  78. if int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
  79. logger.debug("[WX]history message {} skipped".format(msgId))
  80. return
  81. return func(self, cmsg)
  82. return wrapper
  83. @wework.msg_register(
  84. [ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_VOICE_MSG])
  85. def all_msg_handler(wework_instance: ntwork.WeWork, message):
  86. logger.debug(f"收到消息: {message}")
  87. if 'data' in message:
  88. # 首先查找conversation_id,如果没有找到,则查找room_conversation_id
  89. conversation_id = message['data'].get('conversation_id', message['data'].get('room_conversation_id'))
  90. if conversation_id is not None:
  91. is_group = "R:" in conversation_id
  92. try:
  93. cmsg = create_message(wework_instance=wework_instance, message=message, is_group=is_group)
  94. except NotImplementedError as e:
  95. logger.error(f"[WX]{message.get('MsgId', 'unknown')} 跳过: {e}")
  96. return None
  97. delay = random.randint(1, 2)
  98. timer = threading.Timer(delay, handle_message, args=(cmsg, is_group))
  99. timer.start()
  100. else:
  101. logger.debug("消息数据中无 conversation_id")
  102. return None
  103. return None
  104. def accept_friend_with_retries(wework_instance, user_id, corp_id):
  105. result = wework_instance.accept_friend(user_id, corp_id)
  106. logger.debug(f'result:{result}')
  107. # @wework.msg_register(ntwork.MT_RECV_FRIEND_MSG)
  108. # def friend(wework_instance: ntwork.WeWork, message):
  109. # data = message["data"]
  110. # user_id = data["user_id"]
  111. # corp_id = data["corp_id"]
  112. # logger.info(f"接收到好友请求,消息内容:{data}")
  113. # delay = random.randint(1, 180)
  114. # threading.Timer(delay, accept_friend_with_retries, args=(wework_instance, user_id, corp_id)).start()
  115. #
  116. # return None
  117. def get_with_retry(get_func, max_retries=5, delay=5):
  118. retries = 0
  119. result = None
  120. while retries < max_retries:
  121. result = get_func()
  122. if result:
  123. break
  124. logger.warning(f"获取数据失败,重试第{retries + 1}次······")
  125. retries += 1
  126. time.sleep(delay) # 等待一段时间后重试
  127. return result
  128. @singleton
  129. class WeworkChannel(ChatChannel):
  130. NOT_SUPPORT_REPLYTYPE = []
  131. def __init__(self):
  132. super().__init__()
  133. def startup(self):
  134. smart = conf().get("wework_smart", True)
  135. wework.open(smart)
  136. logger.info("等待登录······")
  137. wework.wait_login()
  138. login_info = wework.get_login_info()
  139. self.user_id = login_info['user_id']
  140. self.name = login_info['nickname']
  141. logger.info(f"登录信息:>>>user_id:{self.user_id}>>>>>>>>name:{self.name}")
  142. logger.info("静默延迟60s,等待客户端刷新数据,请勿进行任何操作······")
  143. time.sleep(60)
  144. contacts = get_with_retry(wework.get_external_contacts)
  145. rooms = get_with_retry(wework.get_rooms)
  146. directory = os.path.join(os.getcwd(), "tmp")
  147. if not contacts or not rooms:
  148. logger.error("获取contacts或rooms失败,程序退出")
  149. ntwork.exit_()
  150. os.exit(0)
  151. if not os.path.exists(directory):
  152. os.makedirs(directory)
  153. # 将contacts保存到json文件中
  154. with open(os.path.join(directory, 'wework_contacts.json'), 'w', encoding='utf-8') as f:
  155. json.dump(contacts, f, ensure_ascii=False, indent=4)
  156. with open(os.path.join(directory, 'wework_rooms.json'), 'w', encoding='utf-8') as f:
  157. json.dump(rooms, f, ensure_ascii=False, indent=4)
  158. # 创建一个空字典来保存结果
  159. result = {}
  160. # 遍历列表中的每个字典
  161. for room in rooms['room_list']:
  162. # 获取聊天室ID
  163. room_wxid = room['conversation_id']
  164. # 获取聊天室成员
  165. room_members = wework.get_room_members(room_wxid)
  166. # 将聊天室成员保存到结果字典中
  167. result[room_wxid] = room_members
  168. # 将结果保存到json文件中
  169. with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
  170. json.dump(result, f, ensure_ascii=False, indent=4)
  171. logger.info("wework程序初始化完成········")
  172. run.forever()
  173. @time_checker
  174. @_check
  175. def handle_single(self, cmsg: ChatMessage):
  176. if cmsg.from_user_id == cmsg.to_user_id:
  177. # ignore self reply
  178. return
  179. if cmsg.ctype == ContextType.VOICE:
  180. if not conf().get("speech_recognition"):
  181. return
  182. logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
  183. elif cmsg.ctype == ContextType.IMAGE:
  184. logger.debug("[WX]receive image msg: {}".format(cmsg.content))
  185. elif cmsg.ctype == ContextType.PATPAT:
  186. logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
  187. elif cmsg.ctype == ContextType.TEXT:
  188. logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
  189. else:
  190. logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
  191. context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
  192. if context:
  193. self.produce(context)
  194. @time_checker
  195. @_check
  196. def handle_group(self, cmsg: ChatMessage):
  197. if cmsg.ctype == ContextType.VOICE:
  198. if not conf().get("speech_recognition"):
  199. return
  200. logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
  201. elif cmsg.ctype == ContextType.IMAGE:
  202. logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
  203. elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
  204. logger.debug("[WX]receive note msg: {}".format(cmsg.content))
  205. elif cmsg.ctype == ContextType.TEXT:
  206. pass
  207. else:
  208. logger.debug("[WX]receive group msg: {}".format(cmsg.content))
  209. context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
  210. if context:
  211. self.produce(context)
  212. # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
  213. def send(self, reply: Reply, context: Context):
  214. logger.debug(f"context: {context}")
  215. receiver = context["receiver"]
  216. actual_user_id = context["msg"].actual_user_id
  217. if reply.type == ReplyType.TEXT or reply.type == ReplyType.TEXT_:
  218. match = re.search(r"^@(.*?)\n", reply.content)
  219. logger.debug(f"match: {match}")
  220. if match:
  221. new_content = re.sub(r"^@(.*?)\n", "\n", reply.content)
  222. at_list = [actual_user_id]
  223. logger.debug(f"new_content: {new_content}")
  224. wework.send_room_at_msg(receiver, new_content, at_list)
  225. else:
  226. wework.send_text(receiver, reply.content)
  227. logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
  228. elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
  229. wework.send_text(receiver, reply.content)
  230. logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
  231. elif reply.type == ReplyType.IMAGE: # 从文件读取图片
  232. image_storage = reply.content
  233. image_storage.seek(0)
  234. # Read data from image_storage
  235. data = image_storage.read()
  236. # Create a temporary file
  237. with tempfile.NamedTemporaryFile(delete=False) as temp:
  238. temp_path = temp.name
  239. temp.write(data)
  240. # Send the image
  241. wework.send_image(receiver, temp_path)
  242. logger.info("[WX] sendImage, receiver={}".format(receiver))
  243. # Remove the temporary file
  244. os.remove(temp_path)
  245. elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
  246. img_url = reply.content
  247. filename = str(uuid.uuid4())
  248. # 调用你的函数,下载图片并保存为本地文件
  249. image_path = download_and_compress_image(img_url, filename)
  250. wework.send_image(receiver, file_path=image_path)
  251. logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
  252. elif reply.type == ReplyType.VIDEO_URL:
  253. video_url = reply.content
  254. filename = str(uuid.uuid4())
  255. video_path = download_video(video_url, filename)
  256. if video_path is None:
  257. # 如果视频太大,下载可能会被跳过,此时 video_path 将为 None
  258. wework.send_text(receiver, "抱歉,视频太大了!!!")
  259. else:
  260. wework.send_video(receiver, video_path)
  261. logger.info("[WX] sendVideo, receiver={}".format(receiver))
  262. elif reply.type == ReplyType.VOICE:
  263. wework.send_file(receiver, reply.content)
  264. logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))