Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

89 lines
3.6KB

  1. # encoding:utf-8
  2. import json
  3. import os
  4. from bridge.context import ContextType
  5. from bridge.reply import Reply, ReplyType
  6. import plugins
  7. from plugins import *
  8. from common.log import logger
  9. import webuiapi
  10. import io
  11. @plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent")
  12. class SDWebUI(Plugin):
  13. def __init__(self):
  14. super().__init__()
  15. curdir = os.path.dirname(__file__)
  16. config_path = os.path.join(curdir, "config.json")
  17. try:
  18. with open(config_path, "r", encoding="utf-8") as f:
  19. config = json.load(f)
  20. self.rules = config["rules"]
  21. defaults = config["defaults"]
  22. self.default_params = defaults["params"]
  23. self.default_options = defaults["options"]
  24. self.start_args = config["start"]
  25. self.api = webuiapi.WebUIApi(**self.start_args)
  26. self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
  27. logger.info("[SD] inited")
  28. except FileNotFoundError:
  29. logger.error(f"[SD] init failed, {config_path} not found")
  30. except Exception as e:
  31. logger.error("[SD] init failed, exception: %s" % e)
  32. def on_handle_context(self, e_context: EventContext):
  33. if e_context['context'].type != ContextType.IMAGE_CREATE:
  34. return
  35. logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content)
  36. logger.info("[SD] image_query={}".format(e_context['context'].content))
  37. reply = Reply()
  38. try:
  39. content = e_context['context'].content[:]
  40. # 解析用户输入 如"横版 高清 二次元:cat"
  41. keywords, prompt = content.split(":", 1)
  42. keywords = keywords.split()
  43. rule_params = {}
  44. rule_options = {}
  45. for keyword in keywords:
  46. matched = False
  47. for rule in self.rules:
  48. if keyword in rule["keywords"]:
  49. for key in rule["params"]:
  50. rule_params[key] = rule["params"][key]
  51. if "options" in rule:
  52. for key in rule["options"]:
  53. rule_options[key] = rule["options"][key]
  54. matched = True
  55. break # 一个关键词只匹配一个规则
  56. if not matched:
  57. logger.warning("[SD] keyword not matched: %s" % keyword)
  58. params = {**self.default_params, **rule_params}
  59. options = {**self.default_options, **rule_options}
  60. params["prompt"] = params.get("prompt", "")+f", {prompt}"
  61. if len(options) > 0:
  62. logger.info("[SD] cover rule_options={}".format(rule_options))
  63. self.api.set_options(options)
  64. logger.info("[SD] params={}".format(params))
  65. result = self.api.txt2img(
  66. **params
  67. )
  68. reply.type = ReplyType.IMAGE
  69. b_img = io.BytesIO()
  70. result.image.save(b_img, format="PNG")
  71. reply.content = b_img
  72. e_context.action = EventAction.BREAK_PASS # 事件结束后,不跳过处理context的默认逻辑
  73. except Exception as e:
  74. reply.type = ReplyType.ERROR
  75. reply.content = "[SD] "+str(e)
  76. logger.error("[SD] exception: %s" % e)
  77. e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
  78. finally:
  79. e_context['reply'] = reply