Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

wechaty_channel.py 18KB

pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
pirms 1 gada
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. # encoding:utf-8
  2. """
  3. wechaty channel
  4. Python Wechaty - https://github.com/wechaty/python-wechaty
  5. """
  6. import io
  7. import os
  8. import json
  9. import time
  10. import asyncio
  11. import requests
  12. import pysilk
  13. import wave
  14. from pydub import AudioSegment
  15. from typing import Optional, Union
  16. from bridge.context import Context, ContextType
  17. from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore
  18. from wechaty import Wechaty, Contact
  19. from wechaty.user import Message, Room, MiniProgram, UrlLink
  20. from channel.channel import Channel
  21. from common.log import logger
  22. from common.tmp_dir import TmpDir
  23. from config import conf
  24. class WechatyChannel(Channel):
  25. def __init__(self):
  26. pass
  27. def startup(self):
  28. asyncio.run(self.main())
  29. async def main(self):
  30. config = conf()
  31. # 使用PadLocal协议 比较稳定(免费web协议 os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:8080')
  32. token = config.get('wechaty_puppet_service_token')
  33. os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
  34. global bot
  35. bot = Wechaty()
  36. bot.on('scan', self.on_scan)
  37. bot.on('login', self.on_login)
  38. bot.on('message', self.on_message)
  39. await bot.start()
  40. async def on_login(self, contact: Contact):
  41. logger.info('[WX] login user={}'.format(contact))
  42. async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None,
  43. data: Optional[str] = None):
  44. contact = self.Contact.load(self.contact_id)
  45. logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
  46. # print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}')
  47. async def on_message(self, msg: Message):
  48. """
  49. listen for message event
  50. """
  51. from_contact = msg.talker() # 获取消息的发送者
  52. to_contact = msg.to() # 接收人
  53. room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
  54. from_user_id = from_contact.contact_id
  55. to_user_id = to_contact.contact_id # 接收人id
  56. # other_user_id = msg['User']['UserName'] # 对手方id
  57. content = msg.text()
  58. mention_content = await msg.mention_text() # 返回过滤掉@name后的消息
  59. match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
  60. conversation: Union[Room, Contact] = from_contact if room is None else room
  61. if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
  62. if not msg.is_self() and match_prefix is not None:
  63. # 好友向自己发送消息
  64. if match_prefix != '':
  65. str_list = content.split(match_prefix, 1)
  66. if len(str_list) == 2:
  67. content = str_list[1].strip()
  68. img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
  69. if img_match_prefix:
  70. content = content.split(img_match_prefix, 1)[1].strip()
  71. await self._do_send_img(content, from_user_id)
  72. else:
  73. await self._do_send(content, from_user_id)
  74. elif msg.is_self() and match_prefix:
  75. # 自己给好友发送消息
  76. str_list = content.split(match_prefix, 1)
  77. if len(str_list) == 2:
  78. content = str_list[1].strip()
  79. img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
  80. if img_match_prefix:
  81. content = content.split(img_match_prefix, 1)[1].strip()
  82. await self._do_send_img(content, to_user_id)
  83. else:
  84. await self._do_send(content, to_user_id)
  85. elif room is None and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
  86. if not msg.is_self(): # 接收语音消息
  87. # 下载语音文件
  88. voice_file = await msg.to_file_box()
  89. silk_file = TmpDir().path() + voice_file.name
  90. await voice_file.to_file(silk_file)
  91. logger.info("[WX]receive voice file: " + silk_file)
  92. # 将文件转成wav格式音频
  93. wav_file = silk_file.replace(".slk", ".wav")
  94. with open(silk_file, 'rb') as f:
  95. silk_data = f.read()
  96. pcm_data = pysilk.decode(silk_data)
  97. with wave.open(wav_file, 'wb') as wav_data:
  98. wav_data.setnchannels(1)
  99. wav_data.setsampwidth(2)
  100. wav_data.setframerate(24000)
  101. wav_data.writeframes(pcm_data)
  102. if os.path.exists(wav_file):
  103. converter_state = "true" # 转换wav成功
  104. else:
  105. converter_state = "false" # 转换wav失败
  106. logger.info("[WX]receive voice converter: " + converter_state)
  107. # 语音识别为文本
  108. query = super().build_voice_to_text(wav_file).content
  109. # 交验关键字
  110. match_prefix = self.check_prefix(query, conf().get('single_chat_prefix'))
  111. if match_prefix is not None:
  112. if match_prefix != '':
  113. str_list = query.split(match_prefix, 1)
  114. if len(str_list) == 2:
  115. query = str_list[1].strip()
  116. # 返回消息
  117. if conf().get('voice_reply_voice'):
  118. await self._do_send_voice(query, from_user_id)
  119. else:
  120. await self._do_send(query, from_user_id)
  121. else:
  122. logger.info("[WX]receive voice check prefix: " + 'False')
  123. # 清除缓存文件
  124. os.remove(wav_file)
  125. os.remove(silk_file)
  126. elif room and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
  127. # 群组&文本消息
  128. room_id = room.room_id
  129. room_name = await room.topic()
  130. from_user_id = from_contact.contact_id
  131. from_user_name = from_contact.name
  132. is_at = await msg.mention_self()
  133. content = mention_content
  134. config = conf()
  135. match_prefix = (is_at and not config.get("group_at_off", False)) \
  136. or self.check_prefix(content, config.get('group_chat_prefix')) \
  137. or self.check_contain(content, config.get('group_chat_keyword'))
  138. # Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
  139. # 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
  140. prefixes = config.get('group_chat_prefix')
  141. for prefix in prefixes:
  142. if content.startswith(prefix):
  143. content = content.replace(prefix, '', 1).strip()
  144. break
  145. if ('ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
  146. 'group_name_white_list') or self.check_contain(room_name, config.get(
  147. 'group_name_keyword_white_list'))) and match_prefix:
  148. img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
  149. if img_match_prefix:
  150. content = content.split(img_match_prefix, 1)[1].strip()
  151. await self._do_send_group_img(content, room_id)
  152. else:
  153. await self._do_send_group(content, room_id, room_name, from_user_id, from_user_name)
  154. elif room and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
  155. # 群组&语音消息
  156. room_id = room.room_id
  157. room_name = await room.topic()
  158. from_user_id = from_contact.contact_id
  159. from_user_name = from_contact.name
  160. is_at = await msg.mention_self()
  161. config = conf()
  162. # 是否开启语音识别、群消息响应功能、群名白名单符合等条件
  163. if config.get('group_speech_recognition') and (
  164. 'ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
  165. 'group_name_white_list') or self.check_contain(room_name, config.get(
  166. 'group_name_keyword_white_list'))):
  167. # 下载语音文件
  168. voice_file = await msg.to_file_box()
  169. silk_file = TmpDir().path() + voice_file.name
  170. await voice_file.to_file(silk_file)
  171. logger.info("[WX]receive voice file: " + silk_file)
  172. # 将文件转成wav格式音频
  173. wav_file = silk_file.replace(".slk", ".wav")
  174. with open(silk_file, 'rb') as f:
  175. silk_data = f.read()
  176. pcm_data = pysilk.decode(silk_data)
  177. with wave.open(wav_file, 'wb') as wav_data:
  178. wav_data.setnchannels(1)
  179. wav_data.setsampwidth(2)
  180. wav_data.setframerate(24000)
  181. wav_data.writeframes(pcm_data)
  182. if os.path.exists(wav_file):
  183. converter_state = "true" # 转换wav成功
  184. else:
  185. converter_state = "false" # 转换wav失败
  186. logger.info("[WX]receive voice converter: " + converter_state)
  187. # 语音识别为文本
  188. query = super().build_voice_to_text(wav_file).content
  189. # 校验关键字
  190. match_prefix = self.check_prefix(query, config.get('group_chat_prefix')) \
  191. or self.check_contain(query, config.get('group_chat_keyword'))
  192. # Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
  193. if match_prefix is not None:
  194. # 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
  195. prefixes = config.get('group_chat_prefix')
  196. for prefix in prefixes:
  197. if query.startswith(prefix):
  198. query = query.replace(prefix, '', 1).strip()
  199. break
  200. # 返回消息
  201. img_match_prefix = self.check_prefix(query, conf().get('image_create_prefix'))
  202. if img_match_prefix:
  203. query = query.split(img_match_prefix, 1)[1].strip()
  204. await self._do_send_group_img(query, room_id)
  205. elif config.get('voice_reply_voice'):
  206. await self._do_send_group_voice(query, room_id, room_name, from_user_id, from_user_name)
  207. else:
  208. await self._do_send_group(query, room_id, room_name, from_user_id, from_user_name)
  209. else:
  210. logger.info("[WX]receive voice check prefix: " + 'False')
  211. # 清除缓存文件
  212. os.remove(wav_file)
  213. os.remove(silk_file)
  214. async def send(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
  215. logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
  216. if receiver:
  217. contact = await bot.Contact.find(receiver)
  218. await contact.say(message)
  219. async def send_group(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
  220. logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
  221. if receiver:
  222. room = await bot.Room.find(receiver)
  223. await room.say(message)
  224. async def _do_send(self, query, reply_user_id):
  225. try:
  226. if not query:
  227. return
  228. context = Context(ContextType.TEXT, query)
  229. context['session_id'] = reply_user_id
  230. reply_text = super().build_reply_content(query, context).content
  231. if reply_text:
  232. await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
  233. except Exception as e:
  234. logger.exception(e)
  235. async def _do_send_voice(self, query, reply_user_id):
  236. try:
  237. if not query:
  238. return
  239. context = Context(ContextType.TEXT, query)
  240. context['session_id'] = reply_user_id
  241. reply_text = super().build_reply_content(query, context).content
  242. if reply_text:
  243. # 转换 mp3 文件为 silk 格式
  244. mp3_file = super().build_text_to_voice(reply_text).content
  245. silk_file = mp3_file.replace(".mp3", ".silk")
  246. # Load the MP3 file
  247. audio = AudioSegment.from_file(mp3_file, format="mp3")
  248. # Convert to WAV format
  249. audio = audio.set_frame_rate(24000).set_channels(1)
  250. wav_data = audio.raw_data
  251. sample_width = audio.sample_width
  252. # Encode to SILK format
  253. silk_data = pysilk.encode(wav_data, 24000)
  254. # Save the silk file
  255. with open(silk_file, "wb") as f:
  256. f.write(silk_data)
  257. # 发送语音
  258. t = int(time.time())
  259. file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
  260. await self.send(file_box, reply_user_id)
  261. # 清除缓存文件
  262. os.remove(mp3_file)
  263. os.remove(silk_file)
  264. except Exception as e:
  265. logger.exception(e)
  266. async def _do_send_img(self, query, reply_user_id):
  267. try:
  268. if not query:
  269. return
  270. context = Context(ContextType.IMAGE_CREATE, query)
  271. img_url = super().build_reply_content(query, context).content
  272. if not img_url:
  273. return
  274. # 图片下载
  275. # pic_res = requests.get(img_url, stream=True)
  276. # image_storage = io.BytesIO()
  277. # for block in pic_res.iter_content(1024):
  278. # image_storage.write(block)
  279. # image_storage.seek(0)
  280. # 图片发送
  281. logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
  282. t = int(time.time())
  283. file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
  284. await self.send(file_box, reply_user_id)
  285. except Exception as e:
  286. logger.exception(e)
  287. async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name):
  288. if not query:
  289. return
  290. context = Context(ContextType.TEXT, query)
  291. group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
  292. if ('ALL_GROUP' in group_chat_in_one_session or \
  293. group_name in group_chat_in_one_session or \
  294. self.check_contain(group_name, group_chat_in_one_session)):
  295. context['session_id'] = str(group_id)
  296. else:
  297. context['session_id'] = str(group_id) + '-' + str(group_user_id)
  298. reply_text = super().build_reply_content(query, context).content
  299. if reply_text:
  300. reply_text = '@' + group_user_name + ' ' + reply_text.strip()
  301. await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
  302. async def _do_send_group_voice(self, query, group_id, group_name, group_user_id, group_user_name):
  303. if not query:
  304. return
  305. context = Context(ContextType.TEXT, query)
  306. group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
  307. if ('ALL_GROUP' in group_chat_in_one_session or \
  308. group_name in group_chat_in_one_session or \
  309. self.check_contain(group_name, group_chat_in_one_session)):
  310. context['session_id'] = str(group_id)
  311. else:
  312. context['session_id'] = str(group_id) + '-' + str(group_user_id)
  313. reply_text = super().build_reply_content(query, context).content
  314. if reply_text:
  315. reply_text = '@' + group_user_name + ' ' + reply_text.strip()
  316. # 转换 mp3 文件为 silk 格式
  317. mp3_file = super().build_text_to_voice(reply_text).content
  318. silk_file = mp3_file.replace(".mp3", ".silk")
  319. # Load the MP3 file
  320. audio = AudioSegment.from_file(mp3_file, format="mp3")
  321. # Convert to WAV format
  322. audio = audio.set_frame_rate(24000).set_channels(1)
  323. wav_data = audio.raw_data
  324. sample_width = audio.sample_width
  325. # Encode to SILK format
  326. silk_data = pysilk.encode(wav_data, 24000)
  327. # Save the silk file
  328. with open(silk_file, "wb") as f:
  329. f.write(silk_data)
  330. # 发送语音
  331. t = int(time.time())
  332. file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
  333. await self.send_group(file_box, group_id)
  334. # 清除缓存文件
  335. os.remove(mp3_file)
  336. os.remove(silk_file)
  337. async def _do_send_group_img(self, query, reply_room_id):
  338. try:
  339. if not query:
  340. return
  341. context = Context(ContextType.IMAGE_CREATE, query)
  342. img_url = super().build_reply_content(query, context).content
  343. if not img_url:
  344. return
  345. # 图片发送
  346. logger.info('[WX] sendImage, receiver={}'.format(reply_room_id))
  347. t = int(time.time())
  348. file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
  349. await self.send_group(file_box, reply_room_id)
  350. except Exception as e:
  351. logger.exception(e)
  352. def check_prefix(self, content, prefix_list):
  353. for prefix in prefix_list:
  354. if content.startswith(prefix):
  355. return prefix
  356. return None
  357. def check_contain(self, content, keyword_list):
  358. if not keyword_list:
  359. return None
  360. for ky in keyword_list:
  361. if content.find(ky) != -1:
  362. return True
  363. return None