You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

391 lines
18KB

  1. # encoding:utf-8
  2. """
  3. wechat channel
  4. """
  5. import os
  6. import re
  7. import requests
  8. import io
  9. import time
  10. from common.singleton import singleton
  11. from lib import itchat
  12. import json
  13. from lib.itchat.content import *
  14. from bridge.reply import *
  15. from bridge.context import *
  16. from channel.channel import Channel
  17. from concurrent.futures import ThreadPoolExecutor
  18. from common.log import logger
  19. from common.tmp_dir import TmpDir
  20. from config import conf
  21. from common.time_check import time_checker
  22. from common.expired_dict import ExpiredDict
  23. from plugins import *
  24. try:
  25. from voice.audio_convert import mp3_to_wav
  26. except Exception as e:
  27. pass
  28. thread_pool = ThreadPoolExecutor(max_workers=8)
  29. def thread_pool_callback(worker):
  30. worker_exception = worker.exception()
  31. if worker_exception:
  32. logger.exception("Worker return exception: {}".format(worker_exception))
  33. @itchat.msg_register(TEXT)
  34. def handler_single_msg(msg):
  35. WechatChannel().handle_text(msg)
  36. return None
  37. @itchat.msg_register(TEXT, isGroupChat=True)
  38. def handler_group_msg(msg):
  39. WechatChannel().handle_group(msg)
  40. return None
  41. @itchat.msg_register(VOICE)
  42. def handler_single_voice(msg):
  43. WechatChannel().handle_voice(msg)
  44. return None
  45. @itchat.msg_register(VOICE, isGroupChat=True)
  46. def handler_group_voice(msg):
  47. WechatChannel().handle_group_voice(msg)
  48. return None
  49. def _check(func):
  50. def wrapper(self, msg):
  51. msgId = msg['MsgId']
  52. if msgId in self.receivedMsgs:
  53. logger.info("Wechat message {} already received, ignore".format(msgId))
  54. return
  55. self.receivedMsgs[msgId] = msg
  56. create_time = msg['CreateTime'] # 消息时间
  57. if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
  58. logger.debug("[WX]history message {} skipped".format(msgId))
  59. return
  60. return func(self, msg)
  61. return wrapper
  62. @singleton
  63. class WechatChannel(Channel):
  64. def __init__(self):
  65. self.user_id = None
  66. self.name = None
  67. self.receivedMsgs = ExpiredDict(60*60*24)
  68. def startup(self):
  69. itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
  70. # login by scan QRCode
  71. hotReload = conf().get('hot_reload', False)
  72. try:
  73. itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
  74. except Exception as e:
  75. if hotReload:
  76. logger.error("Hot reload failed, try to login without hot reload")
  77. itchat.logout()
  78. os.remove("itchat.pkl")
  79. itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
  80. else:
  81. raise e
  82. self.user_id = itchat.instance.storageClass.userName
  83. self.name = itchat.instance.storageClass.nickName
  84. logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
  85. # start message listener
  86. itchat.run()
  87. # handle_* 系列函数处理收到的消息后构造Context,然后传入handle函数中处理Context和发送回复
  88. # Context包含了消息的所有信息,包括以下属性
  89. # type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
  90. # content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
  91. # kwargs 附加参数字典,包含以下的key:
  92. # session_id: 会话id
  93. # isgroup: 是否是群聊
  94. # receiver: 需要回复的对象
  95. # msg: itchat的原始消息对象
  96. # origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
  97. # desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
  98. @time_checker
  99. @_check
  100. def handle_voice(self, msg):
  101. if conf().get('speech_recognition') != True:
  102. return
  103. logger.debug("[WX]receive voice msg: " + msg['FileName'])
  104. to_user_id = msg['ToUserName']
  105. from_user_id = msg['FromUserName']
  106. try:
  107. other_user_id = msg['User']['UserName'] # 对手方id
  108. except Exception as e:
  109. logger.warn("[WX]get other_user_id failed: " + str(e))
  110. if from_user_id == self.userName:
  111. other_user_id = to_user_id
  112. else:
  113. other_user_id = from_user_id
  114. if from_user_id == other_user_id:
  115. context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id)
  116. if context:
  117. thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
  118. @time_checker
  119. @_check
  120. def handle_text(self, msg):
  121. logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
  122. content = msg['Text']
  123. from_user_id = msg['FromUserName']
  124. to_user_id = msg['ToUserName'] # 接收人id
  125. try:
  126. other_user_id = msg['User']['UserName'] # 对手方id
  127. except Exception as e:
  128. logger.warn("[WX]get other_user_id failed: " + str(e))
  129. if from_user_id == self.userName:
  130. other_user_id = to_user_id
  131. else:
  132. other_user_id = from_user_id
  133. if "」\n- - - - - - - - - - - - - - -" in content:
  134. logger.debug("[WX]reference query skipped")
  135. return
  136. context = self._compose_context(ContextType.TEXT, content, isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id)
  137. if context:
  138. thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
  139. @time_checker
  140. @_check
  141. def handle_group(self, msg):
  142. logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
  143. group_name = msg['User'].get('NickName', None)
  144. group_id = msg['User'].get('UserName', None)
  145. if not group_name:
  146. return ""
  147. content = msg.content
  148. if "」\n- - - - - - - - - - - - - - -" in content:
  149. logger.debug("[WX]reference query skipped")
  150. return ""
  151. pattern = f'@{self.name}(\u2005|\u0020)'
  152. content = re.sub(pattern, r'', content)
  153. config = conf()
  154. group_name_white_list = config.get('group_name_white_list', [])
  155. group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
  156. if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
  157. group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
  158. session_id = msg['ActualUserName']
  159. if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
  160. session_id = group_id
  161. context = self._compose_context(ContextType.TEXT, content, isgroup=True, msg=msg, receiver=group_id, session_id=session_id)
  162. if context:
  163. thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
  164. @time_checker
  165. @_check
  166. def handle_group_voice(self, msg):
  167. if conf().get('group_speech_recognition', False) != True:
  168. return
  169. logger.debug("[WX]receive voice for group msg: " + msg['FileName'])
  170. group_name = msg['User'].get('NickName', None)
  171. group_id = msg['User'].get('UserName', None)
  172. # 验证群名
  173. if not group_name:
  174. return ""
  175. config = conf()
  176. group_name_white_list = config.get('group_name_white_list', [])
  177. group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
  178. if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
  179. group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
  180. session_id =msg['ActualUserName']
  181. if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
  182. session_id = group_id
  183. context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=True, msg=msg, receiver=group_id, session_id=session_id)
  184. if context:
  185. thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
  186. # 根据消息构造context,消息内容相关的触发项写在这里
  187. def _compose_context(self, ctype: ContextType, content, **kwargs):
  188. context = Context(ctype, content)
  189. context.kwargs = kwargs
  190. if 'origin_ctype' not in context:
  191. context['origin_ctype'] = ctype
  192. if ctype == ContextType.TEXT:
  193. if context["isgroup"]: # 群聊
  194. # 校验关键字
  195. match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
  196. match_contain = check_contain(content, conf().get('group_chat_keyword'))
  197. if match_prefix is not None or match_contain is not None:
  198. # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
  199. if match_prefix:
  200. content = content.replace(match_prefix, '', 1).strip()
  201. elif context['msg']['IsAt'] and not conf().get("group_at_off", False):
  202. logger.info("[WX]receive group at, continue")
  203. elif context["origin_ctype"] == ContextType.VOICE:
  204. logger.info("[WX]receive group voice, checkprefix didn't match")
  205. return None
  206. else:
  207. return None
  208. else: # 单聊
  209. match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
  210. if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
  211. content = content.replace(match_prefix, '', 1).strip()
  212. elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
  213. pass
  214. else:
  215. return None
  216. img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
  217. if img_match_prefix:
  218. content = content.replace(img_match_prefix, '', 1).strip()
  219. context.type = ContextType.IMAGE_CREATE
  220. else:
  221. context.type = ContextType.TEXT
  222. context.content = content
  223. elif context.type == ContextType.VOICE:
  224. if 'desire_rtype' not in context and conf().get('voice_reply_voice'):
  225. context['desire_rtype'] = ReplyType.VOICE
  226. return context
  227. # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
  228. def send(self, reply: Reply, receiver, retry_cnt = 0):
  229. try:
  230. if reply.type == ReplyType.TEXT:
  231. itchat.send(reply.content, toUserName=receiver)
  232. logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
  233. elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
  234. itchat.send(reply.content, toUserName=receiver)
  235. logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
  236. elif reply.type == ReplyType.VOICE:
  237. itchat.send_file(reply.content, toUserName=receiver)
  238. logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
  239. elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
  240. img_url = reply.content
  241. pic_res = requests.get(img_url, stream=True)
  242. image_storage = io.BytesIO()
  243. for block in pic_res.iter_content(1024):
  244. image_storage.write(block)
  245. image_storage.seek(0)
  246. itchat.send_image(image_storage, toUserName=receiver)
  247. logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
  248. elif reply.type == ReplyType.IMAGE: # 从文件读取图片
  249. image_storage = reply.content
  250. image_storage.seek(0)
  251. itchat.send_image(image_storage, toUserName=receiver)
  252. logger.info('[WX] sendImage, receiver={}'.format(receiver))
  253. except Exception as e:
  254. logger.error('[WX] sendMsg error: {}, receiver={}'.format(e, receiver))
  255. if retry_cnt < 2:
  256. time.sleep(3+3*retry_cnt)
  257. self.send(reply, receiver, retry_cnt + 1)
  258. # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
  259. def handle(self, context: Context):
  260. if context is None or not context.content:
  261. return
  262. logger.debug('[WX] ready to handle context: {}'.format(context))
  263. # reply的构建步骤
  264. reply = self._generate_reply(context)
  265. logger.debug('[WX] ready to decorate reply: {}'.format(reply))
  266. # reply的包装步骤
  267. reply = self._decorate_reply(context, reply)
  268. # reply的发送步骤
  269. self._send_reply(context, reply)
  270. def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
  271. e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
  272. 'channel': self, 'context': context, 'reply': reply}))
  273. reply = e_context['reply']
  274. if not e_context.is_pass():
  275. logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
  276. if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
  277. reply = super().build_reply_content(context.content, context)
  278. elif context.type == ContextType.VOICE: # 语音消息
  279. msg = context['msg']
  280. mp3_path = TmpDir().path() + context.content
  281. msg.download(mp3_path)
  282. # mp3转wav
  283. wav_path = os.path.splitext(mp3_path)[0] + '.wav'
  284. try:
  285. mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path)
  286. except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
  287. logger.warning("[WX]mp3 to wav error, use mp3 path. " + str(e))
  288. wav_path = mp3_path
  289. # 语音识别
  290. reply = super().build_voice_to_text(wav_path)
  291. # 删除临时文件
  292. try:
  293. os.remove(wav_path)
  294. os.remove(mp3_path)
  295. except Exception as e:
  296. logger.warning("[WX]delete temp file error: " + str(e))
  297. if reply.type == ReplyType.TEXT:
  298. new_context = self._compose_context(
  299. ContextType.TEXT, reply.content, **context.kwargs)
  300. if new_context:
  301. reply = self._generate_reply(new_context)
  302. else:
  303. return
  304. else:
  305. logger.error('[WX] unknown context type: {}'.format(context.type))
  306. return
  307. return reply
  308. def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
  309. if reply and reply.type:
  310. e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
  311. 'channel': self, 'context': context, 'reply': reply}))
  312. reply = e_context['reply']
  313. desire_rtype = context.get('desire_rtype')
  314. if not e_context.is_pass() and reply and reply.type:
  315. if reply.type == ReplyType.TEXT:
  316. reply_text = reply.content
  317. if desire_rtype == ReplyType.VOICE:
  318. reply = super().build_text_to_voice(reply.content)
  319. return self._decorate_reply(context, reply)
  320. if context['isgroup']:
  321. reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
  322. reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
  323. else:
  324. reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
  325. reply.content = reply_text
  326. elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
  327. reply.content = str(reply.type)+":\n" + reply.content
  328. elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
  329. pass
  330. else:
  331. logger.error('[WX] unknown reply type: {}'.format(reply.type))
  332. return
  333. if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
  334. logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
  335. return reply
  336. def _send_reply(self, context: Context, reply: Reply):
  337. if reply and reply.type:
  338. e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
  339. 'channel': self, 'context': context, 'reply': reply}))
  340. reply = e_context['reply']
  341. if not e_context.is_pass() and reply and reply.type:
  342. logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
  343. self.send(reply, context['receiver'])
  344. def check_prefix(content, prefix_list):
  345. for prefix in prefix_list:
  346. if content.startswith(prefix):
  347. return prefix
  348. return None
  349. def check_contain(content, keyword_list):
  350. if not keyword_list:
  351. return None
  352. for ky in keyword_list:
  353. if content.find(ky) != -1:
  354. return True
  355. return None