You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

433 satır
18KB

  1. from enum import Enum
  2. from config import conf
  3. from common.log import logger
  4. import requests
  5. import threading
  6. import time
  7. from bridge.reply import Reply, ReplyType
  8. import asyncio
  9. from bridge.context import ContextType
  10. from plugins import EventContext, EventAction
  11. from .utils import Util
  12. INVALID_REQUEST = 410
  13. NOT_FOUND_ORIGIN_IMAGE = 461
  14. NOT_FOUND_TASK = 462
  15. class TaskType(Enum):
  16. GENERATE = "generate"
  17. UPSCALE = "upscale"
  18. VARIATION = "variation"
  19. RESET = "reset"
  20. def __str__(self):
  21. return self.name
  22. class Status(Enum):
  23. PENDING = "pending"
  24. FINISHED = "finished"
  25. EXPIRED = "expired"
  26. ABORTED = "aborted"
  27. def __str__(self):
  28. return self.name
  29. class TaskMode(Enum):
  30. FAST = "fast"
  31. RELAX = "relax"
  32. task_name_mapping = {
  33. TaskType.GENERATE.name: "生成",
  34. TaskType.UPSCALE.name: "放大",
  35. TaskType.VARIATION.name: "变换",
  36. TaskType.RESET.name: "重新生成",
  37. }
  38. class MJTask:
  39. def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 6,
  40. status=Status.PENDING):
  41. self.id = id
  42. self.user_id = user_id
  43. self.task_type = task_type
  44. self.raw_prompt = raw_prompt
  45. self.send_func = None # send_func(img_url)
  46. self.expiry_time = time.time() + expires
  47. self.status = status
  48. self.img_url = None # url
  49. self.img_id = None
  50. def __str__(self):
  51. return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
  52. # midjourney bot
  53. class MJBot:
  54. def __init__(self, config):
  55. self.base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/img/midjourney"
  56. self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
  57. self.config = config
  58. self.tasks = {}
  59. self.temp_dict = {}
  60. self.tasks_lock = threading.Lock()
  61. self.event_loop = asyncio.new_event_loop()
  62. def judge_mj_task_type(self, e_context: EventContext):
  63. """
  64. 判断MJ任务的类型
  65. :param e_context: 上下文
  66. :return: 任务类型枚举
  67. """
  68. if not self.config:
  69. return None
  70. trigger_prefix = conf().get("plugin_trigger_prefix", "$")
  71. context = e_context['context']
  72. if context.type == ContextType.TEXT:
  73. cmd_list = context.content.split(maxsplit=1)
  74. if not cmd_list:
  75. return None
  76. if cmd_list[0].lower() == f"{trigger_prefix}mj":
  77. return TaskType.GENERATE
  78. elif cmd_list[0].lower() == f"{trigger_prefix}mju":
  79. return TaskType.UPSCALE
  80. elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
  81. return TaskType.VARIATION
  82. elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
  83. return TaskType.RESET
  84. elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix") and self.config.get("enabled"):
  85. return TaskType.GENERATE
  86. def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
  87. """
  88. 处理mj任务
  89. :param mj_type: mj任务类型
  90. :param e_context: 对话上下文
  91. """
  92. context = e_context['context']
  93. session_id = context["session_id"]
  94. cmd = context.content.split(maxsplit=1)
  95. if len(cmd) == 1 and context.type == ContextType.TEXT:
  96. # midjourney 帮助指令
  97. self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
  98. return
  99. if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
  100. if not Util.is_admin(e_context):
  101. Util.set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
  102. return
  103. # midjourney 开关指令
  104. is_open = True
  105. tips_text = "开启"
  106. if cmd[1] == "close":
  107. tips_text = "关闭"
  108. is_open = False
  109. self.config["enabled"] = is_open
  110. self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
  111. return
  112. if not self.config.get("enabled"):
  113. logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
  114. self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
  115. return
  116. if not self._check_rate_limit(session_id, e_context):
  117. logger.warn("[MJ] midjourney task exceed rate limit")
  118. return
  119. if mj_type == TaskType.GENERATE:
  120. if context.type == ContextType.IMAGE_CREATE:
  121. raw_prompt = context.content
  122. else:
  123. # 图片生成
  124. raw_prompt = cmd[1]
  125. reply = self.generate(raw_prompt, session_id, e_context)
  126. e_context['reply'] = reply
  127. e_context.action = EventAction.BREAK_PASS
  128. return
  129. elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION:
  130. # 图片放大/变换
  131. clist = cmd[1].split()
  132. if len(clist) < 2:
  133. self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
  134. return
  135. img_id = clist[0]
  136. index = int(clist[1])
  137. if index < 1 or index > 4:
  138. self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
  139. return
  140. key = f"{str(mj_type)}_{img_id}_{index}"
  141. if self.temp_dict.get(key):
  142. self._set_reply_text(f"第 {index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context)
  143. return
  144. # 执行图片放大/变换操作
  145. reply = self.do_operate(mj_type, session_id, img_id, e_context, index)
  146. e_context['reply'] = reply
  147. e_context.action = EventAction.BREAK_PASS
  148. return
  149. elif mj_type == TaskType.RESET:
  150. # 图片重新生成
  151. clist = cmd[1].split()
  152. if len(clist) < 1:
  153. self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
  154. return
  155. img_id = clist[0]
  156. # 图片重新生成
  157. reply = self.do_operate(mj_type, session_id, img_id, e_context)
  158. e_context['reply'] = reply
  159. e_context.action = EventAction.BREAK_PASS
  160. else:
  161. self._set_reply_text(f"暂不支持该命令", e_context)
  162. def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
  163. """
  164. 图片生成
  165. :param prompt: 提示词
  166. :param user_id: 用户id
  167. :param e_context: 对话上下文
  168. :return: 任务ID
  169. """
  170. logger.info(f"[MJ] image generate, prompt={prompt}")
  171. mode = self._fetch_mode(prompt)
  172. body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
  173. if not self.config.get("img_proxy"):
  174. body["img_proxy"] = False
  175. res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
  176. if res.status_code == 200:
  177. res = res.json()
  178. logger.debug(f"[MJ] image generate, res={res}")
  179. if res.get("code") == 200:
  180. task_id = res.get("data").get("task_id")
  181. real_prompt = res.get("data").get("real_prompt")
  182. if mode == TaskMode.RELAX.value:
  183. time_str = "1~10分钟"
  184. else:
  185. time_str = "1分钟"
  186. content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
  187. if real_prompt:
  188. content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
  189. else:
  190. content += f"prompt: {prompt}"
  191. reply = Reply(ReplyType.INFO, content)
  192. task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
  193. task_type=TaskType.GENERATE)
  194. # put to memory dict
  195. self.tasks[task.id] = task
  196. # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
  197. self._do_check_task(task, e_context)
  198. return reply
  199. else:
  200. res_json = res.json()
  201. logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
  202. if res.status_code == INVALID_REQUEST:
  203. reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
  204. else:
  205. reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
  206. return reply
  207. def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext,
  208. index: int = None) -> Reply:
  209. logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}")
  210. body = {"type": task_type.name, "img_id": img_id}
  211. if index:
  212. body["index"] = index
  213. if not self.config.get("img_proxy"):
  214. body["img_proxy"] = False
  215. res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
  216. logger.debug(res)
  217. if res.status_code == 200:
  218. res = res.json()
  219. if res.get("code") == 200:
  220. task_id = res.get("data").get("task_id")
  221. logger.info(f"[MJ] image operate processing, task_id={task_id}")
  222. icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"}
  223. content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待"
  224. reply = Reply(ReplyType.INFO, content)
  225. task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type)
  226. # put to memory dict
  227. self.tasks[task.id] = task
  228. key = f"{task_type.name}_{img_id}_{index}"
  229. self.temp_dict[key] = True
  230. # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
  231. self._do_check_task(task, e_context)
  232. return reply
  233. else:
  234. error_msg = ""
  235. if res.status_code == NOT_FOUND_ORIGIN_IMAGE:
  236. error_msg = "请输入正确的图片ID"
  237. res_json = res.json()
  238. logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}")
  239. reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
  240. return reply
  241. def check_task_sync(self, task: MJTask, e_context: EventContext):
  242. logger.debug(f"[MJ] start check task status, {task}")
  243. max_retry_times = 90
  244. while max_retry_times > 0:
  245. time.sleep(10)
  246. url = f"{self.base_url}/tasks/{task.id}"
  247. try:
  248. res = requests.get(url, headers=self.headers, timeout=8)
  249. if res.status_code == 200:
  250. res_json = res.json()
  251. logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
  252. f"data={res_json.get('data')}, thread={threading.current_thread().name}")
  253. if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
  254. # process success res
  255. if self.tasks.get(task.id):
  256. self.tasks[task.id].status = Status.FINISHED
  257. self._process_success_task(task, res_json.get("data"), e_context)
  258. return
  259. max_retry_times -= 1
  260. else:
  261. res_json = res.json()
  262. logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
  263. max_retry_times -= 20
  264. except Exception as e:
  265. max_retry_times -= 20
  266. logger.warn(e)
  267. logger.warn("[MJ] end from poll")
  268. if self.tasks.get(task.id):
  269. self.tasks[task.id].status = Status.EXPIRED
  270. def _do_check_task(self, task: MJTask, e_context: EventContext):
  271. threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
  272. def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
  273. """
  274. 处理任务成功的结果
  275. :param task: MJ任务
  276. :param res: 请求结果
  277. :param e_context: 对话上下文
  278. """
  279. # channel send img
  280. task.status = Status.FINISHED
  281. task.img_id = res.get("img_id")
  282. task.img_url = res.get("img_url")
  283. logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
  284. # send img
  285. reply = Reply(ReplyType.IMAGE_URL, task.img_url)
  286. channel = e_context["channel"]
  287. _send(channel, reply, e_context["context"])
  288. # send info
  289. trigger_prefix = conf().get("plugin_trigger_prefix", "$")
  290. text = ""
  291. if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET:
  292. text = f"🎨绘画完成!\n"
  293. if task.raw_prompt:
  294. text += f"prompt: {task.raw_prompt}\n"
  295. text += f"- - - - - - - - -\n图片ID: {task.img_id}"
  296. text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n"
  297. text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
  298. text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n"
  299. text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1"
  300. text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n"
  301. text += f"例如:\n{trigger_prefix}mjr {task.img_id}"
  302. reply = Reply(ReplyType.INFO, text)
  303. _send(channel, reply, e_context["context"])
  304. self._print_tasks()
  305. return
  306. def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
  307. """
  308. midjourney任务限流控制
  309. :param user_id: 用户id
  310. :param e_context: 对话上下文
  311. :return: 任务是否能够生成, True:可以生成, False: 被限流
  312. """
  313. tasks = self.find_tasks_by_user_id(user_id)
  314. task_count = len([t for t in tasks if t.status == Status.PENDING])
  315. if task_count >= self.config.get("max_tasks_per_user"):
  316. reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
  317. e_context["reply"] = reply
  318. e_context.action = EventAction.BREAK_PASS
  319. return False
  320. task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
  321. if task_count >= self.config.get("max_tasks"):
  322. reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试")
  323. e_context["reply"] = reply
  324. e_context.action = EventAction.BREAK_PASS
  325. return False
  326. return True
  327. def _fetch_mode(self, prompt) -> str:
  328. mode = self.config.get("mode")
  329. if "--relax" in prompt or mode == TaskMode.RELAX.value:
  330. return TaskMode.RELAX.value
  331. return mode or TaskMode.FAST.value
  332. def _run_loop(self, loop: asyncio.BaseEventLoop):
  333. """
  334. 运行事件循环,用于轮询任务的线程
  335. :param loop: 事件循环
  336. """
  337. loop.run_forever()
  338. loop.stop()
  339. def _print_tasks(self):
  340. for id in self.tasks:
  341. logger.debug(f"[MJ] current task: {self.tasks[id]}")
  342. def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
  343. """
  344. 设置回复文本
  345. :param content: 回复内容
  346. :param e_context: 对话上下文
  347. :param level: 回复等级
  348. """
  349. reply = Reply(level, content)
  350. e_context["reply"] = reply
  351. e_context.action = EventAction.BREAK_PASS
  352. def get_help_text(self, verbose=False, **kwargs):
  353. trigger_prefix = conf().get("plugin_trigger_prefix", "$")
  354. help_text = "🎨利用Midjourney进行画图\n\n"
  355. if not verbose:
  356. return help_text
  357. help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID"
  358. help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
  359. help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
  360. return help_text
  361. def find_tasks_by_user_id(self, user_id) -> list:
  362. result = []
  363. with self.tasks_lock:
  364. now = time.time()
  365. for task in self.tasks.values():
  366. if task.status == Status.PENDING and now > task.expiry_time:
  367. task.status = Status.EXPIRED
  368. logger.info(f"[MJ] {task} expired")
  369. if task.user_id == user_id:
  370. result.append(task)
  371. return result
  372. def _send(channel, reply: Reply, context, retry_cnt=0):
  373. try:
  374. channel.send(reply, context)
  375. except Exception as e:
  376. logger.error("[WX] sendMsg error: {}".format(str(e)))
  377. if isinstance(e, NotImplementedError):
  378. return
  379. logger.exception(e)
  380. if retry_cnt < 2:
  381. time.sleep(3 + 3 * retry_cnt)
  382. channel.send(reply, context, retry_cnt + 1)
  383. def check_prefix(content, prefix_list):
  384. if not prefix_list:
  385. return None
  386. for prefix in prefix_list:
  387. if content.startswith(prefix):
  388. return prefix
  389. return None