# encoding:utf-8 import json import os from bridge.context import ContextType from bridge.reply import Reply, ReplyType import plugins from plugins import * from common.log import logger import webuiapi import io @plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent") class SDWebUI(Plugin): def __init__(self): super().__init__() curdir = os.path.dirname(__file__) config_path = os.path.join(curdir, "config.json") try: with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) self.rules = config["rules"] defaults = config["defaults"] self.default_params = defaults["params"] self.default_options = defaults["options"] self.start_args = config["start"] self.api = webuiapi.WebUIApi(**self.start_args) self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context logger.info("[SD] inited") except FileNotFoundError: logger.error(f"[SD] init failed, {config_path} not found") except Exception as e: logger.error("[SD] init failed, exception: %s" % e) def on_handle_context(self, e_context: EventContext): if e_context['context'].type != ContextType.IMAGE_CREATE: return logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content) logger.info("[SD] image_query={}".format(e_context['context'].content)) reply = Reply() try: content = e_context['context'].content[:] # 解析用户输入 如"横版 高清 二次元:cat" keywords, prompt = content.split(":", 1) keywords = keywords.split() rule_params = {} rule_options = {} for keyword in keywords: matched = False for rule in self.rules: if keyword in rule["keywords"]: for key in rule["params"]: rule_params[key] = rule["params"][key] if "options" in rule: for key in rule["options"]: rule_options[key] = rule["options"][key] matched = True break # 一个关键词只匹配一个规则 if not matched: logger.warning("[SD] keyword not matched: %s" % keyword) params = {**self.default_params, **rule_params} options = {**self.default_options, **rule_options} params["prompt"] = params.get("prompt", "")+f", {prompt}" if len(options) > 0: logger.info("[SD] cover rule_options={}".format(rule_options)) self.api.set_options(options) logger.info("[SD] params={}".format(params)) result = self.api.txt2img( **params ) reply.type = ReplyType.IMAGE b_img = io.BytesIO() result.image.save(b_img, format="PNG") reply.content = b_img e_context.action = EventAction.BREAK_PASS # 事件结束后,不跳过处理context的默认逻辑 except Exception as e: reply.type = ReplyType.ERROR reply.content = "[SD] "+str(e) logger.error("[SD] exception: %s" % e) e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 finally: e_context['reply'] = reply