Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

434 linhas
18KB

  1. import os
  2. import re
  3. import threading
  4. import time
  5. from asyncio import CancelledError
  6. from concurrent.futures import Future, ThreadPoolExecutor
  7. from bridge.context import *
  8. from bridge.reply import *
  9. from channel.channel import Channel
  10. from common.dequeue import Dequeue
  11. from common.log import logger
  12. from config import conf
  13. from plugins import *
  14. try:
  15. from voice.audio_convert import any_to_wav
  16. except Exception as e:
  17. pass
  18. # 抽象类, 它包含了与消息通道无关的通用处理逻辑
  19. class ChatChannel(Channel):
  20. name = None # 登录的用户名
  21. user_id = None # 登录的用户id
  22. futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
  23. sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
  24. lock = threading.Lock() # 用于控制对sessions的访问
  25. handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
  26. def __init__(self):
  27. _thread = threading.Thread(target=self.consume)
  28. _thread.setDaemon(True)
  29. _thread.start()
  30. # 根据消息构造context,消息内容相关的触发项写在这里
  31. def _compose_context(self, ctype: ContextType, content, **kwargs):
  32. context = Context(ctype, content)
  33. context.kwargs = kwargs
  34. # context首次传入时,origin_ctype是None,
  35. # 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
  36. # origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
  37. if "origin_ctype" not in context:
  38. context["origin_ctype"] = ctype
  39. # context首次传入时,receiver是None,根据类型设置receiver
  40. first_in = "receiver" not in context
  41. # 群名匹配过程,设置session_id和receiver
  42. if first_in: # context首次传入时,receiver是None,根据类型设置receiver
  43. config = conf()
  44. cmsg = context["msg"]
  45. if context.get("isgroup", False):
  46. group_name = cmsg.other_user_nickname
  47. group_id = cmsg.other_user_id
  48. group_name_white_list = config.get("group_name_white_list", [])
  49. group_name_keyword_white_list = config.get(
  50. "group_name_keyword_white_list", []
  51. )
  52. if any(
  53. [
  54. group_name in group_name_white_list,
  55. "ALL_GROUP" in group_name_white_list,
  56. check_contain(group_name, group_name_keyword_white_list),
  57. ]
  58. ):
  59. group_chat_in_one_session = conf().get(
  60. "group_chat_in_one_session", []
  61. )
  62. session_id = cmsg.actual_user_id
  63. if any(
  64. [
  65. group_name in group_chat_in_one_session,
  66. "ALL_GROUP" in group_chat_in_one_session,
  67. ]
  68. ):
  69. session_id = group_id
  70. else:
  71. return None
  72. context["session_id"] = session_id
  73. context["receiver"] = group_id
  74. else:
  75. context["session_id"] = cmsg.other_user_id
  76. context["receiver"] = cmsg.other_user_id
  77. e_context = PluginManager().emit_event(
  78. EventContext(
  79. Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
  80. )
  81. )
  82. context = e_context["context"]
  83. if e_context.is_pass() or context is None:
  84. return context
  85. if cmsg.from_user_id == self.user_id and not config.get(
  86. "trigger_by_self", True
  87. ):
  88. logger.debug("[WX]self message skipped")
  89. return None
  90. # 消息内容匹配过程,并处理content
  91. if ctype == ContextType.TEXT:
  92. if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
  93. logger.debug("[WX]reference query skipped")
  94. return None
  95. if context.get("isgroup", False): # 群聊
  96. # 校验关键字
  97. match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
  98. match_contain = check_contain(content, conf().get("group_chat_keyword"))
  99. flag = False
  100. if match_prefix is not None or match_contain is not None:
  101. flag = True
  102. if match_prefix:
  103. content = content.replace(match_prefix, "", 1).strip()
  104. if context["msg"].is_at:
  105. logger.info("[WX]receive group at")
  106. if not conf().get("group_at_off", False):
  107. flag = True
  108. pattern = f"@{self.name}(\u2005|\u0020)"
  109. content = re.sub(pattern, r"", content)
  110. if not flag:
  111. if context["origin_ctype"] == ContextType.VOICE:
  112. logger.info(
  113. "[WX]receive group voice, but checkprefix didn't match"
  114. )
  115. return None
  116. else: # 单聊
  117. match_prefix = check_prefix(
  118. content, conf().get("single_chat_prefix", [""])
  119. )
  120. if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
  121. content = content.replace(match_prefix, "", 1).strip()
  122. elif (
  123. context["origin_ctype"] == ContextType.VOICE
  124. ): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
  125. pass
  126. else:
  127. return None
  128. img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
  129. if img_match_prefix:
  130. content = content.replace(img_match_prefix, "", 1)
  131. context.type = ContextType.IMAGE_CREATE
  132. else:
  133. context.type = ContextType.TEXT
  134. context.content = content.strip()
  135. if (
  136. "desire_rtype" not in context
  137. and conf().get("always_reply_voice")
  138. and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
  139. ):
  140. context["desire_rtype"] = ReplyType.VOICE
  141. elif context.type == ContextType.VOICE:
  142. if (
  143. "desire_rtype" not in context
  144. and conf().get("voice_reply_voice")
  145. and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
  146. ):
  147. context["desire_rtype"] = ReplyType.VOICE
  148. return context
  149. def _handle(self, context: Context):
  150. if context is None or not context.content:
  151. return
  152. logger.debug("[WX] ready to handle context: {}".format(context))
  153. # reply的构建步骤
  154. reply = self._generate_reply(context)
  155. logger.debug("[WX] ready to decorate reply: {}".format(reply))
  156. # reply的包装步骤
  157. reply = self._decorate_reply(context, reply)
  158. # reply的发送步骤
  159. self._send_reply(context, reply)
  160. def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
  161. e_context = PluginManager().emit_event(
  162. EventContext(
  163. Event.ON_HANDLE_CONTEXT,
  164. {"channel": self, "context": context, "reply": reply},
  165. )
  166. )
  167. reply = e_context["reply"]
  168. if not e_context.is_pass():
  169. logger.debug(
  170. "[WX] ready to handle context: type={}, content={}".format(
  171. context.type, context.content
  172. )
  173. )
  174. if (
  175. context.type == ContextType.TEXT
  176. or context.type == ContextType.IMAGE_CREATE
  177. ): # 文字和图片消息
  178. reply = super().build_reply_content(context.content, context)
  179. elif context.type == ContextType.VOICE: # 语音消息
  180. cmsg = context["msg"]
  181. cmsg.prepare()
  182. file_path = context.content
  183. wav_path = os.path.splitext(file_path)[0] + ".wav"
  184. try:
  185. any_to_wav(file_path, wav_path)
  186. except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
  187. logger.warning("[WX]any to wav error, use raw path. " + str(e))
  188. wav_path = file_path
  189. # 语音识别
  190. reply = super().build_voice_to_text(wav_path)
  191. # 删除临时文件
  192. try:
  193. os.remove(file_path)
  194. if wav_path != file_path:
  195. os.remove(wav_path)
  196. except Exception as e:
  197. pass
  198. # logger.warning("[WX]delete temp file error: " + str(e))
  199. if reply.type == ReplyType.TEXT:
  200. new_context = self._compose_context(
  201. ContextType.TEXT, reply.content, **context.kwargs
  202. )
  203. if new_context:
  204. reply = self._generate_reply(new_context)
  205. else:
  206. return
  207. elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
  208. pass
  209. else:
  210. logger.error("[WX] unknown context type: {}".format(context.type))
  211. return
  212. return reply
  213. def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
  214. if reply and reply.type:
  215. e_context = PluginManager().emit_event(
  216. EventContext(
  217. Event.ON_DECORATE_REPLY,
  218. {"channel": self, "context": context, "reply": reply},
  219. )
  220. )
  221. reply = e_context["reply"]
  222. desire_rtype = context.get("desire_rtype")
  223. if not e_context.is_pass() and reply and reply.type:
  224. if reply.type in self.NOT_SUPPORT_REPLYTYPE:
  225. logger.error("[WX]reply type not support: " + str(reply.type))
  226. reply.type = ReplyType.ERROR
  227. reply.content = "不支持发送的消息类型: " + str(reply.type)
  228. if reply.type == ReplyType.TEXT:
  229. reply_text = reply.content
  230. if (
  231. desire_rtype == ReplyType.VOICE
  232. and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
  233. ):
  234. reply = super().build_text_to_voice(reply.content)
  235. return self._decorate_reply(context, reply)
  236. if context.get("isgroup", False):
  237. reply_text = (
  238. "@"
  239. + context["msg"].actual_user_nickname
  240. + " "
  241. + reply_text.strip()
  242. )
  243. reply_text = (
  244. conf().get("group_chat_reply_prefix", "") + reply_text
  245. )
  246. else:
  247. reply_text = (
  248. conf().get("single_chat_reply_prefix", "") + reply_text
  249. )
  250. reply.content = reply_text
  251. elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
  252. reply.content = "[" + str(reply.type) + "]\n" + reply.content
  253. elif (
  254. reply.type == ReplyType.IMAGE_URL
  255. or reply.type == ReplyType.VOICE
  256. or reply.type == ReplyType.IMAGE
  257. ):
  258. pass
  259. else:
  260. logger.error("[WX] unknown reply type: {}".format(reply.type))
  261. return
  262. if (
  263. desire_rtype
  264. and desire_rtype != reply.type
  265. and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
  266. ):
  267. logger.warning(
  268. "[WX] desire_rtype: {}, but reply type: {}".format(
  269. context.get("desire_rtype"), reply.type
  270. )
  271. )
  272. return reply
  273. def _send_reply(self, context: Context, reply: Reply):
  274. if reply and reply.type:
  275. e_context = PluginManager().emit_event(
  276. EventContext(
  277. Event.ON_SEND_REPLY,
  278. {"channel": self, "context": context, "reply": reply},
  279. )
  280. )
  281. reply = e_context["reply"]
  282. if not e_context.is_pass() and reply and reply.type:
  283. logger.debug(
  284. "[WX] ready to send reply: {}, context: {}".format(reply, context)
  285. )
  286. self._send(reply, context)
  287. def _send(self, reply: Reply, context: Context, retry_cnt=0):
  288. try:
  289. self.send(reply, context)
  290. except Exception as e:
  291. logger.error("[WX] sendMsg error: {}".format(str(e)))
  292. if isinstance(e, NotImplementedError):
  293. return
  294. logger.exception(e)
  295. if retry_cnt < 2:
  296. time.sleep(3 + 3 * retry_cnt)
  297. self._send(reply, context, retry_cnt + 1)
  298. def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
  299. logger.debug("Worker return success, session_id = {}".format(session_id))
  300. def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
  301. logger.exception("Worker return exception: {}".format(exception))
  302. def _thread_pool_callback(self, session_id, **kwargs):
  303. def func(worker: Future):
  304. try:
  305. worker_exception = worker.exception()
  306. if worker_exception:
  307. self._fail_callback(
  308. session_id, exception=worker_exception, **kwargs
  309. )
  310. else:
  311. self._success_callback(session_id, **kwargs)
  312. except CancelledError as e:
  313. logger.info("Worker cancelled, session_id = {}".format(session_id))
  314. except Exception as e:
  315. logger.exception("Worker raise exception: {}".format(e))
  316. with self.lock:
  317. self.sessions[session_id][1].release()
  318. return func
  319. def produce(self, context: Context):
  320. session_id = context["session_id"]
  321. with self.lock:
  322. if session_id not in self.sessions:
  323. self.sessions[session_id] = [
  324. Dequeue(),
  325. threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
  326. ]
  327. if context.type == ContextType.TEXT and context.content.startswith("#"):
  328. self.sessions[session_id][0].putleft(context) # 优先处理管理命令
  329. else:
  330. self.sessions[session_id][0].put(context)
  331. # 消费者函数,单独线程,用于从消息队列中取出消息并处理
  332. def consume(self):
  333. while True:
  334. with self.lock:
  335. session_ids = list(self.sessions.keys())
  336. for session_id in session_ids:
  337. context_queue, semaphore = self.sessions[session_id]
  338. if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
  339. if not context_queue.empty():
  340. context = context_queue.get()
  341. logger.debug("[WX] consume context: {}".format(context))
  342. future: Future = self.handler_pool.submit(
  343. self._handle, context
  344. )
  345. future.add_done_callback(
  346. self._thread_pool_callback(session_id, context=context)
  347. )
  348. if session_id not in self.futures:
  349. self.futures[session_id] = []
  350. self.futures[session_id].append(future)
  351. elif (
  352. semaphore._initial_value == semaphore._value + 1
  353. ): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
  354. self.futures[session_id] = [
  355. t for t in self.futures[session_id] if not t.done()
  356. ]
  357. assert (
  358. len(self.futures[session_id]) == 0
  359. ), "thread pool error"
  360. del self.sessions[session_id]
  361. else:
  362. semaphore.release()
  363. time.sleep(0.1)
  364. # 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
  365. def cancel_session(self, session_id):
  366. with self.lock:
  367. if session_id in self.sessions:
  368. for future in self.futures[session_id]:
  369. future.cancel()
  370. cnt = self.sessions[session_id][0].qsize()
  371. if cnt > 0:
  372. logger.info(
  373. "Cancel {} messages in session {}".format(cnt, session_id)
  374. )
  375. self.sessions[session_id][0] = Dequeue()
  376. def cancel_all_session(self):
  377. with self.lock:
  378. for session_id in self.sessions:
  379. for future in self.futures[session_id]:
  380. future.cancel()
  381. cnt = self.sessions[session_id][0].qsize()
  382. if cnt > 0:
  383. logger.info(
  384. "Cancel {} messages in session {}".format(cnt, session_id)
  385. )
  386. self.sessions[session_id][0] = Dequeue()
  387. def check_prefix(content, prefix_list):
  388. if not prefix_list:
  389. return None
  390. for prefix in prefix_list:
  391. if content.startswith(prefix):
  392. return prefix
  393. return None
  394. def check_contain(content, keyword_list):
  395. if not keyword_list:
  396. return None
  397. for ky in keyword_list:
  398. if content.find(ky) != -1:
  399. return True
  400. return None