@@ -1,6 +1,6 @@ | |||
# encoding:utf-8 | |||
import config | |||
from config import conf, load_config | |||
from channel import channel_factory | |||
from common.log import logger | |||
@@ -9,10 +9,10 @@ from plugins import * | |||
def run(): | |||
try: | |||
# load config | |||
config.load_config() | |||
load_config() | |||
# create channel | |||
channel_name='wx' | |||
channel_name=conf().get('channel_type', 'wx') | |||
channel = channel_factory.create_channel(channel_name) | |||
if channel_name=='wx': | |||
PluginManager().load_plugins() | |||
@@ -6,9 +6,9 @@ from common import const | |||
def create_bot(bot_type): | |||
""" | |||
create a channel instance | |||
:param channel_type: channel type code | |||
:return: channel instance | |||
create a bot_type instance | |||
:param bot_type: bot type code | |||
:return: bot instance | |||
""" | |||
if bot_type == const.BAIDU: | |||
# Baidu Unit对话接口 | |||
@@ -5,6 +5,9 @@ wechat channel | |||
""" | |||
import os | |||
import requests | |||
import io | |||
import time | |||
from lib import itchat | |||
import json | |||
from lib.itchat.content import * | |||
@@ -17,17 +20,18 @@ from common.tmp_dir import TmpDir | |||
from config import conf | |||
from common.time_check import time_checker | |||
from plugins import * | |||
import requests | |||
import io | |||
import time | |||
from voice.audio_convert import mp3_to_wav | |||
thread_pool = ThreadPoolExecutor(max_workers=8) | |||
def thread_pool_callback(worker): | |||
worker_exception = worker.exception() | |||
if worker_exception: | |||
logger.exception("Worker return exception: {}".format(worker_exception)) | |||
@itchat.msg_register(TEXT) | |||
def handler_single_msg(msg): | |||
WechatChannel().handle_text(msg) | |||
@@ -48,6 +52,8 @@ def handler_group_voice(msg): | |||
WechatChannel().handle_group_voice(msg) | |||
return None | |||
class WechatChannel(Channel): | |||
def __init__(self): | |||
self.userName = None | |||
@@ -55,7 +61,7 @@ class WechatChannel(Channel): | |||
def startup(self): | |||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||
# login by scan QRCode | |||
hotReload = conf().get('hot_reload', False) | |||
try: | |||
@@ -119,7 +125,7 @@ class WechatChannel(Channel): | |||
other_user_id = from_user_id | |||
create_time = msg['CreateTime'] # 消息时间 | |||
match_prefix = check_prefix(content, conf().get('single_chat_prefix')) | |||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息 | |||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | |||
logger.debug("[WX]history message skipped") | |||
return | |||
if "」\n- - - - - - - - - - - - - - -" in content: | |||
@@ -130,7 +136,8 @@ class WechatChannel(Channel): | |||
elif match_prefix is None: | |||
return | |||
context = Context() | |||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | |||
context.kwargs = {'isgroup': False, 'msg': msg, | |||
'receiver': other_user_id, 'session_id': other_user_id} | |||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | |||
if img_match_prefix: | |||
@@ -148,7 +155,7 @@ class WechatChannel(Channel): | |||
group_name = msg['User'].get('NickName', None) | |||
group_id = msg['User'].get('UserName', None) | |||
create_time = msg['CreateTime'] # 消息时间 | |||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息 | |||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | |||
logger.debug("[WX]history group message skipped") | |||
return | |||
if not group_name: | |||
@@ -166,11 +173,11 @@ class WechatChannel(Channel): | |||
return "" | |||
config = conf() | |||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ | |||
or check_contain(origin_content, config.get('group_chat_keyword')) | |||
or check_contain(origin_content, config.get('group_chat_keyword')) | |||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: | |||
context = Context() | |||
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} | |||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | |||
if img_match_prefix: | |||
content = content.replace(img_match_prefix, '', 1).strip() | |||
@@ -217,7 +224,7 @@ class WechatChannel(Channel): | |||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | |||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | |||
def send(self, reply : Reply, receiver): | |||
def send(self, reply: Reply, receiver): | |||
if reply.type == ReplyType.TEXT: | |||
itchat.send(reply.content, toUserName=receiver) | |||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||
@@ -250,9 +257,10 @@ class WechatChannel(Channel): | |||
reply = Reply() | |||
logger.debug('[WX] ready to handle context: {}'.format(context)) | |||
# reply的构建步骤 | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, { | |||
'channel': self, 'context': context, 'reply': reply})) | |||
reply = e_context['reply'] | |||
if not e_context.is_pass(): | |||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) | |||
@@ -260,22 +268,31 @@ class WechatChannel(Channel): | |||
reply = super().build_reply_content(context.content, context) | |||
elif context.type == ContextType.VOICE: # 语音消息 | |||
msg = context['msg'] | |||
file_name = TmpDir().path() + context.content | |||
msg.download(file_name) | |||
reply = super().build_voice_to_text(file_name) | |||
if reply.type == ReplyType.TEXT: | |||
content = reply.content # 语音转文字后,将文字内容作为新的context | |||
# 如果是群消息,判断是否触发关键字 | |||
if context['isgroup']: | |||
mp3_path = TmpDir().path() + context.content | |||
msg.download(mp3_path) | |||
# mp3转wav | |||
wav_path = os.path.splitext(mp3_path)[0] + '.wav' | |||
mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path) | |||
# 语音识别 | |||
reply = super().build_voice_to_text(wav_path) | |||
# 删除临时文件 | |||
os.remove(wav_path) | |||
os.remove(mp3_path) | |||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: | |||
content = reply.content # 语音转文字后,将文字内容作为新的context | |||
context.type = ContextType.TEXT | |||
if context["isgroup"]: | |||
# 校验关键字 | |||
match_prefix = check_prefix(content, conf().get('group_chat_prefix')) | |||
match_contain = check_contain(content, conf().get('group_chat_keyword')) | |||
logger.debug('[WX] group chat prefix match: {}'.format(match_prefix)) | |||
if match_prefix is None and match_contain is None: | |||
return | |||
else: | |||
if match_prefix is not None or match_contain is not None: | |||
# 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能 | |||
if match_prefix: | |||
content = content.replace(match_prefix, '', 1).strip() | |||
else: | |||
logger.info("[WX]receive voice, checkprefix didn't match") | |||
return | |||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | |||
if img_match_prefix: | |||
content = content.replace(img_match_prefix, '', 1).strip() | |||
@@ -292,11 +309,12 @@ class WechatChannel(Channel): | |||
return | |||
logger.debug('[WX] ready to decorate reply: {}'.format(reply)) | |||
# reply的包装步骤 | |||
if reply and reply.type: | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | |||
reply=e_context['reply'] | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, { | |||
'channel': self, 'context': context, 'reply': reply})) | |||
reply = e_context['reply'] | |||
if not e_context.is_pass() and reply and reply.type: | |||
if reply.type == ReplyType.TEXT: | |||
reply_text = reply.content | |||
@@ -314,10 +332,11 @@ class WechatChannel(Channel): | |||
logger.error('[WX] unknown reply type: {}'.format(reply.type)) | |||
return | |||
# reply的发送步骤 | |||
# reply的发送步骤 | |||
if reply and reply.type: | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | |||
reply=e_context['reply'] | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, { | |||
'channel': self, 'context': context, 'reply': reply})) | |||
reply = e_context['reply'] | |||
if not e_context.is_pass() and reply and reply.type: | |||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) | |||
self.send(reply, context['receiver']) | |||
@@ -4,25 +4,19 @@ | |||
wechaty channel | |||
Python Wechaty - https://github.com/wechaty/python-wechaty | |||
""" | |||
import io | |||
import os | |||
import json | |||
import time | |||
import asyncio | |||
import requests | |||
import pysilk | |||
import wave | |||
from pydub import AudioSegment | |||
from typing import Optional, Union | |||
from bridge.context import Context, ContextType | |||
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore | |||
from wechaty import Wechaty, Contact | |||
from wechaty.user import Message, Room, MiniProgram, UrlLink | |||
from wechaty.user import Message, MiniProgram, UrlLink | |||
from channel.channel import Channel | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
from config import conf | |||
from voice.audio_convert import sil_to_wav, mp3_to_sil | |||
class WechatyChannel(Channel): | |||
@@ -50,8 +44,9 @@ class WechatyChannel(Channel): | |||
async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None, | |||
data: Optional[str] = None): | |||
contact = self.Contact.load(self.contact_id) | |||
logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code)) | |||
pass | |||
# contact = self.Contact.load(self.contact_id) | |||
# logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code)) | |||
# print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}') | |||
async def on_message(self, msg: Message): | |||
@@ -67,7 +62,7 @@ class WechatyChannel(Channel): | |||
content = msg.text() | |||
mention_content = await msg.mention_text() # 返回过滤掉@name后的消息 | |||
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) | |||
conversation: Union[Room, Contact] = from_contact if room is None else room | |||
# conversation: Union[Room, Contact] = from_contact if room is None else room | |||
if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT: | |||
if not msg.is_self() and match_prefix is not None: | |||
@@ -102,21 +97,8 @@ class WechatyChannel(Channel): | |||
await voice_file.to_file(silk_file) | |||
logger.info("[WX]receive voice file: " + silk_file) | |||
# 将文件转成wav格式音频 | |||
wav_file = silk_file.replace(".slk", ".wav") | |||
with open(silk_file, 'rb') as f: | |||
silk_data = f.read() | |||
pcm_data = pysilk.decode(silk_data) | |||
with wave.open(wav_file, 'wb') as wav_data: | |||
wav_data.setnchannels(1) | |||
wav_data.setsampwidth(2) | |||
wav_data.setframerate(24000) | |||
wav_data.writeframes(pcm_data) | |||
if os.path.exists(wav_file): | |||
converter_state = "true" # 转换wav成功 | |||
else: | |||
converter_state = "false" # 转换wav失败 | |||
logger.info("[WX]receive voice converter: " + converter_state) | |||
wav_file = os.path.splitext(silk_file)[0] + '.wav' | |||
sil_to_wav(silk_file, wav_file) | |||
# 语音识别为文本 | |||
query = super().build_voice_to_text(wav_file).content | |||
# 交验关键字 | |||
@@ -183,21 +165,8 @@ class WechatyChannel(Channel): | |||
await voice_file.to_file(silk_file) | |||
logger.info("[WX]receive voice file: " + silk_file) | |||
# 将文件转成wav格式音频 | |||
wav_file = silk_file.replace(".slk", ".wav") | |||
with open(silk_file, 'rb') as f: | |||
silk_data = f.read() | |||
pcm_data = pysilk.decode(silk_data) | |||
with wave.open(wav_file, 'wb') as wav_data: | |||
wav_data.setnchannels(1) | |||
wav_data.setsampwidth(2) | |||
wav_data.setframerate(24000) | |||
wav_data.writeframes(pcm_data) | |||
if os.path.exists(wav_file): | |||
converter_state = "true" # 转换wav成功 | |||
else: | |||
converter_state = "false" # 转换wav失败 | |||
logger.info("[WX]receive voice converter: " + converter_state) | |||
wav_file = os.path.splitext(silk_file)[0] + '.wav' | |||
sil_to_wav(silk_file, wav_file) | |||
# 语音识别为文本 | |||
query = super().build_voice_to_text(wav_file).content | |||
# 校验关键字 | |||
@@ -260,21 +229,12 @@ class WechatyChannel(Channel): | |||
if reply_text: | |||
# 转换 mp3 文件为 silk 格式 | |||
mp3_file = super().build_text_to_voice(reply_text).content | |||
silk_file = mp3_file.replace(".mp3", ".silk") | |||
# Load the MP3 file | |||
audio = AudioSegment.from_file(mp3_file, format="mp3") | |||
# Convert to WAV format | |||
audio = audio.set_frame_rate(24000).set_channels(1) | |||
wav_data = audio.raw_data | |||
sample_width = audio.sample_width | |||
# Encode to SILK format | |||
silk_data = pysilk.encode(wav_data, 24000) | |||
# Save the silk file | |||
with open(silk_file, "wb") as f: | |||
f.write(silk_data) | |||
silk_file = os.path.splitext(mp3_file)[0] + '.sil' | |||
voiceLength = mp3_to_sil(mp3_file, silk_file) | |||
# 发送语音 | |||
t = int(time.time()) | |||
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk') | |||
file_box = FileBox.from_file(silk_file, name=str(t) + '.sil') | |||
file_box.metadata = {'voiceLength': voiceLength} | |||
await self.send(file_box, reply_user_id) | |||
# 清除缓存文件 | |||
os.remove(mp3_file) | |||
@@ -337,21 +297,12 @@ class WechatyChannel(Channel): | |||
reply_text = '@' + group_user_name + ' ' + reply_text.strip() | |||
# 转换 mp3 文件为 silk 格式 | |||
mp3_file = super().build_text_to_voice(reply_text).content | |||
silk_file = mp3_file.replace(".mp3", ".silk") | |||
# Load the MP3 file | |||
audio = AudioSegment.from_file(mp3_file, format="mp3") | |||
# Convert to WAV format | |||
audio = audio.set_frame_rate(24000).set_channels(1) | |||
wav_data = audio.raw_data | |||
sample_width = audio.sample_width | |||
# Encode to SILK format | |||
silk_data = pysilk.encode(wav_data, 24000) | |||
# Save the silk file | |||
with open(silk_file, "wb") as f: | |||
f.write(silk_data) | |||
silk_file = os.path.splitext(mp3_file)[0] + '.sil' | |||
voiceLength = mp3_to_sil(mp3_file, silk_file) | |||
# 发送语音 | |||
t = int(time.time()) | |||
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk') | |||
file_box.metadata = {'voiceLength': voiceLength} | |||
await self.send_group(file_box, group_id) | |||
# 清除缓存文件 | |||
os.remove(mp3_file) | |||
@@ -5,71 +5,77 @@ import os | |||
from common.log import logger | |||
# 将所有可用的配置项写在字典里, 请使用小写字母 | |||
available_setting ={ | |||
#openai api配置 | |||
"open_ai_api_key": "", # openai api key | |||
"open_ai_api_base": "https://api.openai.com/v1", # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base | |||
"proxy": "", # openai使用的代理 | |||
"model": "gpt-3.5-turbo", # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | |||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt | |||
#Bot触发配置 | |||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | |||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | |||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 | |||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀 | |||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复 | |||
"group_at_off": False, # 是否关闭群聊时@bot的触发 | |||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 | |||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表 | |||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | |||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | |||
#chatgpt会话参数 | |||
"expires_in_seconds": 3600, # 无操作会话的过期时间 | |||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 | |||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | |||
#chatgpt限流配置 | |||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制 | |||
"rate_limit_dalle": 50, # openai dalle的调用频率限制 | |||
#chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create | |||
available_setting = { | |||
# openai api配置 | |||
"open_ai_api_key": "", # openai api key | |||
# openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base | |||
"open_ai_api_base": "https://api.openai.com/v1", | |||
"proxy": "", # openai使用的代理 | |||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | |||
"model": "gpt-3.5-turbo", | |||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt | |||
# Bot触发配置 | |||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | |||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | |||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 | |||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀 | |||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复 | |||
"group_at_off": False, # 是否关闭群聊时@bot的触发 | |||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 | |||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表 | |||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | |||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | |||
# chatgpt会话参数 | |||
"expires_in_seconds": 3600, # 无操作会话的过期时间 | |||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 | |||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | |||
# chatgpt限流配置 | |||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制 | |||
"rate_limit_dalle": 50, # openai dalle的调用频率限制 | |||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create | |||
"temperature": 0.9, | |||
"top_p": 1, | |||
"frequency_penalty": 0, | |||
"presence_penalty": 0, | |||
#语音设置 | |||
"speech_recognition": False, # 是否开启语音识别 | |||
"group_speech_recognition": False, # 是否开启群组语音识别 | |||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key | |||
"voice_to_text": "openai", # 语音识别引擎,支持openai和google | |||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu和google | |||
# 语音设置 | |||
"speech_recognition": False, # 是否开启语音识别 | |||
"group_speech_recognition": False, # 是否开启群组语音识别 | |||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key | |||
"voice_to_text": "openai", # 语音识别引擎,支持openai,google | |||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline) | |||
# baidu api的配置, 使用百度语音识别和语音合成时需要 | |||
'baidu_app_id': "", | |||
'baidu_api_key': "", | |||
'baidu_secret_key': "", | |||
"baidu_app_id": "", | |||
"baidu_api_key": "", | |||
"baidu_secret_key": "", | |||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场 | |||
"baidu_dev_pid": "1536", | |||
#服务时间限制,目前支持itchat | |||
"chat_time_module": False, # 是否开启服务时间限制 | |||
"chat_start_time": "00:00", # 服务开始时间 | |||
"chat_stop_time": "24:00", # 服务结束时间 | |||
# 服务时间限制,目前支持itchat | |||
"chat_time_module": False, # 是否开启服务时间限制 | |||
"chat_start_time": "00:00", # 服务开始时间 | |||
"chat_stop_time": "24:00", # 服务结束时间 | |||
# itchat的配置 | |||
"hot_reload": False, # 是否开启热重载 | |||
"hot_reload": False, # 是否开启热重载 | |||
# wechaty的配置 | |||
"wechaty_puppet_service_token": "", # wechaty的token | |||
"wechaty_puppet_service_token": "", # wechaty的token | |||
# chatgpt指令自定义触发词 | |||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令 | |||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令 | |||
"channel_type": "wx", # 通道类型,支持wx,wxy和terminal | |||
} | |||
class Config(dict): | |||
def __getitem__(self, key): | |||
if key not in available_setting: | |||
@@ -82,15 +88,17 @@ class Config(dict): | |||
return super().__setitem__(key, value) | |||
def get(self, key, default=None): | |||
try : | |||
try: | |||
return self[key] | |||
except KeyError as e: | |||
return default | |||
except Exception as e: | |||
raise e | |||
config = Config() | |||
def load_config(): | |||
global config | |||
config_path = "./config.json" | |||
@@ -109,7 +117,8 @@ def load_config(): | |||
for name, value in os.environ.items(): | |||
name = name.lower() | |||
if name in available_setting: | |||
logger.info("[INIT] override config by environ args: {}={}".format(name, value)) | |||
logger.info( | |||
"[INIT] override config by environ args: {}={}".format(name, value)) | |||
try: | |||
config[name] = eval(value) | |||
except: | |||
@@ -118,9 +127,8 @@ def load_config(): | |||
logger.info("[INIT] load config: {}".format(config)) | |||
def get_root(): | |||
return os.path.dirname(os.path.abspath( __file__ )) | |||
return os.path.dirname(os.path.abspath(__file__)) | |||
def read_file(path): | |||
@@ -1,3 +1,15 @@ | |||
itchat-uos==1.5.0.dev0 | |||
openai | |||
wechaty | |||
openai>=0.27.2 | |||
baidu_aip>=4.16.10 | |||
gTTS>=2.3.1 | |||
HTMLParser>=0.0.2 | |||
pydub>=0.25.1 | |||
PyQRCode>=1.2.1 | |||
pysilk>=0.0.1 | |||
pysilk_mod>=1.6.0 | |||
pyttsx3>=2.90 | |||
requests>=2.28.2 | |||
SpeechRecognition>=3.10.0 | |||
tiktoken>=0.3.2 | |||
webuiapi>=0.6.2 | |||
wechaty>=0.10.7 | |||
wechaty_puppet>=0.4.23 |
@@ -0,0 +1,60 @@ | |||
import wave | |||
import pysilk | |||
from pydub import AudioSegment | |||
def get_pcm_from_wav(wav_path): | |||
""" | |||
从 wav 文件中读取 pcm | |||
:param wav_path: wav 文件路径 | |||
:returns: pcm 数据 | |||
""" | |||
wav = wave.open(wav_path, "rb") | |||
return wav.readframes(wav.getnframes()) | |||
def mp3_to_wav(mp3_path, wav_path): | |||
""" | |||
把mp3格式转成pcm文件 | |||
""" | |||
audio = AudioSegment.from_mp3(mp3_path) | |||
audio.export(wav_path, format="wav") | |||
def pcm_to_silk(pcm_path, silk_path): | |||
""" | |||
wav 文件转成 silk | |||
return 声音长度,毫秒 | |||
""" | |||
audio = AudioSegment.from_wav(pcm_path) | |||
wav_data = audio.raw_data | |||
silk_data = pysilk.encode( | |||
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate) | |||
with open(silk_path, "wb") as f: | |||
f.write(silk_data) | |||
return audio.duration_seconds * 1000 | |||
def mp3_to_sil(mp3_path, silk_path): | |||
""" | |||
mp3 文件转成 silk | |||
return 声音长度,毫秒 | |||
""" | |||
audio = AudioSegment.from_mp3(mp3_path) | |||
wav_data = audio.raw_data | |||
silk_data = pysilk.encode( | |||
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate) | |||
# Save the silk file | |||
with open(silk_path, "wb") as f: | |||
f.write(silk_data) | |||
return audio.duration_seconds * 1000 | |||
def sil_to_wav(silk_path, wav_path, rate: int = 24000): | |||
""" | |||
silk 文件转 wav | |||
""" | |||
wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate) | |||
with open(wav_path, "wb") as f: | |||
f.write(wav_data) |
@@ -8,19 +8,53 @@ from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
from voice.voice import Voice | |||
from voice.audio_convert import get_pcm_from_wav | |||
from config import conf | |||
""" | |||
百度的语音识别API. | |||
dev_pid: | |||
- 1936: 普通话远场 | |||
- 1536:普通话(支持简单的英文识别) | |||
- 1537:普通话(纯中文识别) | |||
- 1737:英语 | |||
- 1637:粤语 | |||
- 1837:四川话 | |||
要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号, | |||
之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key | |||
填入 config.json 中. | |||
baidu_app_id: '' | |||
baidu_api_key: '' | |||
baidu_secret_key: '' | |||
baidu_dev_pid: '1536' | |||
""" | |||
class BaiduVoice(Voice): | |||
APP_ID = conf().get('baidu_app_id') | |||
API_KEY = conf().get('baidu_api_key') | |||
SECRET_KEY = conf().get('baidu_secret_key') | |||
DEV_ID = conf().get('baidu_dev_pid') | |||
client = AipSpeech(APP_ID, API_KEY, SECRET_KEY) | |||
def __init__(self): | |||
pass | |||
def voiceToText(self, voice_file): | |||
pass | |||
# 识别本地文件 | |||
logger.debug('[Baidu] voice file name={}'.format(voice_file)) | |||
pcm = get_pcm_from_wav(voice_file) | |||
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.DEV_ID}) | |||
if res["err_no"] == 0: | |||
logger.info("百度语音识别到了:{}".format(res["result"])) | |||
text = "".join(res["result"]) | |||
reply = Reply(ReplyType.TEXT, text) | |||
else: | |||
logger.info("百度语音识别出错了: {}".format(res["err_msg"])) | |||
if res["err_msg"] == "request pv too much": | |||
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费") | |||
reply = Reply(ReplyType.ERROR, | |||
"百度语音识别出错了;{0}".format(res["err_msg"])) | |||
return reply | |||
def textToVoice(self, text): | |||
result = self.client.synthesis(text, 'zh', 1, { | |||
@@ -30,7 +64,8 @@ class BaiduVoice(Voice): | |||
fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||
with open(fileName, 'wb') as f: | |||
f.write(result) | |||
logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | |||
logger.info( | |||
'[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | |||
reply = Reply(ReplyType.VOICE, fileName) | |||
else: | |||
logger.error('[Baidu] textToVoice error={}'.format(result)) | |||
@@ -3,12 +3,10 @@ | |||
google voice service | |||
""" | |||
import pathlib | |||
import subprocess | |||
import time | |||
from bridge.reply import Reply, ReplyType | |||
import speech_recognition | |||
import pyttsx3 | |||
from gtts import gTTS | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
from voice.voice import Voice | |||
@@ -16,22 +14,12 @@ from voice.voice import Voice | |||
class GoogleVoice(Voice): | |||
recognizer = speech_recognition.Recognizer() | |||
engine = pyttsx3.init() | |||
def __init__(self): | |||
# 语速 | |||
self.engine.setProperty('rate', 125) | |||
# 音量 | |||
self.engine.setProperty('volume', 1.0) | |||
# 0为男声,1为女声 | |||
voices = self.engine.getProperty('voices') | |||
self.engine.setProperty('voice', voices[1].id) | |||
pass | |||
def voiceToText(self, voice_file): | |||
new_file = voice_file.replace('.mp3', '.wav') | |||
subprocess.call('ffmpeg -i ' + voice_file + | |||
' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True) | |||
with speech_recognition.AudioFile(new_file) as source: | |||
with speech_recognition.AudioFile(voice_file) as source: | |||
audio = self.recognizer.record(source) | |||
try: | |||
text = self.recognizer.recognize_google(audio, language='zh-CN') | |||
@@ -46,12 +34,12 @@ class GoogleVoice(Voice): | |||
return reply | |||
def textToVoice(self, text): | |||
try: | |||
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||
self.engine.save_to_file(text, textFile) | |||
self.engine.runAndWait() | |||
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||
tts = gTTS(text=text, lang='zh') | |||
tts.save(mp3File) | |||
logger.info( | |||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile)) | |||
reply = Reply(ReplyType.VOICE, textFile) | |||
'[Google] textToVoice text={} voice file name={}'.format(text, mp3File)) | |||
reply = Reply(ReplyType.VOICE, mp3File) | |||
except Exception as e: | |||
reply = Reply(ReplyType.ERROR, str(e)) | |||
finally: | |||
@@ -28,6 +28,3 @@ class OpenaiVoice(Voice): | |||
reply = Reply(ReplyType.ERROR, str(e)) | |||
finally: | |||
return reply | |||
def textToVoice(self, text): | |||
pass |
@@ -0,0 +1,37 @@ | |||
""" | |||
pytts voice service (offline) | |||
""" | |||
import time | |||
import pyttsx3 | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
from voice.voice import Voice | |||
class PyttsVoice(Voice): | |||
engine = pyttsx3.init() | |||
def __init__(self): | |||
# 语速 | |||
self.engine.setProperty('rate', 125) | |||
# 音量 | |||
self.engine.setProperty('volume', 1.0) | |||
for voice in self.engine.getProperty('voices'): | |||
if "Chinese" in voice.name: | |||
self.engine.setProperty('voice', voice.id) | |||
def textToVoice(self, text): | |||
try: | |||
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||
self.engine.save_to_file(text, mp3File) | |||
self.engine.runAndWait() | |||
logger.info( | |||
'[Pytts] textToVoice text={} voice file name={}'.format(text, mp3File)) | |||
reply = Reply(ReplyType.VOICE, mp3File) | |||
except Exception as e: | |||
reply = Reply(ReplyType.ERROR, str(e)) | |||
finally: | |||
return reply |
@@ -17,4 +17,7 @@ def create_voice(voice_type): | |||
elif voice_type == 'openai': | |||
from voice.openai.openai_voice import OpenaiVoice | |||
return OpenaiVoice() | |||
elif voice_type == 'pytts': | |||
from voice.pytts.pytts_voice import PyttsVoice | |||
return PyttsVoice() | |||
raise RuntimeError |