简易支持插件,添加sdwebui(novelai画图), godcmd(管理员指令增强)插件,Banwords(敏感词过滤)插件develop
@@ -7,3 +7,4 @@ config.json | |||
QR.png | |||
nohup.out | |||
tmp | |||
plugins.json |
@@ -4,14 +4,17 @@ import config | |||
from channel import channel_factory | |||
from common.log import logger | |||
from plugins import * | |||
if __name__ == '__main__': | |||
try: | |||
# load config | |||
config.load_config() | |||
# create channel | |||
channel = channel_factory.create_channel("wx") | |||
channel_name='wx' | |||
channel = channel_factory.create_channel(channel_name) | |||
if channel_name=='wx': | |||
PluginManager().load_plugins() | |||
# startup channel | |||
channel.startup() | |||
@@ -2,6 +2,7 @@ | |||
import requests | |||
from bot.bot import Bot | |||
from bridge.reply import Reply, ReplyType | |||
# Baidu Unit对话接口 (可用, 但能力较弱) | |||
@@ -14,7 +15,8 @@ class BaiduUnitBot(Bot): | |||
headers = {'content-type': 'application/x-www-form-urlencoded'} | |||
response = requests.post(url, data=post_data.encode(), headers=headers) | |||
if response: | |||
return response.json()['result']['context']['SYS_PRESUMED_HIST'][1] | |||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1]) | |||
return reply | |||
def get_token(self): | |||
access_key = 'YOUR_ACCESS_KEY' | |||
@@ -3,8 +3,12 @@ Auto-replay chat robot abstract class | |||
""" | |||
from bridge.context import Context | |||
from bridge.reply import Reply | |||
class Bot(object): | |||
def reply(self, query, context=None): | |||
def reply(self, query, context : Context =None) -> Reply: | |||
""" | |||
bot auto-reply content | |||
:param req: received message | |||
@@ -1,41 +1,42 @@ | |||
# encoding:utf-8 | |||
from bot.bot import Bot | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf, load_config | |||
from common.log import logger | |||
from common.expired_dict import ExpiredDict | |||
import openai | |||
import time | |||
if conf().get('expires_in_seconds'): | |||
all_sessions = ExpiredDict(conf().get('expires_in_seconds')) | |||
else: | |||
all_sessions = dict() | |||
# OpenAI对话模型API (可用) | |||
class ChatGPTBot(Bot): | |||
def __init__(self): | |||
openai.api_key = conf().get('open_ai_api_key') | |||
proxy = conf().get('proxy') | |||
self.sessions = SessionManager() | |||
if proxy: | |||
openai.proxy = proxy | |||
def reply(self, query, context=None): | |||
# acquire reply content | |||
if not context or not context.get('type') or context.get('type') == 'TEXT': | |||
if context.type == ContextType.TEXT: | |||
logger.info("[OPEN_AI] query={}".format(query)) | |||
session_id = context.get('session_id') or context.get('from_user_id') | |||
session_id = context['session_id'] | |||
reply = None | |||
if query == '#清除记忆': | |||
Session.clear_session(session_id) | |||
return '记忆已清除' | |||
self.sessions.clear_session(session_id) | |||
reply = Reply(ReplyType.INFO, '记忆已清除') | |||
elif query == '#清除所有': | |||
Session.clear_all_session() | |||
return '所有人记忆已清除' | |||
self.sessions.clear_all_session() | |||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | |||
elif query == '#更新配置': | |||
load_config() | |||
return '配置已更新' | |||
session = Session.build_session_query(query, session_id) | |||
reply = Reply(ReplyType.INFO, '配置已更新') | |||
if reply: | |||
return reply | |||
session = self.sessions.build_session_query(query, session_id) | |||
logger.debug("[OPEN_AI] session query={}".format(session)) | |||
# if context.get('stream'): | |||
@@ -44,14 +45,29 @@ class ChatGPTBot(Bot): | |||
reply_content = self.reply_text(session, session_id, 0) | |||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) | |||
if reply_content["completion_tokens"] > 0: | |||
Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) | |||
return reply_content["content"] | |||
elif context.get('type', None) == 'IMAGE_CREATE': | |||
return self.create_img(query, 0) | |||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: | |||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||
elif reply_content["completion_tokens"] > 0: | |||
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"]) | |||
reply = Reply(ReplyType.TEXT, reply_content["content"]) | |||
else: | |||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content)) | |||
return reply | |||
elif context.type == ContextType.IMAGE_CREATE: | |||
ok, retstring = self.create_img(query, 0) | |||
reply = None | |||
if ok: | |||
reply = Reply(ReplyType.IMAGE_URL, retstring) | |||
else: | |||
reply = Reply(ReplyType.ERROR, retstring) | |||
return reply | |||
else: | |||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) | |||
return reply | |||
def reply_text(self, session, session_id, retry_count=0) ->dict: | |||
def reply_text(self, session, session_id, retry_count=0) -> dict: | |||
''' | |||
call openai's ChatCompletion to get the answer | |||
:param session: a conversation session | |||
@@ -70,8 +86,8 @@ class ChatGPTBot(Bot): | |||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
) | |||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | |||
return {"total_tokens": response["usage"]["total_tokens"], | |||
"completion_tokens": response["usage"]["completion_tokens"], | |||
return {"total_tokens": response["usage"]["total_tokens"], | |||
"completion_tokens": response["usage"]["completion_tokens"], | |||
"content": response.choices[0]['message']['content']} | |||
except openai.error.RateLimitError as e: | |||
# rate limit exception | |||
@@ -86,15 +102,15 @@ class ChatGPTBot(Bot): | |||
# api connection exception | |||
logger.warn(e) | |||
logger.warn("[OPEN_AI] APIConnection failed") | |||
return {"completion_tokens": 0, "content":"我连接不到你的网络"} | |||
return {"completion_tokens": 0, "content": "我连接不到你的网络"} | |||
except openai.error.Timeout as e: | |||
logger.warn(e) | |||
logger.warn("[OPEN_AI] Timeout") | |||
return {"completion_tokens": 0, "content":"我没有收到你的消息"} | |||
return {"completion_tokens": 0, "content": "我没有收到你的消息"} | |||
except Exception as e: | |||
# unknown exception | |||
logger.exception(e) | |||
Session.clear_session(session_id) | |||
self.sessions.clear_session(session_id) | |||
return {"completion_tokens": 0, "content": "请再问我一次吧"} | |||
def create_img(self, query, retry_count=0): | |||
@@ -107,7 +123,7 @@ class ChatGPTBot(Bot): | |||
) | |||
image_url = response['data'][0]['url'] | |||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | |||
return image_url | |||
return True, image_url | |||
except openai.error.RateLimitError as e: | |||
logger.warn(e) | |||
if retry_count < 1: | |||
@@ -115,14 +131,21 @@ class ChatGPTBot(Bot): | |||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | |||
return self.create_img(query, retry_count+1) | |||
else: | |||
return "提问太快啦,请休息一下再问我吧" | |||
return False, "提问太快啦,请休息一下再问我吧" | |||
except Exception as e: | |||
logger.exception(e) | |||
return None | |||
return False, str(e) | |||
class SessionManager(object): | |||
def __init__(self): | |||
if conf().get('expires_in_seconds'): | |||
sessions = ExpiredDict(conf().get('expires_in_seconds')) | |||
else: | |||
sessions = dict() | |||
self.sessions = sessions | |||
class Session(object): | |||
@staticmethod | |||
def build_session_query(query, session_id): | |||
def build_session_query(self, query, session_id): | |||
''' | |||
build query with conversation history | |||
e.g. [ | |||
@@ -135,36 +158,33 @@ class Session(object): | |||
:param session_id: session id | |||
:return: query content with conversaction | |||
''' | |||
session = all_sessions.get(session_id, []) | |||
session = self.sessions.get(session_id, []) | |||
if len(session) == 0: | |||
system_prompt = conf().get("character_desc", "") | |||
system_item = {'role': 'system', 'content': system_prompt} | |||
session.append(system_item) | |||
all_sessions[session_id] = session | |||
self.sessions[session_id] = session | |||
user_item = {'role': 'user', 'content': query} | |||
session.append(user_item) | |||
return session | |||
@staticmethod | |||
def save_session(answer, session_id, total_tokens): | |||
def save_session(self, answer, session_id, total_tokens): | |||
max_tokens = conf().get("conversation_max_tokens") | |||
if not max_tokens: | |||
# default 3000 | |||
max_tokens = 1000 | |||
max_tokens=int(max_tokens) | |||
max_tokens = int(max_tokens) | |||
session = all_sessions.get(session_id) | |||
session = self.sessions.get(session_id) | |||
if session: | |||
# append conversation | |||
gpt_item = {'role': 'assistant', 'content': answer} | |||
session.append(gpt_item) | |||
# discard exceed limit conversation | |||
Session.discard_exceed_conversation(session, max_tokens, total_tokens) | |||
self.discard_exceed_conversation(session, max_tokens, total_tokens) | |||
@staticmethod | |||
def discard_exceed_conversation(session, max_tokens, total_tokens): | |||
def discard_exceed_conversation(self, session, max_tokens, total_tokens): | |||
dec_tokens = int(total_tokens) | |||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) | |||
while dec_tokens > max_tokens: | |||
@@ -173,13 +193,11 @@ class Session(object): | |||
session.pop(1) | |||
session.pop(1) | |||
else: | |||
break | |||
break | |||
dec_tokens = dec_tokens - max_tokens | |||
@staticmethod | |||
def clear_session(session_id): | |||
all_sessions[session_id] = [] | |||
def clear_session(self, session_id): | |||
self.sessions[session_id] = [] | |||
@staticmethod | |||
def clear_all_session(): | |||
all_sessions.clear() | |||
def clear_all_session(self): | |||
self.sessions.clear() |
@@ -1,6 +1,8 @@ | |||
# encoding:utf-8 | |||
from bot.bot import Bot | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf | |||
from common.log import logger | |||
import openai | |||
@@ -13,30 +15,31 @@ class OpenAIBot(Bot): | |||
def __init__(self): | |||
openai.api_key = conf().get('open_ai_api_key') | |||
def reply(self, query, context=None): | |||
# acquire reply content | |||
if not context or not context.get('type') or context.get('type') == 'TEXT': | |||
logger.info("[OPEN_AI] query={}".format(query)) | |||
from_user_id = context.get('from_user_id') or context.get('session_id') | |||
if query == '#清除记忆': | |||
Session.clear_session(from_user_id) | |||
return '记忆已清除' | |||
elif query == '#清除所有': | |||
Session.clear_all_session() | |||
return '所有人记忆已清除' | |||
new_query = Session.build_session_query(query, from_user_id) | |||
logger.debug("[OPEN_AI] session query={}".format(new_query)) | |||
reply_content = self.reply_text(new_query, from_user_id, 0) | |||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) | |||
if reply_content and query: | |||
Session.save_session(query, reply_content, from_user_id) | |||
return reply_content | |||
elif context.get('type', None) == 'IMAGE_CREATE': | |||
return self.create_img(query, 0) | |||
if context and context.type: | |||
if context.type == ContextType.TEXT: | |||
logger.info("[OPEN_AI] query={}".format(query)) | |||
from_user_id = context['session_id'] | |||
reply = None | |||
if query == '#清除记忆': | |||
Session.clear_session(from_user_id) | |||
reply = Reply(ReplyType.INFO, '记忆已清除') | |||
elif query == '#清除所有': | |||
Session.clear_all_session() | |||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | |||
else: | |||
new_query = Session.build_session_query(query, from_user_id) | |||
logger.debug("[OPEN_AI] session query={}".format(new_query)) | |||
reply_content = self.reply_text(new_query, from_user_id, 0) | |||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) | |||
if reply_content and query: | |||
Session.save_session(query, reply_content, from_user_id) | |||
reply = Reply(ReplyType.TEXT, reply_content) | |||
return reply | |||
elif context.type == ContextType.IMAGE_CREATE: | |||
return self.create_img(query, 0) | |||
def reply_text(self, query, user_id, retry_count=0): | |||
try: | |||
@@ -1,16 +1,42 @@ | |||
from bridge.context import Context | |||
from bridge.reply import Reply | |||
from common.log import logger | |||
from bot import bot_factory | |||
from common.singleton import singleton | |||
from voice import voice_factory | |||
@singleton | |||
class Bridge(object): | |||
def __init__(self): | |||
pass | |||
self.btype={ | |||
"chat": "chatGPT", | |||
"voice_to_text": "openai", | |||
"text_to_voice": "baidu" | |||
} | |||
self.bots={} | |||
def fetch_reply_content(self, query, context): | |||
return bot_factory.create_bot("chatGPT").reply(query, context) | |||
def get_bot(self,typename): | |||
if self.bots.get(typename) is None: | |||
logger.info("create bot {} for {}".format(self.btype[typename],typename)) | |||
if typename == "text_to_voice": | |||
self.bots[typename] = voice_factory.create_voice(self.btype[typename]) | |||
elif typename == "voice_to_text": | |||
self.bots[typename] = voice_factory.create_voice(self.btype[typename]) | |||
elif typename == "chat": | |||
self.bots[typename] = bot_factory.create_bot(self.btype[typename]) | |||
return self.bots[typename] | |||
def get_bot_type(self,typename): | |||
return self.btype[typename] | |||
def fetch_voice_to_text(self, voiceFile): | |||
return voice_factory.create_voice("openai").voiceToText(voiceFile) | |||
def fetch_text_to_voice(self, text): | |||
return voice_factory.create_voice("baidu").textToVoice(text) | |||
def fetch_reply_content(self, query, context : Context) -> Reply: | |||
return self.get_bot("chat").reply(query, context) | |||
def fetch_voice_to_text(self, voiceFile) -> Reply: | |||
return self.get_bot("voice_to_text").voiceToText(voiceFile) | |||
def fetch_text_to_voice(self, text) -> Reply: | |||
return self.get_bot("text_to_voice").textToVoice(text) | |||
@@ -0,0 +1,42 @@ | |||
# encoding:utf-8 | |||
from enum import Enum | |||
class ContextType (Enum): | |||
TEXT = 1 # 文本消息 | |||
VOICE = 2 # 音频消息 | |||
IMAGE_CREATE = 3 # 创建图片命令 | |||
def __str__(self): | |||
return self.name | |||
class Context: | |||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()): | |||
self.type = type | |||
self.content = content | |||
self.kwargs = kwargs | |||
def __getitem__(self, key): | |||
if key == 'type': | |||
return self.type | |||
elif key == 'content': | |||
return self.content | |||
else: | |||
return self.kwargs[key] | |||
def __setitem__(self, key, value): | |||
if key == 'type': | |||
self.type = value | |||
elif key == 'content': | |||
self.content = value | |||
else: | |||
self.kwargs[key] = value | |||
def __delitem__(self, key): | |||
if key == 'type': | |||
self.type = None | |||
elif key == 'content': | |||
self.content = None | |||
else: | |||
del self.kwargs[key] | |||
def __str__(self): | |||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) |
@@ -0,0 +1,22 @@ | |||
# encoding:utf-8 | |||
from enum import Enum | |||
class ReplyType(Enum): | |||
TEXT = 1 # 文本 | |||
VOICE = 2 # 音频文件 | |||
IMAGE = 3 # 图片文件 | |||
IMAGE_URL = 4 # 图片URL | |||
INFO = 9 | |||
ERROR = 10 | |||
def __str__(self): | |||
return self.name | |||
class Reply: | |||
def __init__(self, type : ReplyType = None , content = None): | |||
self.type = type | |||
self.content = content | |||
def __str__(self): | |||
return "Reply(type={}, content={})".format(self.type, self.content) |
@@ -3,6 +3,8 @@ Message sending channel abstract class | |||
""" | |||
from bridge.bridge import Bridge | |||
from bridge.context import Context | |||
from bridge.reply import Reply | |||
class Channel(object): | |||
def startup(self): | |||
@@ -27,11 +29,11 @@ class Channel(object): | |||
""" | |||
raise NotImplementedError | |||
def build_reply_content(self, query, context=None): | |||
def build_reply_content(self, query, context : Context=None) -> Reply: | |||
return Bridge().fetch_reply_content(query, context) | |||
def build_voice_to_text(self, voice_file): | |||
def build_voice_to_text(self, voice_file) -> Reply: | |||
return Bridge().fetch_voice_to_text(voice_file) | |||
def build_text_to_voice(self, text): | |||
def build_text_to_voice(self, text) -> Reply: | |||
return Bridge().fetch_text_to_voice(text) |
@@ -7,16 +7,24 @@ wechat channel | |||
import itchat | |||
import json | |||
from itchat.content import * | |||
from bridge.reply import * | |||
from bridge.context import * | |||
from channel.channel import Channel | |||
from concurrent.futures import ThreadPoolExecutor | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
from config import conf | |||
from plugins import * | |||
import requests | |||
import io | |||
thread_pool = ThreadPoolExecutor(max_workers=8) | |||
thread_pool = ThreadPoolExecutor(max_workers=8) | |||
def thread_pool_callback(worker): | |||
worker_exception = worker.exception() | |||
if worker_exception: | |||
logger.exception("Worker return exception: {}".format(worker_exception)) | |||
@itchat.msg_register(TEXT) | |||
def handler_single_msg(msg): | |||
@@ -47,62 +55,52 @@ class WechatChannel(Channel): | |||
# start message listener | |||
itchat.run() | |||
# handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context | |||
# context是一个字典,包含了消息的所有信息,包括以下key | |||
# type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE | |||
# content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 | |||
# session_id: 会话id | |||
# isgroup: 是否是群聊 | |||
# msg: 原始消息对象 | |||
# receiver: 需要回复的对象 | |||
def handle_voice(self, msg): | |||
if conf().get('speech_recognition') != True : | |||
if conf().get('speech_recognition') != True: | |||
return | |||
logger.debug("[WX]receive voice msg: " + msg['FileName']) | |||
thread_pool.submit(self._do_handle_voice, msg) | |||
def _do_handle_voice(self, msg): | |||
from_user_id = msg['FromUserName'] | |||
other_user_id = msg['User']['UserName'] | |||
if from_user_id == other_user_id: | |||
file_name = TmpDir().path() + msg['FileName'] | |||
msg.download(file_name) | |||
query = super().build_voice_to_text(file_name) | |||
if conf().get('voice_reply_voice'): | |||
self._do_send_voice(query, from_user_id) | |||
else: | |||
self._do_send_text(query, from_user_id) | |||
context = Context(ContextType.VOICE,msg['FileName']) | |||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | |||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | |||
def handle_text(self, msg): | |||
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) | |||
content = msg['Text'] | |||
self._handle_single_msg(msg, content) | |||
def _handle_single_msg(self, msg, content): | |||
from_user_id = msg['FromUserName'] | |||
to_user_id = msg['ToUserName'] # 接收人id | |||
other_user_id = msg['User']['UserName'] # 对手方id | |||
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix')) | |||
match_prefix = check_prefix(content, conf().get('single_chat_prefix')) | |||
if "」\n- - - - - - - - - - - - - - -" in content: | |||
logger.debug("[WX]reference query skipped") | |||
return | |||
if from_user_id == other_user_id and match_prefix is not None: | |||
# 好友向自己发送消息 | |||
if match_prefix != '': | |||
str_list = content.split(match_prefix, 1) | |||
if len(str_list) == 2: | |||
content = str_list[1].strip() | |||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) | |||
if img_match_prefix: | |||
content = content.split(img_match_prefix, 1)[1].strip() | |||
thread_pool.submit(self._do_send_img, content, from_user_id) | |||
else : | |||
thread_pool.submit(self._do_send_text, content, from_user_id) | |||
elif to_user_id == other_user_id and match_prefix: | |||
# 自己给好友发送消息 | |||
str_list = content.split(match_prefix, 1) | |||
if len(str_list) == 2: | |||
content = str_list[1].strip() | |||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) | |||
if img_match_prefix: | |||
content = content.split(img_match_prefix, 1)[1].strip() | |||
thread_pool.submit(self._do_send_img, content, to_user_id) | |||
else: | |||
thread_pool.submit(self._do_send_text, content, to_user_id) | |||
if match_prefix: | |||
content = content.replace(match_prefix, '', 1).strip() | |||
else: | |||
return | |||
context = Context() | |||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | |||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | |||
if img_match_prefix: | |||
content = content.replace(img_match_prefix, '', 1).strip() | |||
context.type = ContextType.IMAGE_CREATE | |||
else: | |||
context.type = ContextType.TEXT | |||
context.content = content | |||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | |||
def handle_group(self, msg): | |||
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) | |||
@@ -122,100 +120,128 @@ class WechatChannel(Channel): | |||
logger.debug("[WX]reference query skipped") | |||
return "" | |||
config = conf() | |||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \ | |||
or self.check_contain(origin_content, config.get('group_chat_keyword')) | |||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: | |||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix')) | |||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \ | |||
or check_contain(origin_content, config.get('group_chat_keyword')) | |||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix: | |||
context = Context() | |||
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id} | |||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | |||
if img_match_prefix: | |||
content = content.split(img_match_prefix, 1)[1].strip() | |||
thread_pool.submit(self._do_send_img, content, group_id) | |||
content = content.replace(img_match_prefix, '', 1).strip() | |||
context.type = ContextType.IMAGE_CREATE | |||
else: | |||
thread_pool.submit(self._do_send_group, content, msg) | |||
def send(self, msg, receiver): | |||
itchat.send(msg, toUserName=receiver) | |||
logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver)) | |||
def _do_send_voice(self, query, reply_user_id): | |||
try: | |||
if not query: | |||
return | |||
context = dict() | |||
context['from_user_id'] = reply_user_id | |||
reply_text = super().build_reply_content(query, context) | |||
if reply_text: | |||
replyFile = super().build_text_to_voice(reply_text) | |||
itchat.send_file(replyFile, toUserName=reply_user_id) | |||
logger.info('[WX] sendFile={}, receiver={}'.format(replyFile, reply_user_id)) | |||
except Exception as e: | |||
logger.exception(e) | |||
def _do_send_text(self, query, reply_user_id): | |||
try: | |||
if not query: | |||
return | |||
context = dict() | |||
context['session_id'] = reply_user_id | |||
reply_text = super().build_reply_content(query, context) | |||
if reply_text: | |||
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) | |||
except Exception as e: | |||
logger.exception(e) | |||
def _do_send_img(self, query, reply_user_id): | |||
try: | |||
if not query: | |||
return | |||
context = dict() | |||
context['type'] = 'IMAGE_CREATE' | |||
img_url = super().build_reply_content(query, context) | |||
if not img_url: | |||
return | |||
# 图片下载 | |||
context.type = ContextType.TEXT | |||
context.content = content | |||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | |||
if ('ALL_GROUP' in group_chat_in_one_session or | |||
group_name in group_chat_in_one_session or | |||
check_contain(group_name, group_chat_in_one_session)): | |||
context['session_id'] = group_id | |||
else: | |||
context['session_id'] = msg['ActualUserName'] | |||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback) | |||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | |||
def send(self, reply : Reply, receiver): | |||
if reply.type == ReplyType.TEXT: | |||
itchat.send(reply.content, toUserName=receiver) | |||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | |||
itchat.send(reply.content, toUserName=receiver) | |||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||
elif reply.type == ReplyType.VOICE: | |||
itchat.send_file(reply.content, toUserName=receiver) | |||
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) | |||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||
img_url = reply.content | |||
pic_res = requests.get(img_url, stream=True) | |||
image_storage = io.BytesIO() | |||
for block in pic_res.iter_content(1024): | |||
image_storage.write(block) | |||
image_storage.seek(0) | |||
itchat.send_image(image_storage, toUserName=receiver) | |||
logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver)) | |||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||
image_storage = reply.content | |||
image_storage.seek(0) | |||
itchat.send_image(image_storage, toUserName=receiver) | |||
logger.info('[WX] sendImage, receiver={}'.format(receiver)) | |||
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类 | |||
def handle(self, context): | |||
reply = Reply() | |||
logger.debug('[WX] ready to handle context: {}'.format(context)) | |||
# reply的构建步骤 | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply})) | |||
reply = e_context['reply'] | |||
if not e_context.is_pass(): | |||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) | |||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: | |||
reply = super().build_reply_content(context.content, context) | |||
elif context.type == ContextType.VOICE: | |||
msg = context['msg'] | |||
file_name = TmpDir().path() + context.content | |||
msg.download(file_name) | |||
reply = super().build_voice_to_text(file_name) | |||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: | |||
context.content = reply.content # 语音转文字后,将文字内容作为新的context | |||
context.type = ContextType.TEXT | |||
reply = super().build_reply_content(context.content, context) | |||
if reply.type == ReplyType.TEXT: | |||
if conf().get('voice_reply_voice'): | |||
reply = super().build_text_to_voice(reply.content) | |||
else: | |||
logger.error('[WX] unknown context type: {}'.format(context.type)) | |||
return | |||
# 图片发送 | |||
itchat.send_image(image_storage, reply_user_id) | |||
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id)) | |||
except Exception as e: | |||
logger.exception(e) | |||
def _do_send_group(self, query, msg): | |||
if not query: | |||
return | |||
context = dict() | |||
group_name = msg['User']['NickName'] | |||
group_id = msg['User']['UserName'] | |||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | |||
if ('ALL_GROUP' in group_chat_in_one_session or \ | |||
group_name in group_chat_in_one_session or \ | |||
self.check_contain(group_name, group_chat_in_one_session)): | |||
context['session_id'] = group_id | |||
else: | |||
context['session_id'] = msg['ActualUserName'] | |||
reply_text = super().build_reply_content(query, context) | |||
if reply_text: | |||
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip() | |||
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) | |||
def check_prefix(self, content, prefix_list): | |||
for prefix in prefix_list: | |||
if content.startswith(prefix): | |||
return prefix | |||
return None | |||
logger.debug('[WX] ready to decorate reply: {}'.format(reply)) | |||
# reply的包装步骤 | |||
if reply and reply.type: | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | |||
reply=e_context['reply'] | |||
if not e_context.is_pass() and reply and reply.type: | |||
if reply.type == ReplyType.TEXT: | |||
reply_text = reply.content | |||
if context['isgroup']: | |||
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip() | |||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text | |||
else: | |||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text | |||
reply.content = reply_text | |||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | |||
reply.content = str(reply.type)+":\n" + reply.content | |||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: | |||
pass | |||
else: | |||
logger.error('[WX] unknown reply type: {}'.format(reply.type)) | |||
return | |||
# reply的发送步骤 | |||
if reply and reply.type: | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply})) | |||
reply=e_context['reply'] | |||
if not e_context.is_pass() and reply and reply.type: | |||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver'])) | |||
self.send(reply, context['receiver']) | |||
def check_prefix(content, prefix_list): | |||
for prefix in prefix_list: | |||
if content.startswith(prefix): | |||
return prefix | |||
return None | |||
def check_contain(self, content, keyword_list): | |||
if not keyword_list: | |||
return None | |||
for ky in keyword_list: | |||
if content.find(ky) != -1: | |||
return True | |||
def check_contain(content, keyword_list): | |||
if not keyword_list: | |||
return None | |||
for ky in keyword_list: | |||
if content.find(ky) != -1: | |||
return True | |||
return None |
@@ -11,6 +11,7 @@ import time | |||
import asyncio | |||
import requests | |||
from typing import Optional, Union | |||
from bridge.context import Context, ContextType | |||
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore | |||
from wechaty import Wechaty, Contact | |||
from wechaty.user import Message, Room, MiniProgram, UrlLink | |||
@@ -127,9 +128,9 @@ class WechatyChannel(Channel): | |||
try: | |||
if not query: | |||
return | |||
context = dict() | |||
context = Context(ContextType.TEXT, query) | |||
context['session_id'] = reply_user_id | |||
reply_text = super().build_reply_content(query, context) | |||
reply_text = super().build_reply_content(query, context).content | |||
if reply_text: | |||
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) | |||
except Exception as e: | |||
@@ -139,9 +140,8 @@ class WechatyChannel(Channel): | |||
try: | |||
if not query: | |||
return | |||
context = dict() | |||
context['type'] = 'IMAGE_CREATE' | |||
img_url = super().build_reply_content(query, context) | |||
context = Context(ContextType.IMAGE_CREATE, query) | |||
img_url = super().build_reply_content(query, context).content | |||
if not img_url: | |||
return | |||
# 图片下载 | |||
@@ -162,7 +162,7 @@ class WechatyChannel(Channel): | |||
async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name): | |||
if not query: | |||
return | |||
context = dict() | |||
context = Context(ContextType.TEXT, query) | |||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | |||
if ('ALL_GROUP' in group_chat_in_one_session or \ | |||
group_name in group_chat_in_one_session or \ | |||
@@ -170,7 +170,7 @@ class WechatyChannel(Channel): | |||
context['session_id'] = str(group_id) | |||
else: | |||
context['session_id'] = str(group_id) + '-' + str(group_user_id) | |||
reply_text = super().build_reply_content(query, context) | |||
reply_text = super().build_reply_content(query, context).content | |||
if reply_text: | |||
reply_text = '@' + group_user_name + ' ' + reply_text.strip() | |||
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) | |||
@@ -179,9 +179,8 @@ class WechatyChannel(Channel): | |||
try: | |||
if not query: | |||
return | |||
context = dict() | |||
context['type'] = 'IMAGE_CREATE' | |||
img_url = super().build_reply_content(query, context) | |||
context = Context(ContextType.IMAGE_CREATE, query) | |||
img_url = super().build_reply_content(query, context).content | |||
if not img_url: | |||
return | |||
# 图片发送 | |||
@@ -0,0 +1,9 @@ | |||
def singleton(cls): | |||
instances = {} | |||
def get_instance(*args, **kwargs): | |||
if cls not in instances: | |||
instances[cls] = cls(*args, **kwargs) | |||
return instances[cls] | |||
return get_instance |
@@ -0,0 +1,65 @@ | |||
import heapq | |||
class SortedDict(dict): | |||
def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False): | |||
if init_dict is None: | |||
init_dict = [] | |||
if isinstance(init_dict, dict): | |||
init_dict = init_dict.items() | |||
self.sort_func = sort_func | |||
self.sorted_keys = None | |||
self.reverse = reverse | |||
self.heap = [] | |||
for k, v in init_dict: | |||
self[k] = v | |||
def __setitem__(self, key, value): | |||
if key in self: | |||
super().__setitem__(key, value) | |||
for i, (priority, k) in enumerate(self.heap): | |||
if k == key: | |||
self.heap[i] = (self.sort_func(key, value), key) | |||
heapq.heapify(self.heap) | |||
break | |||
self.sorted_keys = None | |||
else: | |||
super().__setitem__(key, value) | |||
heapq.heappush(self.heap, (self.sort_func(key, value), key)) | |||
self.sorted_keys = None | |||
def __delitem__(self, key): | |||
super().__delitem__(key) | |||
for i, (priority, k) in enumerate(self.heap): | |||
if k == key: | |||
del self.heap[i] | |||
heapq.heapify(self.heap) | |||
break | |||
self.sorted_keys = None | |||
def keys(self): | |||
if self.sorted_keys is None: | |||
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] | |||
return self.sorted_keys | |||
def items(self): | |||
if self.sorted_keys is None: | |||
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] | |||
sorted_items = [(k, self[k]) for k in self.sorted_keys] | |||
return sorted_items | |||
def _update_heap(self, key): | |||
for i, (priority, k) in enumerate(self.heap): | |||
if k == key: | |||
new_priority = self.sort_func(key, self[key]) | |||
if new_priority != priority: | |||
self.heap[i] = (new_priority, key) | |||
heapq.heapify(self.heap) | |||
self.sorted_keys = None | |||
break | |||
def __iter__(self): | |||
return iter(self.keys()) | |||
def __repr__(self): | |||
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})' |
@@ -0,0 +1,9 @@ | |||
from .plugin_manager import PluginManager | |||
from .event import * | |||
from .plugin import * | |||
instance = PluginManager() | |||
register = instance.register | |||
# load_plugins = instance.load_plugins | |||
# emit_event = instance.emit_event |
@@ -0,0 +1 @@ | |||
banwords.txt |
@@ -0,0 +1,9 @@ | |||
### 说明 | |||
简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。 | |||
`config.json`中能够填写默认的处理行为,目前行为有: | |||
- `ignore` : 无视这条消息。 | |||
- `replace` : 将消息中的敏感词替换成"*",并回复违规。 | |||
### 致谢 | |||
搜索功能实现来自https://github.com/toolgood/ToolGood.Words |
@@ -0,0 +1,250 @@ | |||
#!/usr/bin/env python | |||
# -*- coding:utf-8 -*- | |||
# ToolGood.Words.WordsSearch.py | |||
# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words | |||
# Licensed under the Apache License 2.0 | |||
# 更新日志 | |||
# 2020.04.06 第一次提交 | |||
# 2020.05.16 修改,支持大于0xffff的字符 | |||
__all__ = ['WordsSearch'] | |||
__author__ = 'Lin Zhijun' | |||
__date__ = '2020.05.16' | |||
class TrieNode(): | |||
def __init__(self): | |||
self.Index = 0 | |||
self.Index = 0 | |||
self.Layer = 0 | |||
self.End = False | |||
self.Char = '' | |||
self.Results = [] | |||
self.m_values = {} | |||
self.Failure = None | |||
self.Parent = None | |||
def Add(self,c): | |||
if c in self.m_values : | |||
return self.m_values[c] | |||
node = TrieNode() | |||
node.Parent = self | |||
node.Char = c | |||
self.m_values[c] = node | |||
return node | |||
def SetResults(self,index): | |||
if (self.End == False): | |||
self.End = True | |||
self.Results.append(index) | |||
class TrieNode2(): | |||
def __init__(self): | |||
self.End = False | |||
self.Results = [] | |||
self.m_values = {} | |||
self.minflag = 0xffff | |||
self.maxflag = 0 | |||
def Add(self,c,node3): | |||
if (self.minflag > c): | |||
self.minflag = c | |||
if (self.maxflag < c): | |||
self.maxflag = c | |||
self.m_values[c] = node3 | |||
def SetResults(self,index): | |||
if (self.End == False) : | |||
self.End = True | |||
if (index in self.Results )==False : | |||
self.Results.append(index) | |||
def HasKey(self,c): | |||
return c in self.m_values | |||
def TryGetValue(self,c): | |||
if (self.minflag <= c and self.maxflag >= c): | |||
if c in self.m_values: | |||
return self.m_values[c] | |||
return None | |||
class WordsSearch(): | |||
def __init__(self): | |||
self._first = {} | |||
self._keywords = [] | |||
self._indexs=[] | |||
def SetKeywords(self,keywords): | |||
self._keywords = keywords | |||
self._indexs=[] | |||
for i in range(len(keywords)): | |||
self._indexs.append(i) | |||
root = TrieNode() | |||
allNodeLayer={} | |||
for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++) | |||
p = self._keywords[i] | |||
nd = root | |||
for j in range(len(p)): # for (j = 0; j < p.length; j++) | |||
nd = nd.Add(ord(p[j])) | |||
if (nd.Layer == 0): | |||
nd.Layer = j + 1 | |||
if nd.Layer in allNodeLayer: | |||
allNodeLayer[nd.Layer].append(nd) | |||
else: | |||
allNodeLayer[nd.Layer]=[] | |||
allNodeLayer[nd.Layer].append(nd) | |||
nd.SetResults(i) | |||
allNode = [] | |||
allNode.append(root) | |||
for key in allNodeLayer.keys(): | |||
for nd in allNodeLayer[key]: | |||
allNode.append(nd) | |||
allNodeLayer=None | |||
for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) | |||
if i==0 : | |||
continue | |||
nd=allNode[i] | |||
nd.Index = i | |||
r = nd.Parent.Failure | |||
c = nd.Char | |||
while (r != None and (c in r.m_values)==False): | |||
r = r.Failure | |||
if (r == None): | |||
nd.Failure = root | |||
else: | |||
nd.Failure = r.m_values[c] | |||
for key2 in nd.Failure.Results : | |||
nd.SetResults(key2) | |||
root.Failure = root | |||
allNode2 = [] | |||
for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) | |||
allNode2.append( TrieNode2()) | |||
for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++) | |||
oldNode = allNode[i] | |||
newNode = allNode2[i] | |||
for key in oldNode.m_values : | |||
index = oldNode.m_values[key].Index | |||
newNode.Add(key, allNode2[index]) | |||
for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++) | |||
item = oldNode.Results[index] | |||
newNode.SetResults(item) | |||
oldNode=oldNode.Failure | |||
while oldNode != root: | |||
for key in oldNode.m_values : | |||
if (newNode.HasKey(key) == False): | |||
index = oldNode.m_values[key].Index | |||
newNode.Add(key, allNode2[index]) | |||
for index in range(len(oldNode.Results)): | |||
item = oldNode.Results[index] | |||
newNode.SetResults(item) | |||
oldNode=oldNode.Failure | |||
allNode = None | |||
root = None | |||
# first = [] | |||
# for index in range(65535):# for (index = 0; index < 0xffff; index++) | |||
# first.append(None) | |||
# for key in allNode2[0].m_values : | |||
# first[key] = allNode2[0].m_values[key] | |||
self._first = allNode2[0] | |||
def FindFirst(self,text): | |||
ptr = None | |||
for index in range(len(text)): # for (index = 0; index < text.length; index++) | |||
t =ord(text[index]) # text.charCodeAt(index) | |||
tn = None | |||
if (ptr == None): | |||
tn = self._first.TryGetValue(t) | |||
else: | |||
tn = ptr.TryGetValue(t) | |||
if (tn==None): | |||
tn = self._first.TryGetValue(t) | |||
if (tn != None): | |||
if (tn.End): | |||
item = tn.Results[0] | |||
keyword = self._keywords[item] | |||
return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] } | |||
ptr = tn | |||
return None | |||
def FindAll(self,text): | |||
ptr = None | |||
list = [] | |||
for index in range(len(text)): # for (index = 0; index < text.length; index++) | |||
t =ord(text[index]) # text.charCodeAt(index) | |||
tn = None | |||
if (ptr == None): | |||
tn = self._first.TryGetValue(t) | |||
else: | |||
tn = ptr.TryGetValue(t) | |||
if (tn==None): | |||
tn = self._first.TryGetValue(t) | |||
if (tn != None): | |||
if (tn.End): | |||
for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++) | |||
item = tn.Results[j] | |||
keyword = self._keywords[item] | |||
list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }) | |||
ptr = tn | |||
return list | |||
def ContainsAny(self,text): | |||
ptr = None | |||
for index in range(len(text)): # for (index = 0; index < text.length; index++) | |||
t =ord(text[index]) # text.charCodeAt(index) | |||
tn = None | |||
if (ptr == None): | |||
tn = self._first.TryGetValue(t) | |||
else: | |||
tn = ptr.TryGetValue(t) | |||
if (tn==None): | |||
tn = self._first.TryGetValue(t) | |||
if (tn != None): | |||
if (tn.End): | |||
return True | |||
ptr = tn | |||
return False | |||
def Replace(self,text, replaceChar = '*'): | |||
result = list(text) | |||
ptr = None | |||
for i in range(len(text)): # for (i = 0; i < text.length; i++) | |||
t =ord(text[i]) # text.charCodeAt(index) | |||
tn = None | |||
if (ptr == None): | |||
tn = self._first.TryGetValue(t) | |||
else: | |||
tn = ptr.TryGetValue(t) | |||
if (tn==None): | |||
tn = self._first.TryGetValue(t) | |||
if (tn != None): | |||
if (tn.End): | |||
maxLength = len( self._keywords[tn.Results[0]]) | |||
start = i + 1 - maxLength | |||
for j in range(start,i+1): # for (j = start; j <= i; j++) | |||
result[j] = replaceChar | |||
ptr = tn | |||
return ''.join(result) |
@@ -0,0 +1,63 @@ | |||
# encoding:utf-8 | |||
import json | |||
import os | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
import plugins | |||
from plugins import * | |||
from common.log import logger | |||
from .WordsSearch import WordsSearch | |||
@plugins.register(name="Banwords", desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent", desire_priority= 100) | |||
class Banwords(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
try: | |||
curdir=os.path.dirname(__file__) | |||
config_path=os.path.join(curdir,"config.json") | |||
conf=None | |||
if not os.path.exists(config_path): | |||
conf={"action":"ignore"} | |||
with open(config_path,"w") as f: | |||
json.dump(conf,f,indent=4) | |||
else: | |||
with open(config_path,"r") as f: | |||
conf=json.load(f) | |||
self.searchr = WordsSearch() | |||
self.action = conf["action"] | |||
banwords_path = os.path.join(curdir,"banwords.txt") | |||
with open(banwords_path, 'r', encoding='utf-8') as f: | |||
words=[] | |||
for line in f: | |||
word = line.strip() | |||
if word: | |||
words.append(word) | |||
self.searchr.SetKeywords(words) | |||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||
logger.info("[Banwords] inited") | |||
except Exception as e: | |||
logger.error("Banwords init failed: %s" % e) | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]: | |||
return | |||
content = e_context['context'].content | |||
logger.debug("[Banwords] on_handle_context. content: %s" % content) | |||
if self.action == "ignore": | |||
f = self.searchr.FindFirst(content) | |||
if f: | |||
logger.info("Banwords: %s" % f["Keyword"]) | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
elif self.action == "replace": | |||
if self.searchr.ContainsAny(content): | |||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content)) | |||
e_context['reply'] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
return |
@@ -0,0 +1,3 @@ | |||
nipples | |||
pennis | |||
法轮功 |
@@ -0,0 +1,3 @@ | |||
{ | |||
"action": "ignore" | |||
} |
@@ -0,0 +1,49 @@ | |||
# encoding:utf-8 | |||
from enum import Enum | |||
class Event(Enum): | |||
# ON_RECEIVE_MESSAGE = 1 # 收到消息 | |||
ON_HANDLE_CONTEXT = 2 # 处理消息前 | |||
""" | |||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } | |||
""" | |||
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 | |||
""" | |||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | |||
""" | |||
ON_SEND_REPLY = 4 # 发送回复前 | |||
""" | |||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | |||
""" | |||
# AFTER_SEND_REPLY = 5 # 发送回复后 | |||
class EventAction(Enum): | |||
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 | |||
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 | |||
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 | |||
class EventContext: | |||
def __init__(self, event, econtext=dict()): | |||
self.event = event | |||
self.econtext = econtext | |||
self.action = EventAction.CONTINUE | |||
def __getitem__(self, key): | |||
return self.econtext[key] | |||
def __setitem__(self, key, value): | |||
self.econtext[key] = value | |||
def __delitem__(self, key): | |||
del self.econtext[key] | |||
def is_pass(self): | |||
return self.action == EventAction.BREAK_PASS |
@@ -0,0 +1,4 @@ | |||
{ | |||
"password": "", | |||
"admin_users": [] | |||
} |
@@ -0,0 +1,289 @@ | |||
# encoding:utf-8 | |||
import json | |||
import os | |||
import traceback | |||
from typing import Tuple | |||
from bridge.bridge import Bridge | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from config import load_config | |||
import plugins | |||
from plugins import * | |||
from common.log import logger | |||
# 定义指令集 | |||
COMMANDS = { | |||
"help": { | |||
"alias": ["help", "帮助"], | |||
"desc": "打印指令集合", | |||
}, | |||
"auth": { | |||
"alias": ["auth", "认证"], | |||
"args": ["口令"], | |||
"desc": "管理员认证", | |||
}, | |||
# "id": { | |||
# "alias": ["id", "用户"], | |||
# "desc": "获取用户id", #目前无实际意义 | |||
# }, | |||
"reset": { | |||
"alias": ["reset", "重置会话"], | |||
"desc": "重置会话", | |||
}, | |||
} | |||
ADMIN_COMMANDS = { | |||
"resume": { | |||
"alias": ["resume", "恢复服务"], | |||
"desc": "恢复服务", | |||
}, | |||
"stop": { | |||
"alias": ["stop", "暂停服务"], | |||
"desc": "暂停服务", | |||
}, | |||
"reconf": { | |||
"alias": ["reconf", "重载配置"], | |||
"desc": "重载配置(不包含插件配置)", | |||
}, | |||
"resetall": { | |||
"alias": ["resetall", "重置所有会话"], | |||
"desc": "重置所有会话", | |||
}, | |||
"scanp": { | |||
"alias": ["scanp", "扫描插件"], | |||
"desc": "扫描插件目录是否有新插件", | |||
}, | |||
"plist": { | |||
"alias": ["plist", "插件"], | |||
"desc": "打印当前插件列表", | |||
}, | |||
"setpri": { | |||
"alias": ["setpri", "设置插件优先级"], | |||
"args": ["插件名", "优先级"], | |||
"desc": "设置指定插件的优先级,越大越优先", | |||
}, | |||
"reloadp": { | |||
"alias": ["reloadp", "重载插件"], | |||
"args": ["插件名"], | |||
"desc": "重载指定插件配置", | |||
}, | |||
"enablep": { | |||
"alias": ["enablep", "启用插件"], | |||
"args": ["插件名"], | |||
"desc": "启用指定插件", | |||
}, | |||
"disablep": { | |||
"alias": ["disablep", "禁用插件"], | |||
"args": ["插件名"], | |||
"desc": "禁用指定插件", | |||
}, | |||
"debug": { | |||
"alias": ["debug", "调试模式", "DEBUG"], | |||
"desc": "开启机器调试日志", | |||
}, | |||
} | |||
# 定义帮助函数 | |||
def get_help_text(isadmin, isgroup): | |||
help_text = "可用指令:\n" | |||
for cmd, info in COMMANDS.items(): | |||
if cmd=="auth" and (isadmin or isgroup): # 群聊不可认证 | |||
continue | |||
alias=["#"+a for a in info['alias']] | |||
help_text += f"{','.join(alias)} " | |||
if 'args' in info: | |||
args=["{"+a+"}" for a in info['args']] | |||
help_text += f"{' '.join(args)} " | |||
help_text += f": {info['desc']}\n" | |||
if ADMIN_COMMANDS and isadmin: | |||
help_text += "\n管理员指令:\n" | |||
for cmd, info in ADMIN_COMMANDS.items(): | |||
alias=["#"+a for a in info['alias']] | |||
help_text += f"{','.join(alias)} " | |||
help_text += f": {info['desc']}\n" | |||
return help_text | |||
@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent", desire_priority= 999) | |||
class Godcmd(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
curdir=os.path.dirname(__file__) | |||
config_path=os.path.join(curdir,"config.json") | |||
gconf=None | |||
if not os.path.exists(config_path): | |||
gconf={"password":"","admin_users":[]} | |||
with open(config_path,"w") as f: | |||
json.dump(gconf,f,indent=4) | |||
else: | |||
with open(config_path,"r") as f: | |||
gconf=json.load(f) | |||
self.password = gconf["password"] | |||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用 | |||
self.isrunning = True # 机器人是否运行中 | |||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||
logger.info("[Godcmd] inited") | |||
def on_handle_context(self, e_context: EventContext): | |||
context_type = e_context['context'].type | |||
if context_type != ContextType.TEXT: | |||
if not self.isrunning: | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
content = e_context['context'].content | |||
logger.debug("[Godcmd] on_handle_context. content: %s" % content) | |||
if content.startswith("#"): | |||
# msg = e_context['context']['msg'] | |||
user = e_context['context']['receiver'] | |||
session_id = e_context['context']['session_id'] | |||
isgroup = e_context['context']['isgroup'] | |||
bottype = Bridge().get_bot_type("chat") | |||
bot = Bridge().get_bot("chat") | |||
# 将命令和参数分割 | |||
command_parts = content[1:].split(" ") | |||
cmd = command_parts[0] | |||
args = command_parts[1:] | |||
isadmin=False | |||
if user in self.admin_users: | |||
isadmin=True | |||
ok=False | |||
result="string" | |||
if any(cmd in info['alias'] for info in COMMANDS.values()): | |||
cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias']) | |||
if cmd == "auth": | |||
ok, result = self.authenticate(user, args, isadmin, isgroup) | |||
elif cmd == "help": | |||
ok, result = True, get_help_text(isadmin, isgroup) | |||
elif cmd == "id": | |||
ok, result = True, f"用户id=\n{user}" | |||
elif cmd == "reset": | |||
if bottype == "chatGPT": | |||
bot.sessions.clear_session(session_id) | |||
ok, result = True, "会话已重置" | |||
else: | |||
ok, result = False, "当前对话机器人不支持重置会话" | |||
logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) | |||
elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): | |||
if isadmin: | |||
if isgroup: | |||
ok, result = False, "群聊不可执行管理员指令" | |||
else: | |||
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias']) | |||
if cmd == "stop": | |||
self.isrunning = False | |||
ok, result = True, "服务已暂停" | |||
elif cmd == "resume": | |||
self.isrunning = True | |||
ok, result = True, "服务已恢复" | |||
elif cmd == "reconf": | |||
load_config() | |||
ok, result = True, "配置已重载" | |||
elif cmd == "resetall": | |||
if bottype == "chatGPT": | |||
bot.sessions.clear_all_session() | |||
ok, result = True, "重置所有会话成功" | |||
else: | |||
ok, result = False, "当前对话机器人不支持重置会话" | |||
elif cmd == "debug": | |||
logger.setLevel('DEBUG') | |||
ok, result = True, "DEBUG模式已开启" | |||
elif cmd == "plist": | |||
plugins = PluginManager().list_plugins() | |||
ok = True | |||
result = "插件列表:\n" | |||
for name,plugincls in plugins.items(): | |||
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " | |||
if plugincls.enabled: | |||
result += "已启用\n" | |||
else: | |||
result += "未启用\n" | |||
elif cmd == "scanp": | |||
new_plugins = PluginManager().scan_plugins() | |||
ok, result = True, "插件扫描完成" | |||
PluginManager().activate_plugins() | |||
if len(new_plugins) >0 : | |||
result += "\n发现新插件:\n" | |||
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) | |||
else : | |||
result +=", 未发现新插件" | |||
elif cmd == "setpri": | |||
if len(args) != 2: | |||
ok, result = False, "请提供插件名和优先级" | |||
else: | |||
ok = PluginManager().set_plugin_priority(args[0], int(args[1])) | |||
if ok: | |||
result = "插件" + args[0] + "优先级已设置为" + args[1] | |||
else: | |||
result = "插件不存在" | |||
elif cmd == "reloadp": | |||
if len(args) != 1: | |||
ok, result = False, "请提供插件名" | |||
else: | |||
ok = PluginManager().reload_plugin(args[0]) | |||
if ok: | |||
result = "插件配置已重载" | |||
else: | |||
result = "插件不存在" | |||
elif cmd == "enablep": | |||
if len(args) != 1: | |||
ok, result = False, "请提供插件名" | |||
else: | |||
ok = PluginManager().enable_plugin(args[0]) | |||
if ok: | |||
result = "插件已启用" | |||
else: | |||
result = "插件不存在" | |||
elif cmd == "disablep": | |||
if len(args) != 1: | |||
ok, result = False, "请提供插件名" | |||
else: | |||
ok = PluginManager().disable_plugin(args[0]) | |||
if ok: | |||
result = "插件已禁用" | |||
else: | |||
result = "插件不存在" | |||
logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user)) | |||
else: | |||
ok, result = False, "需要管理员权限才能执行该指令" | |||
else: | |||
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" | |||
reply = Reply() | |||
if ok: | |||
reply.type = ReplyType.INFO | |||
else: | |||
reply.type = ReplyType.ERROR | |||
reply.content = result | |||
e_context['reply'] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
elif not self.isrunning: | |||
e_context.action = EventAction.BREAK_PASS | |||
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : | |||
if isgroup: | |||
return False,"请勿在群聊中认证" | |||
if isadmin: | |||
return False,"管理员账号无需认证" | |||
if len(self.password) == 0: | |||
return False,"未设置口令,无法认证" | |||
if len(args) != 1: | |||
return False,"请提供口令" | |||
password = args[0] | |||
if password == self.password: | |||
self.admin_users.append(userid) | |||
return True,"认证成功" | |||
else: | |||
return False,"认证失败" | |||
@@ -0,0 +1,46 @@ | |||
# encoding:utf-8 | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
import plugins | |||
from plugins import * | |||
from common.log import logger | |||
@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1) | |||
class Hello(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||
logger.info("[Hello] inited") | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type != ContextType.TEXT: | |||
return | |||
content = e_context['context'].content | |||
logger.debug("[Hello] on_handle_context. content: %s" % content) | |||
if content == "Hello": | |||
reply = Reply() | |||
reply.type = ReplyType.TEXT | |||
msg = e_context['context']['msg'] | |||
if e_context['context']['isgroup']: | |||
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group") | |||
else: | |||
reply.content = "Hello, " + msg['User'].get('NickName', "My friend") | |||
e_context['reply'] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
if content == "Hi": | |||
reply = Reply() | |||
reply.type = ReplyType.TEXT | |||
reply.content = "Hi" | |||
e_context['reply'] = reply | |||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply | |||
if content == "End": | |||
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" | |||
e_context['context'].type = "IMAGE_CREATE" | |||
content = "The World" | |||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 |
@@ -0,0 +1,3 @@ | |||
class Plugin: | |||
def __init__(self): | |||
self.handlers = {} |
@@ -0,0 +1,171 @@ | |||
# encoding:utf-8 | |||
import importlib | |||
import json | |||
import os | |||
from common.singleton import singleton | |||
from common.sorted_dict import SortedDict | |||
from .event import * | |||
from .plugin import * | |||
from common.log import logger | |||
@singleton | |||
class PluginManager: | |||
def __init__(self): | |||
self.plugins = SortedDict(lambda k,v: v.priority,reverse=True) | |||
self.listening_plugins = {} | |||
self.instances = {} | |||
self.pconf = {} | |||
def register(self, name: str, desc: str, version: str, author: str, desire_priority: int = 0): | |||
def wrapper(plugincls): | |||
plugincls.name = name | |||
plugincls.desc = desc | |||
plugincls.version = version | |||
plugincls.author = author | |||
plugincls.priority = desire_priority | |||
plugincls.enabled = True | |||
self.plugins[name.upper()] = plugincls | |||
logger.info("Plugin %s_v%s registered" % (name, version)) | |||
return plugincls | |||
return wrapper | |||
def save_config(self): | |||
with open("plugins/plugins.json", "w", encoding="utf-8") as f: | |||
json.dump(self.pconf, f, indent=4, ensure_ascii=False) | |||
def load_config(self): | |||
logger.info("Loading plugins config...") | |||
modified = False | |||
if os.path.exists("plugins/plugins.json"): | |||
with open("plugins/plugins.json", "r", encoding="utf-8") as f: | |||
pconf = json.load(f) | |||
pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True) | |||
else: | |||
modified = True | |||
pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)} | |||
self.pconf = pconf | |||
if modified: | |||
self.save_config() | |||
return pconf | |||
def scan_plugins(self): | |||
logger.info("Scaning plugins ...") | |||
plugins_dir = "plugins" | |||
for plugin_name in os.listdir(plugins_dir): | |||
plugin_path = os.path.join(plugins_dir, plugin_name) | |||
if os.path.isdir(plugin_path): | |||
# 判断插件是否包含同名.py文件 | |||
main_module_path = os.path.join(plugin_path, plugin_name+".py") | |||
if os.path.isfile(main_module_path): | |||
# 导入插件 | |||
import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name) | |||
main_module = importlib.import_module(import_path) | |||
pconf = self.pconf | |||
new_plugins = [] | |||
modified = False | |||
for name, plugincls in self.plugins.items(): | |||
rawname = plugincls.name | |||
if rawname not in pconf["plugins"]: | |||
new_plugins.append(plugincls) | |||
modified = True | |||
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) | |||
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} | |||
else: | |||
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] | |||
self.plugins[name].priority = pconf["plugins"][rawname]["priority"] | |||
self.plugins._update_heap(name) # 更新下plugins中的顺序 | |||
if modified: | |||
self.save_config() | |||
return new_plugins | |||
def refresh_order(self): | |||
for event in self.listening_plugins.keys(): | |||
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) | |||
def activate_plugins(self): # 生成新开启的插件实例 | |||
for name, plugincls in self.plugins.items(): | |||
if plugincls.enabled: | |||
if name not in self.instances: | |||
instance = plugincls() | |||
self.instances[name] = instance | |||
for event in instance.handlers: | |||
if event not in self.listening_plugins: | |||
self.listening_plugins[event] = [] | |||
self.listening_plugins[event].append(name) | |||
self.refresh_order() | |||
def reload_plugin(self, name:str): | |||
name = name.upper() | |||
if name in self.instances: | |||
for event in self.listening_plugins: | |||
if name in self.listening_plugins[event]: | |||
self.listening_plugins[event].remove(name) | |||
del self.instances[name] | |||
self.activate_plugins() | |||
return True | |||
return False | |||
def load_plugins(self): | |||
self.load_config() | |||
self.scan_plugins() | |||
pconf = self.pconf | |||
logger.debug("plugins.json config={}".format(pconf)) | |||
for name,plugin in pconf["plugins"].items(): | |||
if name.upper() not in self.plugins: | |||
logger.error("Plugin %s not found, but found in plugins.json" % name) | |||
self.activate_plugins() | |||
def emit_event(self, e_context: EventContext, *args, **kwargs): | |||
if e_context.event in self.listening_plugins: | |||
for name in self.listening_plugins[e_context.event]: | |||
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: | |||
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) | |||
instance = self.instances[name] | |||
instance.handlers[e_context.event](e_context, *args, **kwargs) | |||
return e_context | |||
def set_plugin_priority(self, name:str, priority:int): | |||
name = name.upper() | |||
if name not in self.plugins: | |||
return False | |||
if self.plugins[name].priority == priority: | |||
return True | |||
self.plugins[name].priority = priority | |||
self.plugins._update_heap(name) | |||
rawname = self.plugins[name].name | |||
self.pconf["plugins"][rawname]["priority"] = priority | |||
self.pconf["plugins"]._update_heap(rawname) | |||
self.save_config() | |||
self.refresh_order() | |||
return True | |||
def enable_plugin(self, name:str): | |||
name = name.upper() | |||
if name not in self.plugins: | |||
return False | |||
if not self.plugins[name].enabled : | |||
self.plugins[name].enabled = True | |||
rawname = self.plugins[name].name | |||
self.pconf["plugins"][rawname]["enabled"] = True | |||
self.save_config() | |||
self.activate_plugins() | |||
return True | |||
return True | |||
def disable_plugin(self, name:str): | |||
name = name.upper() | |||
if name not in self.plugins: | |||
return False | |||
if self.plugins[name].enabled : | |||
self.plugins[name].enabled = False | |||
rawname = self.plugins[name].name | |||
self.pconf["plugins"][rawname]["enabled"] = False | |||
self.save_config() | |||
return True | |||
return True | |||
def list_plugins(self): | |||
return self.plugins |
@@ -0,0 +1,70 @@ | |||
{ | |||
"start":{ | |||
"host" : "127.0.0.1", | |||
"port" : 7860 | |||
}, | |||
"defaults": { | |||
"params": { | |||
"sampler_name": "DPM++ 2M Karras", | |||
"steps": 20, | |||
"width": 512, | |||
"height": 512, | |||
"cfg_scale": 7, | |||
"prompt":"masterpiece, best quality", | |||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", | |||
"enable_hr": false, | |||
"hr_scale": 2, | |||
"hr_upscaler": "Latent", | |||
"hr_second_pass_steps": 15, | |||
"denoising_strength": 0.7 | |||
}, | |||
"options": { | |||
"sd_model_checkpoint": "perfectWorld_v2Baked" | |||
} | |||
}, | |||
"rules": [ | |||
{ | |||
"keywords": [ | |||
"横版", | |||
"壁纸" | |||
], | |||
"params": { | |||
"width": 640, | |||
"height": 384 | |||
}, | |||
"desc": "分辨率会变成640x384" | |||
}, | |||
{ | |||
"keywords": [ | |||
"竖版" | |||
], | |||
"params": { | |||
"width": 384, | |||
"height": 640 | |||
} | |||
}, | |||
{ | |||
"keywords": [ | |||
"高清" | |||
], | |||
"params": { | |||
"enable_hr": true, | |||
"hr_scale": 1.6 | |||
}, | |||
"desc": "出图分辨率长宽都会提高1.6倍" | |||
}, | |||
{ | |||
"keywords": [ | |||
"二次元" | |||
], | |||
"params": { | |||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", | |||
"prompt": "masterpiece, best quality" | |||
}, | |||
"options": { | |||
"sd_model_checkpoint": "meinamix_meinaV8" | |||
}, | |||
"desc": "使用二次元风格模型出图" | |||
} | |||
] | |||
} |
@@ -0,0 +1,69 @@ | |||
### 插件描述 | |||
本插件用于将画图请求转发给stable diffusion webui。 | |||
### 环境要求 | |||
使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。 | |||
具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。 | |||
请**安装**本插件的依赖包```webuiapi``` | |||
``` | |||
```pip install webuiapi``` | |||
``` | |||
### 使用说明 | |||
请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。 | |||
#### 画图请求格式 | |||
用户的画图请求格式为: | |||
``` | |||
<画图触发词><关键词1> <关键词2> ... <关键词n>:<prompt> | |||
``` | |||
- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。 | |||
- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准: | |||
- 关键词中包含`help`或`帮助`,会打印出帮助文档。 | |||
第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后 | |||
例如: 画横版 高清 二次元:cat | |||
会触发三个关键词 "横版", "高清", "二次元",prompt为"cat" | |||
若默认参数是: | |||
``` | |||
"width": 512, | |||
"height": 512, | |||
"enable_hr": false, | |||
"prompt": "8k" | |||
"negative_prompt": "nsfw", | |||
"sd_model_checkpoint": "perfectWorld_v2Baked" | |||
``` | |||
"横版"触发的规则参数为: | |||
``` | |||
"width": 640, | |||
"height": 384, | |||
``` | |||
"高清"触发的规则参数为: | |||
``` | |||
"enable_hr": true, | |||
"hr_scale": 1.6, | |||
``` | |||
"二次元"触发的规则参数为: | |||
``` | |||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", | |||
"steps": 20, | |||
"prompt": "masterpiece, best quality", | |||
"sd_model_checkpoint": "meinamix_meinaV8" | |||
``` | |||
最后将第一个":"后的内容cat连接在prompt后,得到最终参数为: | |||
``` | |||
"width": 640, | |||
"height": 384, | |||
"enable_hr": true, | |||
"hr_scale": 1.6, | |||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)", | |||
"steps": 20, | |||
"prompt": "masterpiece, best quality, cat", | |||
"sd_model_checkpoint": "meinamix_meinaV8" | |||
``` | |||
PS: 参数分为两部分: | |||
- 一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致 | |||
- 另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。 |
@@ -0,0 +1,114 @@ | |||
# encoding:utf-8 | |||
import json | |||
import os | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf | |||
import plugins | |||
from plugins import * | |||
from common.log import logger | |||
import webuiapi | |||
import io | |||
@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent") | |||
class SDWebUI(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
curdir = os.path.dirname(__file__) | |||
config_path = os.path.join(curdir, "config.json") | |||
try: | |||
with open(config_path, "r", encoding="utf-8") as f: | |||
config = json.load(f) | |||
self.rules = config["rules"] | |||
defaults = config["defaults"] | |||
self.default_params = defaults["params"] | |||
self.default_options = defaults["options"] | |||
self.start_args = config["start"] | |||
self.api = webuiapi.WebUIApi(**self.start_args) | |||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||
logger.info("[SD] inited") | |||
except FileNotFoundError: | |||
logger.error(f"[SD] init failed, {config_path} not found") | |||
except Exception as e: | |||
logger.error("[SD] init failed, exception: %s" % e) | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type != ContextType.IMAGE_CREATE: | |||
return | |||
logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content) | |||
logger.info("[SD] image_query={}".format(e_context['context'].content)) | |||
reply = Reply() | |||
try: | |||
content = e_context['context'].content[:] | |||
# 解析用户输入 如"横版 高清 二次元:cat" | |||
if ":" in content: | |||
keywords, prompt = content.split(":", 1) | |||
else: | |||
keywords = content | |||
prompt = "" | |||
keywords = keywords.split() | |||
if "help" in keywords or "帮助" in keywords: | |||
reply.type = ReplyType.INFO | |||
reply.content = self.get_help_text() | |||
else: | |||
rule_params = {} | |||
rule_options = {} | |||
for keyword in keywords: | |||
matched = False | |||
for rule in self.rules: | |||
if keyword in rule["keywords"]: | |||
for key in rule["params"]: | |||
rule_params[key] = rule["params"][key] | |||
if "options" in rule: | |||
for key in rule["options"]: | |||
rule_options[key] = rule["options"][key] | |||
matched = True | |||
break # 一个关键词只匹配一个规则 | |||
if not matched: | |||
logger.warning("[SD] keyword not matched: %s" % keyword) | |||
params = {**self.default_params, **rule_params} | |||
options = {**self.default_options, **rule_options} | |||
params["prompt"] = params.get("prompt", "")+f", {prompt}" | |||
if len(options) > 0: | |||
logger.info("[SD] cover options={}".format(options)) | |||
self.api.set_options(options) | |||
logger.info("[SD] params={}".format(params)) | |||
result = self.api.txt2img( | |||
**params | |||
) | |||
reply.type = ReplyType.IMAGE | |||
b_img = io.BytesIO() | |||
result.image.save(b_img, format="PNG") | |||
reply.content = b_img | |||
e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑 | |||
except Exception as e: | |||
reply.type = ReplyType.ERROR | |||
reply.content = "[SD] "+str(e) | |||
logger.error("[SD] exception: %s" % e) | |||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | |||
finally: | |||
e_context['reply'] = reply | |||
def get_help_text(self): | |||
if not conf().get('image_create_prefix'): | |||
return "画图功能未启用" | |||
else: | |||
trigger = conf()['image_create_prefix'][0] | |||
help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n" | |||
help_text += "目前可用关键词:\n" | |||
for rule in self.rules: | |||
keywords = [f"[{keyword}]" for keyword in rule['keywords']] | |||
help_text += f"{','.join(keywords)}" | |||
if "desc" in rule: | |||
help_text += f"-{rule['desc']}\n" | |||
else: | |||
help_text += "\n" | |||
return help_text |
@@ -4,6 +4,7 @@ baidu voice service | |||
""" | |||
import time | |||
from aip import AipSpeech | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
from voice.voice import Voice | |||
@@ -30,7 +31,8 @@ class BaiduVoice(Voice): | |||
with open(fileName, 'wb') as f: | |||
f.write(result) | |||
logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | |||
return fileName | |||
reply = Reply(ReplyType.VOICE, fileName) | |||
else: | |||
logger.error('[Baidu] textToVoice error={}'.format(result)) | |||
return None | |||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | |||
return reply |
@@ -6,6 +6,7 @@ google voice service | |||
import pathlib | |||
import subprocess | |||
import time | |||
from bridge.reply import Reply, ReplyType | |||
import speech_recognition | |||
import pyttsx3 | |||
from common.log import logger | |||
@@ -36,16 +37,22 @@ class GoogleVoice(Voice): | |||
text = self.recognizer.recognize_google(audio, language='zh-CN') | |||
logger.info( | |||
'[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) | |||
return text | |||
reply = Reply(ReplyType.TEXT, text) | |||
except speech_recognition.UnknownValueError: | |||
return "抱歉,我听不懂。" | |||
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") | |||
except speech_recognition.RequestError as e: | |||
return "抱歉,无法连接到 Google 语音识别服务;{0}".format(e) | |||
reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) | |||
finally: | |||
return reply | |||
def textToVoice(self, text): | |||
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||
self.engine.save_to_file(text, textFile) | |||
self.engine.runAndWait() | |||
logger.info( | |||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile)) | |||
return textFile | |||
try: | |||
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3' | |||
self.engine.save_to_file(text, textFile) | |||
self.engine.runAndWait() | |||
logger.info( | |||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile)) | |||
reply = Reply(ReplyType.VOICE, textFile) | |||
except Exception as e: | |||
reply = Reply(ReplyType.ERROR, str(e)) | |||
finally: | |||
return reply |
@@ -4,6 +4,7 @@ google voice service | |||
""" | |||
import json | |||
import openai | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf | |||
from common.log import logger | |||
from voice.voice import Voice | |||
@@ -16,12 +17,17 @@ class OpenaiVoice(Voice): | |||
def voiceToText(self, voice_file): | |||
logger.debug( | |||
'[Openai] voice file name={}'.format(voice_file)) | |||
file = open(voice_file, "rb") | |||
reply = openai.Audio.transcribe("whisper-1", file) | |||
text = reply["text"] | |||
logger.info( | |||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) | |||
return text | |||
try: | |||
file = open(voice_file, "rb") | |||
result = openai.Audio.transcribe("whisper-1", file) | |||
text = result["text"] | |||
reply = Reply(ReplyType.TEXT, text) | |||
logger.info( | |||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) | |||
except Exception as e: | |||
reply = Reply(ReplyType.ERROR, str(e)) | |||
finally: | |||
return reply | |||
def textToVoice(self, text): | |||
pass |