@@ -1,11 +1,12 @@ | |||||
from bot import bot_factory | |||||
from bot.factory import create_bot | |||||
from bridge.context import Context | from bridge.context import Context | ||||
from bridge.reply import Reply | from bridge.reply import Reply | ||||
from common import const | from common import const | ||||
from common.log import logger | from common.log import logger | ||||
from common.singleton import singleton | from common.singleton import singleton | ||||
from config import conf | from config import conf | ||||
from voice import voice_factory | |||||
from translate.factory import create_translator | |||||
from voice.factory import create_voice | |||||
@singleton | @singleton | ||||
@@ -15,6 +16,7 @@ class Bridge(object): | |||||
"chat": const.CHATGPT, | "chat": const.CHATGPT, | ||||
"voice_to_text": conf().get("voice_to_text", "openai"), | "voice_to_text": conf().get("voice_to_text", "openai"), | ||||
"text_to_voice": conf().get("text_to_voice", "google"), | "text_to_voice": conf().get("text_to_voice", "google"), | ||||
"translate": conf().get("translate", "baidu"), | |||||
} | } | ||||
model_type = conf().get("model") | model_type = conf().get("model") | ||||
if model_type in ["text-davinci-003"]: | if model_type in ["text-davinci-003"]: | ||||
@@ -27,11 +29,13 @@ class Bridge(object): | |||||
if self.bots.get(typename) is None: | if self.bots.get(typename) is None: | ||||
logger.info("create bot {} for {}".format(self.btype[typename], typename)) | logger.info("create bot {} for {}".format(self.btype[typename], typename)) | ||||
if typename == "text_to_voice": | if typename == "text_to_voice": | ||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename]) | |||||
self.bots[typename] = create_voice(self.btype[typename]) | |||||
elif typename == "voice_to_text": | elif typename == "voice_to_text": | ||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename]) | |||||
self.bots[typename] = create_voice(self.btype[typename]) | |||||
elif typename == "chat": | elif typename == "chat": | ||||
self.bots[typename] = bot_factory.create_bot(self.btype[typename]) | |||||
self.bots[typename] = create_bot(self.btype[typename]) | |||||
elif typename == "translate": | |||||
self.bots[typename] = create_translator(self.btype[typename]) | |||||
return self.bots[typename] | return self.bots[typename] | ||||
def get_bot_type(self, typename): | def get_bot_type(self, typename): | ||||
@@ -66,6 +66,11 @@ available_setting = { | |||||
"chat_time_module": False, # 是否开启服务时间限制 | "chat_time_module": False, # 是否开启服务时间限制 | ||||
"chat_start_time": "00:00", # 服务开始时间 | "chat_start_time": "00:00", # 服务开始时间 | ||||
"chat_stop_time": "24:00", # 服务结束时间 | "chat_stop_time": "24:00", # 服务结束时间 | ||||
# 翻译api | |||||
"translate": "baidu", # 翻译api,支持baidu | |||||
# baidu翻译api的配置 | |||||
"baidu_translate_app_id": "", # 百度翻译api的appid | |||||
"baidu_translate_app_key": "", # 百度翻译api的秘钥 | |||||
# itchat的配置 | # itchat的配置 | ||||
"hot_reload": False, # 是否开启热重载 | "hot_reload": False, # 是否开启热重载 | ||||
# wechaty的配置 | # wechaty的配置 | ||||
@@ -0,0 +1,49 @@ | |||||
# -*- coding: utf-8 -*- | |||||
import random | |||||
from hashlib import md5 | |||||
import requests | |||||
from config import conf | |||||
from translate.translator import Translator | |||||
# from langid import classify | |||||
class BaiduTranslator(Translator): | |||||
def __init__(self) -> None: | |||||
super().__init__() | |||||
endpoint = "http://api.fanyi.baidu.com" | |||||
path = "/api/trans/vip/translate" | |||||
self.url = endpoint + path | |||||
self.appid = conf().get("baidu_translate_app_id") | |||||
self.appkey = conf().get("baidu_translate_app_key") | |||||
# For list of language codes, please refer to `https://api.fanyi.baidu.com/doc/21`, need to convert to ISO 639-1 codes | |||||
def translate(self, query: str, from_lang: str = "", to_lang: str = "en") -> str: | |||||
if not from_lang: | |||||
from_lang = "auto" # baidu suppport auto detect | |||||
# from_lang = classify(query)[0] | |||||
salt = random.randint(32768, 65536) | |||||
sign = self.make_md5(self.appid + query + str(salt) + self.appkey) | |||||
headers = {"Content-Type": "application/x-www-form-urlencoded"} | |||||
payload = {"appid": self.appid, "q": query, "from": from_lang, "to": to_lang, "salt": salt, "sign": sign} | |||||
retry_cnt = 3 | |||||
while retry_cnt: | |||||
r = requests.post(self.url, params=payload, headers=headers) | |||||
result = r.json() | |||||
if errcode := result.get("error_code", "52000") != "52000": | |||||
if errcode == "52001" or errcode == "52002": | |||||
retry_cnt -= 1 | |||||
continue | |||||
else: | |||||
raise Exception(result["error_msg"]) | |||||
else: | |||||
break | |||||
text = "\n".join([item["dst"] for item in result["trans_result"]]) | |||||
return text | |||||
def make_md5(self, s, encoding="utf-8"): | |||||
return md5(s.encode(encoding)).hexdigest() |
@@ -0,0 +1,6 @@ | |||||
def create_translator(voice_type): | |||||
if voice_type == "baidu": | |||||
from translate.baidu.baidu_translate import BaiduTranslator | |||||
return BaiduTranslator() | |||||
raise RuntimeError |
@@ -0,0 +1,12 @@ | |||||
""" | |||||
Voice service abstract class | |||||
""" | |||||
class Translator(object): | |||||
# please use https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes to specify language | |||||
def translate(self, query: str, from_lang: str = "", to_lang: str = "en") -> str: | |||||
""" | |||||
Translate text from one language to another | |||||
""" | |||||
raise NotImplementedError |