From 6b247ae880ac3e1210243f1e4d58bc9de5e431eb Mon Sep 17 00:00:00 2001 From: zhayujie Date: Mon, 7 Aug 2023 19:14:09 +0800 Subject: [PATCH] feat: add midjourney variation and reset --- plugins/linkai/linkai.py | 4 +- plugins/linkai/midjourney.py | 115 +++++++++++++++++++---------------- 2 files changed, 64 insertions(+), 55 deletions(-) diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index 9f21e60..a71b3b1 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -131,7 +131,9 @@ class LinkAI(Plugin): if not verbose: return help_text help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n\n例如: \n"$linkai app Kv2fXJcH"\n\n' - help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" + help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID" + help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\"" + help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\"" return help_text diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 9512db7..735195a 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -11,6 +11,9 @@ from bridge.context import ContextType from plugins import EventContext, EventAction INVALID_REQUEST = 410 +NOT_FOUND_ORIGIN_IMAGE = 461 +NOT_FOUND_TASK = 462 + class TaskType(Enum): GENERATE = "generate" @@ -18,6 +21,9 @@ class TaskType(Enum): VARIATION = "variation" RESET = "reset" + def __str__(self): + return self.name + class Status(Enum): PENDING = "pending" @@ -34,6 +40,14 @@ class TaskMode(Enum): RELAX = "relax" +task_name_mapping = { + TaskType.GENERATE.name: "生成", + TaskType.UPSCALE.name: "放大", + TaskType.VARIATION.name: "变换", + TaskType.RESET.name: "重新生成", +} + + class MJTask: def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30, status=Status.PENDING): @@ -79,6 +93,10 @@ class MJBot: return TaskType.GENERATE elif cmd_list[0].lower() == f"{trigger_prefix}mju": return TaskType.UPSCALE + elif cmd_list[0].lower() == f"{trigger_prefix}mjv": + return TaskType.VARIATION + elif cmd_list[0].lower() == f"{trigger_prefix}mjr": + return TaskType.RESET elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): return TaskType.GENERATE @@ -127,8 +145,8 @@ class MJBot: e_context.action = EventAction.BREAK_PASS return - elif mj_type == TaskType.UPSCALE: - # 图片放大 + elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION: + # 图片放大/变换 clist = cmd[1].split() if len(clist) < 2: self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context) @@ -138,16 +156,27 @@ class MJBot: if index < 1 or index > 4: self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context) return - key = f"{TaskType.UPSCALE.name}_{img_id}_{index}" + key = f"{str(mj_type)}_{img_id}_{index}" if self.temp_dict.get(key): - self._set_reply_text(f"第 {index} 张图片已经放大过了", e_context) + self._set_reply_text(f"第 {index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context) return - # 图片放大操作 - reply = self.upscale(session_id, img_id, index, e_context) + # 执行图片放大/变换操作 + reply = self.do_operate(mj_type, session_id, img_id, e_context, index) e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS return + elif mj_type == TaskType.RESET: + # 图片重新生成 + clist = cmd[1].split() + if len(clist) < 1: + self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context) + return + img_id = clist[0] + # 图片重新生成 + reply = self.do_operate(mj_type, session_id, img_id, e_context) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS else: self._set_reply_text(f"暂不支持该命令", e_context) @@ -197,9 +226,12 @@ class MJBot: reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") return reply - def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply: - logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") - body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index} + def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext, + index: int = None) -> Reply: + logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}") + body = {"type": task_type.name, "img_id": img_id} + if index: + body["index"] = index if not self.config.get("img_proxy"): body["img_proxy"] = False res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40)) @@ -208,23 +240,24 @@ class MJBot: res = res.json() if res.get("code") == 200: task_id = res.get("data").get("task_id") - logger.info(f"[MJ] image upscale processing, task_id={task_id}") - content = f"🔎图片正在放大中,请耐心等待" + logger.info(f"[MJ] image operate processing, task_id={task_id}") + icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"} + content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待" reply = Reply(ReplyType.INFO, content) - task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=TaskType.UPSCALE) + task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type) # put to memory dict self.tasks[task.id] = task - key = f"{TaskType.UPSCALE.name}_{img_id}_{index}" + key = f"{task_type.name}_{img_id}_{index}" self.temp_dict[key] = True # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) self._do_check_task(task, e_context) return reply else: error_msg = "" - if res.status_code == 461: + if res.status_code == NOT_FOUND_ORIGIN_IMAGE: error_msg = "请输入正确的图片ID" res_json = res.json() - logger.error(f"[MJ] upscale error, msg={res_json.get('message')}, status_code={res.status_code}") + logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}") reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试") return reply @@ -258,40 +291,6 @@ class MJBot: if self.tasks.get(task.id): self.tasks[task.id].status = Status.EXPIRED - async def check_task_async(self, task: MJTask, e_context: EventContext): - try: - logger.debug(f"[MJ] start check task status, {task}") - max_retry_times = 90 - while max_retry_times > 0: - await asyncio.sleep(10) - async with aiohttp.ClientSession() as session: - url = f"{self.base_url}/tasks/{task.id}" - try: - async with session.get(url, headers=self.headers) as res: - if res.status == 200: - res_json = await res.json() - logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, " - f"data={res_json.get('data')}, thread={threading.current_thread().name}") - if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: - # process success res - if self.tasks.get(task.id): - self.tasks[task.id].status = Status.FINISHED - self._process_success_task(task, res_json.get("data"), e_context) - return - else: - res_json = await res.json() - logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}") - max_retry_times -= 20 - except Exception as e: - max_retry_times -= 20 - logger.warn(e) - max_retry_times -= 1 - logger.warn("[MJ] end from poll") - if self.tasks.get(task.id): - self.tasks[task.id].status = Status.EXPIRED - except Exception as e: - logger.error(e) - def _do_check_task(self, task: MJTask, e_context: EventContext): threading.Thread(target=self.check_task_sync, args=(task, e_context)).start() @@ -316,10 +315,17 @@ class MJBot: # send info trigger_prefix = conf().get("plugin_trigger_prefix", "$") text = "" - if task.task_type == TaskType.GENERATE: - text = f"🎨绘画完成!\nprompt: {task.raw_prompt}\n- - - - - - - - -\n图片ID: {task.img_id}" - text += f"\n\n🔎可使用 {trigger_prefix}mju 命令放大指定图片\n" + if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET: + text = f"🎨绘画完成!\n" + if task.raw_prompt: + text += f"prompt: {task.raw_prompt}\n" + text += f"- - - - - - - - -\n图片ID: {task.img_id}" + text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n" text += f"例如:\n{trigger_prefix}mju {task.img_id} 1" + text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n" + text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1" + text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n" + text += f"例如:\n{trigger_prefix}mjr {task.img_id}" reply = Reply(ReplyType.INFO, text) channel._send(reply, e_context["context"]) @@ -382,8 +388,9 @@ class MJBot: help_text = "🎨利用Midjourney进行画图\n\n" if not verbose: return help_text - help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" - + help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID" + help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\"" + help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\"" return help_text def find_tasks_by_user_id(self, user_id) -> list: