@@ -1,6 +1,6 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
import config | |||||
from config import conf, load_config | |||||
from channel import channel_factory | from channel import channel_factory | ||||
from common.log import logger | from common.log import logger | ||||
@@ -9,10 +9,10 @@ from plugins import * | |||||
def run(): | def run(): | ||||
try: | try: | ||||
# load config | # load config | ||||
config.load_config() | |||||
load_config() | |||||
# create channel | # create channel | ||||
channel_name='wx' | |||||
channel_name=conf().get('channel_type', 'wx') | |||||
channel = channel_factory.create_channel(channel_name) | channel = channel_factory.create_channel(channel_name) | ||||
if channel_name=='wx': | if channel_name=='wx': | ||||
PluginManager().load_plugins() | PluginManager().load_plugins() | ||||
@@ -6,9 +6,9 @@ from common import const | |||||
def create_bot(bot_type): | def create_bot(bot_type): | ||||
""" | """ | ||||
create a channel instance | |||||
:param channel_type: channel type code | |||||
:return: channel instance | |||||
create a bot_type instance | |||||
:param bot_type: bot type code | |||||
:return: bot instance | |||||
""" | """ | ||||
if bot_type == const.BAIDU: | if bot_type == const.BAIDU: | ||||
# Baidu Unit对话接口 | # Baidu Unit对话接口 | ||||
@@ -5,6 +5,9 @@ wechat channel | |||||
""" | """ | ||||
import os | import os | ||||
import requests | |||||
import io | |||||
import time | |||||
from lib import itchat | from lib import itchat | ||||
import json | import json | ||||
from lib.itchat.content import * | from lib.itchat.content import * | ||||
@@ -17,17 +20,18 @@ from common.tmp_dir import TmpDir | |||||
from config import conf | from config import conf | ||||
from common.time_check import time_checker | from common.time_check import time_checker | ||||
from plugins import * | from plugins import * | ||||
import requests | |||||
import io | |||||
import time | |||||
from voice.audio_convert import mp3_to_wav | |||||
thread_pool = ThreadPoolExecutor(max_workers=8) | thread_pool = ThreadPoolExecutor(max_workers=8) | ||||
def thread_pool_callback(worker): | def thread_pool_callback(worker): | ||||
worker_exception = worker.exception() | worker_exception = worker.exception() | ||||
if worker_exception: | if worker_exception: | ||||
logger.exception("Worker return exception: {}".format(worker_exception)) | logger.exception("Worker return exception: {}".format(worker_exception)) | ||||
@itchat.msg_register(TEXT) | @itchat.msg_register(TEXT) | ||||
def handler_single_msg(msg): | def handler_single_msg(msg): | ||||
WechatChannel().handle_text(msg) | WechatChannel().handle_text(msg) | ||||
@@ -48,6 +52,8 @@ def handler_group_voice(msg): | |||||
WechatChannel().handle_group_voice(msg) | WechatChannel().handle_group_voice(msg) | ||||
return None | return None | ||||
class WechatChannel(Channel): | class WechatChannel(Channel): | ||||
def __init__(self): | def __init__(self): | ||||
self.userName = None | self.userName = None | ||||
@@ -55,7 +61,7 @@ class WechatChannel(Channel): | |||||
def startup(self): | def startup(self): | ||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||||
# login by scan QRCode | # login by scan QRCode | ||||
hotReload = conf().get('hot_reload', False) | hotReload = conf().get('hot_reload', False) | ||||
try: | try: | ||||
@@ -119,7 +125,7 @@ class WechatChannel(Channel): | |||||
other_user_id = from_user_id | other_user_id = from_user_id | ||||
create_time = msg['CreateTime'] # 消息时间 | create_time = msg['CreateTime'] # 消息时间 | ||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix')) | match_prefix = check_prefix(content, conf().get('single_chat_prefix')) | ||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息 | |||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | |||||
logger.debug("[WX]history message skipped") | logger.debug("[WX]history message skipped") | ||||
return | return | ||||
if "」\n- - - - - - - - - - - - - - -" in content: | if "」\n- - - - - - - - - - - - - - -" in content: | ||||
@@ -130,7 +136,8 @@ class WechatChannel(Channel): | |||||
elif match_prefix is None: | elif match_prefix is None: | ||||
return | return | ||||
context = Context() | context = Context() | ||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | |||||
context.kwargs = {'isgroup': False, 'msg': msg, | |||||
'receiver': other_user_id, 'session_id': other_user_id} | |||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | ||||
if img_match_prefix: | if img_match_prefix: | ||||
@@ -148,7 +155,7 @@ class WechatChannel(Channel): | |||||
group_name = msg['User'].get('NickName', None) | group_name = msg['User'].get('NickName', None) | ||||
group_id = msg['User'].get('UserName', None) | group_id = msg['User'].get('UserName', None) | ||||
create_time = msg['CreateTime'] # 消息时间 | create_time = msg['CreateTime'] # 消息时间 | ||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息 | |||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | |||||
logger.debug("[WX]history group message skipped") | logger.debug("[WX]history group message skipped") | ||||
return | return | ||||
if not group_name: | if not group_name: | ||||
@@ -166,11 +173,11 @@ class WechatChannel(Channel): | |||||
return "" | return "" | ||||
config = conf() | config = conf() | ||||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ | match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ | ||||
or check_contain(origin_content, config.get('group_chat_keyword')) | |||||
or check_contain(origin_content, config.get('group_chat_keyword')) | |||||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: | if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: | ||||
context = Context() | context = Context() | ||||
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} | context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} | ||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | ||||
if img_match_prefix: | if img_match_prefix: | ||||
content = content.replace(img_match_prefix, '', 1).strip() | content = content.replace(img_match_prefix, '', 1).strip() | ||||
@@ -217,7 +224,7 @@ class WechatChannel(Channel): | |||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | ||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | ||||
def send(self, reply : Reply, receiver): | |||||
def send(self, reply: Reply, receiver): | |||||
if reply.type == ReplyType.TEXT: | if reply.type == ReplyType.TEXT: | ||||
itchat.send(reply.content, toUserName=receiver) | itchat.send(reply.content, toUserName=receiver) | ||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | ||||
@@ -250,9 +257,10 @@ class WechatChannel(Channel): | |||||
reply = Reply() | reply = Reply() | ||||
logger.debug('[WX] ready to handle context: {}'.format(context)) | logger.debug('[WX] ready to handle context: {}'.format(context)) | ||||
# reply的构建步骤 | # reply的构建步骤 | ||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, { | |||||
'channel': self, 'context': context, 'reply': reply})) | |||||
reply = e_context['reply'] | reply = e_context['reply'] | ||||
if not e_context.is_pass(): | if not e_context.is_pass(): | ||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) | logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) | ||||
@@ -260,22 +268,31 @@ class WechatChannel(Channel): | |||||
reply = super().build_reply_content(context.content, context) | reply = super().build_reply_content(context.content, context) | ||||
elif context.type == ContextType.VOICE: # 语音消息 | elif context.type == ContextType.VOICE: # 语音消息 | ||||
msg = context['msg'] | msg = context['msg'] | ||||
file_name = TmpDir().path() + context.content | |||||
msg.download(file_name) | |||||
reply = super().build_voice_to_text(file_name) | |||||
if reply.type == ReplyType.TEXT: | |||||
content = reply.content # 语音转文字后,将文字内容作为新的context | |||||
# 如果是群消息,判断是否触发关键字 | |||||
if context['isgroup']: | |||||
mp3_path = TmpDir().path() + context.content | |||||
msg.download(mp3_path) | |||||
# mp3转wav | |||||
wav_path = os.path.splitext(mp3_path)[0] + '.wav' | |||||
mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path) | |||||
# 语音识别 | |||||
reply = super().build_voice_to_text(wav_path) | |||||
# 删除临时文件 | |||||
os.remove(wav_path) | |||||
os.remove(mp3_path) | |||||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: | |||||
content = reply.content # 语音转文字后,将文字内容作为新的context | |||||
context.type = ContextType.TEXT | |||||
if context["isgroup"]: | |||||
# 校验关键字 | |||||
match_prefix = check_prefix(content, conf().get('group_chat_prefix')) | match_prefix = check_prefix(content, conf().get('group_chat_prefix')) | ||||
match_contain = check_contain(content, conf().get('group_chat_keyword')) | match_contain = check_contain(content, conf().get('group_chat_keyword')) | ||||
logger.debug('[WX] group chat prefix match: {}'.format(match_prefix)) | |||||
if match_prefix is None and match_contain is None: | |||||
return | |||||
else: | |||||
if match_prefix is not None or match_contain is not None: | |||||
# 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能 | |||||
if match_prefix: | if match_prefix: | ||||
content = content.replace(match_prefix, '', 1).strip() | content = content.replace(match_prefix, '', 1).strip() | ||||
else: | |||||
logger.info("[WX]receive voice, checkprefix didn't match") | |||||
return | |||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | ||||
if img_match_prefix: | if img_match_prefix: | ||||
content = content.replace(img_match_prefix, '', 1).strip() | content = content.replace(img_match_prefix, '', 1).strip() | ||||
@@ -292,11 +309,12 @@ class WechatChannel(Channel): | |||||
return | return | ||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply)) | logger.debug('[WX] ready to decorate reply: {}'.format(reply)) | ||||
# reply的包装步骤 | # reply的包装步骤 | ||||
if reply and reply.type: | if reply and reply.type: | ||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | |||||
reply=e_context['reply'] | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, { | |||||
'channel': self, 'context': context, 'reply': reply})) | |||||
reply = e_context['reply'] | |||||
if not e_context.is_pass() and reply and reply.type: | if not e_context.is_pass() and reply and reply.type: | ||||
if reply.type == ReplyType.TEXT: | if reply.type == ReplyType.TEXT: | ||||
reply_text = reply.content | reply_text = reply.content | ||||
@@ -314,10 +332,11 @@ class WechatChannel(Channel): | |||||
logger.error('[WX] unknown reply type: {}'.format(reply.type)) | logger.error('[WX] unknown reply type: {}'.format(reply.type)) | ||||
return | return | ||||
# reply的发送步骤 | |||||
# reply的发送步骤 | |||||
if reply and reply.type: | if reply and reply.type: | ||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | |||||
reply=e_context['reply'] | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, { | |||||
'channel': self, 'context': context, 'reply': reply})) | |||||
reply = e_context['reply'] | |||||
if not e_context.is_pass() and reply and reply.type: | if not e_context.is_pass() and reply and reply.type: | ||||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) | logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) | ||||
self.send(reply, context['receiver']) | self.send(reply, context['receiver']) | ||||
@@ -4,25 +4,19 @@ | |||||
wechaty channel | wechaty channel | ||||
Python Wechaty - https://github.com/wechaty/python-wechaty | Python Wechaty - https://github.com/wechaty/python-wechaty | ||||
""" | """ | ||||
import io | |||||
import os | import os | ||||
import json | |||||
import time | import time | ||||
import asyncio | import asyncio | ||||
import requests | |||||
import pysilk | |||||
import wave | |||||
from pydub import AudioSegment | |||||
from typing import Optional, Union | from typing import Optional, Union | ||||
from bridge.context import Context, ContextType | from bridge.context import Context, ContextType | ||||
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore | from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore | ||||
from wechaty import Wechaty, Contact | from wechaty import Wechaty, Contact | ||||
from wechaty.user import Message, Room, MiniProgram, UrlLink | |||||
from wechaty.user import Message, MiniProgram, UrlLink | |||||
from channel.channel import Channel | from channel.channel import Channel | ||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
from config import conf | from config import conf | ||||
from voice.audio_convert import sil_to_wav, mp3_to_sil | |||||
class WechatyChannel(Channel): | class WechatyChannel(Channel): | ||||
@@ -50,8 +44,9 @@ class WechatyChannel(Channel): | |||||
async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None, | async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None, | ||||
data: Optional[str] = None): | data: Optional[str] = None): | ||||
contact = self.Contact.load(self.contact_id) | |||||
logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code)) | |||||
pass | |||||
# contact = self.Contact.load(self.contact_id) | |||||
# logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code)) | |||||
# print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}') | # print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}') | ||||
async def on_message(self, msg: Message): | async def on_message(self, msg: Message): | ||||
@@ -67,7 +62,7 @@ class WechatyChannel(Channel): | |||||
content = msg.text() | content = msg.text() | ||||
mention_content = await msg.mention_text() # 返回过滤掉@name后的消息 | mention_content = await msg.mention_text() # 返回过滤掉@name后的消息 | ||||
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) | match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) | ||||
conversation: Union[Room, Contact] = from_contact if room is None else room | |||||
# conversation: Union[Room, Contact] = from_contact if room is None else room | |||||
if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT: | if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT: | ||||
if not msg.is_self() and match_prefix is not None: | if not msg.is_self() and match_prefix is not None: | ||||
@@ -102,21 +97,8 @@ class WechatyChannel(Channel): | |||||
await voice_file.to_file(silk_file) | await voice_file.to_file(silk_file) | ||||
logger.info("[WX]receive voice file: " + silk_file) | logger.info("[WX]receive voice file: " + silk_file) | ||||
# 将文件转成wav格式音频 | # 将文件转成wav格式音频 | ||||
wav_file = silk_file.replace(".slk", ".wav") | |||||
with open(silk_file, 'rb') as f: | |||||
silk_data = f.read() | |||||
pcm_data = pysilk.decode(silk_data) | |||||
with wave.open(wav_file, 'wb') as wav_data: | |||||
wav_data.setnchannels(1) | |||||
wav_data.setsampwidth(2) | |||||
wav_data.setframerate(24000) | |||||
wav_data.writeframes(pcm_data) | |||||
if os.path.exists(wav_file): | |||||
converter_state = "true" # 转换wav成功 | |||||
else: | |||||
converter_state = "false" # 转换wav失败 | |||||
logger.info("[WX]receive voice converter: " + converter_state) | |||||
wav_file = os.path.splitext(silk_file)[0] + '.wav' | |||||
sil_to_wav(silk_file, wav_file) | |||||
# 语音识别为文本 | # 语音识别为文本 | ||||
query = super().build_voice_to_text(wav_file).content | query = super().build_voice_to_text(wav_file).content | ||||
# 交验关键字 | # 交验关键字 | ||||
@@ -183,21 +165,8 @@ class WechatyChannel(Channel): | |||||
await voice_file.to_file(silk_file) | await voice_file.to_file(silk_file) | ||||
logger.info("[WX]receive voice file: " + silk_file) | logger.info("[WX]receive voice file: " + silk_file) | ||||
# 将文件转成wav格式音频 | # 将文件转成wav格式音频 | ||||
wav_file = silk_file.replace(".slk", ".wav") | |||||
with open(silk_file, 'rb') as f: | |||||
silk_data = f.read() | |||||
pcm_data = pysilk.decode(silk_data) | |||||
with wave.open(wav_file, 'wb') as wav_data: | |||||
wav_data.setnchannels(1) | |||||
wav_data.setsampwidth(2) | |||||
wav_data.setframerate(24000) | |||||
wav_data.writeframes(pcm_data) | |||||
if os.path.exists(wav_file): | |||||
converter_state = "true" # 转换wav成功 | |||||
else: | |||||
converter_state = "false" # 转换wav失败 | |||||
logger.info("[WX]receive voice converter: " + converter_state) | |||||
wav_file = os.path.splitext(silk_file)[0] + '.wav' | |||||
sil_to_wav(silk_file, wav_file) | |||||
# 语音识别为文本 | # 语音识别为文本 | ||||
query = super().build_voice_to_text(wav_file).content | query = super().build_voice_to_text(wav_file).content | ||||
# 校验关键字 | # 校验关键字 | ||||
@@ -260,21 +229,12 @@ class WechatyChannel(Channel): | |||||
if reply_text: | if reply_text: | ||||
# 转换 mp3 文件为 silk 格式 | # 转换 mp3 文件为 silk 格式 | ||||
mp3_file = super().build_text_to_voice(reply_text).content | mp3_file = super().build_text_to_voice(reply_text).content | ||||
silk_file = mp3_file.replace(".mp3", ".silk") | |||||
# Load the MP3 file | |||||
audio = AudioSegment.from_file(mp3_file, format="mp3") | |||||
# Convert to WAV format | |||||
audio = audio.set_frame_rate(24000).set_channels(1) | |||||
wav_data = audio.raw_data | |||||
sample_width = audio.sample_width | |||||
# Encode to SILK format | |||||
silk_data = pysilk.encode(wav_data, 24000) | |||||
# Save the silk file | |||||
with open(silk_file, "wb") as f: | |||||
f.write(silk_data) | |||||
silk_file = os.path.splitext(mp3_file)[0] + '.sil' | |||||
voiceLength = mp3_to_sil(mp3_file, silk_file) | |||||
# 发送语音 | # 发送语音 | ||||
t = int(time.time()) | t = int(time.time()) | ||||
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk') | |||||
file_box = FileBox.from_file(silk_file, name=str(t) + '.sil') | |||||
file_box.metadata = {'voiceLength': voiceLength} | |||||
await self.send(file_box, reply_user_id) | await self.send(file_box, reply_user_id) | ||||
# 清除缓存文件 | # 清除缓存文件 | ||||
os.remove(mp3_file) | os.remove(mp3_file) | ||||
@@ -337,21 +297,12 @@ class WechatyChannel(Channel): | |||||
reply_text = '@' + group_user_name + ' ' + reply_text.strip() | reply_text = '@' + group_user_name + ' ' + reply_text.strip() | ||||
# 转换 mp3 文件为 silk 格式 | # 转换 mp3 文件为 silk 格式 | ||||
mp3_file = super().build_text_to_voice(reply_text).content | mp3_file = super().build_text_to_voice(reply_text).content | ||||
silk_file = mp3_file.replace(".mp3", ".silk") | |||||
# Load the MP3 file | |||||
audio = AudioSegment.from_file(mp3_file, format="mp3") | |||||
# Convert to WAV format | |||||
audio = audio.set_frame_rate(24000).set_channels(1) | |||||
wav_data = audio.raw_data | |||||
sample_width = audio.sample_width | |||||
# Encode to SILK format | |||||
silk_data = pysilk.encode(wav_data, 24000) | |||||
# Save the silk file | |||||
with open(silk_file, "wb") as f: | |||||
f.write(silk_data) | |||||
silk_file = os.path.splitext(mp3_file)[0] + '.sil' | |||||
voiceLength = mp3_to_sil(mp3_file, silk_file) | |||||
# 发送语音 | # 发送语音 | ||||
t = int(time.time()) | t = int(time.time()) | ||||
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk') | file_box = FileBox.from_file(silk_file, name=str(t) + '.silk') | ||||
file_box.metadata = {'voiceLength': voiceLength} | |||||
await self.send_group(file_box, group_id) | await self.send_group(file_box, group_id) | ||||
# 清除缓存文件 | # 清除缓存文件 | ||||
os.remove(mp3_file) | os.remove(mp3_file) | ||||
@@ -5,71 +5,77 @@ import os | |||||
from common.log import logger | from common.log import logger | ||||
# 将所有可用的配置项写在字典里, 请使用小写字母 | # 将所有可用的配置项写在字典里, 请使用小写字母 | ||||
available_setting ={ | |||||
#openai api配置 | |||||
"open_ai_api_key": "", # openai api key | |||||
"open_ai_api_base": "https://api.openai.com/v1", # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base | |||||
"proxy": "", # openai使用的代理 | |||||
"model": "gpt-3.5-turbo", # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | |||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt | |||||
#Bot触发配置 | |||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | |||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | |||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 | |||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀 | |||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复 | |||||
"group_at_off": False, # 是否关闭群聊时@bot的触发 | |||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 | |||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表 | |||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | |||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | |||||
#chatgpt会话参数 | |||||
"expires_in_seconds": 3600, # 无操作会话的过期时间 | |||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 | |||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | |||||
#chatgpt限流配置 | |||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制 | |||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制 | |||||
#chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create | |||||
available_setting = { | |||||
# openai api配置 | |||||
"open_ai_api_key": "", # openai api key | |||||
# openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base | |||||
"open_ai_api_base": "https://api.openai.com/v1", | |||||
"proxy": "", # openai使用的代理 | |||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | |||||
"model": "gpt-3.5-turbo", | |||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt | |||||
# Bot触发配置 | |||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | |||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | |||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 | |||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀 | |||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复 | |||||
"group_at_off": False, # 是否关闭群聊时@bot的触发 | |||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 | |||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表 | |||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | |||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | |||||
# chatgpt会话参数 | |||||
"expires_in_seconds": 3600, # 无操作会话的过期时间 | |||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 | |||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | |||||
# chatgpt限流配置 | |||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制 | |||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制 | |||||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create | |||||
"temperature": 0.9, | "temperature": 0.9, | ||||
"top_p": 1, | "top_p": 1, | ||||
"frequency_penalty": 0, | "frequency_penalty": 0, | ||||
"presence_penalty": 0, | "presence_penalty": 0, | ||||
#语音设置 | |||||
"speech_recognition": False, # 是否开启语音识别 | |||||
"group_speech_recognition": False, # 是否开启群组语音识别 | |||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key | |||||
"voice_to_text": "openai", # 语音识别引擎,支持openai和google | |||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu和google | |||||
# 语音设置 | |||||
"speech_recognition": False, # 是否开启语音识别 | |||||
"group_speech_recognition": False, # 是否开启群组语音识别 | |||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key | |||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,google | |||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline) | |||||
# baidu api的配置, 使用百度语音识别和语音合成时需要 | # baidu api的配置, 使用百度语音识别和语音合成时需要 | ||||
'baidu_app_id': "", | |||||
'baidu_api_key': "", | |||||
'baidu_secret_key': "", | |||||
"baidu_app_id": "", | |||||
"baidu_api_key": "", | |||||
"baidu_secret_key": "", | |||||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场 | |||||
"baidu_dev_pid": "1536", | |||||
#服务时间限制,目前支持itchat | |||||
"chat_time_module": False, # 是否开启服务时间限制 | |||||
"chat_start_time": "00:00", # 服务开始时间 | |||||
"chat_stop_time": "24:00", # 服务结束时间 | |||||
# 服务时间限制,目前支持itchat | |||||
"chat_time_module": False, # 是否开启服务时间限制 | |||||
"chat_start_time": "00:00", # 服务开始时间 | |||||
"chat_stop_time": "24:00", # 服务结束时间 | |||||
# itchat的配置 | # itchat的配置 | ||||
"hot_reload": False, # 是否开启热重载 | |||||
"hot_reload": False, # 是否开启热重载 | |||||
# wechaty的配置 | # wechaty的配置 | ||||
"wechaty_puppet_service_token": "", # wechaty的token | |||||
"wechaty_puppet_service_token": "", # wechaty的token | |||||
# chatgpt指令自定义触发词 | # chatgpt指令自定义触发词 | ||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令 | |||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令 | |||||
"channel_type": "wx", # 通道类型,支持wx,wxy和terminal | |||||
} | } | ||||
class Config(dict): | class Config(dict): | ||||
def __getitem__(self, key): | def __getitem__(self, key): | ||||
if key not in available_setting: | if key not in available_setting: | ||||
@@ -82,15 +88,17 @@ class Config(dict): | |||||
return super().__setitem__(key, value) | return super().__setitem__(key, value) | ||||
def get(self, key, default=None): | def get(self, key, default=None): | ||||
try : | |||||
try: | |||||
return self[key] | return self[key] | ||||
except KeyError as e: | except KeyError as e: | ||||
return default | return default | ||||
except Exception as e: | except Exception as e: | ||||
raise e | raise e | ||||
config = Config() | config = Config() | ||||
def load_config(): | def load_config(): | ||||
global config | global config | ||||
config_path = "./config.json" | config_path = "./config.json" | ||||
@@ -109,7 +117,8 @@ def load_config(): | |||||
for name, value in os.environ.items(): | for name, value in os.environ.items(): | ||||
name = name.lower() | name = name.lower() | ||||
if name in available_setting: | if name in available_setting: | ||||
logger.info("[INIT] override config by environ args: {}={}".format(name, value)) | |||||
logger.info( | |||||
"[INIT] override config by environ args: {}={}".format(name, value)) | |||||
try: | try: | ||||
config[name] = eval(value) | config[name] = eval(value) | ||||
except: | except: | ||||
@@ -118,9 +127,8 @@ def load_config(): | |||||
logger.info("[INIT] load config: {}".format(config)) | logger.info("[INIT] load config: {}".format(config)) | ||||
def get_root(): | def get_root(): | ||||
return os.path.dirname(os.path.abspath( __file__ )) | |||||
return os.path.dirname(os.path.abspath(__file__)) | |||||
def read_file(path): | def read_file(path): | ||||
@@ -1,3 +1,15 @@ | |||||
itchat-uos==1.5.0.dev0 | |||||
openai | |||||
wechaty | |||||
openai>=0.27.2 | |||||
baidu_aip>=4.16.10 | |||||
gTTS>=2.3.1 | |||||
HTMLParser>=0.0.2 | |||||
pydub>=0.25.1 | |||||
PyQRCode>=1.2.1 | |||||
pysilk>=0.0.1 | |||||
pysilk_mod>=1.6.0 | |||||
pyttsx3>=2.90 | |||||
requests>=2.28.2 | |||||
SpeechRecognition>=3.10.0 | |||||
tiktoken>=0.3.2 | |||||
webuiapi>=0.6.2 | |||||
wechaty>=0.10.7 | |||||
wechaty_puppet>=0.4.23 |
@@ -0,0 +1,60 @@ | |||||
import wave | |||||
import pysilk | |||||
from pydub import AudioSegment | |||||
def get_pcm_from_wav(wav_path): | |||||
""" | |||||
从 wav 文件中读取 pcm | |||||
:param wav_path: wav 文件路径 | |||||
:returns: pcm 数据 | |||||
""" | |||||
wav = wave.open(wav_path, "rb") | |||||
return wav.readframes(wav.getnframes()) | |||||
def mp3_to_wav(mp3_path, wav_path): | |||||
""" | |||||
把mp3格式转成pcm文件 | |||||
""" | |||||
audio = AudioSegment.from_mp3(mp3_path) | |||||
audio.export(wav_path, format="wav") | |||||
def pcm_to_silk(pcm_path, silk_path): | |||||
""" | |||||
wav 文件转成 silk | |||||
return 声音长度,毫秒 | |||||
""" | |||||
audio = AudioSegment.from_wav(pcm_path) | |||||
wav_data = audio.raw_data | |||||
silk_data = pysilk.encode( | |||||
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate) | |||||
with open(silk_path, "wb") as f: | |||||
f.write(silk_data) | |||||
return audio.duration_seconds * 1000 | |||||
def mp3_to_sil(mp3_path, silk_path): | |||||
""" | |||||
mp3 文件转成 silk | |||||
return 声音长度,毫秒 | |||||
""" | |||||
audio = AudioSegment.from_mp3(mp3_path) | |||||
wav_data = audio.raw_data | |||||
silk_data = pysilk.encode( | |||||
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate) | |||||
# Save the silk file | |||||
with open(silk_path, "wb") as f: | |||||
f.write(silk_data) | |||||
return audio.duration_seconds * 1000 | |||||
def sil_to_wav(silk_path, wav_path, rate: int = 24000): | |||||
""" | |||||
silk 文件转 wav | |||||
""" | |||||
wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate) | |||||
with open(wav_path, "wb") as f: | |||||
f.write(wav_data) |
@@ -8,19 +8,53 @@ from bridge.reply import Reply, ReplyType | |||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
from voice.voice import Voice | from voice.voice import Voice | ||||
from voice.audio_convert import get_pcm_from_wav | |||||
from config import conf | from config import conf | ||||
""" | |||||
百度的语音识别API. | |||||
dev_pid: | |||||
- 1936: 普通话远场 | |||||
- 1536:普通话(支持简单的英文识别) | |||||
- 1537:普通话(纯中文识别) | |||||
- 1737:英语 | |||||
- 1637:粤语 | |||||
- 1837:四川话 | |||||
要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号, | |||||
之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key | |||||
填入 config.json 中. | |||||
baidu_app_id: '' | |||||
baidu_api_key: '' | |||||
baidu_secret_key: '' | |||||
baidu_dev_pid: '1536' | |||||
""" | |||||
class BaiduVoice(Voice): | class BaiduVoice(Voice): | ||||
APP_ID = conf().get('baidu_app_id') | APP_ID = conf().get('baidu_app_id') | ||||
API_KEY = conf().get('baidu_api_key') | API_KEY = conf().get('baidu_api_key') | ||||
SECRET_KEY = conf().get('baidu_secret_key') | SECRET_KEY = conf().get('baidu_secret_key') | ||||
DEV_ID = conf().get('baidu_dev_pid') | |||||
client = AipSpeech(APP_ID, API_KEY, SECRET_KEY) | client = AipSpeech(APP_ID, API_KEY, SECRET_KEY) | ||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
def voiceToText(self, voice_file): | def voiceToText(self, voice_file): | ||||
pass | |||||
# 识别本地文件 | |||||
logger.debug('[Baidu] voice file name={}'.format(voice_file)) | |||||
pcm = get_pcm_from_wav(voice_file) | |||||
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.DEV_ID}) | |||||
if res["err_no"] == 0: | |||||
logger.info("百度语音识别到了:{}".format(res["result"])) | |||||
text = "".join(res["result"]) | |||||
reply = Reply(ReplyType.TEXT, text) | |||||
else: | |||||
logger.info("百度语音识别出错了: {}".format(res["err_msg"])) | |||||
if res["err_msg"] == "request pv too much": | |||||
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费") | |||||
reply = Reply(ReplyType.ERROR, | |||||
"百度语音识别出错了;{0}".format(res["err_msg"])) | |||||
return reply | |||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
result = self.client.synthesis(text, 'zh', 1, { | result = self.client.synthesis(text, 'zh', 1, { | ||||
@@ -30,7 +64,8 @@ class BaiduVoice(Voice): | |||||
fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | ||||
with open(fileName, 'wb') as f: | with open(fileName, 'wb') as f: | ||||
f.write(result) | f.write(result) | ||||
logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | |||||
logger.info( | |||||
'[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | |||||
reply = Reply(ReplyType.VOICE, fileName) | reply = Reply(ReplyType.VOICE, fileName) | ||||
else: | else: | ||||
logger.error('[Baidu] textToVoice error={}'.format(result)) | logger.error('[Baidu] textToVoice error={}'.format(result)) | ||||
@@ -3,12 +3,10 @@ | |||||
google voice service | google voice service | ||||
""" | """ | ||||
import pathlib | |||||
import subprocess | |||||
import time | import time | ||||
from bridge.reply import Reply, ReplyType | |||||
import speech_recognition | import speech_recognition | ||||
import pyttsx3 | |||||
from gtts import gTTS | |||||
from bridge.reply import Reply, ReplyType | |||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
from voice.voice import Voice | from voice.voice import Voice | ||||
@@ -16,22 +14,12 @@ from voice.voice import Voice | |||||
class GoogleVoice(Voice): | class GoogleVoice(Voice): | ||||
recognizer = speech_recognition.Recognizer() | recognizer = speech_recognition.Recognizer() | ||||
engine = pyttsx3.init() | |||||
def __init__(self): | def __init__(self): | ||||
# 语速 | |||||
self.engine.setProperty('rate', 125) | |||||
# 音量 | |||||
self.engine.setProperty('volume', 1.0) | |||||
# 0为男声,1为女声 | |||||
voices = self.engine.getProperty('voices') | |||||
self.engine.setProperty('voice', voices[1].id) | |||||
pass | |||||
def voiceToText(self, voice_file): | def voiceToText(self, voice_file): | ||||
new_file = voice_file.replace('.mp3', '.wav') | |||||
subprocess.call('ffmpeg -i ' + voice_file + | |||||
' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True) | |||||
with speech_recognition.AudioFile(new_file) as source: | |||||
with speech_recognition.AudioFile(voice_file) as source: | |||||
audio = self.recognizer.record(source) | audio = self.recognizer.record(source) | ||||
try: | try: | ||||
text = self.recognizer.recognize_google(audio, language='zh-CN') | text = self.recognizer.recognize_google(audio, language='zh-CN') | ||||
@@ -46,12 +34,12 @@ class GoogleVoice(Voice): | |||||
return reply | return reply | ||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
try: | try: | ||||
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||||
self.engine.save_to_file(text, textFile) | |||||
self.engine.runAndWait() | |||||
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||||
tts = gTTS(text=text, lang='zh') | |||||
tts.save(mp3File) | |||||
logger.info( | logger.info( | ||||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile)) | |||||
reply = Reply(ReplyType.VOICE, textFile) | |||||
'[Google] textToVoice text={} voice file name={}'.format(text, mp3File)) | |||||
reply = Reply(ReplyType.VOICE, mp3File) | |||||
except Exception as e: | except Exception as e: | ||||
reply = Reply(ReplyType.ERROR, str(e)) | reply = Reply(ReplyType.ERROR, str(e)) | ||||
finally: | finally: | ||||
@@ -28,6 +28,3 @@ class OpenaiVoice(Voice): | |||||
reply = Reply(ReplyType.ERROR, str(e)) | reply = Reply(ReplyType.ERROR, str(e)) | ||||
finally: | finally: | ||||
return reply | return reply | ||||
def textToVoice(self, text): | |||||
pass |
@@ -0,0 +1,37 @@ | |||||
""" | |||||
pytts voice service (offline) | |||||
""" | |||||
import time | |||||
import pyttsx3 | |||||
from bridge.reply import Reply, ReplyType | |||||
from common.log import logger | |||||
from common.tmp_dir import TmpDir | |||||
from voice.voice import Voice | |||||
class PyttsVoice(Voice): | |||||
engine = pyttsx3.init() | |||||
def __init__(self): | |||||
# 语速 | |||||
self.engine.setProperty('rate', 125) | |||||
# 音量 | |||||
self.engine.setProperty('volume', 1.0) | |||||
for voice in self.engine.getProperty('voices'): | |||||
if "Chinese" in voice.name: | |||||
self.engine.setProperty('voice', voice.id) | |||||
def textToVoice(self, text): | |||||
try: | |||||
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||||
self.engine.save_to_file(text, mp3File) | |||||
self.engine.runAndWait() | |||||
logger.info( | |||||
'[Pytts] textToVoice text={} voice file name={}'.format(text, mp3File)) | |||||
reply = Reply(ReplyType.VOICE, mp3File) | |||||
except Exception as e: | |||||
reply = Reply(ReplyType.ERROR, str(e)) | |||||
finally: | |||||
return reply |
@@ -17,4 +17,7 @@ def create_voice(voice_type): | |||||
elif voice_type == 'openai': | elif voice_type == 'openai': | ||||
from voice.openai.openai_voice import OpenaiVoice | from voice.openai.openai_voice import OpenaiVoice | ||||
return OpenaiVoice() | return OpenaiVoice() | ||||
elif voice_type == 'pytts': | |||||
from voice.pytts.pytts_voice import PyttsVoice | |||||
return PyttsVoice() | |||||
raise RuntimeError | raise RuntimeError |