简易支持插件,添加sdwebui(novelai画图), godcmd(管理员指令增强)插件,Banwords(敏感词过滤)插件develop
@@ -7,3 +7,4 @@ config.json | |||||
QR.png | QR.png | ||||
nohup.out | nohup.out | ||||
tmp | tmp | ||||
plugins.json |
@@ -4,14 +4,17 @@ import config | |||||
from channel import channel_factory | from channel import channel_factory | ||||
from common.log import logger | from common.log import logger | ||||
from plugins import * | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
try: | try: | ||||
# load config | # load config | ||||
config.load_config() | config.load_config() | ||||
# create channel | # create channel | ||||
channel = channel_factory.create_channel("wx") | |||||
channel_name='wx' | |||||
channel = channel_factory.create_channel(channel_name) | |||||
if channel_name=='wx': | |||||
PluginManager().load_plugins() | |||||
# startup channel | # startup channel | ||||
channel.startup() | channel.startup() | ||||
@@ -2,6 +2,7 @@ | |||||
import requests | import requests | ||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bridge.reply import Reply, ReplyType | |||||
# Baidu Unit对话接口 (可用, 但能力较弱) | # Baidu Unit对话接口 (可用, 但能力较弱) | ||||
@@ -14,7 +15,8 @@ class BaiduUnitBot(Bot): | |||||
headers = {'content-type': 'application/x-www-form-urlencoded'} | headers = {'content-type': 'application/x-www-form-urlencoded'} | ||||
response = requests.post(url, data=post_data.encode(), headers=headers) | response = requests.post(url, data=post_data.encode(), headers=headers) | ||||
if response: | if response: | ||||
return response.json()['result']['context']['SYS_PRESUMED_HIST'][1] | |||||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1]) | |||||
return reply | |||||
def get_token(self): | def get_token(self): | ||||
access_key = 'YOUR_ACCESS_KEY' | access_key = 'YOUR_ACCESS_KEY' | ||||
@@ -3,8 +3,12 @@ Auto-replay chat robot abstract class | |||||
""" | """ | ||||
from bridge.context import Context | |||||
from bridge.reply import Reply | |||||
class Bot(object): | class Bot(object): | ||||
def reply(self, query, context=None): | |||||
def reply(self, query, context : Context =None) -> Reply: | |||||
""" | """ | ||||
bot auto-reply content | bot auto-reply content | ||||
:param req: received message | :param req: received message | ||||
@@ -1,41 +1,42 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
from config import conf, load_config | from config import conf, load_config | ||||
from common.log import logger | from common.log import logger | ||||
from common.expired_dict import ExpiredDict | from common.expired_dict import ExpiredDict | ||||
import openai | import openai | ||||
import time | import time | ||||
if conf().get('expires_in_seconds'): | |||||
all_sessions = ExpiredDict(conf().get('expires_in_seconds')) | |||||
else: | |||||
all_sessions = dict() | |||||
# OpenAI对话模型API (可用) | # OpenAI对话模型API (可用) | ||||
class ChatGPTBot(Bot): | class ChatGPTBot(Bot): | ||||
def __init__(self): | def __init__(self): | ||||
openai.api_key = conf().get('open_ai_api_key') | openai.api_key = conf().get('open_ai_api_key') | ||||
proxy = conf().get('proxy') | proxy = conf().get('proxy') | ||||
self.sessions = SessionManager() | |||||
if proxy: | if proxy: | ||||
openai.proxy = proxy | openai.proxy = proxy | ||||
def reply(self, query, context=None): | def reply(self, query, context=None): | ||||
# acquire reply content | # acquire reply content | ||||
if not context or not context.get('type') or context.get('type') == 'TEXT': | |||||
if context.type == ContextType.TEXT: | |||||
logger.info("[OPEN_AI] query={}".format(query)) | logger.info("[OPEN_AI] query={}".format(query)) | ||||
session_id = context.get('session_id') or context.get('from_user_id') | |||||
session_id = context['session_id'] | |||||
reply = None | |||||
if query == '#清除记忆': | if query == '#清除记忆': | ||||
Session.clear_session(session_id) | |||||
return '记忆已清除' | |||||
self.sessions.clear_session(session_id) | |||||
reply = Reply(ReplyType.INFO, '记忆已清除') | |||||
elif query == '#清除所有': | elif query == '#清除所有': | ||||
Session.clear_all_session() | |||||
return '所有人记忆已清除' | |||||
self.sessions.clear_all_session() | |||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | |||||
elif query == '#更新配置': | elif query == '#更新配置': | ||||
load_config() | load_config() | ||||
return '配置已更新' | |||||
session = Session.build_session_query(query, session_id) | |||||
reply = Reply(ReplyType.INFO, '配置已更新') | |||||
if reply: | |||||
return reply | |||||
session = self.sessions.build_session_query(query, session_id) | |||||
logger.debug("[OPEN_AI] session query={}".format(session)) | logger.debug("[OPEN_AI] session query={}".format(session)) | ||||
# if context.get('stream'): | # if context.get('stream'): | ||||
@@ -44,14 +45,29 @@ class ChatGPTBot(Bot): | |||||
reply_content = self.reply_text(session, session_id, 0) | reply_content = self.reply_text(session, session_id, 0) | ||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) | logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) | ||||
if reply_content["completion_tokens"] > 0: | |||||
Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) | |||||
return reply_content["content"] | |||||
elif context.get('type', None) == 'IMAGE_CREATE': | |||||
return self.create_img(query, 0) | |||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: | |||||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||||
elif reply_content["completion_tokens"] > 0: | |||||
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) | |||||
reply = Reply(ReplyType.TEXT, reply_content["content"]) | |||||
else: | |||||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||||
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) | |||||
return reply | |||||
elif context.type == ContextType.IMAGE_CREATE: | |||||
ok, retstring = self.create_img(query, 0) | |||||
reply = None | |||||
if ok: | |||||
reply = Reply(ReplyType.IMAGE_URL, retstring) | |||||
else: | |||||
reply = Reply(ReplyType.ERROR, retstring) | |||||
return reply | |||||
else: | |||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) | |||||
return reply | |||||
def reply_text(self, session, session_id, retry_count=0) ->dict: | |||||
def reply_text(self, session, session_id, retry_count=0) -> dict: | |||||
''' | ''' | ||||
call openai's ChatCompletion to get the answer | call openai's ChatCompletion to get the answer | ||||
:param session: a conversation session | :param session: a conversation session | ||||
@@ -70,8 +86,8 @@ class ChatGPTBot(Bot): | |||||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 | presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 | ||||
) | ) | ||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | ||||
return {"total_tokens": response["usage"]["total_tokens"], | |||||
"completion_tokens": response["usage"]["completion_tokens"], | |||||
return {"total_tokens": response["usage"]["total_tokens"], | |||||
"completion_tokens": response["usage"]["completion_tokens"], | |||||
"content": response.choices[0]['message']['content']} | "content": response.choices[0]['message']['content']} | ||||
except openai.error.RateLimitError as e: | except openai.error.RateLimitError as e: | ||||
# rate limit exception | # rate limit exception | ||||
@@ -86,15 +102,15 @@ class ChatGPTBot(Bot): | |||||
# api connection exception | # api connection exception | ||||
logger.warn(e) | logger.warn(e) | ||||
logger.warn("[OPEN_AI] APIConnection failed") | logger.warn("[OPEN_AI] APIConnection failed") | ||||
return {"completion_tokens": 0, "content":"我连接不到你的网络"} | |||||
return {"completion_tokens": 0, "content": "我连接不到你的网络"} | |||||
except openai.error.Timeout as e: | except openai.error.Timeout as e: | ||||
logger.warn(e) | logger.warn(e) | ||||
logger.warn("[OPEN_AI] Timeout") | logger.warn("[OPEN_AI] Timeout") | ||||
return {"completion_tokens": 0, "content":"我没有收到你的消息"} | |||||
return {"completion_tokens": 0, "content": "我没有收到你的消息"} | |||||
except Exception as e: | except Exception as e: | ||||
# unknown exception | # unknown exception | ||||
logger.exception(e) | logger.exception(e) | ||||
Session.clear_session(session_id) | |||||
self.sessions.clear_session(session_id) | |||||
return {"completion_tokens": 0, "content": "请再问我一次吧"} | return {"completion_tokens": 0, "content": "请再问我一次吧"} | ||||
def create_img(self, query, retry_count=0): | def create_img(self, query, retry_count=0): | ||||
@@ -107,7 +123,7 @@ class ChatGPTBot(Bot): | |||||
) | ) | ||||
image_url = response['data'][0]['url'] | image_url = response['data'][0]['url'] | ||||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | logger.info("[OPEN_AI] image_url={}".format(image_url)) | ||||
return image_url | |||||
return True, image_url | |||||
except openai.error.RateLimitError as e: | except openai.error.RateLimitError as e: | ||||
logger.warn(e) | logger.warn(e) | ||||
if retry_count < 1: | if retry_count < 1: | ||||
@@ -115,14 +131,21 @@ class ChatGPTBot(Bot): | |||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | ||||
return self.create_img(query, retry_count+1) | return self.create_img(query, retry_count+1) | ||||
else: | else: | ||||
return "提问太快啦,请休息一下再问我吧" | |||||
return False, "提问太快啦,请休息一下再问我吧" | |||||
except Exception as e: | except Exception as e: | ||||
logger.exception(e) | logger.exception(e) | ||||
return None | |||||
return False, str(e) | |||||
class SessionManager(object): | |||||
def __init__(self): | |||||
if conf().get('expires_in_seconds'): | |||||
sessions = ExpiredDict(conf().get('expires_in_seconds')) | |||||
else: | |||||
sessions = dict() | |||||
self.sessions = sessions | |||||
class Session(object): | |||||
@staticmethod | |||||
def build_session_query(query, session_id): | |||||
def build_session_query(self, query, session_id): | |||||
''' | ''' | ||||
build query with conversation history | build query with conversation history | ||||
e.g. [ | e.g. [ | ||||
@@ -135,36 +158,33 @@ class Session(object): | |||||
:param session_id: session id | :param session_id: session id | ||||
:return: query content with conversaction | :return: query content with conversaction | ||||
''' | ''' | ||||
session = all_sessions.get(session_id, []) | |||||
session = self.sessions.get(session_id, []) | |||||
if len(session) == 0: | if len(session) == 0: | ||||
system_prompt = conf().get("character_desc", "") | system_prompt = conf().get("character_desc", "") | ||||
system_item = {'role': 'system', 'content': system_prompt} | system_item = {'role': 'system', 'content': system_prompt} | ||||
session.append(system_item) | session.append(system_item) | ||||
all_sessions[session_id] = session | |||||
self.sessions[session_id] = session | |||||
user_item = {'role': 'user', 'content': query} | user_item = {'role': 'user', 'content': query} | ||||
session.append(user_item) | session.append(user_item) | ||||
return session | return session | ||||
@staticmethod | |||||
def save_session(answer, session_id, total_tokens): | |||||
def save_session(self, answer, session_id, total_tokens): | |||||
max_tokens = conf().get("conversation_max_tokens") | max_tokens = conf().get("conversation_max_tokens") | ||||
if not max_tokens: | if not max_tokens: | ||||
# default 3000 | # default 3000 | ||||
max_tokens = 1000 | max_tokens = 1000 | ||||
max_tokens=int(max_tokens) | |||||
max_tokens = int(max_tokens) | |||||
session = all_sessions.get(session_id) | |||||
session = self.sessions.get(session_id) | |||||
if session: | if session: | ||||
# append conversation | # append conversation | ||||
gpt_item = {'role': 'assistant', 'content': answer} | gpt_item = {'role': 'assistant', 'content': answer} | ||||
session.append(gpt_item) | session.append(gpt_item) | ||||
# discard exceed limit conversation | # discard exceed limit conversation | ||||
Session.discard_exceed_conversation(session, max_tokens, total_tokens) | |||||
self.discard_exceed_conversation(session, max_tokens, total_tokens) | |||||
@staticmethod | |||||
def discard_exceed_conversation(session, max_tokens, total_tokens): | |||||
def discard_exceed_conversation(self, session, max_tokens, total_tokens): | |||||
dec_tokens = int(total_tokens) | dec_tokens = int(total_tokens) | ||||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) | # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) | ||||
while dec_tokens > max_tokens: | while dec_tokens > max_tokens: | ||||
@@ -173,13 +193,11 @@ class Session(object): | |||||
session.pop(1) | session.pop(1) | ||||
session.pop(1) | session.pop(1) | ||||
else: | else: | ||||
break | |||||
break | |||||
dec_tokens = dec_tokens - max_tokens | dec_tokens = dec_tokens - max_tokens | ||||
@staticmethod | |||||
def clear_session(session_id): | |||||
all_sessions[session_id] = [] | |||||
def clear_session(self, session_id): | |||||
self.sessions[session_id] = [] | |||||
@staticmethod | |||||
def clear_all_session(): | |||||
all_sessions.clear() | |||||
def clear_all_session(self): | |||||
self.sessions.clear() |
@@ -1,6 +1,8 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
from config import conf | from config import conf | ||||
from common.log import logger | from common.log import logger | ||||
import openai | import openai | ||||
@@ -13,30 +15,31 @@ class OpenAIBot(Bot): | |||||
def __init__(self): | def __init__(self): | ||||
openai.api_key = conf().get('open_ai_api_key') | openai.api_key = conf().get('open_ai_api_key') | ||||
def reply(self, query, context=None): | def reply(self, query, context=None): | ||||
# acquire reply content | # acquire reply content | ||||
if not context or not context.get('type') or context.get('type') == 'TEXT': | |||||
logger.info("[OPEN_AI] query={}".format(query)) | |||||
from_user_id = context.get('from_user_id') or context.get('session_id') | |||||
if query == '#清除记忆': | |||||
Session.clear_session(from_user_id) | |||||
return '记忆已清除' | |||||
elif query == '#清除所有': | |||||
Session.clear_all_session() | |||||
return '所有人记忆已清除' | |||||
new_query = Session.build_session_query(query, from_user_id) | |||||
logger.debug("[OPEN_AI] session query={}".format(new_query)) | |||||
reply_content = self.reply_text(new_query, from_user_id, 0) | |||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) | |||||
if reply_content and query: | |||||
Session.save_session(query, reply_content, from_user_id) | |||||
return reply_content | |||||
elif context.get('type', None) == 'IMAGE_CREATE': | |||||
return self.create_img(query, 0) | |||||
if context and context.type: | |||||
if context.type == ContextType.TEXT: | |||||
logger.info("[OPEN_AI] query={}".format(query)) | |||||
from_user_id = context['session_id'] | |||||
reply = None | |||||
if query == '#清除记忆': | |||||
Session.clear_session(from_user_id) | |||||
reply = Reply(ReplyType.INFO, '记忆已清除') | |||||
elif query == '#清除所有': | |||||
Session.clear_all_session() | |||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | |||||
else: | |||||
new_query = Session.build_session_query(query, from_user_id) | |||||
logger.debug("[OPEN_AI] session query={}".format(new_query)) | |||||
reply_content = self.reply_text(new_query, from_user_id, 0) | |||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) | |||||
if reply_content and query: | |||||
Session.save_session(query, reply_content, from_user_id) | |||||
reply = Reply(ReplyType.TEXT, reply_content) | |||||
return reply | |||||
elif context.type == ContextType.IMAGE_CREATE: | |||||
return self.create_img(query, 0) | |||||
def reply_text(self, query, user_id, retry_count=0): | def reply_text(self, query, user_id, retry_count=0): | ||||
try: | try: | ||||
@@ -1,16 +1,42 @@ | |||||
from bridge.context import Context | |||||
from bridge.reply import Reply | |||||
from common.log import logger | |||||
from bot import bot_factory | from bot import bot_factory | ||||
from common.singleton import singleton | |||||
from voice import voice_factory | from voice import voice_factory | ||||
@singleton | |||||
class Bridge(object): | class Bridge(object): | ||||
def __init__(self): | def __init__(self): | ||||
pass | |||||
self.btype={ | |||||
"chat": "chatGPT", | |||||
"voice_to_text": "openai", | |||||
"text_to_voice": "baidu" | |||||
} | |||||
self.bots={} | |||||
def fetch_reply_content(self, query, context): | |||||
return bot_factory.create_bot("chatGPT").reply(query, context) | |||||
def get_bot(self,typename): | |||||
if self.bots.get(typename) is None: | |||||
logger.info("create bot {} for {}".format(self.btype[typename],typename)) | |||||
if typename == "text_to_voice": | |||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename]) | |||||
elif typename == "voice_to_text": | |||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename]) | |||||
elif typename == "chat": | |||||
self.bots[typename] = bot_factory.create_bot(self.btype[typename]) | |||||
return self.bots[typename] | |||||
def get_bot_type(self,typename): | |||||
return self.btype[typename] | |||||
def fetch_voice_to_text(self, voiceFile): | |||||
return voice_factory.create_voice("openai").voiceToText(voiceFile) | |||||
def fetch_text_to_voice(self, text): | |||||
return voice_factory.create_voice("baidu").textToVoice(text) | |||||
def fetch_reply_content(self, query, context : Context) -> Reply: | |||||
return self.get_bot("chat").reply(query, context) | |||||
def fetch_voice_to_text(self, voiceFile) -> Reply: | |||||
return self.get_bot("voice_to_text").voiceToText(voiceFile) | |||||
def fetch_text_to_voice(self, text) -> Reply: | |||||
return self.get_bot("text_to_voice").textToVoice(text) | |||||
@@ -0,0 +1,42 @@ | |||||
# encoding:utf-8 | |||||
from enum import Enum | |||||
class ContextType (Enum): | |||||
TEXT = 1 # 文本消息 | |||||
VOICE = 2 # 音频消息 | |||||
IMAGE_CREATE = 3 # 创建图片命令 | |||||
def __str__(self): | |||||
return self.name | |||||
class Context: | |||||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()): | |||||
self.type = type | |||||
self.content = content | |||||
self.kwargs = kwargs | |||||
def __getitem__(self, key): | |||||
if key == 'type': | |||||
return self.type | |||||
elif key == 'content': | |||||
return self.content | |||||
else: | |||||
return self.kwargs[key] | |||||
def __setitem__(self, key, value): | |||||
if key == 'type': | |||||
self.type = value | |||||
elif key == 'content': | |||||
self.content = value | |||||
else: | |||||
self.kwargs[key] = value | |||||
def __delitem__(self, key): | |||||
if key == 'type': | |||||
self.type = None | |||||
elif key == 'content': | |||||
self.content = None | |||||
else: | |||||
del self.kwargs[key] | |||||
def __str__(self): | |||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) |
@@ -0,0 +1,22 @@ | |||||
# encoding:utf-8 | |||||
from enum import Enum | |||||
class ReplyType(Enum): | |||||
TEXT = 1 # 文本 | |||||
VOICE = 2 # 音频文件 | |||||
IMAGE = 3 # 图片文件 | |||||
IMAGE_URL = 4 # 图片URL | |||||
INFO = 9 | |||||
ERROR = 10 | |||||
def __str__(self): | |||||
return self.name | |||||
class Reply: | |||||
def __init__(self, type : ReplyType = None , content = None): | |||||
self.type = type | |||||
self.content = content | |||||
def __str__(self): | |||||
return "Reply(type={}, content={})".format(self.type, self.content) |
@@ -3,6 +3,8 @@ Message sending channel abstract class | |||||
""" | """ | ||||
from bridge.bridge import Bridge | from bridge.bridge import Bridge | ||||
from bridge.context import Context | |||||
from bridge.reply import Reply | |||||
class Channel(object): | class Channel(object): | ||||
def startup(self): | def startup(self): | ||||
@@ -27,11 +29,11 @@ class Channel(object): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def build_reply_content(self, query, context=None): | |||||
def build_reply_content(self, query, context : Context=None) -> Reply: | |||||
return Bridge().fetch_reply_content(query, context) | return Bridge().fetch_reply_content(query, context) | ||||
def build_voice_to_text(self, voice_file): | |||||
def build_voice_to_text(self, voice_file) -> Reply: | |||||
return Bridge().fetch_voice_to_text(voice_file) | return Bridge().fetch_voice_to_text(voice_file) | ||||
def build_text_to_voice(self, text): | |||||
def build_text_to_voice(self, text) -> Reply: | |||||
return Bridge().fetch_text_to_voice(text) | return Bridge().fetch_text_to_voice(text) |
@@ -7,16 +7,24 @@ wechat channel | |||||
import itchat | import itchat | ||||
import json | import json | ||||
from itchat.content import * | from itchat.content import * | ||||
from bridge.reply import * | |||||
from bridge.context import * | |||||
from channel.channel import Channel | from channel.channel import Channel | ||||
from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
from config import conf | from config import conf | ||||
from plugins import * | |||||
import requests | import requests | ||||
import io | import io | ||||
thread_pool = ThreadPoolExecutor(max_workers=8) | |||||
thread_pool = ThreadPoolExecutor(max_workers=8) | |||||
def thread_pool_callback(worker): | |||||
worker_exception = worker.exception() | |||||
if worker_exception: | |||||
logger.exception("Worker return exception: {}".format(worker_exception)) | |||||
@itchat.msg_register(TEXT) | @itchat.msg_register(TEXT) | ||||
def handler_single_msg(msg): | def handler_single_msg(msg): | ||||
@@ -47,62 +55,52 @@ class WechatChannel(Channel): | |||||
# start message listener | # start message listener | ||||
itchat.run() | itchat.run() | ||||
# handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context | |||||
# context是一个字典,包含了消息的所有信息,包括以下key | |||||
# type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE | |||||
# content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 | |||||
# session_id: 会话id | |||||
# isgroup: 是否是群聊 | |||||
# msg: 原始消息对象 | |||||
# receiver: 需要回复的对象 | |||||
def handle_voice(self, msg): | def handle_voice(self, msg): | ||||
if conf().get('speech_recognition') != True : | |||||
if conf().get('speech_recognition') != True: | |||||
return | return | ||||
logger.debug("[WX]receive voice msg: " + msg['FileName']) | logger.debug("[WX]receive voice msg: " + msg['FileName']) | ||||
thread_pool.submit(self._do_handle_voice, msg) | |||||
def _do_handle_voice(self, msg): | |||||
from_user_id = msg['FromUserName'] | from_user_id = msg['FromUserName'] | ||||
other_user_id = msg['User']['UserName'] | other_user_id = msg['User']['UserName'] | ||||
if from_user_id == other_user_id: | if from_user_id == other_user_id: | ||||
file_name = TmpDir().path() + msg['FileName'] | |||||
msg.download(file_name) | |||||
query = super().build_voice_to_text(file_name) | |||||
if conf().get('voice_reply_voice'): | |||||
self._do_send_voice(query, from_user_id) | |||||
else: | |||||
self._do_send_text(query, from_user_id) | |||||
context = Context(ContextType.VOICE,msg['FileName']) | |||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | |||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | |||||
def handle_text(self, msg): | def handle_text(self, msg): | ||||
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) | logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) | ||||
content = msg['Text'] | content = msg['Text'] | ||||
self._handle_single_msg(msg, content) | |||||
def _handle_single_msg(self, msg, content): | |||||
from_user_id = msg['FromUserName'] | from_user_id = msg['FromUserName'] | ||||
to_user_id = msg['ToUserName'] # 接收人id | to_user_id = msg['ToUserName'] # 接收人id | ||||
other_user_id = msg['User']['UserName'] # 对手方id | other_user_id = msg['User']['UserName'] # 对手方id | ||||
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) | |||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix')) | |||||
if "」\n- - - - - - - - - - - - - - -" in content: | if "」\n- - - - - - - - - - - - - - -" in content: | ||||
logger.debug("[WX]reference query skipped") | logger.debug("[WX]reference query skipped") | ||||
return | return | ||||
if from_user_id == other_user_id and match_prefix is not None: | |||||
# 好友向自己发送消息 | |||||
if match_prefix != '': | |||||
str_list = content.split(match_prefix, 1) | |||||
if len(str_list) == 2: | |||||
content = str_list[1].strip() | |||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) | |||||
if img_match_prefix: | |||||
content = content.split(img_match_prefix, 1)[1].strip() | |||||
thread_pool.submit(self._do_send_img, content, from_user_id) | |||||
else : | |||||
thread_pool.submit(self._do_send_text, content, from_user_id) | |||||
elif to_user_id == other_user_id and match_prefix: | |||||
# 自己给好友发送消息 | |||||
str_list = content.split(match_prefix, 1) | |||||
if len(str_list) == 2: | |||||
content = str_list[1].strip() | |||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) | |||||
if img_match_prefix: | |||||
content = content.split(img_match_prefix, 1)[1].strip() | |||||
thread_pool.submit(self._do_send_img, content, to_user_id) | |||||
else: | |||||
thread_pool.submit(self._do_send_text, content, to_user_id) | |||||
if match_prefix: | |||||
content = content.replace(match_prefix, '', 1).strip() | |||||
else: | |||||
return | |||||
context = Context() | |||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | |||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | |||||
if img_match_prefix: | |||||
content = content.replace(img_match_prefix, '', 1).strip() | |||||
context.type = ContextType.IMAGE_CREATE | |||||
else: | |||||
context.type = ContextType.TEXT | |||||
context.content = content | |||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | |||||
def handle_group(self, msg): | def handle_group(self, msg): | ||||
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) | logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) | ||||
@@ -122,100 +120,128 @@ class WechatChannel(Channel): | |||||
logger.debug("[WX]reference query skipped") | logger.debug("[WX]reference query skipped") | ||||
return "" | return "" | ||||
config = conf() | config = conf() | ||||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \ | |||||
or self.check_contain(origin_content, config.get('group_chat_keyword')) | |||||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: | |||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) | |||||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ | |||||
or check_contain(origin_content, config.get('group_chat_keyword')) | |||||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: | |||||
context = Context() | |||||
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} | |||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | |||||
if img_match_prefix: | if img_match_prefix: | ||||
content = content.split(img_match_prefix, 1)[1].strip() | |||||
thread_pool.submit(self._do_send_img, content, group_id) | |||||
content = content.replace(img_match_prefix, '', 1).strip() | |||||
context.type = ContextType.IMAGE_CREATE | |||||
else: | else: | ||||
thread_pool.submit(self._do_send_group, content, msg) | |||||
def send(self, msg, receiver): | |||||
itchat.send(msg, toUserName=receiver) | |||||
logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver)) | |||||
def _do_send_voice(self, query, reply_user_id): | |||||
try: | |||||
if not query: | |||||
return | |||||
context = dict() | |||||
context['from_user_id'] = reply_user_id | |||||
reply_text = super().build_reply_content(query, context) | |||||
if reply_text: | |||||
replyFile = super().build_text_to_voice(reply_text) | |||||
itchat.send_file(replyFile, toUserName=reply_user_id) | |||||
logger.info('[WX] sendFile={}, receiver={}'.format(replyFile, reply_user_id)) | |||||
except Exception as e: | |||||
logger.exception(e) | |||||
def _do_send_text(self, query, reply_user_id): | |||||
try: | |||||
if not query: | |||||
return | |||||
context = dict() | |||||
context['session_id'] = reply_user_id | |||||
reply_text = super().build_reply_content(query, context) | |||||
if reply_text: | |||||
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) | |||||
except Exception as e: | |||||
logger.exception(e) | |||||
def _do_send_img(self, query, reply_user_id): | |||||
try: | |||||
if not query: | |||||
return | |||||
context = dict() | |||||
context['type'] = 'IMAGE_CREATE' | |||||
img_url = super().build_reply_content(query, context) | |||||
if not img_url: | |||||
return | |||||
# 图片下载 | |||||
context.type = ContextType.TEXT | |||||
context.content = content | |||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | |||||
if ('ALL_GROUP' in group_chat_in_one_session or | |||||
group_name in group_chat_in_one_session or | |||||
check_contain(group_name, group_chat_in_one_session)): | |||||
context['session_id'] = group_id | |||||
else: | |||||
context['session_id'] = msg['ActualUserName'] | |||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | |||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | |||||
def send(self, reply : Reply, receiver): | |||||
if reply.type == ReplyType.TEXT: | |||||
itchat.send(reply.content, toUserName=receiver) | |||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | |||||
itchat.send(reply.content, toUserName=receiver) | |||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||||
elif reply.type == ReplyType.VOICE: | |||||
itchat.send_file(reply.content, toUserName=receiver) | |||||
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) | |||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||||
img_url = reply.content | |||||
pic_res = requests.get(img_url, stream=True) | pic_res = requests.get(img_url, stream=True) | ||||
image_storage = io.BytesIO() | image_storage = io.BytesIO() | ||||
for block in pic_res.iter_content(1024): | for block in pic_res.iter_content(1024): | ||||
image_storage.write(block) | image_storage.write(block) | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
itchat.send_image(image_storage, toUserName=receiver) | |||||
logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver)) | |||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||||
image_storage = reply.content | |||||
image_storage.seek(0) | |||||
itchat.send_image(image_storage, toUserName=receiver) | |||||
logger.info('[WX] sendImage, receiver={}'.format(receiver)) | |||||
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 | |||||
def handle(self, context): | |||||
reply = Reply() | |||||
logger.debug('[WX] ready to handle context: {}'.format(context)) | |||||
# reply的构建步骤 | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) | |||||
reply = e_context['reply'] | |||||
if not e_context.is_pass(): | |||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) | |||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: | |||||
reply = super().build_reply_content(context.content, context) | |||||
elif context.type == ContextType.VOICE: | |||||
msg = context['msg'] | |||||
file_name = TmpDir().path() + context.content | |||||
msg.download(file_name) | |||||
reply = super().build_voice_to_text(file_name) | |||||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: | |||||
context.content = reply.content # 语音转文字后,将文字内容作为新的context | |||||
context.type = ContextType.TEXT | |||||
reply = super().build_reply_content(context.content, context) | |||||
if reply.type == ReplyType.TEXT: | |||||
if conf().get('voice_reply_voice'): | |||||
reply = super().build_text_to_voice(reply.content) | |||||
else: | |||||
logger.error('[WX] unknown context type: {}'.format(context.type)) | |||||
return | |||||
# 图片发送 | |||||
itchat.send_image(image_storage, reply_user_id) | |||||
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id)) | |||||
except Exception as e: | |||||
logger.exception(e) | |||||
def _do_send_group(self, query, msg): | |||||
if not query: | |||||
return | |||||
context = dict() | |||||
group_name = msg['User']['NickName'] | |||||
group_id = msg['User']['UserName'] | |||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | |||||
if ('ALL_GROUP' in group_chat_in_one_session or \ | |||||
group_name in group_chat_in_one_session or \ | |||||
self.check_contain(group_name, group_chat_in_one_session)): | |||||
context['session_id'] = group_id | |||||
else: | |||||
context['session_id'] = msg['ActualUserName'] | |||||
reply_text = super().build_reply_content(query, context) | |||||
if reply_text: | |||||
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip() | |||||
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) | |||||
def check_prefix(self, content, prefix_list): | |||||
for prefix in prefix_list: | |||||
if content.startswith(prefix): | |||||
return prefix | |||||
return None | |||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply)) | |||||
# reply的包装步骤 | |||||
if reply and reply.type: | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | |||||
reply=e_context['reply'] | |||||
if not e_context.is_pass() and reply and reply.type: | |||||
if reply.type == ReplyType.TEXT: | |||||
reply_text = reply.content | |||||
if context['isgroup']: | |||||
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() | |||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text | |||||
else: | |||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text | |||||
reply.content = reply_text | |||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | |||||
reply.content = str(reply.type)+":\n" + reply.content | |||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: | |||||
pass | |||||
else: | |||||
logger.error('[WX] unknown reply type: {}'.format(reply.type)) | |||||
return | |||||
# reply的发送步骤 | |||||
if reply and reply.type: | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | |||||
reply=e_context['reply'] | |||||
if not e_context.is_pass() and reply and reply.type: | |||||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) | |||||
self.send(reply, context['receiver']) | |||||
def check_prefix(content, prefix_list): | |||||
for prefix in prefix_list: | |||||
if content.startswith(prefix): | |||||
return prefix | |||||
return None | |||||
def check_contain(self, content, keyword_list): | |||||
if not keyword_list: | |||||
return None | |||||
for ky in keyword_list: | |||||
if content.find(ky) != -1: | |||||
return True | |||||
def check_contain(content, keyword_list): | |||||
if not keyword_list: | |||||
return None | return None | ||||
for ky in keyword_list: | |||||
if content.find(ky) != -1: | |||||
return True | |||||
return None |
@@ -11,6 +11,7 @@ import time | |||||
import asyncio | import asyncio | ||||
import requests | import requests | ||||
from typing import Optional, Union | from typing import Optional, Union | ||||
from bridge.context import Context, ContextType | |||||
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore | from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore | ||||
from wechaty import Wechaty, Contact | from wechaty import Wechaty, Contact | ||||
from wechaty.user import Message, Room, MiniProgram, UrlLink | from wechaty.user import Message, Room, MiniProgram, UrlLink | ||||
@@ -127,9 +128,9 @@ class WechatyChannel(Channel): | |||||
try: | try: | ||||
if not query: | if not query: | ||||
return | return | ||||
context = dict() | |||||
context = Context(ContextType.TEXT, query) | |||||
context['session_id'] = reply_user_id | context['session_id'] = reply_user_id | ||||
reply_text = super().build_reply_content(query, context) | |||||
reply_text = super().build_reply_content(query, context).content | |||||
if reply_text: | if reply_text: | ||||
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) | await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) | ||||
except Exception as e: | except Exception as e: | ||||
@@ -139,9 +140,8 @@ class WechatyChannel(Channel): | |||||
try: | try: | ||||
if not query: | if not query: | ||||
return | return | ||||
context = dict() | |||||
context['type'] = 'IMAGE_CREATE' | |||||
img_url = super().build_reply_content(query, context) | |||||
context = Context(ContextType.IMAGE_CREATE, query) | |||||
img_url = super().build_reply_content(query, context).content | |||||
if not img_url: | if not img_url: | ||||
return | return | ||||
# 图片下载 | # 图片下载 | ||||
@@ -162,7 +162,7 @@ class WechatyChannel(Channel): | |||||
async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name): | async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name): | ||||
if not query: | if not query: | ||||
return | return | ||||
context = dict() | |||||
context = Context(ContextType.TEXT, query) | |||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | ||||
if ('ALL_GROUP' in group_chat_in_one_session or \ | if ('ALL_GROUP' in group_chat_in_one_session or \ | ||||
group_name in group_chat_in_one_session or \ | group_name in group_chat_in_one_session or \ | ||||
@@ -170,7 +170,7 @@ class WechatyChannel(Channel): | |||||
context['session_id'] = str(group_id) | context['session_id'] = str(group_id) | ||||
else: | else: | ||||
context['session_id'] = str(group_id) + '-' + str(group_user_id) | context['session_id'] = str(group_id) + '-' + str(group_user_id) | ||||
reply_text = super().build_reply_content(query, context) | |||||
reply_text = super().build_reply_content(query, context).content | |||||
if reply_text: | if reply_text: | ||||
reply_text = '@' + group_user_name + ' ' + reply_text.strip() | reply_text = '@' + group_user_name + ' ' + reply_text.strip() | ||||
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) | await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) | ||||
@@ -179,9 +179,8 @@ class WechatyChannel(Channel): | |||||
try: | try: | ||||
if not query: | if not query: | ||||
return | return | ||||
context = dict() | |||||
context['type'] = 'IMAGE_CREATE' | |||||
img_url = super().build_reply_content(query, context) | |||||
context = Context(ContextType.IMAGE_CREATE, query) | |||||
img_url = super().build_reply_content(query, context).content | |||||
if not img_url: | if not img_url: | ||||
return | return | ||||
# 图片发送 | # 图片发送 | ||||
@@ -0,0 +1,9 @@ | |||||
def singleton(cls): | |||||
instances = {} | |||||
def get_instance(*args, **kwargs): | |||||
if cls not in instances: | |||||
instances[cls] = cls(*args, **kwargs) | |||||
return instances[cls] | |||||
return get_instance |
@@ -0,0 +1,65 @@ | |||||
import heapq | |||||
class SortedDict(dict): | |||||
def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False): | |||||
if init_dict is None: | |||||
init_dict = [] | |||||
if isinstance(init_dict, dict): | |||||
init_dict = init_dict.items() | |||||
self.sort_func = sort_func | |||||
self.sorted_keys = None | |||||
self.reverse = reverse | |||||
self.heap = [] | |||||
for k, v in init_dict: | |||||
self[k] = v | |||||
def __setitem__(self, key, value): | |||||
if key in self: | |||||
super().__setitem__(key, value) | |||||
for i, (priority, k) in enumerate(self.heap): | |||||
if k == key: | |||||
self.heap[i] = (self.sort_func(key, value), key) | |||||
heapq.heapify(self.heap) | |||||
break | |||||
self.sorted_keys = None | |||||
else: | |||||
super().__setitem__(key, value) | |||||
heapq.heappush(self.heap, (self.sort_func(key, value), key)) | |||||
self.sorted_keys = None | |||||
def __delitem__(self, key): | |||||
super().__delitem__(key) | |||||
for i, (priority, k) in enumerate(self.heap): | |||||
if k == key: | |||||
del self.heap[i] | |||||
heapq.heapify(self.heap) | |||||
break | |||||
self.sorted_keys = None | |||||
def keys(self): | |||||
if self.sorted_keys is None: | |||||
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] | |||||
return self.sorted_keys | |||||
def items(self): | |||||
if self.sorted_keys is None: | |||||
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] | |||||
sorted_items = [(k, self[k]) for k in self.sorted_keys] | |||||
return sorted_items | |||||
def _update_heap(self, key): | |||||
for i, (priority, k) in enumerate(self.heap): | |||||
if k == key: | |||||
new_priority = self.sort_func(key, self[key]) | |||||
if new_priority != priority: | |||||
self.heap[i] = (new_priority, key) | |||||
heapq.heapify(self.heap) | |||||
self.sorted_keys = None | |||||
break | |||||
def __iter__(self): | |||||
return iter(self.keys()) | |||||
def __repr__(self): | |||||
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})' |
@@ -0,0 +1,9 @@ | |||||
from .plugin_manager import PluginManager | |||||
from .event import * | |||||
from .plugin import * | |||||
instance = PluginManager() | |||||
register = instance.register | |||||
# load_plugins = instance.load_plugins | |||||
# emit_event = instance.emit_event |
@@ -0,0 +1 @@ | |||||
banwords.txt |
@@ -0,0 +1,9 @@ | |||||
### 说明 | |||||
简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。 | |||||
`config.json`中能够填写默认的处理行为,目前行为有: | |||||
- `ignore` : 无视这条消息。 | |||||
- `replace` : 将消息中的敏感词替换成"*",并回复违规。 | |||||
### 致谢 | |||||
搜索功能实现来自https://github.com/toolgood/ToolGood.Words |
@@ -0,0 +1,250 @@ | |||||
#!/usr/bin/env python | |||||
# -*- coding:utf-8 -*- | |||||
# ToolGood.Words.WordsSearch.py | |||||
# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words | |||||
# Licensed under the Apache License 2.0 | |||||
# 更新日志 | |||||
# 2020.04.06 第一次提交 | |||||
# 2020.05.16 修改,支持大于0xffff的字符 | |||||
__all__ = ['WordsSearch'] | |||||
__author__ = 'Lin Zhijun' | |||||
__date__ = '2020.05.16' | |||||
class TrieNode(): | |||||
def __init__(self): | |||||
self.Index = 0 | |||||
self.Index = 0 | |||||
self.Layer = 0 | |||||
self.End = False | |||||
self.Char = '' | |||||
self.Results = [] | |||||
self.m_values = {} | |||||
self.Failure = None | |||||
self.Parent = None | |||||
def Add(self,c): | |||||
if c in self.m_values : | |||||
return self.m_values[c] | |||||
node = TrieNode() | |||||
node.Parent = self | |||||
node.Char = c | |||||
self.m_values[c] = node | |||||
return node | |||||
def SetResults(self,index): | |||||
if (self.End == False): | |||||
self.End = True | |||||
self.Results.append(index) | |||||
class TrieNode2(): | |||||
def __init__(self): | |||||
self.End = False | |||||
self.Results = [] | |||||
self.m_values = {} | |||||
self.minflag = 0xffff | |||||
self.maxflag = 0 | |||||
def Add(self,c,node3): | |||||
if (self.minflag > c): | |||||
self.minflag = c | |||||
if (self.maxflag < c): | |||||
self.maxflag = c | |||||
self.m_values[c] = node3 | |||||
def SetResults(self,index): | |||||
if (self.End == False) : | |||||
self.End = True | |||||
if (index in self.Results )==False : | |||||
self.Results.append(index) | |||||
def HasKey(self,c): | |||||
return c in self.m_values | |||||
def TryGetValue(self,c): | |||||
if (self.minflag <= c and self.maxflag >= c): | |||||
if c in self.m_values: | |||||
return self.m_values[c] | |||||
return None | |||||
class WordsSearch(): | |||||
def __init__(self): | |||||
self._first = {} | |||||
self._keywords = [] | |||||
self._indexs=[] | |||||
def SetKeywords(self,keywords): | |||||
self._keywords = keywords | |||||
self._indexs=[] | |||||
for i in range(len(keywords)): | |||||
self._indexs.append(i) | |||||
root = TrieNode() | |||||
allNodeLayer={} | |||||
for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++) | |||||
p = self._keywords[i] | |||||
nd = root | |||||
for j in range(len(p)): # for (j = 0; j < p.length; j++) | |||||
nd = nd.Add(ord(p[j])) | |||||
if (nd.Layer == 0): | |||||
nd.Layer = j + 1 | |||||
if nd.Layer in allNodeLayer: | |||||
allNodeLayer[nd.Layer].append(nd) | |||||
else: | |||||
allNodeLayer[nd.Layer]=[] | |||||
allNodeLayer[nd.Layer].append(nd) | |||||
nd.SetResults(i) | |||||
allNode = [] | |||||
allNode.append(root) | |||||
for key in allNodeLayer.keys(): | |||||
for nd in allNodeLayer[key]: | |||||
allNode.append(nd) | |||||
allNodeLayer=None | |||||
for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) | |||||
if i==0 : | |||||
continue | |||||
nd=allNode[i] | |||||
nd.Index = i | |||||
r = nd.Parent.Failure | |||||
c = nd.Char | |||||
while (r != None and (c in r.m_values)==False): | |||||
r = r.Failure | |||||
if (r == None): | |||||
nd.Failure = root | |||||
else: | |||||
nd.Failure = r.m_values[c] | |||||
for key2 in nd.Failure.Results : | |||||
nd.SetResults(key2) | |||||
root.Failure = root | |||||
allNode2 = [] | |||||
for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) | |||||
allNode2.append( TrieNode2()) | |||||
for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++) | |||||
oldNode = allNode[i] | |||||
newNode = allNode2[i] | |||||
for key in oldNode.m_values : | |||||
index = oldNode.m_values[key].Index | |||||
newNode.Add(key, allNode2[index]) | |||||
for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++) | |||||
item = oldNode.Results[index] | |||||
newNode.SetResults(item) | |||||
oldNode=oldNode.Failure | |||||
while oldNode != root: | |||||
for key in oldNode.m_values : | |||||
if (newNode.HasKey(key) == False): | |||||
index = oldNode.m_values[key].Index | |||||
newNode.Add(key, allNode2[index]) | |||||
for index in range(len(oldNode.Results)): | |||||
item = oldNode.Results[index] | |||||
newNode.SetResults(item) | |||||
oldNode=oldNode.Failure | |||||
allNode = None | |||||
root = None | |||||
# first = [] | |||||
# for index in range(65535):# for (index = 0; index < 0xffff; index++) | |||||
# first.append(None) | |||||
# for key in allNode2[0].m_values : | |||||
# first[key] = allNode2[0].m_values[key] | |||||
self._first = allNode2[0] | |||||
def FindFirst(self,text): | |||||
ptr = None | |||||
for index in range(len(text)): # for (index = 0; index < text.length; index++) | |||||
t =ord(text[index]) # text.charCodeAt(index) | |||||
tn = None | |||||
if (ptr == None): | |||||
tn = self._first.TryGetValue(t) | |||||
else: | |||||
tn = ptr.TryGetValue(t) | |||||
if (tn==None): | |||||
tn = self._first.TryGetValue(t) | |||||
if (tn != None): | |||||
if (tn.End): | |||||
item = tn.Results[0] | |||||
keyword = self._keywords[item] | |||||
return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] } | |||||
ptr = tn | |||||
return None | |||||
def FindAll(self,text): | |||||
ptr = None | |||||
list = [] | |||||
for index in range(len(text)): # for (index = 0; index < text.length; index++) | |||||
t =ord(text[index]) # text.charCodeAt(index) | |||||
tn = None | |||||
if (ptr == None): | |||||
tn = self._first.TryGetValue(t) | |||||
else: | |||||
tn = ptr.TryGetValue(t) | |||||
if (tn==None): | |||||
tn = self._first.TryGetValue(t) | |||||
if (tn != None): | |||||
if (tn.End): | |||||
for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++) | |||||
item = tn.Results[j] | |||||
keyword = self._keywords[item] | |||||
list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }) | |||||
ptr = tn | |||||
return list | |||||
def ContainsAny(self,text): | |||||
ptr = None | |||||
for index in range(len(text)): # for (index = 0; index < text.length; index++) | |||||
t =ord(text[index]) # text.charCodeAt(index) | |||||
tn = None | |||||
if (ptr == None): | |||||
tn = self._first.TryGetValue(t) | |||||
else: | |||||
tn = ptr.TryGetValue(t) | |||||
if (tn==None): | |||||
tn = self._first.TryGetValue(t) | |||||
if (tn != None): | |||||
if (tn.End): | |||||
return True | |||||
ptr = tn | |||||
return False | |||||
def Replace(self,text, replaceChar = '*'): | |||||
result = list(text) | |||||
ptr = None | |||||
for i in range(len(text)): # for (i = 0; i < text.length; i++) | |||||
t =ord(text[i]) # text.charCodeAt(index) | |||||
tn = None | |||||
if (ptr == None): | |||||
tn = self._first.TryGetValue(t) | |||||
else: | |||||
tn = ptr.TryGetValue(t) | |||||
if (tn==None): | |||||
tn = self._first.TryGetValue(t) | |||||
if (tn != None): | |||||
if (tn.End): | |||||
maxLength = len( self._keywords[tn.Results[0]]) | |||||
start = i + 1 - maxLength | |||||
for j in range(start,i+1): # for (j = start; j <= i; j++) | |||||
result[j] = replaceChar | |||||
ptr = tn | |||||
return ''.join(result) |
@@ -0,0 +1,63 @@ | |||||
# encoding:utf-8 | |||||
import json | |||||
import os | |||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
import plugins | |||||
from plugins import * | |||||
from common.log import logger | |||||
from .WordsSearch import WordsSearch | |||||
@plugins.register(name="Banwords", desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent", desire_priority= 100) | |||||
class Banwords(Plugin): | |||||
def __init__(self): | |||||
super().__init__() | |||||
try: | |||||
curdir=os.path.dirname(__file__) | |||||
config_path=os.path.join(curdir,"config.json") | |||||
conf=None | |||||
if not os.path.exists(config_path): | |||||
conf={"action":"ignore"} | |||||
with open(config_path,"w") as f: | |||||
json.dump(conf,f,indent=4) | |||||
else: | |||||
with open(config_path,"r") as f: | |||||
conf=json.load(f) | |||||
self.searchr = WordsSearch() | |||||
self.action = conf["action"] | |||||
banwords_path = os.path.join(curdir,"banwords.txt") | |||||
with open(banwords_path, 'r', encoding='utf-8') as f: | |||||
words=[] | |||||
for line in f: | |||||
word = line.strip() | |||||
if word: | |||||
words.append(word) | |||||
self.searchr.SetKeywords(words) | |||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||||
logger.info("[Banwords] inited") | |||||
except Exception as e: | |||||
logger.error("Banwords init failed: %s" % e) | |||||
def on_handle_context(self, e_context: EventContext): | |||||
if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]: | |||||
return | |||||
content = e_context['context'].content | |||||
logger.debug("[Banwords] on_handle_context. content: %s" % content) | |||||
if self.action == "ignore": | |||||
f = self.searchr.FindFirst(content) | |||||
if f: | |||||
logger.info("Banwords: %s" % f["Keyword"]) | |||||
e_context.action = EventAction.BREAK_PASS | |||||
return | |||||
elif self.action == "replace": | |||||
if self.searchr.ContainsAny(content): | |||||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content)) | |||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK_PASS | |||||
return |
@@ -0,0 +1,3 @@ | |||||
nipples | |||||
pennis | |||||
法轮功 |
@@ -0,0 +1,3 @@ | |||||
{ | |||||
"action": "ignore" | |||||
} |
@@ -0,0 +1,49 @@ | |||||
# encoding:utf-8 | |||||
from enum import Enum | |||||
class Event(Enum): | |||||
# ON_RECEIVE_MESSAGE = 1 # 收到消息 | |||||
ON_HANDLE_CONTEXT = 2 # 处理消息前 | |||||
""" | |||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } | |||||
""" | |||||
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 | |||||
""" | |||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | |||||
""" | |||||
ON_SEND_REPLY = 4 # 发送回复前 | |||||
""" | |||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | |||||
""" | |||||
# AFTER_SEND_REPLY = 5 # 发送回复后 | |||||
class EventAction(Enum): | |||||
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 | |||||
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 | |||||
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 | |||||
class EventContext: | |||||
def __init__(self, event, econtext=dict()): | |||||
self.event = event | |||||
self.econtext = econtext | |||||
self.action = EventAction.CONTINUE | |||||
def __getitem__(self, key): | |||||
return self.econtext[key] | |||||
def __setitem__(self, key, value): | |||||
self.econtext[key] = value | |||||
def __delitem__(self, key): | |||||
del self.econtext[key] | |||||
def is_pass(self): | |||||
return self.action == EventAction.BREAK_PASS |
@@ -0,0 +1,4 @@ | |||||
{ | |||||
"password": "", | |||||
"admin_users": [] | |||||
} |
@@ -0,0 +1,289 @@ | |||||
# encoding:utf-8 | |||||
import json | |||||
import os | |||||
import traceback | |||||
from typing import Tuple | |||||
from bridge.bridge import Bridge | |||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
from config import load_config | |||||
import plugins | |||||
from plugins import * | |||||
from common.log import logger | |||||
# 定义指令集 | |||||
COMMANDS = { | |||||
"help": { | |||||
"alias": ["help", "帮助"], | |||||
"desc": "打印指令集合", | |||||
}, | |||||
"auth": { | |||||
"alias": ["auth", "认证"], | |||||
"args": ["口令"], | |||||
"desc": "管理员认证", | |||||
}, | |||||
# "id": { | |||||
# "alias": ["id", "用户"], | |||||
# "desc": "获取用户id", #目前无实际意义 | |||||
# }, | |||||
"reset": { | |||||
"alias": ["reset", "重置会话"], | |||||
"desc": "重置会话", | |||||
}, | |||||
} | |||||
ADMIN_COMMANDS = { | |||||
"resume": { | |||||
"alias": ["resume", "恢复服务"], | |||||
"desc": "恢复服务", | |||||
}, | |||||
"stop": { | |||||
"alias": ["stop", "暂停服务"], | |||||
"desc": "暂停服务", | |||||
}, | |||||
"reconf": { | |||||
"alias": ["reconf", "重载配置"], | |||||
"desc": "重载配置(不包含插件配置)", | |||||
}, | |||||
"resetall": { | |||||
"alias": ["resetall", "重置所有会话"], | |||||
"desc": "重置所有会话", | |||||
}, | |||||
"scanp": { | |||||
"alias": ["scanp", "扫描插件"], | |||||
"desc": "扫描插件目录是否有新插件", | |||||
}, | |||||
"plist": { | |||||
"alias": ["plist", "插件"], | |||||
"desc": "打印当前插件列表", | |||||
}, | |||||
"setpri": { | |||||
"alias": ["setpri", "设置插件优先级"], | |||||
"args": ["插件名", "优先级"], | |||||
"desc": "设置指定插件的优先级,越大越优先", | |||||
}, | |||||
"reloadp": { | |||||
"alias": ["reloadp", "重载插件"], | |||||
"args": ["插件名"], | |||||
"desc": "重载指定插件配置", | |||||
}, | |||||
"enablep": { | |||||
"alias": ["enablep", "启用插件"], | |||||
"args": ["插件名"], | |||||
"desc": "启用指定插件", | |||||
}, | |||||
"disablep": { | |||||
"alias": ["disablep", "禁用插件"], | |||||
"args": ["插件名"], | |||||
"desc": "禁用指定插件", | |||||
}, | |||||
"debug": { | |||||
"alias": ["debug", "调试模式", "DEBUG"], | |||||
"desc": "开启机器调试日志", | |||||
}, | |||||
} | |||||
# 定义帮助函数 | |||||
def get_help_text(isadmin, isgroup): | |||||
help_text = "可用指令:\n" | |||||
for cmd, info in COMMANDS.items(): | |||||
if cmd=="auth" and (isadmin or isgroup): # 群聊不可认证 | |||||
continue | |||||
alias=["#"+a for a in info['alias']] | |||||
help_text += f"{','.join(alias)} " | |||||
if 'args' in info: | |||||
args=["{"+a+"}" for a in info['args']] | |||||
help_text += f"{' '.join(args)} " | |||||
help_text += f": {info['desc']}\n" | |||||
if ADMIN_COMMANDS and isadmin: | |||||
help_text += "\n管理员指令:\n" | |||||
for cmd, info in ADMIN_COMMANDS.items(): | |||||
alias=["#"+a for a in info['alias']] | |||||
help_text += f"{','.join(alias)} " | |||||
help_text += f": {info['desc']}\n" | |||||
return help_text | |||||
@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent", desire_priority= 999) | |||||
class Godcmd(Plugin): | |||||
def __init__(self): | |||||
super().__init__() | |||||
curdir=os.path.dirname(__file__) | |||||
config_path=os.path.join(curdir,"config.json") | |||||
gconf=None | |||||
if not os.path.exists(config_path): | |||||
gconf={"password":"","admin_users":[]} | |||||
with open(config_path,"w") as f: | |||||
json.dump(gconf,f,indent=4) | |||||
else: | |||||
with open(config_path,"r") as f: | |||||
gconf=json.load(f) | |||||
self.password = gconf["password"] | |||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用 | |||||
self.isrunning = True # 机器人是否运行中 | |||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||||
logger.info("[Godcmd] inited") | |||||
def on_handle_context(self, e_context: EventContext): | |||||
context_type = e_context['context'].type | |||||
if context_type != ContextType.TEXT: | |||||
if not self.isrunning: | |||||
e_context.action = EventAction.BREAK_PASS | |||||
return | |||||
content = e_context['context'].content | |||||
logger.debug("[Godcmd] on_handle_context. content: %s" % content) | |||||
if content.startswith("#"): | |||||
# msg = e_context['context']['msg'] | |||||
user = e_context['context']['receiver'] | |||||
session_id = e_context['context']['session_id'] | |||||
isgroup = e_context['context']['isgroup'] | |||||
bottype = Bridge().get_bot_type("chat") | |||||
bot = Bridge().get_bot("chat") | |||||
# 将命令和参数分割 | |||||
command_parts = content[1:].split(" ") | |||||
cmd = command_parts[0] | |||||
args = command_parts[1:] | |||||
isadmin=False | |||||
if user in self.admin_users: | |||||
isadmin=True | |||||
ok=False | |||||
result="string" | |||||
if any(cmd in info['alias'] for info in COMMANDS.values()): | |||||
cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias']) | |||||
if cmd == "auth": | |||||
ok, result = self.authenticate(user, args, isadmin, isgroup) | |||||
elif cmd == "help": | |||||
ok, result = True, get_help_text(isadmin, isgroup) | |||||
elif cmd == "id": | |||||
ok, result = True, f"用户id=\n{user}" | |||||
elif cmd == "reset": | |||||
if bottype == "chatGPT": | |||||
bot.sessions.clear_session(session_id) | |||||
ok, result = True, "会话已重置" | |||||
else: | |||||
ok, result = False, "当前对话机器人不支持重置会话" | |||||
logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) | |||||
elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): | |||||
if isadmin: | |||||
if isgroup: | |||||
ok, result = False, "群聊不可执行管理员指令" | |||||
else: | |||||
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias']) | |||||
if cmd == "stop": | |||||
self.isrunning = False | |||||
ok, result = True, "服务已暂停" | |||||
elif cmd == "resume": | |||||
self.isrunning = True | |||||
ok, result = True, "服务已恢复" | |||||
elif cmd == "reconf": | |||||
load_config() | |||||
ok, result = True, "配置已重载" | |||||
elif cmd == "resetall": | |||||
if bottype == "chatGPT": | |||||
bot.sessions.clear_all_session() | |||||
ok, result = True, "重置所有会话成功" | |||||
else: | |||||
ok, result = False, "当前对话机器人不支持重置会话" | |||||
elif cmd == "debug": | |||||
logger.setLevel('DEBUG') | |||||
ok, result = True, "DEBUG模式已开启" | |||||
elif cmd == "plist": | |||||
plugins = PluginManager().list_plugins() | |||||
ok = True | |||||
result = "插件列表:\n" | |||||
for name,plugincls in plugins.items(): | |||||
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " | |||||
if plugincls.enabled: | |||||
result += "已启用\n" | |||||
else: | |||||
result += "未启用\n" | |||||
elif cmd == "scanp": | |||||
new_plugins = PluginManager().scan_plugins() | |||||
ok, result = True, "插件扫描完成" | |||||
PluginManager().activate_plugins() | |||||
if len(new_plugins) >0 : | |||||
result += "\n发现新插件:\n" | |||||
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) | |||||
else : | |||||
result +=", 未发现新插件" | |||||
elif cmd == "setpri": | |||||
if len(args) != 2: | |||||
ok, result = False, "请提供插件名和优先级" | |||||
else: | |||||
ok = PluginManager().set_plugin_priority(args[0], int(args[1])) | |||||
if ok: | |||||
result = "插件" + args[0] + "优先级已设置为" + args[1] | |||||
else: | |||||
result = "插件不存在" | |||||
elif cmd == "reloadp": | |||||
if len(args) != 1: | |||||
ok, result = False, "请提供插件名" | |||||
else: | |||||
ok = PluginManager().reload_plugin(args[0]) | |||||
if ok: | |||||
result = "插件配置已重载" | |||||
else: | |||||
result = "插件不存在" | |||||
elif cmd == "enablep": | |||||
if len(args) != 1: | |||||
ok, result = False, "请提供插件名" | |||||
else: | |||||
ok = PluginManager().enable_plugin(args[0]) | |||||
if ok: | |||||
result = "插件已启用" | |||||
else: | |||||
result = "插件不存在" | |||||
elif cmd == "disablep": | |||||
if len(args) != 1: | |||||
ok, result = False, "请提供插件名" | |||||
else: | |||||
ok = PluginManager().disable_plugin(args[0]) | |||||
if ok: | |||||
result = "插件已禁用" | |||||
else: | |||||
result = "插件不存在" | |||||
logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user)) | |||||
else: | |||||
ok, result = False, "需要管理员权限才能执行该指令" | |||||
else: | |||||
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" | |||||
reply = Reply() | |||||
if ok: | |||||
reply.type = ReplyType.INFO | |||||
else: | |||||
reply.type = ReplyType.ERROR | |||||
reply.content = result | |||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
elif not self.isrunning: | |||||
e_context.action = EventAction.BREAK_PASS | |||||
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : | |||||
if isgroup: | |||||
return False,"请勿在群聊中认证" | |||||
if isadmin: | |||||
return False,"管理员账号无需认证" | |||||
if len(self.password) == 0: | |||||
return False,"未设置口令,无法认证" | |||||
if len(args) != 1: | |||||
return False,"请提供口令" | |||||
password = args[0] | |||||
if password == self.password: | |||||
self.admin_users.append(userid) | |||||
return True,"认证成功" | |||||
else: | |||||
return False,"认证失败" | |||||
@@ -0,0 +1,46 @@ | |||||
# encoding:utf-8 | |||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
import plugins | |||||
from plugins import * | |||||
from common.log import logger | |||||
@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1) | |||||
class Hello(Plugin): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||||
logger.info("[Hello] inited") | |||||
def on_handle_context(self, e_context: EventContext): | |||||
if e_context['context'].type != ContextType.TEXT: | |||||
return | |||||
content = e_context['context'].content | |||||
logger.debug("[Hello] on_handle_context. content: %s" % content) | |||||
if content == "Hello": | |||||
reply = Reply() | |||||
reply.type = ReplyType.TEXT | |||||
msg = e_context['context']['msg'] | |||||
if e_context['context']['isgroup']: | |||||
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") | |||||
else: | |||||
reply.content = "Hello, " + msg['User'].get('NickName', "My friend") | |||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
if content == "Hi": | |||||
reply = Reply() | |||||
reply.type = ReplyType.TEXT | |||||
reply.content = "Hi" | |||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply | |||||
if content == "End": | |||||
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" | |||||
e_context['context'].type = "IMAGE_CREATE" | |||||
content = "The World" | |||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 |
@@ -0,0 +1,3 @@ | |||||
class Plugin: | |||||
def __init__(self): | |||||
self.handlers = {} |
@@ -0,0 +1,171 @@ | |||||
# encoding:utf-8 | |||||
import importlib | |||||
import json | |||||
import os | |||||
from common.singleton import singleton | |||||
from common.sorted_dict import SortedDict | |||||
from .event import * | |||||
from .plugin import * | |||||
from common.log import logger | |||||
@singleton | |||||
class PluginManager: | |||||
def __init__(self): | |||||
self.plugins = SortedDict(lambda k,v: v.priority,reverse=True) | |||||
self.listening_plugins = {} | |||||
self.instances = {} | |||||
self.pconf = {} | |||||
def register(self, name: str, desc: str, version: str, author: str, desire_priority: int = 0): | |||||
def wrapper(plugincls): | |||||
plugincls.name = name | |||||
plugincls.desc = desc | |||||
plugincls.version = version | |||||
plugincls.author = author | |||||
plugincls.priority = desire_priority | |||||
plugincls.enabled = True | |||||
self.plugins[name.upper()] = plugincls | |||||
logger.info("Plugin %s_v%s registered" % (name, version)) | |||||
return plugincls | |||||
return wrapper | |||||
def save_config(self): | |||||
with open("plugins/plugins.json", "w", encoding="utf-8") as f: | |||||
json.dump(self.pconf, f, indent=4, ensure_ascii=False) | |||||
def load_config(self): | |||||
logger.info("Loading plugins config...") | |||||
modified = False | |||||
if os.path.exists("plugins/plugins.json"): | |||||
with open("plugins/plugins.json", "r", encoding="utf-8") as f: | |||||
pconf = json.load(f) | |||||
pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True) | |||||
else: | |||||
modified = True | |||||
pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)} | |||||
self.pconf = pconf | |||||
if modified: | |||||
self.save_config() | |||||
return pconf | |||||
def scan_plugins(self): | |||||
logger.info("Scaning plugins ...") | |||||
plugins_dir = "plugins" | |||||
for plugin_name in os.listdir(plugins_dir): | |||||
plugin_path = os.path.join(plugins_dir, plugin_name) | |||||
if os.path.isdir(plugin_path): | |||||
# 判断插件是否包含同名.py文件 | |||||
main_module_path = os.path.join(plugin_path, plugin_name+".py") | |||||
if os.path.isfile(main_module_path): | |||||
# 导入插件 | |||||
import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name) | |||||
main_module = importlib.import_module(import_path) | |||||
pconf = self.pconf | |||||
new_plugins = [] | |||||
modified = False | |||||
for name, plugincls in self.plugins.items(): | |||||
rawname = plugincls.name | |||||
if rawname not in pconf["plugins"]: | |||||
new_plugins.append(plugincls) | |||||
modified = True | |||||
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) | |||||
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} | |||||
else: | |||||
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] | |||||
self.plugins[name].priority = pconf["plugins"][rawname]["priority"] | |||||
self.plugins._update_heap(name) # 更新下plugins中的顺序 | |||||
if modified: | |||||
self.save_config() | |||||
return new_plugins | |||||
def refresh_order(self): | |||||
for event in self.listening_plugins.keys(): | |||||
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) | |||||
def activate_plugins(self): # 生成新开启的插件实例 | |||||
for name, plugincls in self.plugins.items(): | |||||
if plugincls.enabled: | |||||
if name not in self.instances: | |||||
instance = plugincls() | |||||
self.instances[name] = instance | |||||
for event in instance.handlers: | |||||
if event not in self.listening_plugins: | |||||
self.listening_plugins[event] = [] | |||||
self.listening_plugins[event].append(name) | |||||
self.refresh_order() | |||||
def reload_plugin(self, name:str): | |||||
name = name.upper() | |||||
if name in self.instances: | |||||
for event in self.listening_plugins: | |||||
if name in self.listening_plugins[event]: | |||||
self.listening_plugins[event].remove(name) | |||||
del self.instances[name] | |||||
self.activate_plugins() | |||||
return True | |||||
return False | |||||
def load_plugins(self): | |||||
self.load_config() | |||||
self.scan_plugins() | |||||
pconf = self.pconf | |||||
logger.debug("plugins.json config={}".format(pconf)) | |||||
for name,plugin in pconf["plugins"].items(): | |||||
if name.upper() not in self.plugins: | |||||
logger.error("Plugin %s not found, but found in plugins.json" % name) | |||||
self.activate_plugins() | |||||
def emit_event(self, e_context: EventContext, *args, **kwargs): | |||||
if e_context.event in self.listening_plugins: | |||||
for name in self.listening_plugins[e_context.event]: | |||||
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: | |||||
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) | |||||
instance = self.instances[name] | |||||
instance.handlers[e_context.event](e_context, *args, **kwargs) | |||||
return e_context | |||||
def set_plugin_priority(self, name:str, priority:int): | |||||
name = name.upper() | |||||
if name not in self.plugins: | |||||
return False | |||||
if self.plugins[name].priority == priority: | |||||
return True | |||||
self.plugins[name].priority = priority | |||||
self.plugins._update_heap(name) | |||||
rawname = self.plugins[name].name | |||||
self.pconf["plugins"][rawname]["priority"] = priority | |||||
self.pconf["plugins"]._update_heap(rawname) | |||||
self.save_config() | |||||
self.refresh_order() | |||||
return True | |||||
def enable_plugin(self, name:str): | |||||
name = name.upper() | |||||
if name not in self.plugins: | |||||
return False | |||||
if not self.plugins[name].enabled : | |||||
self.plugins[name].enabled = True | |||||
rawname = self.plugins[name].name | |||||
self.pconf["plugins"][rawname]["enabled"] = True | |||||
self.save_config() | |||||
self.activate_plugins() | |||||
return True | |||||
return True | |||||
def disable_plugin(self, name:str): | |||||
name = name.upper() | |||||
if name not in self.plugins: | |||||
return False | |||||
if self.plugins[name].enabled : | |||||
self.plugins[name].enabled = False | |||||
rawname = self.plugins[name].name | |||||
self.pconf["plugins"][rawname]["enabled"] = False | |||||
self.save_config() | |||||
return True | |||||
return True | |||||
def list_plugins(self): | |||||
return self.plugins |
@@ -0,0 +1,70 @@ | |||||
{ | |||||
"start":{ | |||||
"host" : "127.0.0.1", | |||||
"port" : 7860 | |||||
}, | |||||
"defaults": { | |||||
"params": { | |||||
"sampler_name": "DPM++ 2M Karras", | |||||
"steps": 20, | |||||
"width": 512, | |||||
"height": 512, | |||||
"cfg_scale": 7, | |||||
"prompt":"masterpiece, best quality", | |||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", | |||||
"enable_hr": false, | |||||
"hr_scale": 2, | |||||
"hr_upscaler": "Latent", | |||||
"hr_second_pass_steps": 15, | |||||
"denoising_strength": 0.7 | |||||
}, | |||||
"options": { | |||||
"sd_model_checkpoint": "perfectWorld_v2Baked" | |||||
} | |||||
}, | |||||
"rules": [ | |||||
{ | |||||
"keywords": [ | |||||
"横版", | |||||
"壁纸" | |||||
], | |||||
"params": { | |||||
"width": 640, | |||||
"height": 384 | |||||
}, | |||||
"desc": "分辨率会变成640x384" | |||||
}, | |||||
{ | |||||
"keywords": [ | |||||
"竖版" | |||||
], | |||||
"params": { | |||||
"width": 384, | |||||
"height": 640 | |||||
} | |||||
}, | |||||
{ | |||||
"keywords": [ | |||||
"高清" | |||||
], | |||||
"params": { | |||||
"enable_hr": true, | |||||
"hr_scale": 1.6 | |||||
}, | |||||
"desc": "出图分辨率长宽都会提高1.6倍" | |||||
}, | |||||
{ | |||||
"keywords": [ | |||||
"二次元" | |||||
], | |||||
"params": { | |||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", | |||||
"prompt": "masterpiece, best quality" | |||||
}, | |||||
"options": { | |||||
"sd_model_checkpoint": "meinamix_meinaV8" | |||||
}, | |||||
"desc": "使用二次元风格模型出图" | |||||
} | |||||
] | |||||
} |
@@ -0,0 +1,69 @@ | |||||
### 插件描述 | |||||
本插件用于将画图请求转发给stable diffusion webui。 | |||||
### 环境要求 | |||||
使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。 | |||||
具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。 | |||||
请**安装**本插件的依赖包```webuiapi``` | |||||
``` | |||||
```pip install webuiapi``` | |||||
``` | |||||
### 使用说明 | |||||
请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。 | |||||
#### 画图请求格式 | |||||
用户的画图请求格式为: | |||||
``` | |||||
<画图触发词><关键词1> <关键词2> ... <关键词n>:<prompt> | |||||
``` | |||||
- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。 | |||||
- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准: | |||||
- 关键词中包含`help`或`帮助`,会打印出帮助文档。 | |||||
第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后 | |||||
例如: 画横版 高清 二次元:cat | |||||
会触发三个关键词 "横版", "高清", "二次元",prompt为"cat" | |||||
若默认参数是: | |||||
``` | |||||
"width": 512, | |||||
"height": 512, | |||||
"enable_hr": false, | |||||
"prompt": "8k" | |||||
"negative_prompt": "nsfw", | |||||
"sd_model_checkpoint": "perfectWorld_v2Baked" | |||||
``` | |||||
"横版"触发的规则参数为: | |||||
``` | |||||
"width": 640, | |||||
"height": 384, | |||||
``` | |||||
"高清"触发的规则参数为: | |||||
``` | |||||
"enable_hr": true, | |||||
"hr_scale": 1.6, | |||||
``` | |||||
"二次元"触发的规则参数为: | |||||
``` | |||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", | |||||
"steps": 20, | |||||
"prompt": "masterpiece, best quality", | |||||
"sd_model_checkpoint": "meinamix_meinaV8" | |||||
``` | |||||
最后将第一个":"后的内容cat连接在prompt后,得到最终参数为: | |||||
``` | |||||
"width": 640, | |||||
"height": 384, | |||||
"enable_hr": true, | |||||
"hr_scale": 1.6, | |||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", | |||||
"steps": 20, | |||||
"prompt": "masterpiece, best quality, cat", | |||||
"sd_model_checkpoint": "meinamix_meinaV8" | |||||
``` | |||||
PS: 参数分为两部分: | |||||
- 一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致 | |||||
- 另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。 |
@@ -0,0 +1,114 @@ | |||||
# encoding:utf-8 | |||||
import json | |||||
import os | |||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
from config import conf | |||||
import plugins | |||||
from plugins import * | |||||
from common.log import logger | |||||
import webuiapi | |||||
import io | |||||
@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent") | |||||
class SDWebUI(Plugin): | |||||
def __init__(self): | |||||
super().__init__() | |||||
curdir = os.path.dirname(__file__) | |||||
config_path = os.path.join(curdir, "config.json") | |||||
try: | |||||
with open(config_path, "r", encoding="utf-8") as f: | |||||
config = json.load(f) | |||||
self.rules = config["rules"] | |||||
defaults = config["defaults"] | |||||
self.default_params = defaults["params"] | |||||
self.default_options = defaults["options"] | |||||
self.start_args = config["start"] | |||||
self.api = webuiapi.WebUIApi(**self.start_args) | |||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||||
logger.info("[SD] inited") | |||||
except FileNotFoundError: | |||||
logger.error(f"[SD] init failed, {config_path} not found") | |||||
except Exception as e: | |||||
logger.error("[SD] init failed, exception: %s" % e) | |||||
def on_handle_context(self, e_context: EventContext): | |||||
if e_context['context'].type != ContextType.IMAGE_CREATE: | |||||
return | |||||
logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content) | |||||
logger.info("[SD] image_query={}".format(e_context['context'].content)) | |||||
reply = Reply() | |||||
try: | |||||
content = e_context['context'].content[:] | |||||
# 解析用户输入 如"横版 高清 二次元:cat" | |||||
if ":" in content: | |||||
keywords, prompt = content.split(":", 1) | |||||
else: | |||||
keywords = content | |||||
prompt = "" | |||||
keywords = keywords.split() | |||||
if "help" in keywords or "帮助" in keywords: | |||||
reply.type = ReplyType.INFO | |||||
reply.content = self.get_help_text() | |||||
else: | |||||
rule_params = {} | |||||
rule_options = {} | |||||
for keyword in keywords: | |||||
matched = False | |||||
for rule in self.rules: | |||||
if keyword in rule["keywords"]: | |||||
for key in rule["params"]: | |||||
rule_params[key] = rule["params"][key] | |||||
if "options" in rule: | |||||
for key in rule["options"]: | |||||
rule_options[key] = rule["options"][key] | |||||
matched = True | |||||
break # 一个关键词只匹配一个规则 | |||||
if not matched: | |||||
logger.warning("[SD] keyword not matched: %s" % keyword) | |||||
params = {**self.default_params, **rule_params} | |||||
options = {**self.default_options, **rule_options} | |||||
params["prompt"] = params.get("prompt", "")+f", {prompt}" | |||||
if len(options) > 0: | |||||
logger.info("[SD] cover options={}".format(options)) | |||||
self.api.set_options(options) | |||||
logger.info("[SD] params={}".format(params)) | |||||
result = self.api.txt2img( | |||||
**params | |||||
) | |||||
reply.type = ReplyType.IMAGE | |||||
b_img = io.BytesIO() | |||||
result.image.save(b_img, format="PNG") | |||||
reply.content = b_img | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑 | |||||
except Exception as e: | |||||
reply.type = ReplyType.ERROR | |||||
reply.content = "[SD] "+str(e) | |||||
logger.error("[SD] exception: %s" % e) | |||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | |||||
finally: | |||||
e_context['reply'] = reply | |||||
def get_help_text(self): | |||||
if not conf().get('image_create_prefix'): | |||||
return "画图功能未启用" | |||||
else: | |||||
trigger = conf()['image_create_prefix'][0] | |||||
help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n" | |||||
help_text += "目前可用关键词:\n" | |||||
for rule in self.rules: | |||||
keywords = [f"[{keyword}]" for keyword in rule['keywords']] | |||||
help_text += f"{','.join(keywords)}" | |||||
if "desc" in rule: | |||||
help_text += f"-{rule['desc']}\n" | |||||
else: | |||||
help_text += "\n" | |||||
return help_text |
@@ -4,6 +4,7 @@ baidu voice service | |||||
""" | """ | ||||
import time | import time | ||||
from aip import AipSpeech | from aip import AipSpeech | ||||
from bridge.reply import Reply, ReplyType | |||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
from voice.voice import Voice | from voice.voice import Voice | ||||
@@ -30,7 +31,8 @@ class BaiduVoice(Voice): | |||||
with open(fileName, 'wb') as f: | with open(fileName, 'wb') as f: | ||||
f.write(result) | f.write(result) | ||||
logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | ||||
return fileName | |||||
reply = Reply(ReplyType.VOICE, fileName) | |||||
else: | else: | ||||
logger.error('[Baidu] textToVoice error={}'.format(result)) | logger.error('[Baidu] textToVoice error={}'.format(result)) | ||||
return None | |||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | |||||
return reply |
@@ -6,6 +6,7 @@ google voice service | |||||
import pathlib | import pathlib | ||||
import subprocess | import subprocess | ||||
import time | import time | ||||
from bridge.reply import Reply, ReplyType | |||||
import speech_recognition | import speech_recognition | ||||
import pyttsx3 | import pyttsx3 | ||||
from common.log import logger | from common.log import logger | ||||
@@ -36,16 +37,22 @@ class GoogleVoice(Voice): | |||||
text = self.recognizer.recognize_google(audio, language='zh-CN') | text = self.recognizer.recognize_google(audio, language='zh-CN') | ||||
logger.info( | logger.info( | ||||
'[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) | '[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) | ||||
return text | |||||
reply = Reply(ReplyType.TEXT, text) | |||||
except speech_recognition.UnknownValueError: | except speech_recognition.UnknownValueError: | ||||
return "抱歉,我听不懂。" | |||||
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") | |||||
except speech_recognition.RequestError as e: | except speech_recognition.RequestError as e: | ||||
return "抱歉,无法连接到 Google 语音识别服务;{0}".format(e) | |||||
reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) | |||||
finally: | |||||
return reply | |||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||||
self.engine.save_to_file(text, textFile) | |||||
self.engine.runAndWait() | |||||
logger.info( | |||||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile)) | |||||
return textFile | |||||
try: | |||||
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||||
self.engine.save_to_file(text, textFile) | |||||
self.engine.runAndWait() | |||||
logger.info( | |||||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile)) | |||||
reply = Reply(ReplyType.VOICE, textFile) | |||||
except Exception as e: | |||||
reply = Reply(ReplyType.ERROR, str(e)) | |||||
finally: | |||||
return reply |
@@ -4,6 +4,7 @@ google voice service | |||||
""" | """ | ||||
import json | import json | ||||
import openai | import openai | ||||
from bridge.reply import Reply, ReplyType | |||||
from config import conf | from config import conf | ||||
from common.log import logger | from common.log import logger | ||||
from voice.voice import Voice | from voice.voice import Voice | ||||
@@ -16,12 +17,17 @@ class OpenaiVoice(Voice): | |||||
def voiceToText(self, voice_file): | def voiceToText(self, voice_file): | ||||
logger.debug( | logger.debug( | ||||
'[Openai] voice file name={}'.format(voice_file)) | '[Openai] voice file name={}'.format(voice_file)) | ||||
file = open(voice_file, "rb") | |||||
reply = openai.Audio.transcribe("whisper-1", file) | |||||
text = reply["text"] | |||||
logger.info( | |||||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) | |||||
return text | |||||
try: | |||||
file = open(voice_file, "rb") | |||||
result = openai.Audio.transcribe("whisper-1", file) | |||||
text = result["text"] | |||||
reply = Reply(ReplyType.TEXT, text) | |||||
logger.info( | |||||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) | |||||
except Exception as e: | |||||
reply = Reply(ReplyType.ERROR, str(e)) | |||||
finally: | |||||
return reply | |||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
pass | pass |