您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

376 行
17KB

  1. # encoding:utf-8
  2. """
  3. wechat channel
  4. """
  5. import os
  6. import re
  7. import requests
  8. import io
  9. import time
  10. from channel.wechat.wechat_message import *
  11. from common.singleton import singleton
  12. from lib import itchat
  13. import json
  14. from lib.itchat.content import *
  15. from bridge.reply import *
  16. from bridge.context import *
  17. from channel.channel import Channel
  18. from concurrent.futures import ThreadPoolExecutor
  19. from common.log import logger
  20. from common.tmp_dir import TmpDir
  21. from config import conf
  22. from common.time_check import time_checker
  23. from common.expired_dict import ExpiredDict
  24. from plugins import *
  25. try:
  26. from voice.audio_convert import mp3_to_wav
  27. except Exception as e:
  28. pass
  29. thread_pool = ThreadPoolExecutor(max_workers=8)
  30. def thread_pool_callback(worker):
  31. worker_exception = worker.exception()
  32. if worker_exception:
  33. logger.exception("Worker return exception: {}".format(worker_exception))
  34. @itchat.msg_register(TEXT)
  35. def handler_single_msg(msg):
  36. WechatChannel().handle_text(WeChatMessage(msg))
  37. return None
  38. @itchat.msg_register(TEXT, isGroupChat=True)
  39. def handler_group_msg(msg):
  40. WechatChannel().handle_group(WeChatMessage(msg,True))
  41. return None
  42. @itchat.msg_register(VOICE)
  43. def handler_single_voice(msg):
  44. WechatChannel().handle_voice(WeChatMessage(msg))
  45. return None
  46. @itchat.msg_register(VOICE, isGroupChat=True)
  47. def handler_group_voice(msg):
  48. WechatChannel().handle_group_voice(WeChatMessage(msg,True))
  49. return None
  50. def _check(func):
  51. def wrapper(self, cmsg: ChatMessage):
  52. msgId = cmsg.msg_id
  53. if msgId in self.receivedMsgs:
  54. logger.info("Wechat message {} already received, ignore".format(msgId))
  55. return
  56. self.receivedMsgs[msgId] = cmsg
  57. create_time = cmsg.create_time # 消息时间戳
  58. if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
  59. logger.debug("[WX]history message {} skipped".format(msgId))
  60. return
  61. return func(self, cmsg)
  62. return wrapper
  63. @singleton
  64. class WechatChannel(Channel):
  65. def __init__(self):
  66. self.user_id = None
  67. self.name = None
  68. self.receivedMsgs = ExpiredDict(60*60*24)
  69. def startup(self):
  70. itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
  71. # login by scan QRCode
  72. hotReload = conf().get('hot_reload', False)
  73. try:
  74. itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
  75. except Exception as e:
  76. if hotReload:
  77. logger.error("Hot reload failed, try to login without hot reload")
  78. itchat.logout()
  79. os.remove("itchat.pkl")
  80. itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
  81. else:
  82. raise e
  83. self.user_id = itchat.instance.storageClass.userName
  84. self.name = itchat.instance.storageClass.nickName
  85. logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
  86. # start message listener
  87. itchat.run()
  88. # handle_* 系列函数处理收到的消息后构造Context,然后传入handle函数中处理Context和发送回复
  89. # Context包含了消息的所有信息,包括以下属性
  90. # type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
  91. # content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
  92. # kwargs 附加参数字典,包含以下的key:
  93. # session_id: 会话id
  94. # isgroup: 是否是群聊
  95. # receiver: 需要回复的对象
  96. # msg: itchat的原始消息对象
  97. # origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
  98. # desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
  99. @time_checker
  100. @_check
  101. def handle_voice(self, cmsg : ChatMessage):
  102. if conf().get('speech_recognition') != True:
  103. return
  104. logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
  105. from_user_id = cmsg.from_user_id
  106. other_user_id = cmsg.other_user_id
  107. if from_user_id == other_user_id:
  108. context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg, receiver=other_user_id, session_id=other_user_id)
  109. if context:
  110. thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
  111. @time_checker
  112. @_check
  113. def handle_text(self, cmsg : ChatMessage):
  114. logger.debug("[WX]receive text msg: " + json.dumps(cmsg._rawmsg, ensure_ascii=False))
  115. content = cmsg.content
  116. other_user_id = cmsg.other_user_id
  117. if "」\n- - - - - - - - - - - - - - -" in cmsg.content:
  118. logger.debug("[WX]reference query skipped")
  119. return
  120. context = self._compose_context(ContextType.TEXT, content, isgroup=False, msg=cmsg, receiver=other_user_id, session_id=other_user_id)
  121. if context:
  122. thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
  123. @time_checker
  124. @_check
  125. def handle_group(self, cmsg : ChatMessage):
  126. logger.debug("[WX]receive group msg: " + json.dumps(cmsg._rawmsg, ensure_ascii=False))
  127. group_name = cmsg.other_user_nickname
  128. group_id = cmsg.other_user_id
  129. if not group_name:
  130. return ""
  131. content = cmsg.content
  132. if "」\n- - - - - - - - - - - - - - -" in content:
  133. logger.debug("[WX]reference query skipped")
  134. return ""
  135. pattern = f'@{self.name}(\u2005|\u0020)'
  136. content = re.sub(pattern, r'', content)
  137. config = conf()
  138. group_name_white_list = config.get('group_name_white_list', [])
  139. group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
  140. if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
  141. group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
  142. session_id = cmsg.actual_user_id
  143. if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
  144. session_id = group_id
  145. context = self._compose_context(ContextType.TEXT, content, isgroup=True, msg=cmsg, receiver=group_id, session_id=session_id)
  146. if context:
  147. thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
  148. @time_checker
  149. @_check
  150. def handle_group_voice(self, cmsg : ChatMessage):
  151. if conf().get('group_speech_recognition', False) != True:
  152. return
  153. logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
  154. group_name = cmsg.other_user_nickname
  155. group_id = cmsg.other_user_id
  156. # 验证群名
  157. if not group_name:
  158. return ""
  159. config = conf()
  160. group_name_white_list = config.get('group_name_white_list', [])
  161. group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
  162. if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
  163. group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
  164. session_id = cmsg.actual_user_id
  165. if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
  166. session_id = group_id
  167. context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg, receiver=group_id, session_id=session_id)
  168. if context:
  169. thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
  170. # 根据消息构造context,消息内容相关的触发项写在这里
  171. def _compose_context(self, ctype: ContextType, content, **kwargs):
  172. context = Context(ctype, content)
  173. context.kwargs = kwargs
  174. if 'origin_ctype' not in context:
  175. context['origin_ctype'] = ctype
  176. if ctype == ContextType.TEXT:
  177. if context["isgroup"]: # 群聊
  178. # 校验关键字
  179. match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
  180. match_contain = check_contain(content, conf().get('group_chat_keyword'))
  181. if match_prefix is not None or match_contain is not None:
  182. # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
  183. if match_prefix:
  184. content = content.replace(match_prefix, '', 1).strip()
  185. elif context['msg'].is_at and not conf().get("group_at_off", False):
  186. logger.info("[WX]receive group at, continue")
  187. elif context["origin_ctype"] == ContextType.VOICE:
  188. logger.info("[WX]receive group voice, checkprefix didn't match")
  189. return None
  190. else:
  191. return None
  192. else: # 单聊
  193. match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
  194. if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
  195. content = content.replace(match_prefix, '', 1).strip()
  196. elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
  197. pass
  198. else:
  199. return None
  200. img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
  201. if img_match_prefix:
  202. content = content.replace(img_match_prefix, '', 1).strip()
  203. context.type = ContextType.IMAGE_CREATE
  204. else:
  205. context.type = ContextType.TEXT
  206. context.content = content
  207. elif context.type == ContextType.VOICE:
  208. if 'desire_rtype' not in context and conf().get('voice_reply_voice'):
  209. context['desire_rtype'] = ReplyType.VOICE
  210. return context
  211. # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
  212. def send(self, reply: Reply, receiver, retry_cnt = 0):
  213. try:
  214. if reply.type == ReplyType.TEXT:
  215. itchat.send(reply.content, toUserName=receiver)
  216. logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
  217. elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
  218. itchat.send(reply.content, toUserName=receiver)
  219. logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
  220. elif reply.type == ReplyType.VOICE:
  221. itchat.send_file(reply.content, toUserName=receiver)
  222. logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
  223. elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
  224. img_url = reply.content
  225. pic_res = requests.get(img_url, stream=True)
  226. image_storage = io.BytesIO()
  227. for block in pic_res.iter_content(1024):
  228. image_storage.write(block)
  229. image_storage.seek(0)
  230. itchat.send_image(image_storage, toUserName=receiver)
  231. logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
  232. elif reply.type == ReplyType.IMAGE: # 从文件读取图片
  233. image_storage = reply.content
  234. image_storage.seek(0)
  235. itchat.send_image(image_storage, toUserName=receiver)
  236. logger.info('[WX] sendImage, receiver={}'.format(receiver))
  237. except Exception as e:
  238. logger.error('[WX] sendMsg error: {}, receiver={}'.format(e, receiver))
  239. if retry_cnt < 2:
  240. time.sleep(3+3*retry_cnt)
  241. self.send(reply, receiver, retry_cnt + 1)
  242. # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
  243. def handle(self, context: Context):
  244. if context is None or not context.content:
  245. return
  246. logger.debug('[WX] ready to handle context: {}'.format(context))
  247. # reply的构建步骤
  248. reply = self._generate_reply(context)
  249. logger.debug('[WX] ready to decorate reply: {}'.format(reply))
  250. # reply的包装步骤
  251. reply = self._decorate_reply(context, reply)
  252. # reply的发送步骤
  253. self._send_reply(context, reply)
  254. def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
  255. e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
  256. 'channel': self, 'context': context, 'reply': reply}))
  257. reply = e_context['reply']
  258. if not e_context.is_pass():
  259. logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
  260. if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
  261. reply = super().build_reply_content(context.content, context)
  262. elif context.type == ContextType.VOICE: # 语音消息
  263. msg = context['msg']
  264. msg.prepare()
  265. mp3_path = context.content
  266. # mp3转wav
  267. wav_path = os.path.splitext(mp3_path)[0] + '.wav'
  268. try:
  269. mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path)
  270. except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
  271. logger.warning("[WX]mp3 to wav error, use mp3 path. " + str(e))
  272. wav_path = mp3_path
  273. # 语音识别
  274. reply = super().build_voice_to_text(wav_path)
  275. # 删除临时文件
  276. try:
  277. os.remove(wav_path)
  278. os.remove(mp3_path)
  279. except Exception as e:
  280. logger.warning("[WX]delete temp file error: " + str(e))
  281. if reply.type == ReplyType.TEXT:
  282. new_context = self._compose_context(
  283. ContextType.TEXT, reply.content, **context.kwargs)
  284. if new_context:
  285. reply = self._generate_reply(new_context)
  286. else:
  287. return
  288. else:
  289. logger.error('[WX] unknown context type: {}'.format(context.type))
  290. return
  291. return reply
  292. def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
  293. if reply and reply.type:
  294. e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
  295. 'channel': self, 'context': context, 'reply': reply}))
  296. reply = e_context['reply']
  297. desire_rtype = context.get('desire_rtype')
  298. if not e_context.is_pass() and reply and reply.type:
  299. if reply.type == ReplyType.TEXT:
  300. reply_text = reply.content
  301. if desire_rtype == ReplyType.VOICE:
  302. reply = super().build_text_to_voice(reply.content)
  303. return self._decorate_reply(context, reply)
  304. if context['isgroup']:
  305. reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
  306. reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
  307. else:
  308. reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
  309. reply.content = reply_text
  310. elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
  311. reply.content = str(reply.type)+":\n" + reply.content
  312. elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
  313. pass
  314. else:
  315. logger.error('[WX] unknown reply type: {}'.format(reply.type))
  316. return
  317. if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
  318. logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
  319. return reply
  320. def _send_reply(self, context: Context, reply: Reply):
  321. if reply and reply.type:
  322. e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
  323. 'channel': self, 'context': context, 'reply': reply}))
  324. reply = e_context['reply']
  325. if not e_context.is_pass() and reply and reply.type:
  326. logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
  327. self.send(reply, context['receiver'])
  328. def check_prefix(content, prefix_list):
  329. for prefix in prefix_list:
  330. if content.startswith(prefix):
  331. return prefix
  332. return None
  333. def check_contain(content, keyword_list):
  334. if not keyword_list:
  335. return None
  336. for ky in keyword_list:
  337. if content.find(ky) != -1:
  338. return True
  339. return None