diff --git a/bridge/context.py b/bridge/context.py index 1fbe4d4..50be100 100644 --- a/bridge/context.py +++ b/bridge/context.py @@ -14,6 +14,15 @@ class Context: self.type = type self.content = content self.kwargs = kwargs + + def __contains__(self, key): + if key == 'type': + return self.type is not None + elif key == 'content': + return self.content is not None + else: + return key in self.kwargs + def __getitem__(self, key): if key == 'type': return self.type @@ -21,6 +30,12 @@ class Context: return self.content else: return self.kwargs[key] + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default def __setitem__(self, key, value): if key == 'type': diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index f789ec2..2958a88 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -90,6 +90,8 @@ class WechatChannel(Channel): # isgroup: 是否是群聊 # receiver: 需要回复的对象 # msg: itchat的原始消息对象 + # origin_ctype: 原始消息类型,用于私聊语音消息时,避免匹配前缀 + # desire_rtype: 希望回复类型,TEXT类型是文本回复,VOICE类型是语音回复 def handle_voice(self, msg): if conf().get('speech_recognition') != True: @@ -106,9 +108,9 @@ class WechatChannel(Channel): else: other_user_id = from_user_id if from_user_id == other_user_id: - context = Context(ContextType.VOICE,msg['FileName']) - context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} - thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) + context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id) + if context: + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) @time_checker def handle_text(self, msg): @@ -125,30 +127,16 @@ class WechatChannel(Channel): else: other_user_id = from_user_id create_time = msg['CreateTime'] # 消息时间 - match_prefix = check_prefix(content, conf().get('single_chat_prefix')) if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 logger.debug("[WX]history message skipped") return if "」\n- - - - - - - - - - - - - - -" in content: logger.debug("[WX]reference query skipped") return - if match_prefix: - content = content.replace(match_prefix, '', 1).strip() - elif match_prefix is None: - return - context = Context() - context.kwargs = {'isgroup': False, 'msg': msg, - 'receiver': other_user_id, 'session_id': other_user_id} - - img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) - if img_match_prefix: - content = content.replace(img_match_prefix, '', 1).strip() - context.type = ContextType.IMAGE_CREATE - else: - context.type = ContextType.TEXT - - context.content = content - thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) + + context = self._compose_context(ContextType.TEXT, content, isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id) + if context: + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) @time_checker def handle_group(self, msg): @@ -172,30 +160,19 @@ class WechatChannel(Channel): if "」\n- - - - - - - - - - - - - - -" in content: logger.debug("[WX]reference query skipped") return "" - config = conf() - match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ - or check_contain(origin_content, config.get('group_chat_keyword')) - if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: - context = Context() - context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} - img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) - if img_match_prefix: - content = content.replace(img_match_prefix, '', 1).strip() - context.type = ContextType.IMAGE_CREATE - else: - context.type = ContextType.TEXT - context.content = content + config = conf() + group_name_white_list = config.get('group_name_white_list', []) + group_name_keyword_white_list = config.get('group_name_keyword_white_list', []) + if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list), msg['IsAt'] and not config.get("group_at_off", False)]): group_chat_in_one_session = conf().get('group_chat_in_one_session', []) - if ('ALL_GROUP' in group_chat_in_one_session or - group_name in group_chat_in_one_session or - check_contain(group_name, group_chat_in_one_session)): - context['session_id'] = group_id - else: - context['session_id'] = msg['ActualUserName'] - - thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) + session_id = msg['ActualUserName'] + if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]): + session_id = group_id + context = self._compose_context(ContextType.TEXT, content, isgroup=True, msg=msg, receiver=group_id, session_id=session_id) + if context: + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) def handle_group_voice(self, msg): if conf().get('group_speech_recognition', False) != True: @@ -210,20 +187,57 @@ class WechatChannel(Channel): # 验证群名 if not group_name: return "" - if ('ALL_GROUP' in conf().get('group_name_white_list') or group_name in conf().get('group_name_white_list') or check_contain(group_name, conf().get('group_name_keyword_white_list'))): - context = Context(ContextType.VOICE,msg['FileName']) - context.kwargs = {'isgroup': True, 'msg': msg, 'receiver': group_id} - + + config = conf() + group_name_white_list = config.get('group_name_white_list', []) + group_name_keyword_white_list = config.get('group_name_keyword_white_list', []) + if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]): group_chat_in_one_session = conf().get('group_chat_in_one_session', []) - if ('ALL_GROUP' in group_chat_in_one_session or - group_name in group_chat_in_one_session or - check_contain(group_name, group_chat_in_one_session)): - context['session_id'] = group_id + session_id =msg['ActualUserName'] + if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]): + session_id = group_id + context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=True, msg=msg, receiver=group_id, session_id=session_id) + if context: + thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) + + def _compose_context(self, ctype: ContextType, content, **kwargs): + context = Context(ctype, content) + context.kwargs = kwargs + if 'origin_ctype' not in context: + context['origin_ctype'] = ctype + + if ctype == ContextType.TEXT: + if context["isgroup"]: # 群聊 + # 校验关键字 + match_prefix = check_prefix(content, conf().get('group_chat_prefix')) + match_contain = check_contain(content, conf().get('group_chat_keyword')) + if match_prefix is not None or match_contain is not None: + # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能 + if match_prefix: + content = content.replace(match_prefix, '', 1).strip() + elif context["origin_ctype"] == ContextType.VOICE: + logger.info("[WX]receive group voice, checkprefix didn't match") + return None + else: # 单聊 + match_prefix = check_prefix(content, conf().get('single_chat_prefix')) + if match_prefix: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 + content = content.replace(match_prefix, '', 1).strip() + elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,不匹配前缀,直接返回 + pass + else: + return None + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) + if img_match_prefix: + content = content.replace(img_match_prefix, '', 1).strip() + context.type = ContextType.IMAGE_CREATE else: - context['session_id'] = msg['ActualUserName'] - - thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) - + context.type = ContextType.TEXT + context.content = content + elif context.type == ContextType.VOICE: + if 'desire_rtype' not in context and conf().get('voice_reply_voice'): + context['desire_rtype'] = ReplyType.VOICE + return context + # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 def send(self, reply: Reply, receiver, retry_cnt = 0): try: @@ -257,23 +271,29 @@ class WechatChannel(Channel): self.send(reply, receiver, retry_cnt + 1) # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 - def handle(self, context): - if not context.content: - return - - reply = Reply() - + def handle(self, context: Context): + if context is None or not context.content: + return logger.debug('[WX] ready to handle context: {}'.format(context)) - # reply的构建步骤 + reply = self._generate_reply(context) + + logger.debug('[WX] ready to decorate reply: {}'.format(reply)) + # reply的包装步骤 + reply = self._decorate_reply(context, reply) + + # reply的发送步骤 + self._send_reply(context, reply) + + def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, { 'channel': self, 'context': context, 'reply': reply})) reply = e_context['reply'] if not e_context.is_pass(): logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) - if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 + if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 reply = super().build_reply_content(context.content, context) - elif context.type == ContextType.VOICE: # 语音消息 + elif context.type == ContextType.VOICE: # 语音消息 msg = context['msg'] mp3_path = TmpDir().path() + context.content msg.download(mp3_path) @@ -281,7 +301,7 @@ class WechatChannel(Channel): wav_path = os.path.splitext(mp3_path)[0] + '.wav' try: mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path) - except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别 + except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别 logger.warning("[WX]mp3 to wav error, use mp3 path. " + str(e)) wav_path = mp3_path # 语音识别 @@ -293,50 +313,28 @@ class WechatChannel(Channel): except Exception as e: logger.warning("[WX]delete temp file error: " + str(e)) - if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: - content = reply.content # 语音转文字后,将文字内容作为新的context - context.type = ContextType.TEXT - if context["isgroup"]: # 群聊 - # 校验关键字 - match_prefix = check_prefix(content, conf().get('group_chat_prefix')) - match_contain = check_contain(content, conf().get('group_chat_keyword')) - if match_prefix is not None or match_contain is not None: - # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能 - if match_prefix: - content = content.replace(match_prefix, '', 1).strip() - else: - logger.info("[WX]receive voice, checkprefix didn't match") - return - else: # 单聊 - match_prefix = check_prefix(content, conf().get('single_chat_prefix')) - if match_prefix: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 - content = content.replace(match_prefix, '', 1).strip() - - img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) - if img_match_prefix: - content = content.replace(img_match_prefix, '', 1).strip() - context.type = ContextType.IMAGE_CREATE - else: - context.type = ContextType.TEXT - context.content = content - reply = super().build_reply_content(context.content, context) - if reply.type == ReplyType.TEXT: - if conf().get('voice_reply_voice'): - reply = super().build_text_to_voice(reply.content) + if reply.type == ReplyType.TEXT: + new_context = self._compose_context( + ContextType.TEXT, reply.content, **context.kwargs) + if new_context: + reply = self._generate_reply(new_context) else: logger.error('[WX] unknown context type: {}'.format(context.type)) return + return reply - logger.debug('[WX] ready to decorate reply: {}'.format(reply)) - - # reply的包装步骤 + def _decorate_reply(self, context: Context, reply: Reply) -> Reply: if reply and reply.type: e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, { 'channel': self, 'context': context, 'reply': reply})) reply = e_context['reply'] + desire_rtype = context.get('desire_rtype') if not e_context.is_pass() and reply and reply.type: if reply.type == ReplyType.TEXT: reply_text = reply.content + if desire_rtype == ReplyType.VOICE: + reply = super().build_text_to_voice(reply.content) + return self._decorate_reply(context, reply) if context['isgroup']: reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() reply_text = conf().get("group_chat_reply_prefix", "")+reply_text @@ -350,8 +348,11 @@ class WechatChannel(Channel): else: logger.error('[WX] unknown reply type: {}'.format(reply.type)) return + if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: + logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type)) + return reply - # reply的发送步骤 + def _send_reply(self, context: Context, reply: Reply): if reply and reply.type: e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, { 'channel': self, 'context': context, 'reply': reply})) @@ -360,6 +361,7 @@ class WechatChannel(Channel): logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) self.send(reply, context['receiver']) + def check_prefix(content, prefix_list): for prefix in prefix_list: if content.startswith(prefix):