Browse Source

refactor: using one processing logic in chat channel

develop
lanvent 1 year ago
parent
commit
02cd553990
7 changed files with 159 additions and 339 deletions
  1. +1
    -0
      channel/channel.py
  2. +18
    -18
      channel/chat_channel.py
  3. +4
    -10
      channel/wechat/wechat_channel.py
  4. +1
    -1
      channel/wechat/wechat_message.py
  5. +99
    -304
      channel/wechat/wechaty_channel.py
  6. +26
    -6
      channel/wechat/wechaty_message.py
  7. +10
    -0
      voice/audio_convert.py

+ 1
- 0
channel/channel.py View File

@@ -20,6 +20,7 @@ class Channel(object):
""" """
raise NotImplementedError raise NotImplementedError


# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context): def send(self, reply: Reply, context: Context):
""" """
send message to user send message to user


+ 18
- 18
channel/chat_channel.py View File

@@ -1,6 +1,7 @@






import os
import re import re
import time import time
from common.expired_dict import ExpiredDict from common.expired_dict import ExpiredDict
@@ -10,7 +11,10 @@ from bridge.context import *
from config import conf from config import conf
from common.log import logger from common.log import logger
from plugins import * from plugins import *

try:
from voice.audio_convert import any_to_wav
except Exception as e:
pass


# 抽象类, 它包含了与消息通道无关的通用处理逻辑 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
class ChatChannel(Channel): class ChatChannel(Channel):
@@ -30,11 +34,13 @@ class ChatChannel(Channel):
context['origin_ctype'] = ctype context['origin_ctype'] = ctype
# context首次传入时,receiver是None,根据类型设置receiver # context首次传入时,receiver是None,根据类型设置receiver
first_in = 'receiver' not in context first_in = 'receiver' not in context

# 群名匹配过程,设置session_id和receiver # 群名匹配过程,设置session_id和receiver
if first_in: # context首次传入时,receiver是None,根据类型设置receiver if first_in: # context首次传入时,receiver是None,根据类型设置receiver
config = conf() config = conf()
cmsg = context['msg'] cmsg = context['msg']
if cmsg.from_user_id == self.user_id:
logger.debug("[WX]self message skipped")
return None
if context["isgroup"]: if context["isgroup"]:
group_name = cmsg.other_user_nickname group_name = cmsg.other_user_nickname
group_id = cmsg.other_user_id group_id = cmsg.other_user_id
@@ -47,14 +53,13 @@ class ChatChannel(Channel):
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]): if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
session_id = group_id session_id = group_id
else: else:
return
return None
context['session_id'] = session_id context['session_id'] = session_id
context['receiver'] = group_id context['receiver'] = group_id
else: else:
context['session_id'] = cmsg.other_user_id context['session_id'] = cmsg.other_user_id
context['receiver'] = cmsg.other_user_id context['receiver'] = cmsg.other_user_id



# 消息内容匹配过程,并处理content # 消息内容匹配过程,并处理content
if ctype == ContextType.TEXT: if ctype == ContextType.TEXT:
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息 if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
@@ -99,10 +104,6 @@ class ChatChannel(Channel):




return context return context
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context, retry_cnt = 0):
raise NotImplementedError


# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
def _handle(self, context: Context): def _handle(self, context: Context):
@@ -128,22 +129,21 @@ class ChatChannel(Channel):
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
reply = super().build_reply_content(context.content, context) reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息 elif context.type == ContextType.VOICE: # 语音消息
msg = context['msg']
msg.prepare()
mp3_path = context.content
# mp3转wav
wav_path = os.path.splitext(mp3_path)[0] + '.wav'
cmsg = context['msg']
cmsg.prepare()
file_path = context.content
wav_path = os.path.splitext(file_path)[0] + '.wav'
try: try:
mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path)
any_to_wav(file_path, wav_path)
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别 except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
logger.warning("[WX]mp3 to wav error, use mp3 path. " + str(e))
wav_path = mp3_path
logger.warning("[WX]any to wav error, use raw path. " + str(e))
wav_path = file_path
# 语音识别 # 语音识别
reply = super().build_voice_to_text(wav_path) reply = super().build_voice_to_text(wav_path)
# 删除临时文件 # 删除临时文件
try: try:
os.remove(file_path)
os.remove(wav_path) os.remove(wav_path)
os.remove(mp3_path)
except Exception as e: except Exception as e:
logger.warning("[WX]delete temp file error: " + str(e)) logger.warning("[WX]delete temp file error: " + str(e))


@@ -204,7 +204,7 @@ class ChatChannel(Channel):
logger.error('[WX] sendMsg error: {}'.format(e)) logger.error('[WX] sendMsg error: {}'.format(e))
if retry_cnt < 2: if retry_cnt < 2:
time.sleep(3+3*retry_cnt) time.sleep(3+3*retry_cnt)
self._send(reply, context, retry_cnt+1)
self._send(reply, context, retry_cnt+1)






+ 4
- 10
channel/wechat/wechat_channel.py View File

@@ -8,28 +8,22 @@ import os
import requests import requests
import io import io
import time import time
import json
from channel.chat_channel import ChatChannel from channel.chat_channel import ChatChannel
from channel.wechat.wechat_message import * from channel.wechat.wechat_message import *
from common.singleton import singleton from common.singleton import singleton
from common.log import logger
from lib import itchat from lib import itchat
import json
from lib.itchat.content import * from lib.itchat.content import *
from bridge.reply import * from bridge.reply import *
from bridge.context import * from bridge.context import *
from channel.channel import Channel
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from common.log import logger
from config import conf from config import conf
from common.time_check import time_checker from common.time_check import time_checker
from common.expired_dict import ExpiredDict from common.expired_dict import ExpiredDict
from plugins import * from plugins import *
try:
from voice.audio_convert import mp3_to_wav
except Exception as e:
pass
thread_pool = ThreadPoolExecutor(max_workers=8) thread_pool = ThreadPoolExecutor(max_workers=8)



def thread_pool_callback(worker): def thread_pool_callback(worker):
worker_exception = worker.exception() worker_exception = worker.exception()
if worker_exception: if worker_exception:
@@ -122,7 +116,7 @@ class WechatChannel(ChatChannel):
@time_checker @time_checker
@_check @_check
def handle_text(self, cmsg : ChatMessage): def handle_text(self, cmsg : ChatMessage):
logger.debug("[WX]receive text msg: " + json.dumps(cmsg._rawmsg, ensure_ascii=False))
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg) context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg)
if context: if context:
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
@@ -130,7 +124,7 @@ class WechatChannel(ChatChannel):
@time_checker @time_checker
@_check @_check
def handle_group(self, cmsg : ChatMessage): def handle_group(self, cmsg : ChatMessage):
logger.debug("[WX]receive group msg: " + json.dumps(cmsg._rawmsg, ensure_ascii=False))
logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg) context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg)
if context: if context:
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback) thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)


+ 1
- 1
channel/wechat/wechat_message.py View File

@@ -44,7 +44,7 @@ class WeChatMessage(ChatMessage):
self.from_user_nickname = self.other_user_nickname self.from_user_nickname = self.other_user_nickname
if self.other_user_id == self.to_user_id: if self.other_user_id == self.to_user_id:
self.to_user_nickname = self.other_user_nickname self.to_user_nickname = self.other_user_nickname
except KeyError as e:
except KeyError as e: # 处理偶尔没有对方信息的情况
logger.warn("[WX]get other_user_id failed: " + str(e)) logger.warn("[WX]get other_user_id failed: " + str(e))
if self.from_user_id == user_id: if self.from_user_id == user_id:
self.other_user_id = self.to_user_id self.other_user_id = self.to_user_id


+ 99
- 304
channel/wechat/wechaty_channel.py View File

@@ -4,21 +4,32 @@
wechaty channel wechaty channel
Python Wechaty - https://github.com/wechaty/python-wechaty Python Wechaty - https://github.com/wechaty/python-wechaty
""" """
import base64
from concurrent.futures import ThreadPoolExecutor
import os import os
import time import time
import asyncio import asyncio
from typing import Optional, Union
from bridge.context import Context, ContextType
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore
from bridge.context import Context
from wechaty_puppet import FileBox
from wechaty import Wechaty, Contact from wechaty import Wechaty, Contact
from wechaty.user import Message, MiniProgram, UrlLink
from channel.channel import Channel
from wechaty.user import Message
from bridge.reply import *
from bridge.context import *
from channel.chat_channel import ChatChannel
from channel.wechat.wechaty_message import WechatyMessage
from common.log import logger from common.log import logger
from common.tmp_dir import TmpDir
from config import conf from config import conf
from voice.audio_convert import sil_to_wav, mp3_to_sil

class WechatyChannel(Channel):
try:
from voice.audio_convert import mp3_to_sil
except Exception as e:
pass

thread_pool = ThreadPoolExecutor(max_workers=8)
def thread_pool_callback(worker):
worker_exception = worker.exception()
if worker_exception:
logger.exception("Worker return exception: {}".format(worker_exception))
class WechatyChannel(ChatChannel):


def __init__(self): def __init__(self):
pass pass
@@ -28,312 +39,96 @@ class WechatyChannel(Channel):


async def main(self): async def main(self):
config = conf() config = conf()
# 使用PadLocal协议 比较稳定(免费web协议 os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:8080')
token = config.get('wechaty_puppet_service_token') token = config.get('wechaty_puppet_service_token')
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
global bot
bot = Wechaty()

bot.on('scan', self.on_scan)
bot.on('login', self.on_login)
bot.on('message', self.on_message)
await bot.start()
self.bot = Wechaty()
self.bot.on('login', self.on_login)
self.bot.on('message', self.on_message)
await self.bot.start()


async def on_login(self, contact: Contact): async def on_login(self, contact: Contact):
self.user_id = contact.contact_id
self.name = contact.name
logger.info('[WX] login user={}'.format(contact)) logger.info('[WX] login user={}'.format(contact))


async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None,
data: Optional[str] = None):
pass
# contact = self.Contact.load(self.contact_id)
# logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
# print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}')
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context):
receiver_id = context['receiver']
loop = asyncio.get_event_loop()
if context['isgroup']:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result()
else:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result()
msg = None
if reply.type == ReplyType.TEXT:
msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.VOICE:
voiceLength = None
if reply.content.endswith('.mp3'):
mp3_file = reply.content
sil_file = os.path.splitext(mp3_file)[0] + '.sil'
voiceLength = mp3_to_sil(mp3_file, sil_file)
try:
os.remove(mp3_file)
except Exception as e:
pass
elif reply.content.endswith('.sil'):
sil_file = reply.content
else:
raise Exception('voice file must be mp3 or sil format')
# 发送语音
t = int(time.time())
msg = FileBox.from_file(sil_file, name=str(t) + '.sil')
if voiceLength is not None:
msg.metadata['voiceLength'] = voiceLength
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
try:
os.remove(sil_file)
except Exception as e:
pass
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
t = int(time.time())
msg = FileBox.from_url(url=img_url, name=str(t) + '.png')
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
t = int(time.time())
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png')
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendImage, receiver={}'.format(receiver))


async def on_message(self, msg: Message): async def on_message(self, msg: Message):
""" """
listen for message event listen for message event
""" """
from_contact = msg.talker() # 获取消息的发送者
to_contact = msg.to() # 接收人
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
from_user_id = from_contact.contact_id
to_user_id = to_contact.contact_id # 接收人id
# other_user_id = msg['User']['UserName'] # 对手方id
content = msg.text()
mention_content = await msg.mention_text() # 返回过滤掉@name后的消息
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
# conversation: Union[Room, Contact] = from_contact if room is None else room

if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
if not msg.is_self() and match_prefix is not None:
# 好友向自己发送消息
if match_prefix != '':
str_list = content.split(match_prefix, 1)
if len(str_list) == 2:
content = str_list[1].strip()

img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
await self._do_send_img(content, from_user_id)
else:
await self._do_send(content, from_user_id)
elif msg.is_self() and match_prefix:
# 自己给好友发送消息
str_list = content.split(match_prefix, 1)
if len(str_list) == 2:
content = str_list[1].strip()
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
await self._do_send_img(content, to_user_id)
else:
await self._do_send(content, to_user_id)
elif room is None and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
if not msg.is_self(): # 接收语音消息
# 下载语音文件
voice_file = await msg.to_file_box()
silk_file = TmpDir().path() + voice_file.name
await voice_file.to_file(silk_file)
logger.info("[WX]receive voice file: " + silk_file)
# 将文件转成wav格式音频
wav_file = os.path.splitext(silk_file)[0] + '.wav'
sil_to_wav(silk_file, wav_file)
# 语音识别为文本
query = super().build_voice_to_text(wav_file).content
# 交验关键字
match_prefix = self.check_prefix(query, conf().get('single_chat_prefix'))
if match_prefix is not None:
if match_prefix != '':
str_list = query.split(match_prefix, 1)
if len(str_list) == 2:
query = str_list[1].strip()
# 返回消息
if conf().get('voice_reply_voice'):
await self._do_send_voice(query, from_user_id)
else:
await self._do_send(query, from_user_id)
else:
logger.info("[WX]receive voice check prefix: " + 'False')
# 清除缓存文件
os.remove(wav_file)
os.remove(silk_file)
elif room and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
# 群组&文本消息
room_id = room.room_id
room_name = await room.topic()
from_user_id = from_contact.contact_id
from_user_name = from_contact.name
is_at = await msg.mention_self()
content = mention_content
config = conf()
match_prefix = (is_at and not config.get("group_at_off", False)) \
or self.check_prefix(content, config.get('group_chat_prefix')) \
or self.check_contain(content, config.get('group_chat_keyword'))
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
prefixes = config.get('group_chat_prefix')
for prefix in prefixes:
if content.startswith(prefix):
content = content.replace(prefix, '', 1).strip()
break
if ('ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
'group_name_white_list') or self.check_contain(room_name, config.get(
'group_name_keyword_white_list'))) and match_prefix:
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
await self._do_send_group_img(content, room_id)
else:
await self._do_send_group(content, room_id, room_name, from_user_id, from_user_name)
elif room and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
# 群组&语音消息
room_id = room.room_id
room_name = await room.topic()
from_user_id = from_contact.contact_id
from_user_name = from_contact.name
is_at = await msg.mention_self()
config = conf()
# 是否开启语音识别、群消息响应功能、群名白名单符合等条件
if config.get('group_speech_recognition') and (
'ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
'group_name_white_list') or self.check_contain(room_name, config.get(
'group_name_keyword_white_list'))):
# 下载语音文件
voice_file = await msg.to_file_box()
silk_file = TmpDir().path() + voice_file.name
await voice_file.to_file(silk_file)
logger.info("[WX]receive voice file: " + silk_file)
# 将文件转成wav格式音频
wav_file = os.path.splitext(silk_file)[0] + '.wav'
sil_to_wav(silk_file, wav_file)
# 语音识别为文本
query = super().build_voice_to_text(wav_file).content
# 校验关键字
match_prefix = self.check_prefix(query, config.get('group_chat_prefix')) \
or self.check_contain(query, config.get('group_chat_keyword'))
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
if match_prefix is not None:
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
prefixes = config.get('group_chat_prefix')
for prefix in prefixes:
if query.startswith(prefix):
query = query.replace(prefix, '', 1).strip()
break
# 返回消息
img_match_prefix = self.check_prefix(query, conf().get('image_create_prefix'))
if img_match_prefix:
query = query.split(img_match_prefix, 1)[1].strip()
await self._do_send_group_img(query, room_id)
elif config.get('voice_reply_voice'):
await self._do_send_group_voice(query, room_id, room_name, from_user_id, from_user_name)
else:
await self._do_send_group(query, room_id, room_name, from_user_id, from_user_name)
else:
logger.info("[WX]receive voice check prefix: " + 'False')
# 清除缓存文件
os.remove(wav_file)
os.remove(silk_file)

async def send(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
if receiver:
contact = await bot.Contact.find(receiver)
await contact.say(message)

async def send_group(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
if receiver:
room = await bot.Room.find(receiver)
await room.say(message)

async def _do_send(self, query, reply_user_id):
try:
if not query:
return
context = Context(ContextType.TEXT, query)
context['session_id'] = reply_user_id
reply_text = super().build_reply_content(query, context).content
if reply_text:
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e:
logger.exception(e)

async def _do_send_voice(self, query, reply_user_id):
try:
if not query:
return
context = Context(ContextType.TEXT, query)
context['session_id'] = reply_user_id
reply_text = super().build_reply_content(query, context).content
if reply_text:
# 转换 mp3 文件为 silk 格式
mp3_file = super().build_text_to_voice(reply_text).content
silk_file = os.path.splitext(mp3_file)[0] + '.sil'
voiceLength = mp3_to_sil(mp3_file, silk_file)
# 发送语音
t = int(time.time())
file_box = FileBox.from_file(silk_file, name=str(t) + '.sil')
file_box.metadata = {'voiceLength': voiceLength}
await self.send(file_box, reply_user_id)
# 清除缓存文件
os.remove(mp3_file)
os.remove(silk_file)
except Exception as e:
logger.exception(e)
async def _do_send_img(self, query, reply_user_id):
try: try:
if not query:
return
context = Context(ContextType.IMAGE_CREATE, query)
img_url = super().build_reply_content(query, context).content
if not img_url:
return
# 图片下载
# pic_res = requests.get(img_url, stream=True)
# image_storage = io.BytesIO()
# for block in pic_res.iter_content(1024):
# image_storage.write(block)
# image_storage.seek(0)

# 图片发送
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
t = int(time.time())
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
await self.send(file_box, reply_user_id)
except Exception as e:
logger.exception(e)

async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name):
if not query:
cmsg = await WechatyMessage(msg)
except NotImplementedError as e:
logger.debug('[WX] {}'.format(e))
return return
context = Context(ContextType.TEXT, query)
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
if ('ALL_GROUP' in group_chat_in_one_session or \
group_name in group_chat_in_one_session or \
self.check_contain(group_name, group_chat_in_one_session)):
context['session_id'] = str(group_id)
else:
context['session_id'] = str(group_id) + '-' + str(group_user_id)
reply_text = super().build_reply_content(query, context).content
if reply_text:
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)

async def _do_send_group_voice(self, query, group_id, group_name, group_user_id, group_user_name):
if not query:
return
context = Context(ContextType.TEXT, query)
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
if ('ALL_GROUP' in group_chat_in_one_session or \
group_name in group_chat_in_one_session or \
self.check_contain(group_name, group_chat_in_one_session)):
context['session_id'] = str(group_id)
else:
context['session_id'] = str(group_id) + '-' + str(group_user_id)
reply_text = super().build_reply_content(query, context).content
if reply_text:
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
# 转换 mp3 文件为 silk 格式
mp3_file = super().build_text_to_voice(reply_text).content
silk_file = os.path.splitext(mp3_file)[0] + '.sil'
voiceLength = mp3_to_sil(mp3_file, silk_file)
# 发送语音
t = int(time.time())
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
file_box.metadata = {'voiceLength': voiceLength}
await self.send_group(file_box, group_id)
# 清除缓存文件
os.remove(mp3_file)
os.remove(silk_file)

async def _do_send_group_img(self, query, reply_room_id):
try:
if not query:
return
context = Context(ContextType.IMAGE_CREATE, query)
img_url = super().build_reply_content(query, context).content
if not img_url:
return
# 图片发送
logger.info('[WX] sendImage, receiver={}'.format(reply_room_id))
t = int(time.time())
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
await self.send_group(file_box, reply_room_id)
except Exception as e: except Exception as e:
logger.exception(e)
def check_prefix(self, content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None
def check_contain(self, content, keyword_list):
if not keyword_list:
return None
for ky in keyword_list:
if content.find(ky) != -1:
return True
return None
logger.exception('[WX] {}'.format(e))
return
logger.debug('[WX] message:{}'.format(cmsg))
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
isgroup = room is not None
ctype = cmsg.ctype
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
if context:
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
thread_pool.submit(self._handle_loop, context, asyncio.get_event_loop()).add_done_callback(thread_pool_callback)

def _handle_loop(self,context,loop):
asyncio.set_event_loop(loop)
self._handle(context)

+ 26
- 6
channel/wechat/wechaty_message.py View File

@@ -1,3 +1,5 @@
import asyncio
import re
from wechaty import MessageType from wechaty import MessageType
from bridge.context import ContextType from bridge.context import ContextType
from channel.chat_message import ChatMessage from channel.chat_message import ChatMessage
@@ -17,7 +19,6 @@ class aobject(object):


async def __init__(self): async def __init__(self):
pass pass

class WechatyMessage(ChatMessage, aobject): class WechatyMessage(ChatMessage, aobject):


async def __init__(self, wechaty_msg: Message): async def __init__(self, wechaty_msg: Message):
@@ -36,13 +37,23 @@ class WechatyMessage(ChatMessage, aobject):
self.ctype = ContextType.VOICE self.ctype = ContextType.VOICE
voice_file = await wechaty_msg.to_file_box() voice_file = await wechaty_msg.to_file_box()
self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径 self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
self._prepare_fn = lambda: voice_file.to_file(self.content)

def func():
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result()
self._prepare_fn = func
else: else:
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type())) raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
from_contact = wechaty_msg.talker() # 获取消息的发送者 from_contact = wechaty_msg.talker() # 获取消息的发送者
self.from_user_id = from_contact.contact_id self.from_user_id = from_contact.contact_id
self.from_user_nickname = from_contact.name self.from_user_nickname = from_contact.name

# group中的from和to,wechaty跟itchat含义不一样
# wecahty: from是消息实际发送者, to:所在群
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
if self.is_group: if self.is_group:
self.to_user_id = room.room_id self.to_user_id = room.room_id
@@ -52,14 +63,23 @@ class WechatyMessage(ChatMessage, aobject):
self.to_user_id = to_contact.contact_id self.to_user_id = to_contact.contact_id
self.to_user_nickname = to_contact.name self.to_user_nickname = to_contact.name


if wechaty_msg.is_self():
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
self.other_user_id = self.to_user_id self.other_user_id = self.to_user_id
self.other_user_nickname = self.to_user_nickname self.other_user_nickname = self.to_user_nickname
else: else:
self.other_user_id = self.from_user_id self.other_user_id = self.from_user_id
self.other_user_nickname = self.from_user_nickname self.other_user_nickname = self.from_user_nickname


if self.is_group:

if self.is_group: # wechaty群聊中,实际发送用户就是from_user
self.is_at = await wechaty_msg.mention_self() self.is_at = await wechaty_msg.mention_self()
self.actual_user_id = self.other_user_id
self.actual_user_nickname = self.other_user_nickname
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
name = wechaty_msg.wechaty.user_self().name
pattern = f'@{name}(\u2005|\u0020)'
if re.search(pattern,self.content):
logger.debug(f'wechaty message {self.msg_id} include at')
self.is_at = True

self.actual_user_id = self.from_user_id
self.actual_user_nickname = self.from_user_nickname

+ 10
- 0
voice/audio_convert.py View File

@@ -21,6 +21,16 @@ def mp3_to_wav(mp3_path, wav_path):
audio = AudioSegment.from_mp3(mp3_path) audio = AudioSegment.from_mp3(mp3_path)
audio.export(wav_path, format="wav") audio.export(wav_path, format="wav")


def any_to_wav(any_path, wav_path):
"""
把任意格式转成wav文件
"""
if any_path.endswith('.wav'):
return
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
return sil_to_wav(any_path, wav_path)
audio = AudioSegment.from_file(any_path)
audio.export(wav_path, format="wav")


def pcm_to_silk(pcm_path, silk_path): def pcm_to_silk(pcm_path, silk_path):
""" """


Loading…
Cancel
Save