@@ -3,8 +3,12 @@ Auto-replay chat robot abstract class | |||||
""" | """ | ||||
from bridge.context import Context | |||||
from bridge.reply import Reply | |||||
class Bot(object): | class Bot(object): | ||||
def reply(self, query, context=None): | |||||
def reply(self, query, context : Context =None) -> Reply: | |||||
""" | """ | ||||
bot auto-reply content | bot auto-reply content | ||||
:param req: received message | :param req: received message | ||||
@@ -1,6 +1,8 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
from config import conf, load_config | from config import conf, load_config | ||||
from common.log import logger | from common.log import logger | ||||
from common.expired_dict import ExpiredDict | from common.expired_dict import ExpiredDict | ||||
@@ -19,22 +21,19 @@ class ChatGPTBot(Bot): | |||||
def reply(self, query, context=None): | def reply(self, query, context=None): | ||||
# acquire reply content | # acquire reply content | ||||
if context['type'] == 'TEXT': | |||||
if context.type == ContextType.TEXT: | |||||
logger.info("[OPEN_AI] query={}".format(query)) | logger.info("[OPEN_AI] query={}".format(query)) | ||||
session_id = context['session_id'] | session_id = context['session_id'] | ||||
reply = None | reply = None | ||||
if query == '#清除记忆': | if query == '#清除记忆': | ||||
self.sessions.clear_session(session_id) | self.sessions.clear_session(session_id) | ||||
reply = {'type': 'INFO', 'content': '记忆已清除'} | |||||
reply = Reply(ReplyType.INFO, '记忆已清除') | |||||
elif query == '#清除所有': | elif query == '#清除所有': | ||||
self.sessions.clear_all_session() | self.sessions.clear_all_session() | ||||
reply = {'type': 'INFO', 'content': '所有人记忆已清除'} | |||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | |||||
elif query == '#更新配置': | elif query == '#更新配置': | ||||
load_config() | load_config() | ||||
reply = {'type': 'INFO', 'content': '配置已更新'} | |||||
elif query == '#DEBUG': | |||||
logger.setLevel('DEBUG') | |||||
reply = {'type': 'INFO', 'content': 'DEBUG模式已开启'} | |||||
reply = Reply(ReplyType.INFO, '配置已更新') | |||||
if reply: | if reply: | ||||
return reply | return reply | ||||
session = self.sessions.build_session_query(query, session_id) | session = self.sessions.build_session_query(query, session_id) | ||||
@@ -47,25 +46,25 @@ class ChatGPTBot(Bot): | |||||
reply_content = self.reply_text(session, session_id, 0) | reply_content = self.reply_text(session, session_id, 0) | ||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) | logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) | ||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: | if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: | ||||
reply = {'type': 'ERROR', 'content': reply_content['content']} | |||||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||||
elif reply_content["completion_tokens"] > 0: | elif reply_content["completion_tokens"] > 0: | ||||
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) | self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) | ||||
reply={'type':'TEXT', 'content':reply_content["content"]} | |||||
reply = Reply(ReplyType.TEXT, reply_content["content"]) | |||||
else: | else: | ||||
reply = {'type': 'ERROR', 'content': reply_content['content']} | |||||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||||
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) | logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) | ||||
return reply | return reply | ||||
elif context['type'] == 'IMAGE_CREATE': | |||||
elif context.type == ContextType.IMAGE_CREATE: | |||||
ok, retstring = self.create_img(query, 0) | ok, retstring = self.create_img(query, 0) | ||||
reply = None | reply = None | ||||
if ok: | if ok: | ||||
reply = {'type': 'IMAGE_URL', 'content': retstring} | |||||
reply = Reply(ReplyType.IMAGE_URL, retstring) | |||||
else: | else: | ||||
reply = {'type': 'ERROR', 'content': retstring} | |||||
reply = Reply(ReplyType.ERROR, retstring) | |||||
return reply | return reply | ||||
else: | else: | ||||
reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])} | |||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) | |||||
return reply | return reply | ||||
def reply_text(self, session, session_id, retry_count=0) -> dict: | def reply_text(self, session, session_id, retry_count=0) -> dict: | ||||
@@ -1,3 +1,5 @@ | |||||
from bridge.context import Context | |||||
from bridge.reply import Reply | |||||
from common.log import logger | from common.log import logger | ||||
from bot import bot_factory | from bot import bot_factory | ||||
from common.singleton import singleton | from common.singleton import singleton | ||||
@@ -28,16 +30,13 @@ class Bridge(object): | |||||
def get_bot_type(self,typename): | def get_bot_type(self,typename): | ||||
return self.btype[typename] | return self.btype[typename] | ||||
# 以下所有函数需要得到一个reply字典,格式如下: | |||||
# reply["type"] = "ERROR" / "TEXT" / "VOICE" / ... | |||||
# reply["content"] = reply的内容 | |||||
def fetch_reply_content(self, query, context): | |||||
def fetch_reply_content(self, query, context : Context) -> Reply: | |||||
return self.get_bot("chat").reply(query, context) | return self.get_bot("chat").reply(query, context) | ||||
def fetch_voice_to_text(self, voiceFile): | |||||
def fetch_voice_to_text(self, voiceFile) -> Reply: | |||||
return self.get_bot("voice_to_text").voiceToText(voiceFile) | return self.get_bot("voice_to_text").voiceToText(voiceFile) | ||||
def fetch_text_to_voice(self, text): | |||||
def fetch_text_to_voice(self, text) -> Reply: | |||||
return self.get_bot("text_to_voice").textToVoice(text) | return self.get_bot("text_to_voice").textToVoice(text) | ||||
@@ -0,0 +1,42 @@ | |||||
# encoding:utf-8 | |||||
from enum import Enum | |||||
class ContextType (Enum): | |||||
TEXT = 1 # 文本消息 | |||||
VOICE = 2 # 音频消息 | |||||
IMAGE_CREATE = 3 # 创建图片命令 | |||||
def __str__(self): | |||||
return self.name | |||||
class Context: | |||||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()): | |||||
self.type = type | |||||
self.content = content | |||||
self.kwargs = kwargs | |||||
def __getitem__(self, key): | |||||
if key == 'type': | |||||
return self.type | |||||
elif key == 'content': | |||||
return self.content | |||||
else: | |||||
return self.kwargs[key] | |||||
def __setitem__(self, key, value): | |||||
if key == 'type': | |||||
self.type = value | |||||
elif key == 'content': | |||||
self.content = value | |||||
else: | |||||
self.kwargs[key] = value | |||||
def __delitem__(self, key): | |||||
if key == 'type': | |||||
self.type = None | |||||
elif key == 'content': | |||||
self.content = None | |||||
else: | |||||
del self.kwargs[key] | |||||
def __str__(self): | |||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) |
@@ -0,0 +1,22 @@ | |||||
# encoding:utf-8 | |||||
from enum import Enum | |||||
class ReplyType(Enum): | |||||
TEXT = 1 # 文本 | |||||
VOICE = 2 # 音频文件 | |||||
IMAGE = 3 # 图片文件 | |||||
IMAGE_URL = 4 # 图片URL | |||||
INFO = 9 | |||||
ERROR = 10 | |||||
def __str__(self): | |||||
return self.name | |||||
class Reply: | |||||
def __init__(self, type : ReplyType = None , content = None): | |||||
self.type = type | |||||
self.content = content | |||||
def __str__(self): | |||||
return "Reply(type={}, content={})".format(self.type, self.content) |
@@ -3,6 +3,8 @@ Message sending channel abstract class | |||||
""" | """ | ||||
from bridge.bridge import Bridge | from bridge.bridge import Bridge | ||||
from bridge.context import Context | |||||
from bridge.reply import Reply | |||||
class Channel(object): | class Channel(object): | ||||
def startup(self): | def startup(self): | ||||
@@ -27,11 +29,11 @@ class Channel(object): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def build_reply_content(self, query, context=None): | |||||
def build_reply_content(self, query, context : Context=None) -> Reply: | |||||
return Bridge().fetch_reply_content(query, context) | return Bridge().fetch_reply_content(query, context) | ||||
def build_voice_to_text(self, voice_file): | |||||
def build_voice_to_text(self, voice_file) -> Reply: | |||||
return Bridge().fetch_voice_to_text(voice_file) | return Bridge().fetch_voice_to_text(voice_file) | ||||
def build_text_to_voice(self, text): | |||||
def build_text_to_voice(self, text) -> Reply: | |||||
return Bridge().fetch_text_to_voice(text) | return Bridge().fetch_text_to_voice(text) |
@@ -7,6 +7,8 @@ wechat channel | |||||
import itchat | import itchat | ||||
import json | import json | ||||
from itchat.content import * | from itchat.content import * | ||||
from bridge.reply import * | |||||
from bridge.context import * | |||||
from channel.channel import Channel | from channel.channel import Channel | ||||
from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
from common.log import logger | from common.log import logger | ||||
@@ -69,10 +71,8 @@ class WechatChannel(Channel): | |||||
from_user_id = msg['FromUserName'] | from_user_id = msg['FromUserName'] | ||||
other_user_id = msg['User']['UserName'] | other_user_id = msg['User']['UserName'] | ||||
if from_user_id == other_user_id: | if from_user_id == other_user_id: | ||||
context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} | |||||
context['type'] = 'VOICE' | |||||
context['content'] = msg['FileName'] | |||||
context['session_id'] = other_user_id | |||||
context = Context(ContextType.VOICE,msg['FileName']) | |||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | |||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | ||||
def handle_text(self, msg): | def handle_text(self, msg): | ||||
@@ -89,17 +89,17 @@ class WechatChannel(Channel): | |||||
content = content.replace(match_prefix, '', 1).strip() | content = content.replace(match_prefix, '', 1).strip() | ||||
else: | else: | ||||
return | return | ||||
context = {'isgroup': False, 'msg': msg, 'receiver': other_user_id} | |||||
context['session_id'] = other_user_id | |||||
context = Context() | |||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | |||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | ||||
if img_match_prefix: | if img_match_prefix: | ||||
content = content.replace(img_match_prefix, '', 1).strip() | content = content.replace(img_match_prefix, '', 1).strip() | ||||
context['type'] = 'IMAGE_CREATE' | |||||
context.type = ContextType.IMAGE_CREATE | |||||
else: | else: | ||||
context['type'] = 'TEXT' | |||||
context.type = ContextType.TEXT | |||||
context['content'] = content | |||||
context.content = content | |||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | ||||
def handle_group(self, msg): | def handle_group(self, msg): | ||||
@@ -123,15 +123,16 @@ class WechatChannel(Channel): | |||||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ | match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ | ||||
or check_contain(origin_content, config.get('group_chat_keyword')) | or check_contain(origin_content, config.get('group_chat_keyword')) | ||||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: | if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: | ||||
context = { 'isgroup': True, 'msg': msg, 'receiver': group_id} | |||||
context = Context() | |||||
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} | |||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | ||||
if img_match_prefix: | if img_match_prefix: | ||||
content = content.replace(img_match_prefix, '', 1).strip() | content = content.replace(img_match_prefix, '', 1).strip() | ||||
context['type'] = 'IMAGE_CREATE' | |||||
context.type = ContextType.IMAGE_CREATE | |||||
else: | else: | ||||
context['type'] = 'TEXT' | |||||
context['content'] = content | |||||
context.type = ContextType.TEXT | |||||
context.content = content | |||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | ||||
if ('ALL_GROUP' in group_chat_in_one_session or | if ('ALL_GROUP' in group_chat_in_one_session or | ||||
@@ -144,18 +145,18 @@ class WechatChannel(Channel): | |||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | ||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | ||||
def send(self, reply, receiver): | |||||
if reply['type'] == 'TEXT': | |||||
itchat.send(reply['content'], toUserName=receiver) | |||||
def send(self, reply : Reply, receiver): | |||||
if reply.type == ReplyType.TEXT: | |||||
itchat.send(reply.content, toUserName=receiver) | |||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | ||||
elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': | |||||
itchat.send(reply['content'], toUserName=receiver) | |||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | |||||
itchat.send(reply.content, toUserName=receiver) | |||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | ||||
elif reply['type'] == 'VOICE': | |||||
itchat.send_file(reply['content'], toUserName=receiver) | |||||
logger.info('[WX] sendFile={}, receiver={}'.format(reply['content'], receiver)) | |||||
elif reply['type']=='IMAGE_URL': # 从网络下载图片 | |||||
img_url = reply['content'] | |||||
elif reply.type == ReplyType.VOICE: | |||||
itchat.send_file(reply.content, toUserName=receiver) | |||||
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) | |||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||||
img_url = reply.content | |||||
pic_res = requests.get(img_url, stream=True) | pic_res = requests.get(img_url, stream=True) | ||||
image_storage = io.BytesIO() | image_storage = io.BytesIO() | ||||
for block in pic_res.iter_content(1024): | for block in pic_res.iter_content(1024): | ||||
@@ -163,69 +164,69 @@ class WechatChannel(Channel): | |||||
image_storage.seek(0) | image_storage.seek(0) | ||||
itchat.send_image(image_storage, toUserName=receiver) | itchat.send_image(image_storage, toUserName=receiver) | ||||
logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver)) | logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver)) | ||||
elif reply['type']=='IMAGE': # 从文件读取图片 | |||||
image_storage = reply['content'] | |||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||||
image_storage = reply.content | |||||
image_storage.seek(0) | image_storage.seek(0) | ||||
itchat.send_image(image_storage, toUserName=receiver) | itchat.send_image(image_storage, toUserName=receiver) | ||||
logger.info('[WX] sendImage, receiver={}'.format(receiver)) | logger.info('[WX] sendImage, receiver={}'.format(receiver)) | ||||
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 | # 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 | ||||
def handle(self, context): | def handle(self, context): | ||||
reply = {} | |||||
reply = Reply() | |||||
logger.debug('[WX] ready to handle context: {}'.format(context)) | logger.debug('[WX] ready to handle context: {}'.format(context)) | ||||
# reply的构建步骤 | # reply的构建步骤 | ||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) | e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) | ||||
reply=e_context['reply'] | |||||
reply = e_context['reply'] | |||||
if not e_context.is_pass(): | if not e_context.is_pass(): | ||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context['type'], context['content'])) | |||||
if context['type'] == 'TEXT' or context['type'] == 'IMAGE_CREATE': | |||||
reply = super().build_reply_content(context['content'], context) | |||||
elif context['type'] == 'VOICE': | |||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) | |||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: | |||||
reply = super().build_reply_content(context.content, context) | |||||
elif context.type == ContextType.VOICE: | |||||
msg = context['msg'] | msg = context['msg'] | ||||
file_name = TmpDir().path() + context['content'] | |||||
file_name = TmpDir().path() + context.content | |||||
msg.download(file_name) | msg.download(file_name) | ||||
reply = super().build_voice_to_text(file_name) | reply = super().build_voice_to_text(file_name) | ||||
if reply['type'] != 'ERROR' and reply['type'] != 'INFO': | |||||
context['content'] = reply['content'] # 语音转文字后,将文字内容作为新的context | |||||
context['type'] = reply['type'] | |||||
reply = super().build_reply_content(context['content'], context) | |||||
if reply['type'] == 'TEXT': | |||||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: | |||||
context.content = reply.content # 语音转文字后,将文字内容作为新的context | |||||
context.type = ContextType.TEXT | |||||
reply = super().build_reply_content(context.content, context) | |||||
if reply.type == ReplyType.TEXT: | |||||
if conf().get('voice_reply_voice'): | if conf().get('voice_reply_voice'): | ||||
reply = super().build_text_to_voice(reply['content']) | |||||
reply = super().build_text_to_voice(reply.content) | |||||
else: | else: | ||||
logger.error('[WX] unknown context type: {}'.format(context['type'])) | |||||
logger.error('[WX] unknown context type: {}'.format(context.type)) | |||||
return | return | ||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply)) | logger.debug('[WX] ready to decorate reply: {}'.format(reply)) | ||||
# reply的包装步骤 | # reply的包装步骤 | ||||
if reply and reply['type']: | |||||
if reply and reply.type: | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | ||||
reply=e_context['reply'] | reply=e_context['reply'] | ||||
if not e_context.is_pass() and reply and reply['type']: | |||||
if reply['type'] == 'TEXT': | |||||
reply_text = reply['content'] | |||||
if not e_context.is_pass() and reply and reply.type: | |||||
if reply.type == ReplyType.TEXT: | |||||
reply_text = reply.content | |||||
if context['isgroup']: | if context['isgroup']: | ||||
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() | reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() | ||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text | reply_text = conf().get("group_chat_reply_prefix", "")+reply_text | ||||
else: | else: | ||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text | reply_text = conf().get("single_chat_reply_prefix", "")+reply_text | ||||
reply['content'] = reply_text | |||||
elif reply['type'] == 'ERROR' or reply['type'] == 'INFO': | |||||
reply['content'] = reply['type']+":\n" + reply['content'] | |||||
elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE': | |||||
reply.content = reply_text | |||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | |||||
reply.content = str(reply.type)+":\n" + reply.content | |||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: | |||||
pass | pass | ||||
else: | else: | ||||
logger.error('[WX] unknown reply type: {}'.format(reply['type'])) | |||||
logger.error('[WX] unknown reply type: {}'.format(reply.type)) | |||||
return | return | ||||
# reply的发送步骤 | # reply的发送步骤 | ||||
if reply and reply['type']: | |||||
if reply and reply.type: | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | ||||
reply=e_context['reply'] | reply=e_context['reply'] | ||||
if not e_context.is_pass() and reply and reply['type']: | |||||
if not e_context.is_pass() and reply and reply.type: | |||||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) | logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) | ||||
self.send(reply, context['receiver']) | self.send(reply, context['receiver']) | ||||
@@ -11,6 +11,7 @@ import time | |||||
import asyncio | import asyncio | ||||
import requests | import requests | ||||
from typing import Optional, Union | from typing import Optional, Union | ||||
from bridge.context import Context, ContextType | |||||
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore | from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore | ||||
from wechaty import Wechaty, Contact | from wechaty import Wechaty, Contact | ||||
from wechaty.user import Message, Room, MiniProgram, UrlLink | from wechaty.user import Message, Room, MiniProgram, UrlLink | ||||
@@ -127,11 +128,9 @@ class WechatyChannel(Channel): | |||||
try: | try: | ||||
if not query: | if not query: | ||||
return | return | ||||
context = dict() | |||||
context = Context(ContextType.TEXT, query) | |||||
context['session_id'] = reply_user_id | context['session_id'] = reply_user_id | ||||
context['type'] = 'TEXT' | |||||
context['content'] = query | |||||
reply_text = super().build_reply_content(query, context)['content'] | |||||
reply_text = super().build_reply_content(query, context).content | |||||
if reply_text: | if reply_text: | ||||
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) | await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) | ||||
except Exception as e: | except Exception as e: | ||||
@@ -141,10 +140,8 @@ class WechatyChannel(Channel): | |||||
try: | try: | ||||
if not query: | if not query: | ||||
return | return | ||||
context = dict() | |||||
context['type'] = 'IMAGE_CREATE' | |||||
context['content'] = query | |||||
img_url = super().build_reply_content(query, context)['content'] | |||||
context = Context(ContextType.IMAGE_CREATE, query) | |||||
img_url = super().build_reply_content(query, context).content | |||||
if not img_url: | if not img_url: | ||||
return | return | ||||
# 图片下载 | # 图片下载 | ||||
@@ -165,7 +162,7 @@ class WechatyChannel(Channel): | |||||
async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name): | async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name): | ||||
if not query: | if not query: | ||||
return | return | ||||
context = dict() | |||||
context = Context(ContextType.TEXT, query) | |||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | ||||
if ('ALL_GROUP' in group_chat_in_one_session or \ | if ('ALL_GROUP' in group_chat_in_one_session or \ | ||||
group_name in group_chat_in_one_session or \ | group_name in group_chat_in_one_session or \ | ||||
@@ -173,9 +170,7 @@ class WechatyChannel(Channel): | |||||
context['session_id'] = str(group_id) | context['session_id'] = str(group_id) | ||||
else: | else: | ||||
context['session_id'] = str(group_id) + '-' + str(group_user_id) | context['session_id'] = str(group_id) + '-' + str(group_user_id) | ||||
context['type'] = 'TEXT' | |||||
context['content'] = query | |||||
reply_text = super().build_reply_content(query, context)['content'] | |||||
reply_text = super().build_reply_content(query, context).content | |||||
if reply_text: | if reply_text: | ||||
reply_text = '@' + group_user_name + ' ' + reply_text.strip() | reply_text = '@' + group_user_name + ' ' + reply_text.strip() | ||||
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) | await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) | ||||
@@ -184,10 +179,8 @@ class WechatyChannel(Channel): | |||||
try: | try: | ||||
if not query: | if not query: | ||||
return | return | ||||
context = dict() | |||||
context['type'] = 'IMAGE_CREATE' | |||||
context['content'] = query | |||||
img_url = super().build_reply_content(query, context)['content'] | |||||
context = Context(ContextType.IMAGE_CREATE, query) | |||||
img_url = super().build_reply_content(query, context).content | |||||
if not img_url: | if not img_url: | ||||
return | return | ||||
# 图片发送 | # 图片发送 | ||||
@@ -5,6 +5,8 @@ import os | |||||
import traceback | import traceback | ||||
from typing import Tuple | from typing import Tuple | ||||
from bridge.bridge import Bridge | from bridge.bridge import Bridge | ||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
from config import load_config | from config import load_config | ||||
import plugins | import plugins | ||||
from plugins import * | from plugins import * | ||||
@@ -123,13 +125,13 @@ class Godcmd(Plugin): | |||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
context_type = e_context['context']['type'] | |||||
if context_type != "TEXT": | |||||
context_type = e_context['context'].type | |||||
if context_type != ContextType.TEXT: | |||||
if not self.isrunning: | if not self.isrunning: | ||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
content = e_context['context']['content'] | |||||
content = e_context['context'].content | |||||
logger.debug("[Godcmd] on_handle_context. content: %s" % content) | logger.debug("[Godcmd] on_handle_context. content: %s" % content) | ||||
if content.startswith("#"): | if content.startswith("#"): | ||||
# msg = e_context['context']['msg'] | # msg = e_context['context']['msg'] | ||||
@@ -239,12 +241,12 @@ class Godcmd(Plugin): | |||||
else: | else: | ||||
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" | ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" | ||||
reply = {} | |||||
reply = Reply() | |||||
if ok: | if ok: | ||||
reply["type"] = "INFO" | |||||
reply.type = ReplyType.INFO | |||||
else: | else: | ||||
reply["type"] = "ERROR" | |||||
reply["content"] = result | |||||
reply.type = ReplyType.ERROR | |||||
reply.content = result | |||||
e_context['reply'] = reply | e_context['reply'] = reply | ||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | ||||
@@ -1,5 +1,7 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
import plugins | import plugins | ||||
from plugins import * | from plugins import * | ||||
from common.log import logger | from common.log import logger | ||||
@@ -14,31 +16,31 @@ class Hello(Plugin): | |||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
if e_context['context']['type'] != "TEXT": | |||||
if e_context['context'].type != ContextType.TEXT: | |||||
return | return | ||||
content = e_context['context']['content'] | |||||
content = e_context['context'].content | |||||
logger.debug("[Hello] on_handle_context. content: %s" % content) | logger.debug("[Hello] on_handle_context. content: %s" % content) | ||||
if content == "Hello": | if content == "Hello": | ||||
reply = {} | |||||
reply['type'] = "TEXT" | |||||
reply = Reply() | |||||
reply.type = ReplyType.TEXT | |||||
msg = e_context['context']['msg'] | msg = e_context['context']['msg'] | ||||
if e_context['context']['isgroup']: | if e_context['context']['isgroup']: | ||||
reply['content'] = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") | |||||
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") | |||||
else: | else: | ||||
reply['content'] = "Hello, " + msg['User'].get('NickName', "My friend") | |||||
reply.content = "Hello, " + msg['User'].get('NickName', "My friend") | |||||
e_context['reply'] = reply | e_context['reply'] = reply | ||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | ||||
if content == "Hi": | if content == "Hi": | ||||
reply={} | |||||
reply['type'] = "TEXT" | |||||
reply['content'] = "Hi" | |||||
reply = Reply() | |||||
reply.type = ReplyType.TEXT | |||||
reply.content = "Hi" | |||||
e_context['reply'] = reply | e_context['reply'] = reply | ||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply | e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply | ||||
if content == "End": | if content == "End": | ||||
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" | # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" | ||||
e_context['context']['type'] = "IMAGE_CREATE" | |||||
e_context['context'].type = "IMAGE_CREATE" | |||||
content = "The World" | content = "The World" | ||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 |
@@ -4,6 +4,7 @@ baidu voice service | |||||
""" | """ | ||||
import time | import time | ||||
from aip import AipSpeech | from aip import AipSpeech | ||||
from bridge.reply import Reply, ReplyType | |||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
from voice.voice import Voice | from voice.voice import Voice | ||||
@@ -30,8 +31,8 @@ class BaiduVoice(Voice): | |||||
with open(fileName, 'wb') as f: | with open(fileName, 'wb') as f: | ||||
f.write(result) | f.write(result) | ||||
logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | ||||
reply = {"type": "VOICE", "content": fileName} | |||||
reply = Reply(ReplyType.VOICE, fileName) | |||||
else: | else: | ||||
logger.error('[Baidu] textToVoice error={}'.format(result)) | logger.error('[Baidu] textToVoice error={}'.format(result)) | ||||
reply = {"type": "ERROR", "content": "抱歉,语音合成失败"} | |||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | |||||
return reply | return reply |
@@ -6,6 +6,7 @@ google voice service | |||||
import pathlib | import pathlib | ||||
import subprocess | import subprocess | ||||
import time | import time | ||||
from bridge.reply import Reply, ReplyType | |||||
import speech_recognition | import speech_recognition | ||||
import pyttsx3 | import pyttsx3 | ||||
from common.log import logger | from common.log import logger | ||||
@@ -32,16 +33,15 @@ class GoogleVoice(Voice): | |||||
' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True) | ' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True) | ||||
with speech_recognition.AudioFile(new_file) as source: | with speech_recognition.AudioFile(new_file) as source: | ||||
audio = self.recognizer.record(source) | audio = self.recognizer.record(source) | ||||
reply = {} | |||||
try: | try: | ||||
text = self.recognizer.recognize_google(audio, language='zh-CN') | text = self.recognizer.recognize_google(audio, language='zh-CN') | ||||
logger.info( | logger.info( | ||||
'[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) | '[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) | ||||
reply = {"type": "TEXT", "content": text} | |||||
reply = Reply(ReplyType.TEXT, text) | |||||
except speech_recognition.UnknownValueError: | except speech_recognition.UnknownValueError: | ||||
reply = {"type": "ERROR", "content": "抱歉,我听不懂"} | |||||
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") | |||||
except speech_recognition.RequestError as e: | except speech_recognition.RequestError as e: | ||||
reply = {"type": "ERROR", "content": "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)} | |||||
reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) | |||||
finally: | finally: | ||||
return reply | return reply | ||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
@@ -51,8 +51,8 @@ class GoogleVoice(Voice): | |||||
self.engine.runAndWait() | self.engine.runAndWait() | ||||
logger.info( | logger.info( | ||||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile)) | '[Google] textToVoice text={} voice file name={}'.format(text, textFile)) | ||||
reply = {"type": "VOICE", "content": textFile} | |||||
reply = Reply(ReplyType.VOICE, textFile) | |||||
except Exception as e: | except Exception as e: | ||||
reply = {"type": "ERROR", "content": str(e)} | |||||
reply = Reply(ReplyType.ERROR, str(e)) | |||||
finally: | finally: | ||||
return reply | return reply |
@@ -4,6 +4,7 @@ google voice service | |||||
""" | """ | ||||
import json | import json | ||||
import openai | import openai | ||||
from bridge.reply import Reply, ReplyType | |||||
from config import conf | from config import conf | ||||
from common.log import logger | from common.log import logger | ||||
from voice.voice import Voice | from voice.voice import Voice | ||||
@@ -16,16 +17,15 @@ class OpenaiVoice(Voice): | |||||
def voiceToText(self, voice_file): | def voiceToText(self, voice_file): | ||||
logger.debug( | logger.debug( | ||||
'[Openai] voice file name={}'.format(voice_file)) | '[Openai] voice file name={}'.format(voice_file)) | ||||
reply={} | |||||
try: | try: | ||||
file = open(voice_file, "rb") | file = open(voice_file, "rb") | ||||
result = openai.Audio.transcribe("whisper-1", file) | result = openai.Audio.transcribe("whisper-1", file) | ||||
text = result["text"] | text = result["text"] | ||||
reply = {"type": "TEXT", "content": text} | |||||
reply = Reply(ReplyType.TEXT, text) | |||||
logger.info( | logger.info( | ||||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) | '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) | ||||
except Exception as e: | except Exception as e: | ||||
reply = {"type": "ERROR", "content": str(e)} | |||||
reply = Reply(ReplyType.ERROR, str(e)) | |||||
finally: | finally: | ||||
return reply | return reply | ||||