Browse Source

Merge pull request #442 from lanvent/dev

简易支持插件,添加sdwebui(novelai画图), godcmd(管理员指令增强)插件,Banwords(敏感词过滤)插件
develop
zhayujie GitHub 1 year ago
parent
commit
2cb30b5f59
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 1628 additions and 238 deletions
  1. +1
    -0
      .gitignore
  2. +5
    -2
      app.py
  3. +3
    -1
      bot/baidu/baidu_unit_bot.py
  4. +5
    -1
      bot/bot.py
  5. +66
    -48
      bot/chatgpt/chat_gpt_bot.py
  6. +25
    -22
      bot/openai/open_ai_bot.py
  7. +33
    -7
      bridge/bridge.py
  8. +42
    -0
      bridge/context.py
  9. +22
    -0
      bridge/reply.py
  10. +5
    -3
      channel/channel.py
  11. +152
    -126
      channel/wechat/wechat_channel.py
  12. +9
    -10
      channel/wechat/wechaty_channel.py
  13. +9
    -0
      common/singleton.py
  14. +65
    -0
      common/sorted_dict.py
  15. +9
    -0
      plugins/__init__.py
  16. +1
    -0
      plugins/banwords/.gitignore
  17. +9
    -0
      plugins/banwords/README.md
  18. +250
    -0
      plugins/banwords/WordsSearch.py
  19. +0
    -0
      plugins/banwords/__init__.py
  20. +63
    -0
      plugins/banwords/banwords.py
  21. +3
    -0
      plugins/banwords/banwords.txt.template
  22. +3
    -0
      plugins/banwords/config.json.template
  23. +49
    -0
      plugins/event.py
  24. +0
    -0
      plugins/godcmd/__init__.py
  25. +4
    -0
      plugins/godcmd/config.json.template
  26. +289
    -0
      plugins/godcmd/godcmd.py
  27. +0
    -0
      plugins/hello/__init__.py
  28. +46
    -0
      plugins/hello/hello.py
  29. +3
    -0
      plugins/plugin.py
  30. +171
    -0
      plugins/plugin_manager.py
  31. +0
    -0
      plugins/sdwebui/__init__.py
  32. +70
    -0
      plugins/sdwebui/config.json.template
  33. +69
    -0
      plugins/sdwebui/readme.md
  34. +114
    -0
      plugins/sdwebui/sdwebui.py
  35. +4
    -2
      voice/baidu/baidu_voice.py
  36. +17
    -10
      voice/google/google_voice.py
  37. +12
    -6
      voice/openai/openai_voice.py

+ 1
- 0
.gitignore View File

@@ -7,3 +7,4 @@ config.json
QR.png QR.png
nohup.out nohup.out
tmp tmp
plugins.json

+ 5
- 2
app.py View File

@@ -4,14 +4,17 @@ import config
from channel import channel_factory from channel import channel_factory
from common.log import logger from common.log import logger


from plugins import *
if __name__ == '__main__': if __name__ == '__main__':
try: try:
# load config # load config
config.load_config() config.load_config()


# create channel # create channel
channel = channel_factory.create_channel("wx")
channel_name='wx'
channel = channel_factory.create_channel(channel_name)
if channel_name=='wx':
PluginManager().load_plugins()


# startup channel # startup channel
channel.startup() channel.startup()


+ 3
- 1
bot/baidu/baidu_unit_bot.py View File

@@ -2,6 +2,7 @@


import requests import requests
from bot.bot import Bot from bot.bot import Bot
from bridge.reply import Reply, ReplyType




# Baidu Unit对话接口 (可用, 但能力较弱) # Baidu Unit对话接口 (可用, 但能力较弱)
@@ -14,7 +15,8 @@ class BaiduUnitBot(Bot):
headers = {'content-type': 'application/x-www-form-urlencoded'} headers = {'content-type': 'application/x-www-form-urlencoded'}
response = requests.post(url, data=post_data.encode(), headers=headers) response = requests.post(url, data=post_data.encode(), headers=headers)
if response: if response:
return response.json()['result']['context']['SYS_PRESUMED_HIST'][1]
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
return reply


def get_token(self): def get_token(self):
access_key = 'YOUR_ACCESS_KEY' access_key = 'YOUR_ACCESS_KEY'


+ 5
- 1
bot/bot.py View File

@@ -3,8 +3,12 @@ Auto-replay chat robot abstract class
""" """




from bridge.context import Context
from bridge.reply import Reply


class Bot(object): class Bot(object):
def reply(self, query, context=None):
def reply(self, query, context : Context =None) -> Reply:
""" """
bot auto-reply content bot auto-reply content
:param req: received message :param req: received message


+ 66
- 48
bot/chatgpt/chat_gpt_bot.py View File

@@ -1,41 +1,42 @@
# encoding:utf-8 # encoding:utf-8


from bot.bot import Bot from bot.bot import Bot
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import conf, load_config from config import conf, load_config
from common.log import logger from common.log import logger
from common.expired_dict import ExpiredDict from common.expired_dict import ExpiredDict
import openai import openai
import time import time


if conf().get('expires_in_seconds'):
all_sessions = ExpiredDict(conf().get('expires_in_seconds'))
else:
all_sessions = dict()


# OpenAI对话模型API (可用) # OpenAI对话模型API (可用)
class ChatGPTBot(Bot): class ChatGPTBot(Bot):
def __init__(self): def __init__(self):
openai.api_key = conf().get('open_ai_api_key') openai.api_key = conf().get('open_ai_api_key')
proxy = conf().get('proxy') proxy = conf().get('proxy')
self.sessions = SessionManager()
if proxy: if proxy:
openai.proxy = proxy openai.proxy = proxy


def reply(self, query, context=None): def reply(self, query, context=None):
# acquire reply content # acquire reply content
if not context or not context.get('type') or context.get('type') == 'TEXT':
if context.type == ContextType.TEXT:
logger.info("[OPEN_AI] query={}".format(query)) logger.info("[OPEN_AI] query={}".format(query))
session_id = context.get('session_id') or context.get('from_user_id')
session_id = context['session_id']
reply = None
if query == '#清除记忆': if query == '#清除记忆':
Session.clear_session(session_id)
return '记忆已清除'
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, '记忆已清除')
elif query == '#清除所有': elif query == '#清除所有':
Session.clear_all_session()
return '所有人记忆已清除'
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
elif query == '#更新配置': elif query == '#更新配置':
load_config() load_config()
return '配置已更新'

session = Session.build_session_query(query, session_id)
reply = Reply(ReplyType.INFO, '配置已更新')
if reply:
return reply
session = self.sessions.build_session_query(query, session_id)
logger.debug("[OPEN_AI] session query={}".format(session)) logger.debug("[OPEN_AI] session query={}".format(session))


# if context.get('stream'): # if context.get('stream'):
@@ -44,14 +45,29 @@ class ChatGPTBot(Bot):


reply_content = self.reply_text(session, session_id, 0) reply_content = self.reply_text(session, session_id, 0)
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"])) logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
if reply_content["completion_tokens"] > 0:
Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
return reply_content["content"]

elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query, 0)
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
reply = Reply(ReplyType.ERROR, reply_content['content'])
elif reply_content["completion_tokens"] > 0:
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"])
else:
reply = Reply(ReplyType.ERROR, reply_content['content'])
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
return reply

elif context.type == ContextType.IMAGE_CREATE:
ok, retstring = self.create_img(query, 0)
reply = None
if ok:
reply = Reply(ReplyType.IMAGE_URL, retstring)
else:
reply = Reply(ReplyType.ERROR, retstring)
return reply
else:
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
return reply


def reply_text(self, session, session_id, retry_count=0) ->dict:
def reply_text(self, session, session_id, retry_count=0) -> dict:
''' '''
call openai's ChatCompletion to get the answer call openai's ChatCompletion to get the answer
:param session: a conversation session :param session: a conversation session
@@ -70,8 +86,8 @@ class ChatGPTBot(Bot):
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容 presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
) )
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
return {"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": response.choices[0]['message']['content']} "content": response.choices[0]['message']['content']}
except openai.error.RateLimitError as e: except openai.error.RateLimitError as e:
# rate limit exception # rate limit exception
@@ -86,15 +102,15 @@ class ChatGPTBot(Bot):
# api connection exception # api connection exception
logger.warn(e) logger.warn(e)
logger.warn("[OPEN_AI] APIConnection failed") logger.warn("[OPEN_AI] APIConnection failed")
return {"completion_tokens": 0, "content":"我连接不到你的网络"}
return {"completion_tokens": 0, "content": "我连接不到你的网络"}
except openai.error.Timeout as e: except openai.error.Timeout as e:
logger.warn(e) logger.warn(e)
logger.warn("[OPEN_AI] Timeout") logger.warn("[OPEN_AI] Timeout")
return {"completion_tokens": 0, "content":"我没有收到你的消息"}
return {"completion_tokens": 0, "content": "我没有收到你的消息"}
except Exception as e: except Exception as e:
# unknown exception # unknown exception
logger.exception(e) logger.exception(e)
Session.clear_session(session_id)
self.sessions.clear_session(session_id)
return {"completion_tokens": 0, "content": "请再问我一次吧"} return {"completion_tokens": 0, "content": "请再问我一次吧"}


def create_img(self, query, retry_count=0): def create_img(self, query, retry_count=0):
@@ -107,7 +123,7 @@ class ChatGPTBot(Bot):
) )
image_url = response['data'][0]['url'] image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url)) logger.info("[OPEN_AI] image_url={}".format(image_url))
return image_url
return True, image_url
except openai.error.RateLimitError as e: except openai.error.RateLimitError as e:
logger.warn(e) logger.warn(e)
if retry_count < 1: if retry_count < 1:
@@ -115,14 +131,21 @@ class ChatGPTBot(Bot):
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.create_img(query, retry_count+1) return self.create_img(query, retry_count+1)
else: else:
return "提问太快啦,请休息一下再问我吧"
return False, "提问太快啦,请休息一下再问我吧"
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
return None
return False, str(e)


class SessionManager(object):
def __init__(self):
if conf().get('expires_in_seconds'):
sessions = ExpiredDict(conf().get('expires_in_seconds'))
else:
sessions = dict()
self.sessions = sessions


class Session(object):
@staticmethod
def build_session_query(query, session_id):
def build_session_query(self, query, session_id):
''' '''
build query with conversation history build query with conversation history
e.g. [ e.g. [
@@ -135,36 +158,33 @@ class Session(object):
:param session_id: session id :param session_id: session id
:return: query content with conversaction :return: query content with conversaction
''' '''
session = all_sessions.get(session_id, [])
session = self.sessions.get(session_id, [])
if len(session) == 0: if len(session) == 0:
system_prompt = conf().get("character_desc", "") system_prompt = conf().get("character_desc", "")
system_item = {'role': 'system', 'content': system_prompt} system_item = {'role': 'system', 'content': system_prompt}
session.append(system_item) session.append(system_item)
all_sessions[session_id] = session
self.sessions[session_id] = session
user_item = {'role': 'user', 'content': query} user_item = {'role': 'user', 'content': query}
session.append(user_item) session.append(user_item)
return session return session


@staticmethod
def save_session(answer, session_id, total_tokens):
def save_session(self, answer, session_id, total_tokens):
max_tokens = conf().get("conversation_max_tokens") max_tokens = conf().get("conversation_max_tokens")
if not max_tokens: if not max_tokens:
# default 3000 # default 3000
max_tokens = 1000 max_tokens = 1000
max_tokens=int(max_tokens)
max_tokens = int(max_tokens)


session = all_sessions.get(session_id)
session = self.sessions.get(session_id)
if session: if session:
# append conversation # append conversation
gpt_item = {'role': 'assistant', 'content': answer} gpt_item = {'role': 'assistant', 'content': answer}
session.append(gpt_item) session.append(gpt_item)


# discard exceed limit conversation # discard exceed limit conversation
Session.discard_exceed_conversation(session, max_tokens, total_tokens)
self.discard_exceed_conversation(session, max_tokens, total_tokens)


@staticmethod
def discard_exceed_conversation(session, max_tokens, total_tokens):
def discard_exceed_conversation(self, session, max_tokens, total_tokens):
dec_tokens = int(total_tokens) dec_tokens = int(total_tokens)
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens)) # logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
while dec_tokens > max_tokens: while dec_tokens > max_tokens:
@@ -173,13 +193,11 @@ class Session(object):
session.pop(1) session.pop(1)
session.pop(1) session.pop(1)
else: else:
break
break
dec_tokens = dec_tokens - max_tokens dec_tokens = dec_tokens - max_tokens


@staticmethod
def clear_session(session_id):
all_sessions[session_id] = []
def clear_session(self, session_id):
self.sessions[session_id] = []


@staticmethod
def clear_all_session():
all_sessions.clear()
def clear_all_session(self):
self.sessions.clear()

+ 25
- 22
bot/openai/open_ai_bot.py View File

@@ -1,6 +1,8 @@
# encoding:utf-8 # encoding:utf-8


from bot.bot import Bot from bot.bot import Bot
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import conf from config import conf
from common.log import logger from common.log import logger
import openai import openai
@@ -13,30 +15,31 @@ class OpenAIBot(Bot):
def __init__(self): def __init__(self):
openai.api_key = conf().get('open_ai_api_key') openai.api_key = conf().get('open_ai_api_key')



def reply(self, query, context=None): def reply(self, query, context=None):
# acquire reply content # acquire reply content
if not context or not context.get('type') or context.get('type') == 'TEXT':
logger.info("[OPEN_AI] query={}".format(query))
from_user_id = context.get('from_user_id') or context.get('session_id')
if query == '#清除记忆':
Session.clear_session(from_user_id)
return '记忆已清除'
elif query == '#清除所有':
Session.clear_all_session()
return '所有人记忆已清除'

new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query))

reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query:
Session.save_session(query, reply_content, from_user_id)
return reply_content

elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query, 0)
if context and context.type:
if context.type == ContextType.TEXT:
logger.info("[OPEN_AI] query={}".format(query))
from_user_id = context['session_id']
reply = None
if query == '#清除记忆':
Session.clear_session(from_user_id)
reply = Reply(ReplyType.INFO, '记忆已清除')
elif query == '#清除所有':
Session.clear_all_session()
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
else:
new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query))

reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query:
Session.save_session(query, reply_content, from_user_id)
reply = Reply(ReplyType.TEXT, reply_content)
return reply
elif context.type == ContextType.IMAGE_CREATE:
return self.create_img(query, 0)


def reply_text(self, query, user_id, retry_count=0): def reply_text(self, query, user_id, retry_count=0):
try: try:


+ 33
- 7
bridge/bridge.py View File

@@ -1,16 +1,42 @@
from bridge.context import Context
from bridge.reply import Reply
from common.log import logger
from bot import bot_factory from bot import bot_factory
from common.singleton import singleton
from voice import voice_factory from voice import voice_factory




@singleton
class Bridge(object): class Bridge(object):
def __init__(self): def __init__(self):
pass
self.btype={
"chat": "chatGPT",
"voice_to_text": "openai",
"text_to_voice": "baidu"
}
self.bots={}


def fetch_reply_content(self, query, context):
return bot_factory.create_bot("chatGPT").reply(query, context)
def get_bot(self,typename):
if self.bots.get(typename) is None:
logger.info("create bot {} for {}".format(self.btype[typename],typename))
if typename == "text_to_voice":
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
elif typename == "voice_to_text":
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
elif typename == "chat":
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
return self.bots[typename]
def get_bot_type(self,typename):
return self.btype[typename]


def fetch_voice_to_text(self, voiceFile):
return voice_factory.create_voice("openai").voiceToText(voiceFile)


def fetch_text_to_voice(self, text):
return voice_factory.create_voice("baidu").textToVoice(text)
def fetch_reply_content(self, query, context : Context) -> Reply:
return self.get_bot("chat").reply(query, context)

def fetch_voice_to_text(self, voiceFile) -> Reply:
return self.get_bot("voice_to_text").voiceToText(voiceFile)

def fetch_text_to_voice(self, text) -> Reply:
return self.get_bot("text_to_voice").textToVoice(text)


+ 42
- 0
bridge/context.py View File

@@ -0,0 +1,42 @@
# encoding:utf-8

from enum import Enum

class ContextType (Enum):
TEXT = 1 # 文本消息
VOICE = 2 # 音频消息
IMAGE_CREATE = 3 # 创建图片命令
def __str__(self):
return self.name
class Context:
def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
self.type = type
self.content = content
self.kwargs = kwargs
def __getitem__(self, key):
if key == 'type':
return self.type
elif key == 'content':
return self.content
else:
return self.kwargs[key]

def __setitem__(self, key, value):
if key == 'type':
self.type = value
elif key == 'content':
self.content = value
else:
self.kwargs[key] = value

def __delitem__(self, key):
if key == 'type':
self.type = None
elif key == 'content':
self.content = None
else:
del self.kwargs[key]
def __str__(self):
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)

+ 22
- 0
bridge/reply.py View File

@@ -0,0 +1,22 @@

# encoding:utf-8

from enum import Enum

class ReplyType(Enum):
TEXT = 1 # 文本
VOICE = 2 # 音频文件
IMAGE = 3 # 图片文件
IMAGE_URL = 4 # 图片URL
INFO = 9
ERROR = 10
def __str__(self):
return self.name

class Reply:
def __init__(self, type : ReplyType = None , content = None):
self.type = type
self.content = content
def __str__(self):
return "Reply(type={}, content={})".format(self.type, self.content)

+ 5
- 3
channel/channel.py View File

@@ -3,6 +3,8 @@ Message sending channel abstract class
""" """


from bridge.bridge import Bridge from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import Reply


class Channel(object): class Channel(object):
def startup(self): def startup(self):
@@ -27,11 +29,11 @@ class Channel(object):
""" """
raise NotImplementedError raise NotImplementedError


def build_reply_content(self, query, context=None):
def build_reply_content(self, query, context : Context=None) -> Reply:
return Bridge().fetch_reply_content(query, context) return Bridge().fetch_reply_content(query, context)


def build_voice_to_text(self, voice_file):
def build_voice_to_text(self, voice_file) -> Reply:
return Bridge().fetch_voice_to_text(voice_file) return Bridge().fetch_voice_to_text(voice_file)
def build_text_to_voice(self, text):
def build_text_to_voice(self, text) -> Reply:
return Bridge().fetch_text_to_voice(text) return Bridge().fetch_text_to_voice(text)

+ 152
- 126
channel/wechat/wechat_channel.py View File

@@ -7,16 +7,24 @@ wechat channel
import itchat import itchat
import json import json
from itchat.content import * from itchat.content import *
from bridge.reply import *
from bridge.context import *
from channel.channel import Channel from channel.channel import Channel
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from common.log import logger from common.log import logger
from common.tmp_dir import TmpDir from common.tmp_dir import TmpDir
from config import conf from config import conf
from plugins import *

import requests import requests
import io import io


thread_pool = ThreadPoolExecutor(max_workers=8)


thread_pool = ThreadPoolExecutor(max_workers=8)
def thread_pool_callback(worker):
worker_exception = worker.exception()
if worker_exception:
logger.exception("Worker return exception: {}".format(worker_exception))


@itchat.msg_register(TEXT) @itchat.msg_register(TEXT)
def handler_single_msg(msg): def handler_single_msg(msg):
@@ -47,62 +55,52 @@ class WechatChannel(Channel):
# start message listener # start message listener
itchat.run() itchat.run()


# handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context
# context是一个字典,包含了消息的所有信息,包括以下key
# type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE
# content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
# session_id: 会话id
# isgroup: 是否是群聊
# msg: 原始消息对象
# receiver: 需要回复的对象

def handle_voice(self, msg): def handle_voice(self, msg):
if conf().get('speech_recognition') != True :
if conf().get('speech_recognition') != True:
return return
logger.debug("[WX]receive voice msg: " + msg['FileName']) logger.debug("[WX]receive voice msg: " + msg['FileName'])
thread_pool.submit(self._do_handle_voice, msg)

def _do_handle_voice(self, msg):
from_user_id = msg['FromUserName'] from_user_id = msg['FromUserName']
other_user_id = msg['User']['UserName'] other_user_id = msg['User']['UserName']
if from_user_id == other_user_id: if from_user_id == other_user_id:
file_name = TmpDir().path() + msg['FileName']
msg.download(file_name)
query = super().build_voice_to_text(file_name)
if conf().get('voice_reply_voice'):
self._do_send_voice(query, from_user_id)
else:
self._do_send_text(query, from_user_id)
context = Context(ContextType.VOICE,msg['FileName'])
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)


def handle_text(self, msg): def handle_text(self, msg):
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False)) logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
content = msg['Text'] content = msg['Text']
self._handle_single_msg(msg, content)

def _handle_single_msg(self, msg, content):
from_user_id = msg['FromUserName'] from_user_id = msg['FromUserName']
to_user_id = msg['ToUserName'] # 接收人id to_user_id = msg['ToUserName'] # 接收人id
other_user_id = msg['User']['UserName'] # 对手方id other_user_id = msg['User']['UserName'] # 对手方id
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
if "」\n- - - - - - - - - - - - - - -" in content: if "」\n- - - - - - - - - - - - - - -" in content:
logger.debug("[WX]reference query skipped") logger.debug("[WX]reference query skipped")
return return
if from_user_id == other_user_id and match_prefix is not None:
# 好友向自己发送消息
if match_prefix != '':
str_list = content.split(match_prefix, 1)
if len(str_list) == 2:
content = str_list[1].strip()

img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
thread_pool.submit(self._do_send_img, content, from_user_id)
else :
thread_pool.submit(self._do_send_text, content, from_user_id)
elif to_user_id == other_user_id and match_prefix:
# 自己给好友发送消息
str_list = content.split(match_prefix, 1)
if len(str_list) == 2:
content = str_list[1].strip()
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
thread_pool.submit(self._do_send_img, content, to_user_id)
else:
thread_pool.submit(self._do_send_text, content, to_user_id)
if match_prefix:
content = content.replace(match_prefix, '', 1).strip()
else:
return
context = Context()
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}

img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip()
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT


context.content = content
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)


def handle_group(self, msg): def handle_group(self, msg):
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False)) logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
@@ -122,100 +120,128 @@ class WechatChannel(Channel):
logger.debug("[WX]reference query skipped") logger.debug("[WX]reference query skipped")
return "" return ""
config = conf() config = conf()
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or self.check_prefix(origin_content, config.get('group_chat_prefix')) \
or self.check_contain(origin_content, config.get('group_chat_keyword'))
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or self.check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
or check_contain(origin_content, config.get('group_chat_keyword'))
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
context = Context()
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id}
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix: if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
thread_pool.submit(self._do_send_img, content, group_id)
content = content.replace(img_match_prefix, '', 1).strip()
context.type = ContextType.IMAGE_CREATE
else: else:
thread_pool.submit(self._do_send_group, content, msg)

def send(self, msg, receiver):
itchat.send(msg, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(msg, receiver))

def _do_send_voice(self, query, reply_user_id):
try:
if not query:
return
context = dict()
context['from_user_id'] = reply_user_id
reply_text = super().build_reply_content(query, context)
if reply_text:
replyFile = super().build_text_to_voice(reply_text)
itchat.send_file(replyFile, toUserName=reply_user_id)
logger.info('[WX] sendFile={}, receiver={}'.format(replyFile, reply_user_id))
except Exception as e:
logger.exception(e)

def _do_send_text(self, query, reply_user_id):
try:
if not query:
return
context = dict()
context['session_id'] = reply_user_id
reply_text = super().build_reply_content(query, context)
if reply_text:
self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e:
logger.exception(e)

def _do_send_img(self, query, reply_user_id):
try:
if not query:
return
context = dict()
context['type'] = 'IMAGE_CREATE'
img_url = super().build_reply_content(query, context)
if not img_url:
return

# 图片下载
context.type = ContextType.TEXT
context.content = content

group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
if ('ALL_GROUP' in group_chat_in_one_session or
group_name in group_chat_in_one_session or
check_contain(group_name, group_chat_in_one_session)):
context['session_id'] = group_id
else:
context['session_id'] = msg['ActualUserName']

thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)

# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply : Reply, receiver):
if reply.type == ReplyType.TEXT:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.VOICE:
itchat.send_file(reply.content, toUserName=receiver)
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True) pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO() image_storage = io.BytesIO()
for block in pic_res.iter_content(1024): for block in pic_res.iter_content(1024):
image_storage.write(block) image_storage.write(block)
image_storage.seek(0) image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage url=, receiver={}'.format(img_url,receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage, receiver={}'.format(receiver))

# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
def handle(self, context):
reply = Reply()

logger.debug('[WX] ready to handle context: {}'.format(context))
# reply的构建步骤
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass():
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE:
msg = context['msg']
file_name = TmpDir().path() + context.content
msg.download(file_name)
reply = super().build_voice_to_text(file_name)
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
context.content = reply.content # 语音转文字后,将文字内容作为新的context
context.type = ContextType.TEXT
reply = super().build_reply_content(context.content, context)
if reply.type == ReplyType.TEXT:
if conf().get('voice_reply_voice'):
reply = super().build_text_to_voice(reply.content)
else:
logger.error('[WX] unknown context type: {}'.format(context.type))
return


# 图片发送
itchat.send_image(image_storage, reply_user_id)
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
except Exception as e:
logger.exception(e)

def _do_send_group(self, query, msg):
if not query:
return
context = dict()
group_name = msg['User']['NickName']
group_id = msg['User']['UserName']
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
if ('ALL_GROUP' in group_chat_in_one_session or \
group_name in group_chat_in_one_session or \
self.check_contain(group_name, group_chat_in_one_session)):
context['session_id'] = group_id
else:
context['session_id'] = msg['ActualUserName']
reply_text = super().build_reply_content(query, context)
if reply_text:
reply_text = '@' + msg['ActualNickName'] + ' ' + reply_text.strip()
self.send(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)


def check_prefix(self, content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
# reply的包装步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
if not e_context.is_pass() and reply and reply.type:
if reply.type == ReplyType.TEXT:
reply_text = reply.content
if context['isgroup']:
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
else:
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = str(reply.type)+":\n" + reply.content
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass
else:
logger.error('[WX] unknown reply type: {}'.format(reply.type))
return

# reply的发送步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
if not e_context.is_pass() and reply and reply.type:
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
self.send(reply, context['receiver'])


def check_prefix(content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None




def check_contain(self, content, keyword_list):
if not keyword_list:
return None
for ky in keyword_list:
if content.find(ky) != -1:
return True
def check_contain(content, keyword_list):
if not keyword_list:
return None return None

for ky in keyword_list:
if content.find(ky) != -1:
return True
return None

+ 9
- 10
channel/wechat/wechaty_channel.py View File

@@ -11,6 +11,7 @@ import time
import asyncio import asyncio
import requests import requests
from typing import Optional, Union from typing import Optional, Union
from bridge.context import Context, ContextType
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore
from wechaty import Wechaty, Contact from wechaty import Wechaty, Contact
from wechaty.user import Message, Room, MiniProgram, UrlLink from wechaty.user import Message, Room, MiniProgram, UrlLink
@@ -127,9 +128,9 @@ class WechatyChannel(Channel):
try: try:
if not query: if not query:
return return
context = dict()
context = Context(ContextType.TEXT, query)
context['session_id'] = reply_user_id context['session_id'] = reply_user_id
reply_text = super().build_reply_content(query, context)
reply_text = super().build_reply_content(query, context).content
if reply_text: if reply_text:
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id) await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e: except Exception as e:
@@ -139,9 +140,8 @@ class WechatyChannel(Channel):
try: try:
if not query: if not query:
return return
context = dict()
context['type'] = 'IMAGE_CREATE'
img_url = super().build_reply_content(query, context)
context = Context(ContextType.IMAGE_CREATE, query)
img_url = super().build_reply_content(query, context).content
if not img_url: if not img_url:
return return
# 图片下载 # 图片下载
@@ -162,7 +162,7 @@ class WechatyChannel(Channel):
async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name): async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name):
if not query: if not query:
return return
context = dict()
context = Context(ContextType.TEXT, query)
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
if ('ALL_GROUP' in group_chat_in_one_session or \ if ('ALL_GROUP' in group_chat_in_one_session or \
group_name in group_chat_in_one_session or \ group_name in group_chat_in_one_session or \
@@ -170,7 +170,7 @@ class WechatyChannel(Channel):
context['session_id'] = str(group_id) context['session_id'] = str(group_id)
else: else:
context['session_id'] = str(group_id) + '-' + str(group_user_id) context['session_id'] = str(group_id) + '-' + str(group_user_id)
reply_text = super().build_reply_content(query, context)
reply_text = super().build_reply_content(query, context).content
if reply_text: if reply_text:
reply_text = '@' + group_user_name + ' ' + reply_text.strip() reply_text = '@' + group_user_name + ' ' + reply_text.strip()
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id) await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
@@ -179,9 +179,8 @@ class WechatyChannel(Channel):
try: try:
if not query: if not query:
return return
context = dict()
context['type'] = 'IMAGE_CREATE'
img_url = super().build_reply_content(query, context)
context = Context(ContextType.IMAGE_CREATE, query)
img_url = super().build_reply_content(query, context).content
if not img_url: if not img_url:
return return
# 图片发送 # 图片发送


+ 9
- 0
common/singleton.py View File

@@ -0,0 +1,9 @@
def singleton(cls):
instances = {}

def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]

return get_instance

+ 65
- 0
common/sorted_dict.py View File

@@ -0,0 +1,65 @@
import heapq


class SortedDict(dict):
def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False):
if init_dict is None:
init_dict = []
if isinstance(init_dict, dict):
init_dict = init_dict.items()
self.sort_func = sort_func
self.sorted_keys = None
self.reverse = reverse
self.heap = []
for k, v in init_dict:
self[k] = v

def __setitem__(self, key, value):
if key in self:
super().__setitem__(key, value)
for i, (priority, k) in enumerate(self.heap):
if k == key:
self.heap[i] = (self.sort_func(key, value), key)
heapq.heapify(self.heap)
break
self.sorted_keys = None
else:
super().__setitem__(key, value)
heapq.heappush(self.heap, (self.sort_func(key, value), key))
self.sorted_keys = None

def __delitem__(self, key):
super().__delitem__(key)
for i, (priority, k) in enumerate(self.heap):
if k == key:
del self.heap[i]
heapq.heapify(self.heap)
break
self.sorted_keys = None

def keys(self):
if self.sorted_keys is None:
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
return self.sorted_keys

def items(self):
if self.sorted_keys is None:
self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
sorted_items = [(k, self[k]) for k in self.sorted_keys]
return sorted_items

def _update_heap(self, key):
for i, (priority, k) in enumerate(self.heap):
if k == key:
new_priority = self.sort_func(key, self[key])
if new_priority != priority:
self.heap[i] = (new_priority, key)
heapq.heapify(self.heap)
self.sorted_keys = None
break

def __iter__(self):
return iter(self.keys())

def __repr__(self):
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})'

+ 9
- 0
plugins/__init__.py View File

@@ -0,0 +1,9 @@
from .plugin_manager import PluginManager
from .event import *
from .plugin import *

instance = PluginManager()

register = instance.register
# load_plugins = instance.load_plugins
# emit_event = instance.emit_event

+ 1
- 0
plugins/banwords/.gitignore View File

@@ -0,0 +1 @@
banwords.txt

+ 9
- 0
plugins/banwords/README.md View File

@@ -0,0 +1,9 @@
### 说明
简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。

`config.json`中能够填写默认的处理行为,目前行为有:
- `ignore` : 无视这条消息。
- `replace` : 将消息中的敏感词替换成"*",并回复违规。

### 致谢
搜索功能实现来自https://github.com/toolgood/ToolGood.Words

+ 250
- 0
plugins/banwords/WordsSearch.py View File

@@ -0,0 +1,250 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# ToolGood.Words.WordsSearch.py
# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words
# Licensed under the Apache License 2.0
# 更新日志
# 2020.04.06 第一次提交
# 2020.05.16 修改,支持大于0xffff的字符

__all__ = ['WordsSearch']
__author__ = 'Lin Zhijun'
__date__ = '2020.05.16'

class TrieNode():
def __init__(self):
self.Index = 0
self.Index = 0
self.Layer = 0
self.End = False
self.Char = ''
self.Results = []
self.m_values = {}
self.Failure = None
self.Parent = None

def Add(self,c):
if c in self.m_values :
return self.m_values[c]
node = TrieNode()
node.Parent = self
node.Char = c
self.m_values[c] = node
return node

def SetResults(self,index):
if (self.End == False):
self.End = True
self.Results.append(index)

class TrieNode2():
def __init__(self):
self.End = False
self.Results = []
self.m_values = {}
self.minflag = 0xffff
self.maxflag = 0

def Add(self,c,node3):
if (self.minflag > c):
self.minflag = c
if (self.maxflag < c):
self.maxflag = c
self.m_values[c] = node3

def SetResults(self,index):
if (self.End == False) :
self.End = True
if (index in self.Results )==False :
self.Results.append(index)

def HasKey(self,c):
return c in self.m_values
def TryGetValue(self,c):
if (self.minflag <= c and self.maxflag >= c):
if c in self.m_values:
return self.m_values[c]
return None


class WordsSearch():
def __init__(self):
self._first = {}
self._keywords = []
self._indexs=[]
def SetKeywords(self,keywords):
self._keywords = keywords
self._indexs=[]
for i in range(len(keywords)):
self._indexs.append(i)

root = TrieNode()
allNodeLayer={}

for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++)
p = self._keywords[i]
nd = root
for j in range(len(p)): # for (j = 0; j < p.length; j++)
nd = nd.Add(ord(p[j]))
if (nd.Layer == 0):
nd.Layer = j + 1
if nd.Layer in allNodeLayer:
allNodeLayer[nd.Layer].append(nd)
else:
allNodeLayer[nd.Layer]=[]
allNodeLayer[nd.Layer].append(nd)
nd.SetResults(i)


allNode = []
allNode.append(root)
for key in allNodeLayer.keys():
for nd in allNodeLayer[key]:
allNode.append(nd)
allNodeLayer=None

for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++)
if i==0 :
continue
nd=allNode[i]
nd.Index = i
r = nd.Parent.Failure
c = nd.Char
while (r != None and (c in r.m_values)==False):
r = r.Failure
if (r == None):
nd.Failure = root
else:
nd.Failure = r.m_values[c]
for key2 in nd.Failure.Results :
nd.SetResults(key2)
root.Failure = root

allNode2 = []
for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++)
allNode2.append( TrieNode2())
for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++)
oldNode = allNode[i]
newNode = allNode2[i]

for key in oldNode.m_values :
index = oldNode.m_values[key].Index
newNode.Add(key, allNode2[index])
for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++)
item = oldNode.Results[index]
newNode.SetResults(item)
oldNode=oldNode.Failure
while oldNode != root:
for key in oldNode.m_values :
if (newNode.HasKey(key) == False):
index = oldNode.m_values[key].Index
newNode.Add(key, allNode2[index])
for index in range(len(oldNode.Results)):
item = oldNode.Results[index]
newNode.SetResults(item)
oldNode=oldNode.Failure
allNode = None
root = None

# first = []
# for index in range(65535):# for (index = 0; index < 0xffff; index++)
# first.append(None)
# for key in allNode2[0].m_values :
# first[key] = allNode2[0].m_values[key]
self._first = allNode2[0]

def FindFirst(self,text):
ptr = None
for index in range(len(text)): # for (index = 0; index < text.length; index++)
t =ord(text[index]) # text.charCodeAt(index)
tn = None
if (ptr == None):
tn = self._first.TryGetValue(t)
else:
tn = ptr.TryGetValue(t)
if (tn==None):
tn = self._first.TryGetValue(t)
if (tn != None):
if (tn.End):
item = tn.Results[0]
keyword = self._keywords[item]
return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }
ptr = tn
return None

def FindAll(self,text):
ptr = None
list = []

for index in range(len(text)): # for (index = 0; index < text.length; index++)
t =ord(text[index]) # text.charCodeAt(index)
tn = None
if (ptr == None):
tn = self._first.TryGetValue(t)
else:
tn = ptr.TryGetValue(t)
if (tn==None):
tn = self._first.TryGetValue(t)
if (tn != None):
if (tn.End):
for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++)
item = tn.Results[j]
keyword = self._keywords[item]
list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] })
ptr = tn
return list


def ContainsAny(self,text):
ptr = None
for index in range(len(text)): # for (index = 0; index < text.length; index++)
t =ord(text[index]) # text.charCodeAt(index)
tn = None
if (ptr == None):
tn = self._first.TryGetValue(t)
else:
tn = ptr.TryGetValue(t)
if (tn==None):
tn = self._first.TryGetValue(t)
if (tn != None):
if (tn.End):
return True
ptr = tn
return False
def Replace(self,text, replaceChar = '*'):
result = list(text)

ptr = None
for i in range(len(text)): # for (i = 0; i < text.length; i++)
t =ord(text[i]) # text.charCodeAt(index)
tn = None
if (ptr == None):
tn = self._first.TryGetValue(t)
else:
tn = ptr.TryGetValue(t)
if (tn==None):
tn = self._first.TryGetValue(t)
if (tn != None):
if (tn.End):
maxLength = len( self._keywords[tn.Results[0]])
start = i + 1 - maxLength
for j in range(start,i+1): # for (j = start; j <= i; j++)
result[j] = replaceChar
ptr = tn
return ''.join(result)

+ 0
- 0
plugins/banwords/__init__.py View File


+ 63
- 0
plugins/banwords/banwords.py View File

@@ -0,0 +1,63 @@
# encoding:utf-8

import json
import os
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
import plugins
from plugins import *
from common.log import logger
from .WordsSearch import WordsSearch


@plugins.register(name="Banwords", desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent", desire_priority= 100)
class Banwords(Plugin):
def __init__(self):
super().__init__()
try:
curdir=os.path.dirname(__file__)
config_path=os.path.join(curdir,"config.json")
conf=None
if not os.path.exists(config_path):
conf={"action":"ignore"}
with open(config_path,"w") as f:
json.dump(conf,f,indent=4)
else:
with open(config_path,"r") as f:
conf=json.load(f)
self.searchr = WordsSearch()
self.action = conf["action"]
banwords_path = os.path.join(curdir,"banwords.txt")
with open(banwords_path, 'r', encoding='utf-8') as f:
words=[]
for line in f:
word = line.strip()
if word:
words.append(word)
self.searchr.SetKeywords(words)
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[Banwords] inited")
except Exception as e:
logger.error("Banwords init failed: %s" % e)


def on_handle_context(self, e_context: EventContext):

if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]:
return
content = e_context['context'].content
logger.debug("[Banwords] on_handle_context. content: %s" % content)
if self.action == "ignore":
f = self.searchr.FindFirst(content)
if f:
logger.info("Banwords: %s" % f["Keyword"])
e_context.action = EventAction.BREAK_PASS
return
elif self.action == "replace":
if self.searchr.ContainsAny(content):
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content))
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS
return

+ 3
- 0
plugins/banwords/banwords.txt.template View File

@@ -0,0 +1,3 @@
nipples
pennis
法轮功

+ 3
- 0
plugins/banwords/config.json.template View File

@@ -0,0 +1,3 @@
{
"action": "ignore"
}

+ 49
- 0
plugins/event.py View File

@@ -0,0 +1,49 @@
# encoding:utf-8

from enum import Enum


class Event(Enum):
# ON_RECEIVE_MESSAGE = 1 # 收到消息

ON_HANDLE_CONTEXT = 2 # 处理消息前
"""
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 }
"""

ON_DECORATE_REPLY = 3 # 得到回复后准备装饰
"""
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
"""

ON_SEND_REPLY = 4 # 发送回复前
"""
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
"""

# AFTER_SEND_REPLY = 5 # 发送回复后


class EventAction(Enum):
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑


class EventContext:
def __init__(self, event, econtext=dict()):
self.event = event
self.econtext = econtext
self.action = EventAction.CONTINUE

def __getitem__(self, key):
return self.econtext[key]

def __setitem__(self, key, value):
self.econtext[key] = value

def __delitem__(self, key):
del self.econtext[key]

def is_pass(self):
return self.action == EventAction.BREAK_PASS

+ 0
- 0
plugins/godcmd/__init__.py View File


+ 4
- 0
plugins/godcmd/config.json.template View File

@@ -0,0 +1,4 @@
{
"password": "",
"admin_users": []
}

+ 289
- 0
plugins/godcmd/godcmd.py View File

@@ -0,0 +1,289 @@
# encoding:utf-8

import json
import os
import traceback
from typing import Tuple
from bridge.bridge import Bridge
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import load_config
import plugins
from plugins import *
from common.log import logger

# 定义指令集
COMMANDS = {
"help": {
"alias": ["help", "帮助"],
"desc": "打印指令集合",
},
"auth": {
"alias": ["auth", "认证"],
"args": ["口令"],
"desc": "管理员认证",
},
# "id": {
# "alias": ["id", "用户"],
# "desc": "获取用户id", #目前无实际意义
# },
"reset": {
"alias": ["reset", "重置会话"],
"desc": "重置会话",
},
}

ADMIN_COMMANDS = {
"resume": {
"alias": ["resume", "恢复服务"],
"desc": "恢复服务",
},
"stop": {
"alias": ["stop", "暂停服务"],
"desc": "暂停服务",
},
"reconf": {
"alias": ["reconf", "重载配置"],
"desc": "重载配置(不包含插件配置)",
},
"resetall": {
"alias": ["resetall", "重置所有会话"],
"desc": "重置所有会话",
},
"scanp": {
"alias": ["scanp", "扫描插件"],
"desc": "扫描插件目录是否有新插件",
},
"plist": {
"alias": ["plist", "插件"],
"desc": "打印当前插件列表",
},
"setpri": {
"alias": ["setpri", "设置插件优先级"],
"args": ["插件名", "优先级"],
"desc": "设置指定插件的优先级,越大越优先",
},
"reloadp": {
"alias": ["reloadp", "重载插件"],
"args": ["插件名"],
"desc": "重载指定插件配置",
},
"enablep": {
"alias": ["enablep", "启用插件"],
"args": ["插件名"],
"desc": "启用指定插件",
},
"disablep": {
"alias": ["disablep", "禁用插件"],
"args": ["插件名"],
"desc": "禁用指定插件",
},
"debug": {
"alias": ["debug", "调试模式", "DEBUG"],
"desc": "开启机器调试日志",
},
}
# 定义帮助函数
def get_help_text(isadmin, isgroup):
help_text = "可用指令:\n"
for cmd, info in COMMANDS.items():
if cmd=="auth" and (isadmin or isgroup): # 群聊不可认证
continue

alias=["#"+a for a in info['alias']]
help_text += f"{','.join(alias)} "
if 'args' in info:
args=["{"+a+"}" for a in info['args']]
help_text += f"{' '.join(args)} "
help_text += f": {info['desc']}\n"
if ADMIN_COMMANDS and isadmin:
help_text += "\n管理员指令:\n"
for cmd, info in ADMIN_COMMANDS.items():
alias=["#"+a for a in info['alias']]
help_text += f"{','.join(alias)} "
help_text += f": {info['desc']}\n"
return help_text

@plugins.register(name="Godcmd", desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent", desire_priority= 999)
class Godcmd(Plugin):

def __init__(self):
super().__init__()

curdir=os.path.dirname(__file__)
config_path=os.path.join(curdir,"config.json")
gconf=None
if not os.path.exists(config_path):
gconf={"password":"","admin_users":[]}
with open(config_path,"w") as f:
json.dump(gconf,f,indent=4)
else:
with open(config_path,"r") as f:
gconf=json.load(f)
self.password = gconf["password"]
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用
self.isrunning = True # 机器人是否运行中

self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[Godcmd] inited")

def on_handle_context(self, e_context: EventContext):
context_type = e_context['context'].type
if context_type != ContextType.TEXT:
if not self.isrunning:
e_context.action = EventAction.BREAK_PASS
return
content = e_context['context'].content
logger.debug("[Godcmd] on_handle_context. content: %s" % content)
if content.startswith("#"):
# msg = e_context['context']['msg']
user = e_context['context']['receiver']
session_id = e_context['context']['session_id']
isgroup = e_context['context']['isgroup']
bottype = Bridge().get_bot_type("chat")
bot = Bridge().get_bot("chat")
# 将命令和参数分割
command_parts = content[1:].split(" ")
cmd = command_parts[0]
args = command_parts[1:]
isadmin=False
if user in self.admin_users:
isadmin=True
ok=False
result="string"
if any(cmd in info['alias'] for info in COMMANDS.values()):
cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias'])
if cmd == "auth":
ok, result = self.authenticate(user, args, isadmin, isgroup)
elif cmd == "help":
ok, result = True, get_help_text(isadmin, isgroup)
elif cmd == "id":
ok, result = True, f"用户id=\n{user}"
elif cmd == "reset":
if bottype == "chatGPT":
bot.sessions.clear_session(session_id)
ok, result = True, "会话已重置"
else:
ok, result = False, "当前对话机器人不支持重置会话"
logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()):
if isadmin:
if isgroup:
ok, result = False, "群聊不可执行管理员指令"
else:
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias'])
if cmd == "stop":
self.isrunning = False
ok, result = True, "服务已暂停"
elif cmd == "resume":
self.isrunning = True
ok, result = True, "服务已恢复"
elif cmd == "reconf":
load_config()
ok, result = True, "配置已重载"
elif cmd == "resetall":
if bottype == "chatGPT":
bot.sessions.clear_all_session()
ok, result = True, "重置所有会话成功"
else:
ok, result = False, "当前对话机器人不支持重置会话"
elif cmd == "debug":
logger.setLevel('DEBUG')
ok, result = True, "DEBUG模式已开启"
elif cmd == "plist":
plugins = PluginManager().list_plugins()
ok = True
result = "插件列表:\n"
for name,plugincls in plugins.items():
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - "
if plugincls.enabled:
result += "已启用\n"
else:
result += "未启用\n"
elif cmd == "scanp":
new_plugins = PluginManager().scan_plugins()
ok, result = True, "插件扫描完成"
PluginManager().activate_plugins()
if len(new_plugins) >0 :
result += "\n发现新插件:\n"
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
else :
result +=", 未发现新插件"
elif cmd == "setpri":
if len(args) != 2:
ok, result = False, "请提供插件名和优先级"
else:
ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
if ok:
result = "插件" + args[0] + "优先级已设置为" + args[1]
else:
result = "插件不存在"
elif cmd == "reloadp":
if len(args) != 1:
ok, result = False, "请提供插件名"
else:
ok = PluginManager().reload_plugin(args[0])
if ok:
result = "插件配置已重载"
else:
result = "插件不存在"
elif cmd == "enablep":
if len(args) != 1:
ok, result = False, "请提供插件名"
else:
ok = PluginManager().enable_plugin(args[0])
if ok:
result = "插件已启用"
else:
result = "插件不存在"
elif cmd == "disablep":
if len(args) != 1:
ok, result = False, "请提供插件名"
else:
ok = PluginManager().disable_plugin(args[0])
if ok:
result = "插件已禁用"
else:
result = "插件不存在"

logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user))
else:
ok, result = False, "需要管理员权限才能执行该指令"
else:
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n"
reply = Reply()
if ok:
reply.type = ReplyType.INFO
else:
reply.type = ReplyType.ERROR
reply.content = result
e_context['reply'] = reply

e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
elif not self.isrunning:
e_context.action = EventAction.BREAK_PASS
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] :
if isgroup:
return False,"请勿在群聊中认证"
if isadmin:
return False,"管理员账号无需认证"
if len(self.password) == 0:
return False,"未设置口令,无法认证"
if len(args) != 1:
return False,"请提供口令"
password = args[0]
if password == self.password:
self.admin_users.append(userid)
return True,"认证成功"
else:
return False,"认证失败"


+ 0
- 0
plugins/hello/__init__.py View File


+ 46
- 0
plugins/hello/hello.py View File

@@ -0,0 +1,46 @@
# encoding:utf-8

from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
import plugins
from plugins import *
from common.log import logger


@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1)
class Hello(Plugin):
def __init__(self):
super().__init__()
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[Hello] inited")

def on_handle_context(self, e_context: EventContext):

if e_context['context'].type != ContextType.TEXT:
return
content = e_context['context'].content
logger.debug("[Hello] on_handle_context. content: %s" % content)
if content == "Hello":
reply = Reply()
reply.type = ReplyType.TEXT
msg = e_context['context']['msg']
if e_context['context']['isgroup']:
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
else:
reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑

if content == "Hi":
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = "Hi"
e_context['reply'] = reply
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply

if content == "End":
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World"
e_context['context'].type = "IMAGE_CREATE"
content = "The World"
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑

+ 3
- 0
plugins/plugin.py View File

@@ -0,0 +1,3 @@
class Plugin:
def __init__(self):
self.handlers = {}

+ 171
- 0
plugins/plugin_manager.py View File

@@ -0,0 +1,171 @@
# encoding:utf-8

import importlib
import json
import os
from common.singleton import singleton
from common.sorted_dict import SortedDict
from .event import *
from .plugin import *
from common.log import logger


@singleton
class PluginManager:
def __init__(self):
self.plugins = SortedDict(lambda k,v: v.priority,reverse=True)
self.listening_plugins = {}
self.instances = {}
self.pconf = {}

def register(self, name: str, desc: str, version: str, author: str, desire_priority: int = 0):
def wrapper(plugincls):
plugincls.name = name
plugincls.desc = desc
plugincls.version = version
plugincls.author = author
plugincls.priority = desire_priority
plugincls.enabled = True
self.plugins[name.upper()] = plugincls
logger.info("Plugin %s_v%s registered" % (name, version))
return plugincls
return wrapper

def save_config(self):
with open("plugins/plugins.json", "w", encoding="utf-8") as f:
json.dump(self.pconf, f, indent=4, ensure_ascii=False)

def load_config(self):
logger.info("Loading plugins config...")

modified = False
if os.path.exists("plugins/plugins.json"):
with open("plugins/plugins.json", "r", encoding="utf-8") as f:
pconf = json.load(f)
pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True)
else:
modified = True
pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)}
self.pconf = pconf
if modified:
self.save_config()
return pconf

def scan_plugins(self):
logger.info("Scaning plugins ...")
plugins_dir = "plugins"
for plugin_name in os.listdir(plugins_dir):
plugin_path = os.path.join(plugins_dir, plugin_name)
if os.path.isdir(plugin_path):
# 判断插件是否包含同名.py文件
main_module_path = os.path.join(plugin_path, plugin_name+".py")
if os.path.isfile(main_module_path):
# 导入插件
import_path = "{}.{}.{}".format(plugins_dir, plugin_name, plugin_name)
main_module = importlib.import_module(import_path)
pconf = self.pconf
new_plugins = []
modified = False
for name, plugincls in self.plugins.items():
rawname = plugincls.name
if rawname not in pconf["plugins"]:
new_plugins.append(plugincls)
modified = True
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority}
else:
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"]
self.plugins[name].priority = pconf["plugins"][rawname]["priority"]
self.plugins._update_heap(name) # 更新下plugins中的顺序
if modified:
self.save_config()
return new_plugins

def refresh_order(self):
for event in self.listening_plugins.keys():
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)

def activate_plugins(self): # 生成新开启的插件实例
for name, plugincls in self.plugins.items():
if plugincls.enabled:
if name not in self.instances:
instance = plugincls()
self.instances[name] = instance
for event in instance.handlers:
if event not in self.listening_plugins:
self.listening_plugins[event] = []
self.listening_plugins[event].append(name)
self.refresh_order()

def reload_plugin(self, name:str):
name = name.upper()
if name in self.instances:
for event in self.listening_plugins:
if name in self.listening_plugins[event]:
self.listening_plugins[event].remove(name)
del self.instances[name]
self.activate_plugins()
return True
return False
def load_plugins(self):
self.load_config()
self.scan_plugins()
pconf = self.pconf
logger.debug("plugins.json config={}".format(pconf))
for name,plugin in pconf["plugins"].items():
if name.upper() not in self.plugins:
logger.error("Plugin %s not found, but found in plugins.json" % name)
self.activate_plugins()

def emit_event(self, e_context: EventContext, *args, **kwargs):
if e_context.event in self.listening_plugins:
for name in self.listening_plugins[e_context.event]:
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event))
instance = self.instances[name]
instance.handlers[e_context.event](e_context, *args, **kwargs)
return e_context

def set_plugin_priority(self, name:str, priority:int):
name = name.upper()
if name not in self.plugins:
return False
if self.plugins[name].priority == priority:
return True
self.plugins[name].priority = priority
self.plugins._update_heap(name)
rawname = self.plugins[name].name
self.pconf["plugins"][rawname]["priority"] = priority
self.pconf["plugins"]._update_heap(rawname)
self.save_config()
self.refresh_order()
return True

def enable_plugin(self, name:str):
name = name.upper()
if name not in self.plugins:
return False
if not self.plugins[name].enabled :
self.plugins[name].enabled = True
rawname = self.plugins[name].name
self.pconf["plugins"][rawname]["enabled"] = True
self.save_config()
self.activate_plugins()
return True
return True
def disable_plugin(self, name:str):
name = name.upper()
if name not in self.plugins:
return False
if self.plugins[name].enabled :
self.plugins[name].enabled = False
rawname = self.plugins[name].name
self.pconf["plugins"][rawname]["enabled"] = False
self.save_config()
return True
return True
def list_plugins(self):
return self.plugins

+ 0
- 0
plugins/sdwebui/__init__.py View File


+ 70
- 0
plugins/sdwebui/config.json.template View File

@@ -0,0 +1,70 @@
{
"start":{
"host" : "127.0.0.1",
"port" : 7860
},
"defaults": {
"params": {
"sampler_name": "DPM++ 2M Karras",
"steps": 20,
"width": 512,
"height": 512,
"cfg_scale": 7,
"prompt":"masterpiece, best quality",
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
"enable_hr": false,
"hr_scale": 2,
"hr_upscaler": "Latent",
"hr_second_pass_steps": 15,
"denoising_strength": 0.7
},
"options": {
"sd_model_checkpoint": "perfectWorld_v2Baked"
}
},
"rules": [
{
"keywords": [
"横版",
"壁纸"
],
"params": {
"width": 640,
"height": 384
},
"desc": "分辨率会变成640x384"
},
{
"keywords": [
"竖版"
],
"params": {
"width": 384,
"height": 640
}
},
{
"keywords": [
"高清"
],
"params": {
"enable_hr": true,
"hr_scale": 1.6
},
"desc": "出图分辨率长宽都会提高1.6倍"
},
{
"keywords": [
"二次元"
],
"params": {
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
"prompt": "masterpiece, best quality"
},
"options": {
"sd_model_checkpoint": "meinamix_meinaV8"
},
"desc": "使用二次元风格模型出图"
}
]
}

+ 69
- 0
plugins/sdwebui/readme.md View File

@@ -0,0 +1,69 @@
### 插件描述
本插件用于将画图请求转发给stable diffusion webui。

### 环境要求
使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。
具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。

请**安装**本插件的依赖包```webuiapi```
```
```pip install webuiapi```
```
### 使用说明
请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。

#### 画图请求格式
用户的画图请求格式为:
```
<画图触发词><关键词1> <关键词2> ... <关键词n>:<prompt>
```
- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。
- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准:
- 关键词中包含`help`或`帮助`,会打印出帮助文档。
第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后

例如: 画横版 高清 二次元:cat
会触发三个关键词 "横版", "高清", "二次元",prompt为"cat"
若默认参数是:
```
"width": 512,
"height": 512,
"enable_hr": false,
"prompt": "8k"
"negative_prompt": "nsfw",
"sd_model_checkpoint": "perfectWorld_v2Baked"
```

"横版"触发的规则参数为:
```
"width": 640,
"height": 384,
```
"高清"触发的规则参数为:
```
"enable_hr": true,
"hr_scale": 1.6,
```
"二次元"触发的规则参数为:
```
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
"steps": 20,
"prompt": "masterpiece, best quality",

"sd_model_checkpoint": "meinamix_meinaV8"
```
最后将第一个":"后的内容cat连接在prompt后,得到最终参数为:
```
"width": 640,
"height": 384,
"enable_hr": true,
"hr_scale": 1.6,
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
"steps": 20,
"prompt": "masterpiece, best quality, cat",
"sd_model_checkpoint": "meinamix_meinaV8"
```
PS: 参数分为两部分:
- 一部分是params,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致
- 另一部分是options,指sdwebui的设置,使用的模型和vae需要写在里面。它和http://127.0.0.1:7860/sdapi/v1/options所返回的键一致。

+ 114
- 0
plugins/sdwebui/sdwebui.py View File

@@ -0,0 +1,114 @@
# encoding:utf-8

import json
import os
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import conf
import plugins
from plugins import *
from common.log import logger
import webuiapi
import io


@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent")
class SDWebUI(Plugin):
def __init__(self):
super().__init__()
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
self.rules = config["rules"]
defaults = config["defaults"]
self.default_params = defaults["params"]
self.default_options = defaults["options"]
self.start_args = config["start"]
self.api = webuiapi.WebUIApi(**self.start_args)
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[SD] inited")
except FileNotFoundError:
logger.error(f"[SD] init failed, {config_path} not found")
except Exception as e:
logger.error("[SD] init failed, exception: %s" % e)
def on_handle_context(self, e_context: EventContext):

if e_context['context'].type != ContextType.IMAGE_CREATE:
return

logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content)

logger.info("[SD] image_query={}".format(e_context['context'].content))
reply = Reply()
try:
content = e_context['context'].content[:]
# 解析用户输入 如"横版 高清 二次元:cat"
if ":" in content:
keywords, prompt = content.split(":", 1)
else:
keywords = content
prompt = ""

keywords = keywords.split()

if "help" in keywords or "帮助" in keywords:
reply.type = ReplyType.INFO
reply.content = self.get_help_text()
else:
rule_params = {}
rule_options = {}
for keyword in keywords:
matched = False
for rule in self.rules:
if keyword in rule["keywords"]:
for key in rule["params"]:
rule_params[key] = rule["params"][key]
if "options" in rule:
for key in rule["options"]:
rule_options[key] = rule["options"][key]
matched = True
break # 一个关键词只匹配一个规则
if not matched:
logger.warning("[SD] keyword not matched: %s" % keyword)
params = {**self.default_params, **rule_params}
options = {**self.default_options, **rule_options}
params["prompt"] = params.get("prompt", "")+f", {prompt}"
if len(options) > 0:
logger.info("[SD] cover options={}".format(options))
self.api.set_options(options)
logger.info("[SD] params={}".format(params))
result = self.api.txt2img(
**params
)
reply.type = ReplyType.IMAGE
b_img = io.BytesIO()
result.image.save(b_img, format="PNG")
reply.content = b_img
e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑
except Exception as e:
reply.type = ReplyType.ERROR
reply.content = "[SD] "+str(e)
logger.error("[SD] exception: %s" % e)
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
finally:
e_context['reply'] = reply

def get_help_text(self):
if not conf().get('image_create_prefix'):
return "画图功能未启用"
else:
trigger = conf()['image_create_prefix'][0]
help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n"
help_text += "目前可用关键词:\n"
for rule in self.rules:
keywords = [f"[{keyword}]" for keyword in rule['keywords']]
help_text += f"{','.join(keywords)}"
if "desc" in rule:
help_text += f"-{rule['desc']}\n"
else:
help_text += "\n"
return help_text

+ 4
- 2
voice/baidu/baidu_voice.py View File

@@ -4,6 +4,7 @@ baidu voice service
""" """
import time import time
from aip import AipSpeech from aip import AipSpeech
from bridge.reply import Reply, ReplyType
from common.log import logger from common.log import logger
from common.tmp_dir import TmpDir from common.tmp_dir import TmpDir
from voice.voice import Voice from voice.voice import Voice
@@ -30,7 +31,8 @@ class BaiduVoice(Voice):
with open(fileName, 'wb') as f: with open(fileName, 'wb') as f:
f.write(result) f.write(result)
logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
return fileName
reply = Reply(ReplyType.VOICE, fileName)
else: else:
logger.error('[Baidu] textToVoice error={}'.format(result)) logger.error('[Baidu] textToVoice error={}'.format(result))
return None
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
return reply

+ 17
- 10
voice/google/google_voice.py View File

@@ -6,6 +6,7 @@ google voice service
import pathlib import pathlib
import subprocess import subprocess
import time import time
from bridge.reply import Reply, ReplyType
import speech_recognition import speech_recognition
import pyttsx3 import pyttsx3
from common.log import logger from common.log import logger
@@ -36,16 +37,22 @@ class GoogleVoice(Voice):
text = self.recognizer.recognize_google(audio, language='zh-CN') text = self.recognizer.recognize_google(audio, language='zh-CN')
logger.info( logger.info(
'[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) '[Google] voiceToText text={} voice file name={}'.format(text, voice_file))
return text
reply = Reply(ReplyType.TEXT, text)
except speech_recognition.UnknownValueError: except speech_recognition.UnknownValueError:
return "抱歉,我听不懂。"
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
except speech_recognition.RequestError as e: except speech_recognition.RequestError as e:
return "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)

reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e))
finally:
return reply
def textToVoice(self, text): def textToVoice(self, text):
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
self.engine.save_to_file(text, textFile)
self.engine.runAndWait()
logger.info(
'[Google] textToVoice text={} voice file name={}'.format(text, textFile))
return textFile
try:
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
self.engine.save_to_file(text, textFile)
self.engine.runAndWait()
logger.info(
'[Google] textToVoice text={} voice file name={}'.format(text, textFile))
reply = Reply(ReplyType.VOICE, textFile)
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))
finally:
return reply

+ 12
- 6
voice/openai/openai_voice.py View File

@@ -4,6 +4,7 @@ google voice service
""" """
import json import json
import openai import openai
from bridge.reply import Reply, ReplyType
from config import conf from config import conf
from common.log import logger from common.log import logger
from voice.voice import Voice from voice.voice import Voice
@@ -16,12 +17,17 @@ class OpenaiVoice(Voice):
def voiceToText(self, voice_file): def voiceToText(self, voice_file):
logger.debug( logger.debug(
'[Openai] voice file name={}'.format(voice_file)) '[Openai] voice file name={}'.format(voice_file))
file = open(voice_file, "rb")
reply = openai.Audio.transcribe("whisper-1", file)
text = reply["text"]
logger.info(
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file))
return text
try:
file = open(voice_file, "rb")
result = openai.Audio.transcribe("whisper-1", file)
text = result["text"]
reply = Reply(ReplyType.TEXT, text)
logger.info(
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file))
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))
finally:
return reply


def textToVoice(self, text): def textToVoice(self, text):
pass pass

Loading…
Cancel
Save