소스 검색

Merge branch 'master' of github.com:zwssunny/chatgpt-on-wechat into zwssunny-master

master
lanvent 1 년 전
부모
커밋
80bf6a0c7a
9개의 변경된 파일292개의 추가작업 그리고 187개의 파일을 삭제
  1. +3
    -3
      app.py
  2. +3
    -3
      bot/bot_factory.py
  3. +85
    -48
      channel/wechat/wechat_channel.py
  4. +17
    -66
      channel/wechat/wechaty_channel.py
  5. +60
    -52
      config.py
  6. +13
    -0
      requirements.txt
  7. +60
    -0
      voice/audio_convert.py
  8. +38
    -3
      voice/baidu/baidu_voice.py
  9. +13
    -12
      voice/google/google_voice.py

+ 3
- 3
app.py 파일 보기

@@ -1,6 +1,6 @@
# encoding:utf-8

import config
from config import conf, load_config
from channel import channel_factory
from common.log import logger

@@ -9,10 +9,10 @@ from plugins import *
def run():
try:
# load config
config.load_config()
load_config()

# create channel
channel_name='wx'
channel_name=conf().get('channel_type', 'wx')
channel = channel_factory.create_channel(channel_name)
if channel_name=='wx':
PluginManager().load_plugins()


+ 3
- 3
bot/bot_factory.py 파일 보기

@@ -6,9 +6,9 @@ from common import const

def create_bot(bot_type):
"""
create a channel instance
:param channel_type: channel type code
:return: channel instance
create a bot_type instance
:param bot_type: bot type code
:return: bot instance
"""
if bot_type == const.BAIDU:
# Baidu Unit对话接口


+ 85
- 48
channel/wechat/wechat_channel.py 파일 보기

@@ -5,6 +5,9 @@ wechat channel
"""

import os
import requests
import io
import time
from lib import itchat
import json
from lib.itchat.content import *
@@ -17,17 +20,18 @@ from common.tmp_dir import TmpDir
from config import conf
from common.time_check import time_checker
from plugins import *
import requests
import io
import time
from voice.audio_convert import mp3_to_wav


thread_pool = ThreadPoolExecutor(max_workers=8)


def thread_pool_callback(worker):
worker_exception = worker.exception()
if worker_exception:
logger.exception("Worker return exception: {}".format(worker_exception))


@itchat.msg_register(TEXT)
def handler_single_msg(msg):
WechatChannel().handle_text(msg)
@@ -48,6 +52,8 @@ def handler_group_voice(msg):
WechatChannel().handle_group_voice(msg)
return None



class WechatChannel(Channel):
def __init__(self):
self.userName = None
@@ -55,14 +61,15 @@ class WechatChannel(Channel):

def startup(self):

itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
# login by scan QRCode
hotReload = conf().get('hot_reload', False)
try:
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
except Exception as e:
if hotReload:
logger.error("Hot reload failed, try to login without hot reload")
logger.error(
"Hot reload failed, try to login without hot reload")
itchat.logout()
os.remove("itchat.pkl")
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
@@ -105,7 +112,8 @@ class WechatChannel(Channel):

@time_checker
def handle_text(self, msg):
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
logger.debug("[WX]receive text msg: " +
json.dumps(msg, ensure_ascii=False))
content = msg['Text']
from_user_id = msg['FromUserName']
to_user_id = msg['ToUserName'] # 接收人id
@@ -119,7 +127,7 @@ class WechatChannel(Channel):
other_user_id = from_user_id
create_time = msg['CreateTime'] # 消息时间
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message skipped")
return
if "」\n- - - - - - - - - - - - - - -" in content:
@@ -130,9 +138,11 @@ class WechatChannel(Channel):
elif match_prefix is None:
return
context = Context()
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
context.kwargs = {'isgroup': False, 'msg': msg,
'receiver': other_user_id, 'session_id': other_user_id}

img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
img_match_prefix = check_prefix(
content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip()
context.type = ContextType.IMAGE_CREATE
@@ -140,15 +150,17 @@ class WechatChannel(Channel):
context.type = ContextType.TEXT

context.content = content
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
thread_pool.submit(self.handle, context).add_done_callback(
thread_pool_callback)

@time_checker
def handle_group(self, msg):
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
logger.debug("[WX]receive group msg: " +
json.dumps(msg, ensure_ascii=False))
group_name = msg['User'].get('NickName', None)
group_id = msg['User'].get('UserName', None)
create_time = msg['CreateTime'] # 消息时间
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history group message skipped")
return
if not group_name:
@@ -166,12 +178,14 @@ class WechatChannel(Channel):
return ""
config = conf()
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
or check_contain(origin_content, config.get('group_chat_keyword'))
or check_contain(origin_content, config.get('group_chat_keyword'))
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
context = Context()
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id}
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
context.kwargs = {'isgroup': True,
'msg': msg, 'receiver': group_id}

img_match_prefix = check_prefix(
content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip()
context.type = ContextType.IMAGE_CREATE
@@ -187,7 +201,8 @@ class WechatChannel(Channel):
else:
context['session_id'] = msg['ActualUserName']

thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
thread_pool.submit(self.handle, context).add_done_callback(
thread_pool_callback)

def handle_group_voice(self, msg):
if conf().get('group_speech_recognition', False) != True:
@@ -217,7 +232,7 @@ class WechatChannel(Channel):
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)

# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply : Reply, receiver):
def send(self, reply: Reply, receiver):
if reply.type == ReplyType.TEXT:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
@@ -226,8 +241,9 @@ class WechatChannel(Channel):
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.VOICE:
itchat.send_file(reply.content, toUserName=receiver)
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
logger.info('[WX] sendFile={}, receiver={}'.format(
reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
@@ -235,8 +251,9 @@ class WechatChannel(Channel):
image_storage.write(block)
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
logger.info('[WX] sendImage url={}, receiver={}'.format(
img_url, receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
@@ -250,9 +267,10 @@ class WechatChannel(Channel):
reply = Reply()

logger.debug('[WX] ready to handle context: {}'.format(context))
# reply的构建步骤
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass():
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
@@ -260,22 +278,35 @@ class WechatChannel(Channel):
reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息
msg = context['msg']
file_name = TmpDir().path() + context.content
msg.download(file_name)
reply = super().build_voice_to_text(file_name)
if reply.type == ReplyType.TEXT:
content = reply.content # 语音转文字后,将文字内容作为新的context
# 如果是群消息,判断是否触发关键字
if context['isgroup']:
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
match_contain = check_contain(content, conf().get('group_chat_keyword'))
logger.debug('[WX] group chat prefix match: {}'.format(match_prefix))
if match_prefix is None and match_contain is None:
return
mp3_path = TmpDir().path() + context.content
msg.download(mp3_path)
# mp3转wav
wav_path = os.path.splitext(mp3_path)[0] + '.wav'
mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path)
# 语音识别
reply = super().build_voice_to_text(wav_path)
# 删除临时文件
os.remove(wav_path)
os.remove(mp3_path)
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
content = reply.content # 语音转文字后,将文字内容作为新的context
context.type = ContextType.TEXT
if (context["isgroup"] == True):
# 校验关键字
match_prefix = check_prefix(content, conf().get('group_chat_prefix')) \
or check_contain(content, conf().get('group_chat_keyword'))
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
if match_prefix is not None:
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
prefixes = conf().get('group_chat_prefix')
for prefix in prefixes:
if content.startswith(prefix):
content = content.replace(prefix, '', 1).strip()
break
else:
if match_prefix:
content = content.replace(match_prefix, '', 1).strip()
logger.info("[WX]receive voice check prefix: " + 'False')
return
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip()
@@ -292,16 +323,19 @@ class WechatChannel(Channel):
return

logger.debug('[WX] ready to decorate reply: {}'.format(reply))
# reply的包装步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass() and reply and reply.type:
if reply.type == ReplyType.TEXT:
reply_text = reply.content
if context['isgroup']:
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
reply_text = '@' + \
context['msg']['ActualNickName'] + \
' ' + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
else:
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
@@ -311,15 +345,18 @@ class WechatChannel(Channel):
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass
else:
logger.error('[WX] unknown reply type: {}'.format(reply.type))
logger.error(
'[WX] unknown reply type: {}'.format(reply.type))
return

# reply的发送步骤
# reply的发送步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass() and reply and reply.type:
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
logger.debug('[WX] ready to send reply: {} to {}'.format(
reply, context['receiver']))
self.send(reply, context['receiver'])

def check_prefix(content, prefix_list):


+ 17
- 66
channel/wechat/wechaty_channel.py 파일 보기

@@ -4,25 +4,19 @@
wechaty channel
Python Wechaty - https://github.com/wechaty/python-wechaty
"""
import io
import os
import json
import time
import asyncio
import requests
import pysilk
import wave
from pydub import AudioSegment
from typing import Optional, Union
from bridge.context import Context, ContextType
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore
from wechaty import Wechaty, Contact
from wechaty.user import Message, Room, MiniProgram, UrlLink
from wechaty.user import Message, MiniProgram, UrlLink
from channel.channel import Channel
from common.log import logger
from common.tmp_dir import TmpDir
from config import conf
from voice.audio_convert import sil_to_wav, mp3_to_sil

class WechatyChannel(Channel):

@@ -50,8 +44,9 @@ class WechatyChannel(Channel):

async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None,
data: Optional[str] = None):
contact = self.Contact.load(self.contact_id)
logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
pass
# contact = self.Contact.load(self.contact_id)
# logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
# print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}')

async def on_message(self, msg: Message):
@@ -67,7 +62,7 @@ class WechatyChannel(Channel):
content = msg.text()
mention_content = await msg.mention_text() # 返回过滤掉@name后的消息
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
conversation: Union[Room, Contact] = from_contact if room is None else room
# conversation: Union[Room, Contact] = from_contact if room is None else room

if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
if not msg.is_self() and match_prefix is not None:
@@ -102,21 +97,8 @@ class WechatyChannel(Channel):
await voice_file.to_file(silk_file)
logger.info("[WX]receive voice file: " + silk_file)
# 将文件转成wav格式音频
wav_file = silk_file.replace(".slk", ".wav")
with open(silk_file, 'rb') as f:
silk_data = f.read()
pcm_data = pysilk.decode(silk_data)

with wave.open(wav_file, 'wb') as wav_data:
wav_data.setnchannels(1)
wav_data.setsampwidth(2)
wav_data.setframerate(24000)
wav_data.writeframes(pcm_data)
if os.path.exists(wav_file):
converter_state = "true" # 转换wav成功
else:
converter_state = "false" # 转换wav失败
logger.info("[WX]receive voice converter: " + converter_state)
wav_file = os.path.splitext(silk_file)[0] + '.wav'
sil_to_wav(silk_file, wav_file)
# 语音识别为文本
query = super().build_voice_to_text(wav_file).content
# 交验关键字
@@ -183,21 +165,8 @@ class WechatyChannel(Channel):
await voice_file.to_file(silk_file)
logger.info("[WX]receive voice file: " + silk_file)
# 将文件转成wav格式音频
wav_file = silk_file.replace(".slk", ".wav")
with open(silk_file, 'rb') as f:
silk_data = f.read()
pcm_data = pysilk.decode(silk_data)

with wave.open(wav_file, 'wb') as wav_data:
wav_data.setnchannels(1)
wav_data.setsampwidth(2)
wav_data.setframerate(24000)
wav_data.writeframes(pcm_data)
if os.path.exists(wav_file):
converter_state = "true" # 转换wav成功
else:
converter_state = "false" # 转换wav失败
logger.info("[WX]receive voice converter: " + converter_state)
wav_file = os.path.splitext(silk_file)[0] + '.wav'
sil_to_wav(silk_file, wav_file)
# 语音识别为文本
query = super().build_voice_to_text(wav_file).content
# 校验关键字
@@ -260,21 +229,12 @@ class WechatyChannel(Channel):
if reply_text:
# 转换 mp3 文件为 silk 格式
mp3_file = super().build_text_to_voice(reply_text).content
silk_file = mp3_file.replace(".mp3", ".silk")
# Load the MP3 file
audio = AudioSegment.from_file(mp3_file, format="mp3")
# Convert to WAV format
audio = audio.set_frame_rate(24000).set_channels(1)
wav_data = audio.raw_data
sample_width = audio.sample_width
# Encode to SILK format
silk_data = pysilk.encode(wav_data, 24000)
# Save the silk file
with open(silk_file, "wb") as f:
f.write(silk_data)
silk_file = os.path.splitext(mp3_file)[0] + '.sil'
voiceLength = mp3_to_sil(mp3_file, silk_file)
# 发送语音
t = int(time.time())
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
file_box = FileBox.from_file(silk_file, name=str(t) + '.sil')
file_box.metadata = {'voiceLength': voiceLength}
await self.send(file_box, reply_user_id)
# 清除缓存文件
os.remove(mp3_file)
@@ -337,21 +297,12 @@ class WechatyChannel(Channel):
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
# 转换 mp3 文件为 silk 格式
mp3_file = super().build_text_to_voice(reply_text).content
silk_file = mp3_file.replace(".mp3", ".silk")
# Load the MP3 file
audio = AudioSegment.from_file(mp3_file, format="mp3")
# Convert to WAV format
audio = audio.set_frame_rate(24000).set_channels(1)
wav_data = audio.raw_data
sample_width = audio.sample_width
# Encode to SILK format
silk_data = pysilk.encode(wav_data, 24000)
# Save the silk file
with open(silk_file, "wb") as f:
f.write(silk_data)
silk_file = os.path.splitext(mp3_file)[0] + '.sil'
voiceLength = mp3_to_sil(mp3_file, silk_file)
# 发送语音
t = int(time.time())
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
file_box.metadata = {'voiceLength': voiceLength}
await self.send_group(file_box, group_id)
# 清除缓存文件
os.remove(mp3_file)


+ 60
- 52
config.py 파일 보기

@@ -5,71 +5,77 @@ import os
from common.log import logger

# 将所有可用的配置项写在字典里, 请使用小写字母
available_setting ={
#openai api配置
"open_ai_api_key": "", # openai api key
"open_ai_api_base": "https://api.openai.com/v1", # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
"proxy": "", # openai使用的代理
"model": "gpt-3.5-turbo", # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
"use_azure_chatgpt": False, # 是否使用azure的chatgpt

#Bot触发配置
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
"group_at_off": False, # 是否关闭群聊时@bot的触发
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
#chatgpt会话参数
"expires_in_seconds": 3600, # 无操作会话的过期时间
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
#chatgpt限流配置
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
"rate_limit_dalle": 50, # openai dalle的调用频率限制


#chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
available_setting = {
# openai api配置
"open_ai_api_key": "", # openai api key
# openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
"open_ai_api_base": "https://api.openai.com/v1",
"proxy": "", # openai使用的代理
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
"model": "gpt-3.5-turbo",
"use_azure_chatgpt": False, # 是否使用azure的chatgpt

# Bot触发配置
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
"group_at_off": False, # 是否关闭群聊时@bot的触发
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀

# chatgpt会话参数
"expires_in_seconds": 3600, # 无操作会话的过期时间
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数

# chatgpt限流配置
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
"rate_limit_dalle": 50, # openai dalle的调用频率限制


# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,

#语音设置
"speech_recognition": False, # 是否开启语音识别
"group_speech_recognition": False, # 是否开启群组语音识别
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
"voice_to_text": "openai", # 语音识别引擎,支持openai和google
"text_to_voice": "baidu", # 语音合成引擎,支持baidu和google
# 语音设置
"speech_recognition": False, # 是否开启语音识别
"group_speech_recognition": False, # 是否开启群组语音识别
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
"voice_to_text": "openai", # 语音识别引擎,支持openai和google
"text_to_voice": "baidu", # 语音合成引擎,支持baidu和google

# baidu api的配置, 使用百度语音识别和语音合成时需要
'baidu_app_id': "",
'baidu_api_key': "",
'baidu_secret_key': "",
"baidu_app_id": "",
"baidu_api_key": "",
"baidu_secret_key": "",
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
"baidu_dev_pid": "1536",

#服务时间限制,目前支持itchat
"chat_time_module": False, # 是否开启服务时间限制
"chat_start_time": "00:00", # 服务开始时间
"chat_stop_time": "24:00", # 服务结束时间
# 服务时间限制,目前支持itchat
"chat_time_module": False, # 是否开启服务时间限制
"chat_start_time": "00:00", # 服务开始时间
"chat_stop_time": "24:00", # 服务结束时间

# itchat的配置
"hot_reload": False, # 是否开启热重载
"hot_reload": False, # 是否开启热重载

# wechaty的配置
"wechaty_puppet_service_token": "", # wechaty的token
"wechaty_puppet_service_token": "", # wechaty的token

# chatgpt指令自定义触发词
"clear_memory_commands": ['#清除记忆'], # 重置会话指令
"clear_memory_commands": ['#清除记忆'], # 重置会话指令
"channel_type": "wx", # 通道类型,支持wx,wxy和terminal


}


class Config(dict):
def __getitem__(self, key):
if key not in available_setting:
@@ -82,15 +88,17 @@ class Config(dict):
return super().__setitem__(key, value)

def get(self, key, default=None):
try :
try:
return self[key]
except KeyError as e:
return default
except Exception as e:
raise e


config = Config()


def load_config():
global config
config_path = "./config.json"
@@ -109,7 +117,8 @@ def load_config():
for name, value in os.environ.items():
name = name.lower()
if name in available_setting:
logger.info("[INIT] override config by environ args: {}={}".format(name, value))
logger.info(
"[INIT] override config by environ args: {}={}".format(name, value))
try:
config[name] = eval(value)
except:
@@ -118,9 +127,8 @@ def load_config():
logger.info("[INIT] load config: {}".format(config))



def get_root():
return os.path.dirname(os.path.abspath( __file__ ))
return os.path.dirname(os.path.abspath(__file__))


def read_file(path):


+ 13
- 0
requirements.txt 파일 보기

@@ -1,3 +1,16 @@
itchat-uos==1.5.0.dev0
openai
wechaty
tiktoken
whisper
baidu-aip
chardet
ffmpy
pydub
pilk
pysilk
pysilk-mod
wave
SpeechRecognition
pyttsx3
gTTS

+ 60
- 0
voice/audio_convert.py 파일 보기

@@ -0,0 +1,60 @@
import wave
import pysilk
from pydub import AudioSegment


def get_pcm_from_wav(wav_path):
"""
从 wav 文件中读取 pcm

:param wav_path: wav 文件路径
:returns: pcm 数据
"""
wav = wave.open(wav_path, "rb")
return wav.readframes(wav.getnframes())


def mp3_to_wav(mp3_path, wav_path):
"""
把mp3格式转成pcm文件
"""
audio = AudioSegment.from_mp3(mp3_path)
audio.export(wav_path, format="wav")


def pcm_to_silk(pcm_path, silk_path):
"""
wav 文件转成 silk
return 声音长度,毫秒
"""
audio = AudioSegment.from_wav(pcm_path)
wav_data = audio.raw_data
silk_data = pysilk.encode(
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
with open(silk_path, "wb") as f:
f.write(silk_data)
return audio.duration_seconds * 1000


def mp3_to_sil(mp3_path, silk_path):
"""
mp3 文件转成 silk
return 声音长度,毫秒
"""
audio = AudioSegment.from_mp3(mp3_path)
wav_data = audio.raw_data
silk_data = pysilk.encode(
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
# Save the silk file
with open(silk_path, "wb") as f:
f.write(silk_data)
return audio.duration_seconds * 1000


def sil_to_wav(silk_path, wav_path, rate: int = 24000):
"""
silk 文件转 wav
"""
wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate)
with open(wav_path, "wb") as f:
f.write(wav_data)

+ 38
- 3
voice/baidu/baidu_voice.py 파일 보기

@@ -8,19 +8,53 @@ from bridge.reply import Reply, ReplyType
from common.log import logger
from common.tmp_dir import TmpDir
from voice.voice import Voice
from voice.audio_convert import get_pcm_from_wav
from config import conf
"""
百度的语音识别API.
dev_pid:
- 1936: 普通话远场
- 1536:普通话(支持简单的英文识别)
- 1537:普通话(纯中文识别)
- 1737:英语
- 1637:粤语
- 1837:四川话
要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号,
之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key
填入 config.json 中.
baidu_app_id: ''
baidu_api_key: ''
baidu_secret_key: ''
baidu_dev_pid: '1536'
"""


class BaiduVoice(Voice):
APP_ID = conf().get('baidu_app_id')
API_KEY = conf().get('baidu_api_key')
SECRET_KEY = conf().get('baidu_secret_key')
DEV_ID = conf().get('baidu_dev_pid')
client = AipSpeech(APP_ID, API_KEY, SECRET_KEY)
def __init__(self):
pass

def voiceToText(self, voice_file):
pass
# 识别本地文件
logger.debug('[Baidu] voice file name={}'.format(voice_file))
pcm = get_pcm_from_wav(voice_file)
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.DEV_ID})
if res["err_no"] == 0:
logger.info("百度语音识别到了:{}".format(res["result"]))
text = "".join(res["result"])
reply = Reply(ReplyType.TEXT, text)
else:
logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
if res["err_msg"] == "request pv too much":
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
reply = Reply(ReplyType.ERROR,
"百度语音识别出错了;{0}".format(res["err_msg"]))
return reply

def textToVoice(self, text):
result = self.client.synthesis(text, 'zh', 1, {
@@ -30,7 +64,8 @@ class BaiduVoice(Voice):
fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
with open(fileName, 'wb') as f:
f.write(result)
logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
logger.info(
'[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
reply = Reply(ReplyType.VOICE, fileName)
else:
logger.error('[Baidu] textToVoice error={}'.format(result))


+ 13
- 12
voice/google/google_voice.py 파일 보기

@@ -3,12 +3,11 @@
google voice service
"""

import pathlib
import subprocess
import time
from bridge.reply import Reply, ReplyType
import speech_recognition
import pyttsx3
from gtts import gTTS
from bridge.reply import Reply, ReplyType
from common.log import logger
from common.tmp_dir import TmpDir
from voice.voice import Voice
@@ -28,10 +27,10 @@ class GoogleVoice(Voice):
self.engine.setProperty('voice', voices[1].id)

def voiceToText(self, voice_file):
new_file = voice_file.replace('.mp3', '.wav')
subprocess.call('ffmpeg -i ' + voice_file +
' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True)
with speech_recognition.AudioFile(new_file) as source:
# new_file = voice_file.replace('.mp3', '.wav')
# subprocess.call('ffmpeg -i ' + voice_file +
# ' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True)
with speech_recognition.AudioFile(voice_file) as source:
audio = self.recognizer.record(source)
try:
text = self.recognizer.recognize_google(audio, language='zh-CN')
@@ -46,12 +45,14 @@ class GoogleVoice(Voice):
return reply
def textToVoice(self, text):
try:
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
self.engine.save_to_file(text, textFile)
self.engine.runAndWait()
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
# self.engine.save_to_file(text, textFile)
# self.engine.runAndWait()
tts = gTTS(text=text, lang='zh')
tts.save(mp3File)
logger.info(
'[Google] textToVoice text={} voice file name={}'.format(text, textFile))
reply = Reply(ReplyType.VOICE, textFile)
'[Google] textToVoice text={} voice file name={}'.format(text, mp3File))
reply = Reply(ReplyType.VOICE, mp3File)
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))
finally:


Loading…
취소
저장