You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 satır
4.7KB

  1. # encoding:utf-8
  2. import json
  3. import os
  4. from bridge.context import ContextType
  5. from bridge.reply import Reply, ReplyType
  6. from config import conf
  7. import plugins
  8. from plugins import *
  9. from common.log import logger
  10. import webuiapi
  11. import io
  12. @plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent")
  13. class SDWebUI(Plugin):
  14. def __init__(self):
  15. super().__init__()
  16. curdir = os.path.dirname(__file__)
  17. config_path = os.path.join(curdir, "config.json")
  18. try:
  19. with open(config_path, "r", encoding="utf-8") as f:
  20. config = json.load(f)
  21. self.rules = config["rules"]
  22. defaults = config["defaults"]
  23. self.default_params = defaults["params"]
  24. self.default_options = defaults["options"]
  25. self.start_args = config["start"]
  26. self.api = webuiapi.WebUIApi(**self.start_args)
  27. self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
  28. logger.info("[SD] inited")
  29. except FileNotFoundError:
  30. logger.error(f"[SD] init failed, {config_path} not found")
  31. except Exception as e:
  32. logger.error("[SD] init failed, exception: %s" % e)
  33. def on_handle_context(self, e_context: EventContext):
  34. if e_context['context'].type != ContextType.IMAGE_CREATE:
  35. return
  36. logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content)
  37. logger.info("[SD] image_query={}".format(e_context['context'].content))
  38. reply = Reply()
  39. try:
  40. content = e_context['context'].content[:]
  41. # 解析用户输入 如"横版 高清 二次元:cat"
  42. if ":" in content:
  43. keywords, prompt = content.split(":", 1)
  44. else:
  45. keywords = content
  46. prompt = ""
  47. keywords = keywords.split()
  48. if "help" in keywords or "帮助" in keywords:
  49. reply.type = ReplyType.INFO
  50. reply.content = self.get_help_text()
  51. else:
  52. rule_params = {}
  53. rule_options = {}
  54. for keyword in keywords:
  55. matched = False
  56. for rule in self.rules:
  57. if keyword in rule["keywords"]:
  58. for key in rule["params"]:
  59. rule_params[key] = rule["params"][key]
  60. if "options" in rule:
  61. for key in rule["options"]:
  62. rule_options[key] = rule["options"][key]
  63. matched = True
  64. break # 一个关键词只匹配一个规则
  65. if not matched:
  66. logger.warning("[SD] keyword not matched: %s" % keyword)
  67. params = {**self.default_params, **rule_params}
  68. options = {**self.default_options, **rule_options}
  69. params["prompt"] = params.get("prompt", "")+f", {prompt}"
  70. if len(options) > 0:
  71. logger.info("[SD] cover options={}".format(options))
  72. self.api.set_options(options)
  73. logger.info("[SD] params={}".format(params))
  74. result = self.api.txt2img(
  75. **params
  76. )
  77. reply.type = ReplyType.IMAGE
  78. b_img = io.BytesIO()
  79. result.image.save(b_img, format="PNG")
  80. reply.content = b_img
  81. e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑
  82. except Exception as e:
  83. reply.type = ReplyType.ERROR
  84. reply.content = "[SD] "+str(e)
  85. logger.error("[SD] exception: %s" % e)
  86. e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
  87. finally:
  88. e_context['reply'] = reply
  89. def get_help_text(self):
  90. if not conf().get('image_create_prefix'):
  91. return "画图功能未启用"
  92. else:
  93. trigger = conf()['image_create_prefix'][0]
  94. help_text = f"请使用<{trigger}[关键词1] [关键词2]...:提示语>的格式作画,如\"{trigger}横版 高清:cat\"\n"
  95. help_text += "目前可用关键词:\n"
  96. for rule in self.rules:
  97. keywords = [f"[{keyword}]" for keyword in rule['keywords']]
  98. help_text += f"{','.join(keywords)}"
  99. if "desc" in rule:
  100. help_text += f"-{rule['desc']}\n"
  101. else:
  102. help_text += "\n"
  103. return help_text