diff --git a/app.py b/app.py index 7d42b9d..c78a72c 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,6 @@ # encoding:utf-8 -import config +from config import conf, load_config from channel import channel_factory from common.log import logger @@ -9,10 +9,10 @@ from plugins import * def run(): try: # load config - config.load_config() + load_config() # create channel - channel_name='wx' + channel_name=conf().get('channel_type', 'wx') channel = channel_factory.create_channel(channel_name) if channel_name=='wx': PluginManager().load_plugins() diff --git a/bot/bot_factory.py b/bot/bot_factory.py index 06df336..cf9cfe7 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -6,9 +6,9 @@ from common import const def create_bot(bot_type): """ - create a channel instance - :param channel_type: channel type code - :return: channel instance + create a bot_type instance + :param bot_type: bot type code + :return: bot instance """ if bot_type == const.BAIDU: # Baidu Unit对话接口 diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index 4e06248..c427407 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -5,6 +5,9 @@ wechat channel """ import os +import requests +import io +import time from lib import itchat import json from lib.itchat.content import * @@ -17,17 +20,18 @@ from common.tmp_dir import TmpDir from config import conf from common.time_check import time_checker from plugins import * -import requests -import io -import time +from voice.audio_convert import mp3_to_wav thread_pool = ThreadPoolExecutor(max_workers=8) + + def thread_pool_callback(worker): worker_exception = worker.exception() if worker_exception: logger.exception("Worker return exception: {}".format(worker_exception)) + @itchat.msg_register(TEXT) def handler_single_msg(msg): WechatChannel().handle_text(msg) @@ -48,6 +52,8 @@ def handler_group_voice(msg): WechatChannel().handle_group_voice(msg) return None + + class WechatChannel(Channel): def __init__(self): self.userName = None @@ -55,7 +61,7 @@ class WechatChannel(Channel): def startup(self): - itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 + itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 # login by scan QRCode hotReload = conf().get('hot_reload', False) try: @@ -119,7 +125,7 @@ class WechatChannel(Channel): other_user_id = from_user_id create_time = msg['CreateTime'] # 消息时间 match_prefix = check_prefix(content, conf().get('single_chat_prefix')) - if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息 + if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 logger.debug("[WX]history message skipped") return if "」\n- - - - - - - - - - - - - - -" in content: @@ -130,7 +136,8 @@ class WechatChannel(Channel): elif match_prefix is None: return context = Context() - context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} + context.kwargs = {'isgroup': False, 'msg': msg, + 'receiver': other_user_id, 'session_id': other_user_id} img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: @@ -148,7 +155,7 @@ class WechatChannel(Channel): group_name = msg['User'].get('NickName', None) group_id = msg['User'].get('UserName', None) create_time = msg['CreateTime'] # 消息时间 - if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息 + if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 logger.debug("[WX]history group message skipped") return if not group_name: @@ -166,11 +173,11 @@ class WechatChannel(Channel): return "" config = conf() match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ - or check_contain(origin_content, config.get('group_chat_keyword')) + or check_contain(origin_content, config.get('group_chat_keyword')) if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: context = Context() context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} - + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() @@ -217,7 +224,7 @@ class WechatChannel(Channel): thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 - def send(self, reply : Reply, receiver): + def send(self, reply: Reply, receiver): if reply.type == ReplyType.TEXT: itchat.send(reply.content, toUserName=receiver) logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) @@ -250,9 +257,10 @@ class WechatChannel(Channel): reply = Reply() logger.debug('[WX] ready to handle context: {}'.format(context)) - + # reply的构建步骤 - e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) + e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, { + 'channel': self, 'context': context, 'reply': reply})) reply = e_context['reply'] if not e_context.is_pass(): logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) @@ -260,22 +268,31 @@ class WechatChannel(Channel): reply = super().build_reply_content(context.content, context) elif context.type == ContextType.VOICE: # 语音消息 msg = context['msg'] - file_name = TmpDir().path() + context.content - msg.download(file_name) - reply = super().build_voice_to_text(file_name) - if reply.type == ReplyType.TEXT: - content = reply.content # 语音转文字后,将文字内容作为新的context - # 如果是群消息,判断是否触发关键字 - if context['isgroup']: + mp3_path = TmpDir().path() + context.content + msg.download(mp3_path) + # mp3转wav + wav_path = os.path.splitext(mp3_path)[0] + '.wav' + mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path) + # 语音识别 + reply = super().build_voice_to_text(wav_path) + # 删除临时文件 + os.remove(wav_path) + os.remove(mp3_path) + if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: + content = reply.content # 语音转文字后,将文字内容作为新的context + context.type = ContextType.TEXT + if context["isgroup"]: + # 校验关键字 match_prefix = check_prefix(content, conf().get('group_chat_prefix')) match_contain = check_contain(content, conf().get('group_chat_keyword')) - logger.debug('[WX] group chat prefix match: {}'.format(match_prefix)) - if match_prefix is None and match_contain is None: - return - else: + if match_prefix is not None or match_contain is not None: + # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能 if match_prefix: content = content.replace(match_prefix, '', 1).strip() - + else: + logger.info("[WX]receive voice, checkprefix didn't match") + return + img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) if img_match_prefix: content = content.replace(img_match_prefix, '', 1).strip() @@ -292,11 +309,12 @@ class WechatChannel(Channel): return logger.debug('[WX] ready to decorate reply: {}'.format(reply)) - + # reply的包装步骤 if reply and reply.type: - e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) - reply=e_context['reply'] + e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, { + 'channel': self, 'context': context, 'reply': reply})) + reply = e_context['reply'] if not e_context.is_pass() and reply and reply.type: if reply.type == ReplyType.TEXT: reply_text = reply.content @@ -314,10 +332,11 @@ class WechatChannel(Channel): logger.error('[WX] unknown reply type: {}'.format(reply.type)) return - # reply的发送步骤 + # reply的发送步骤 if reply and reply.type: - e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) - reply=e_context['reply'] + e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, { + 'channel': self, 'context': context, 'reply': reply})) + reply = e_context['reply'] if not e_context.is_pass() and reply and reply.type: logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) self.send(reply, context['receiver']) diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 1caf5b4..ac6717c 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -4,25 +4,19 @@ wechaty channel Python Wechaty - https://github.com/wechaty/python-wechaty """ -import io import os -import json import time import asyncio -import requests -import pysilk -import wave -from pydub import AudioSegment from typing import Optional, Union from bridge.context import Context, ContextType from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore from wechaty import Wechaty, Contact -from wechaty.user import Message, Room, MiniProgram, UrlLink +from wechaty.user import Message, MiniProgram, UrlLink from channel.channel import Channel from common.log import logger from common.tmp_dir import TmpDir from config import conf - +from voice.audio_convert import sil_to_wav, mp3_to_sil class WechatyChannel(Channel): @@ -50,8 +44,9 @@ class WechatyChannel(Channel): async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None, data: Optional[str] = None): - contact = self.Contact.load(self.contact_id) - logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code)) + pass + # contact = self.Contact.load(self.contact_id) + # logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code)) # print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}') async def on_message(self, msg: Message): @@ -67,7 +62,7 @@ class WechatyChannel(Channel): content = msg.text() mention_content = await msg.mention_text() # 返回过滤掉@name后的消息 match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) - conversation: Union[Room, Contact] = from_contact if room is None else room + # conversation: Union[Room, Contact] = from_contact if room is None else room if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT: if not msg.is_self() and match_prefix is not None: @@ -102,21 +97,8 @@ class WechatyChannel(Channel): await voice_file.to_file(silk_file) logger.info("[WX]receive voice file: " + silk_file) # 将文件转成wav格式音频 - wav_file = silk_file.replace(".slk", ".wav") - with open(silk_file, 'rb') as f: - silk_data = f.read() - pcm_data = pysilk.decode(silk_data) - - with wave.open(wav_file, 'wb') as wav_data: - wav_data.setnchannels(1) - wav_data.setsampwidth(2) - wav_data.setframerate(24000) - wav_data.writeframes(pcm_data) - if os.path.exists(wav_file): - converter_state = "true" # 转换wav成功 - else: - converter_state = "false" # 转换wav失败 - logger.info("[WX]receive voice converter: " + converter_state) + wav_file = os.path.splitext(silk_file)[0] + '.wav' + sil_to_wav(silk_file, wav_file) # 语音识别为文本 query = super().build_voice_to_text(wav_file).content # 交验关键字 @@ -183,21 +165,8 @@ class WechatyChannel(Channel): await voice_file.to_file(silk_file) logger.info("[WX]receive voice file: " + silk_file) # 将文件转成wav格式音频 - wav_file = silk_file.replace(".slk", ".wav") - with open(silk_file, 'rb') as f: - silk_data = f.read() - pcm_data = pysilk.decode(silk_data) - - with wave.open(wav_file, 'wb') as wav_data: - wav_data.setnchannels(1) - wav_data.setsampwidth(2) - wav_data.setframerate(24000) - wav_data.writeframes(pcm_data) - if os.path.exists(wav_file): - converter_state = "true" # 转换wav成功 - else: - converter_state = "false" # 转换wav失败 - logger.info("[WX]receive voice converter: " + converter_state) + wav_file = os.path.splitext(silk_file)[0] + '.wav' + sil_to_wav(silk_file, wav_file) # 语音识别为文本 query = super().build_voice_to_text(wav_file).content # 校验关键字 @@ -260,21 +229,12 @@ class WechatyChannel(Channel): if reply_text: # 转换 mp3 文件为 silk 格式 mp3_file = super().build_text_to_voice(reply_text).content - silk_file = mp3_file.replace(".mp3", ".silk") - # Load the MP3 file - audio = AudioSegment.from_file(mp3_file, format="mp3") - # Convert to WAV format - audio = audio.set_frame_rate(24000).set_channels(1) - wav_data = audio.raw_data - sample_width = audio.sample_width - # Encode to SILK format - silk_data = pysilk.encode(wav_data, 24000) - # Save the silk file - with open(silk_file, "wb") as f: - f.write(silk_data) + silk_file = os.path.splitext(mp3_file)[0] + '.sil' + voiceLength = mp3_to_sil(mp3_file, silk_file) # 发送语音 t = int(time.time()) - file_box = FileBox.from_file(silk_file, name=str(t) + '.silk') + file_box = FileBox.from_file(silk_file, name=str(t) + '.sil') + file_box.metadata = {'voiceLength': voiceLength} await self.send(file_box, reply_user_id) # 清除缓存文件 os.remove(mp3_file) @@ -337,21 +297,12 @@ class WechatyChannel(Channel): reply_text = '@' + group_user_name + ' ' + reply_text.strip() # 转换 mp3 文件为 silk 格式 mp3_file = super().build_text_to_voice(reply_text).content - silk_file = mp3_file.replace(".mp3", ".silk") - # Load the MP3 file - audio = AudioSegment.from_file(mp3_file, format="mp3") - # Convert to WAV format - audio = audio.set_frame_rate(24000).set_channels(1) - wav_data = audio.raw_data - sample_width = audio.sample_width - # Encode to SILK format - silk_data = pysilk.encode(wav_data, 24000) - # Save the silk file - with open(silk_file, "wb") as f: - f.write(silk_data) + silk_file = os.path.splitext(mp3_file)[0] + '.sil' + voiceLength = mp3_to_sil(mp3_file, silk_file) # 发送语音 t = int(time.time()) file_box = FileBox.from_file(silk_file, name=str(t) + '.silk') + file_box.metadata = {'voiceLength': voiceLength} await self.send_group(file_box, group_id) # 清除缓存文件 os.remove(mp3_file) diff --git a/config.py b/config.py index cf67745..de2e380 100644 --- a/config.py +++ b/config.py @@ -5,71 +5,77 @@ import os from common.log import logger # 将所有可用的配置项写在字典里, 请使用小写字母 -available_setting ={ - #openai api配置 - "open_ai_api_key": "", # openai api key - "open_ai_api_base": "https://api.openai.com/v1", # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base - "proxy": "", # openai使用的代理 - "model": "gpt-3.5-turbo", # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 - "use_azure_chatgpt": False, # 是否使用azure的chatgpt - - #Bot触发配置 - "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 - "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 - "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 - "group_chat_reply_prefix": "", # 群聊时自动回复的前缀 - "group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复 - "group_at_off": False, # 是否关闭群聊时@bot的触发 - "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 - "group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表 - "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 - "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 - - #chatgpt会话参数 - "expires_in_seconds": 3600, # 无操作会话的过期时间 - "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 - "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 - - #chatgpt限流配置 - "rate_limit_chatgpt": 20, # chatgpt的调用频率限制 - "rate_limit_dalle": 50, # openai dalle的调用频率限制 - - - #chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create +available_setting = { + # openai api配置 + "open_ai_api_key": "", # openai api key + # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base + "open_ai_api_base": "https://api.openai.com/v1", + "proxy": "", # openai使用的代理 + # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 + "model": "gpt-3.5-turbo", + "use_azure_chatgpt": False, # 是否使用azure的chatgpt + + # Bot触发配置 + "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 + "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 + "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 + "group_chat_reply_prefix": "", # 群聊时自动回复的前缀 + "group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复 + "group_at_off": False, # 是否关闭群聊时@bot的触发 + "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 + "group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表 + "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 + "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 + + # chatgpt会话参数 + "expires_in_seconds": 3600, # 无操作会话的过期时间 + "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 + "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 + + # chatgpt限流配置 + "rate_limit_chatgpt": 20, # chatgpt的调用频率限制 + "rate_limit_dalle": 50, # openai dalle的调用频率限制 + + + # chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create "temperature": 0.9, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, - #语音设置 - "speech_recognition": False, # 是否开启语音识别 - "group_speech_recognition": False, # 是否开启群组语音识别 - "voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key - "voice_to_text": "openai", # 语音识别引擎,支持openai和google - "text_to_voice": "baidu", # 语音合成引擎,支持baidu和google + # 语音设置 + "speech_recognition": False, # 是否开启语音识别 + "group_speech_recognition": False, # 是否开启群组语音识别 + "voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key + "voice_to_text": "openai", # 语音识别引擎,支持openai,google + "text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline) # baidu api的配置, 使用百度语音识别和语音合成时需要 - 'baidu_app_id': "", - 'baidu_api_key': "", - 'baidu_secret_key': "", + "baidu_app_id": "", + "baidu_api_key": "", + "baidu_secret_key": "", + # 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场 + "baidu_dev_pid": "1536", - #服务时间限制,目前支持itchat - "chat_time_module": False, # 是否开启服务时间限制 - "chat_start_time": "00:00", # 服务开始时间 - "chat_stop_time": "24:00", # 服务结束时间 + # 服务时间限制,目前支持itchat + "chat_time_module": False, # 是否开启服务时间限制 + "chat_start_time": "00:00", # 服务开始时间 + "chat_stop_time": "24:00", # 服务结束时间 # itchat的配置 - "hot_reload": False, # 是否开启热重载 + "hot_reload": False, # 是否开启热重载 # wechaty的配置 - "wechaty_puppet_service_token": "", # wechaty的token + "wechaty_puppet_service_token": "", # wechaty的token # chatgpt指令自定义触发词 - "clear_memory_commands": ['#清除记忆'], # 重置会话指令 + "clear_memory_commands": ['#清除记忆'], # 重置会话指令 + "channel_type": "wx", # 通道类型,支持wx,wxy和terminal } + class Config(dict): def __getitem__(self, key): if key not in available_setting: @@ -82,15 +88,17 @@ class Config(dict): return super().__setitem__(key, value) def get(self, key, default=None): - try : + try: return self[key] except KeyError as e: return default except Exception as e: raise e - + + config = Config() + def load_config(): global config config_path = "./config.json" @@ -109,7 +117,8 @@ def load_config(): for name, value in os.environ.items(): name = name.lower() if name in available_setting: - logger.info("[INIT] override config by environ args: {}={}".format(name, value)) + logger.info( + "[INIT] override config by environ args: {}={}".format(name, value)) try: config[name] = eval(value) except: @@ -118,9 +127,8 @@ def load_config(): logger.info("[INIT] load config: {}".format(config)) - def get_root(): - return os.path.dirname(os.path.abspath( __file__ )) + return os.path.dirname(os.path.abspath(__file__)) def read_file(path): diff --git a/requirements.txt b/requirements.txt index 98dd5af..ca08960 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,15 @@ -itchat-uos==1.5.0.dev0 -openai -wechaty +openai>=0.27.2 +baidu_aip>=4.16.10 +gTTS>=2.3.1 +HTMLParser>=0.0.2 +pydub>=0.25.1 +PyQRCode>=1.2.1 +pysilk>=0.0.1 +pysilk_mod>=1.6.0 +pyttsx3>=2.90 +requests>=2.28.2 +SpeechRecognition>=3.10.0 +tiktoken>=0.3.2 +webuiapi>=0.6.2 +wechaty>=0.10.7 +wechaty_puppet>=0.4.23 \ No newline at end of file diff --git a/voice/audio_convert.py b/voice/audio_convert.py new file mode 100644 index 0000000..73d18c8 --- /dev/null +++ b/voice/audio_convert.py @@ -0,0 +1,60 @@ +import wave +import pysilk +from pydub import AudioSegment + + +def get_pcm_from_wav(wav_path): + """ + 从 wav 文件中读取 pcm + + :param wav_path: wav 文件路径 + :returns: pcm 数据 + """ + wav = wave.open(wav_path, "rb") + return wav.readframes(wav.getnframes()) + + +def mp3_to_wav(mp3_path, wav_path): + """ + 把mp3格式转成pcm文件 + """ + audio = AudioSegment.from_mp3(mp3_path) + audio.export(wav_path, format="wav") + + +def pcm_to_silk(pcm_path, silk_path): + """ + wav 文件转成 silk + return 声音长度,毫秒 + """ + audio = AudioSegment.from_wav(pcm_path) + wav_data = audio.raw_data + silk_data = pysilk.encode( + wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate) + with open(silk_path, "wb") as f: + f.write(silk_data) + return audio.duration_seconds * 1000 + + +def mp3_to_sil(mp3_path, silk_path): + """ + mp3 文件转成 silk + return 声音长度,毫秒 + """ + audio = AudioSegment.from_mp3(mp3_path) + wav_data = audio.raw_data + silk_data = pysilk.encode( + wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate) + # Save the silk file + with open(silk_path, "wb") as f: + f.write(silk_data) + return audio.duration_seconds * 1000 + + +def sil_to_wav(silk_path, wav_path, rate: int = 24000): + """ + silk 文件转 wav + """ + wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate) + with open(wav_path, "wb") as f: + f.write(wav_data) diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py index 531d8ce..73375bd 100644 --- a/voice/baidu/baidu_voice.py +++ b/voice/baidu/baidu_voice.py @@ -8,19 +8,53 @@ from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir from voice.voice import Voice +from voice.audio_convert import get_pcm_from_wav from config import conf +""" + 百度的语音识别API. + dev_pid: + - 1936: 普通话远场 + - 1536:普通话(支持简单的英文识别) + - 1537:普通话(纯中文识别) + - 1737:英语 + - 1637:粤语 + - 1837:四川话 + 要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号, + 之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key + 填入 config.json 中. + baidu_app_id: '' + baidu_api_key: '' + baidu_secret_key: '' + baidu_dev_pid: '1536' +""" + class BaiduVoice(Voice): APP_ID = conf().get('baidu_app_id') API_KEY = conf().get('baidu_api_key') SECRET_KEY = conf().get('baidu_secret_key') + DEV_ID = conf().get('baidu_dev_pid') client = AipSpeech(APP_ID, API_KEY, SECRET_KEY) - + def __init__(self): pass def voiceToText(self, voice_file): - pass + # 识别本地文件 + logger.debug('[Baidu] voice file name={}'.format(voice_file)) + pcm = get_pcm_from_wav(voice_file) + res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.DEV_ID}) + if res["err_no"] == 0: + logger.info("百度语音识别到了:{}".format(res["result"])) + text = "".join(res["result"]) + reply = Reply(ReplyType.TEXT, text) + else: + logger.info("百度语音识别出错了: {}".format(res["err_msg"])) + if res["err_msg"] == "request pv too much": + logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费") + reply = Reply(ReplyType.ERROR, + "百度语音识别出错了;{0}".format(res["err_msg"])) + return reply def textToVoice(self, text): result = self.client.synthesis(text, 'zh', 1, { @@ -30,7 +64,8 @@ class BaiduVoice(Voice): fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' with open(fileName, 'wb') as f: f.write(result) - logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) + logger.info( + '[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) reply = Reply(ReplyType.VOICE, fileName) else: logger.error('[Baidu] textToVoice error={}'.format(result)) diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py index 74431db..901cc99 100644 --- a/voice/google/google_voice.py +++ b/voice/google/google_voice.py @@ -3,12 +3,10 @@ google voice service """ -import pathlib -import subprocess import time -from bridge.reply import Reply, ReplyType import speech_recognition -import pyttsx3 +from gtts import gTTS +from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir from voice.voice import Voice @@ -16,22 +14,12 @@ from voice.voice import Voice class GoogleVoice(Voice): recognizer = speech_recognition.Recognizer() - engine = pyttsx3.init() def __init__(self): - # 语速 - self.engine.setProperty('rate', 125) - # 音量 - self.engine.setProperty('volume', 1.0) - # 0为男声,1为女声 - voices = self.engine.getProperty('voices') - self.engine.setProperty('voice', voices[1].id) + pass def voiceToText(self, voice_file): - new_file = voice_file.replace('.mp3', '.wav') - subprocess.call('ffmpeg -i ' + voice_file + - ' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True) - with speech_recognition.AudioFile(new_file) as source: + with speech_recognition.AudioFile(voice_file) as source: audio = self.recognizer.record(source) try: text = self.recognizer.recognize_google(audio, language='zh-CN') @@ -46,12 +34,12 @@ class GoogleVoice(Voice): return reply def textToVoice(self, text): try: - textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' - self.engine.save_to_file(text, textFile) - self.engine.runAndWait() + mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' + tts = gTTS(text=text, lang='zh') + tts.save(mp3File) logger.info( - '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) - reply = Reply(ReplyType.VOICE, textFile) + '[Google] textToVoice text={} voice file name={}'.format(text, mp3File)) + reply = Reply(ReplyType.VOICE, mp3File) except Exception as e: reply = Reply(ReplyType.ERROR, str(e)) finally: diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 2e85e10..c98d0c9 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -28,6 +28,3 @@ class OpenaiVoice(Voice): reply = Reply(ReplyType.ERROR, str(e)) finally: return reply - - def textToVoice(self, text): - pass diff --git a/voice/pytts/pytts_voice.py b/voice/pytts/pytts_voice.py new file mode 100644 index 0000000..8884f39 --- /dev/null +++ b/voice/pytts/pytts_voice.py @@ -0,0 +1,37 @@ + +""" +pytts voice service (offline) +""" + +import time +import pyttsx3 +from bridge.reply import Reply, ReplyType +from common.log import logger +from common.tmp_dir import TmpDir +from voice.voice import Voice + + +class PyttsVoice(Voice): + engine = pyttsx3.init() + + def __init__(self): + # 语速 + self.engine.setProperty('rate', 125) + # 音量 + self.engine.setProperty('volume', 1.0) + for voice in self.engine.getProperty('voices'): + if "Chinese" in voice.name: + self.engine.setProperty('voice', voice.id) + + def textToVoice(self, text): + try: + mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' + self.engine.save_to_file(text, mp3File) + self.engine.runAndWait() + logger.info( + '[Pytts] textToVoice text={} voice file name={}'.format(text, mp3File)) + reply = Reply(ReplyType.VOICE, mp3File) + except Exception as e: + reply = Reply(ReplyType.ERROR, str(e)) + finally: + return reply diff --git a/voice/voice_factory.py b/voice/voice_factory.py index 053840e..591e346 100644 --- a/voice/voice_factory.py +++ b/voice/voice_factory.py @@ -17,4 +17,7 @@ def create_voice(voice_type): elif voice_type == 'openai': from voice.openai.openai_voice import OpenaiVoice return OpenaiVoice() + elif voice_type == 'pytts': + from voice.pytts.pytts_voice import PyttsVoice + return PyttsVoice() raise RuntimeError