From e027286b6de50e21dedce5b5663fed8c7183b758 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sun, 30 Jul 2023 15:16:19 +0800 Subject: [PATCH] fix: midjourney check task thread --- plugins/linkai/linkai.py | 6 +++- plugins/linkai/midjourney.py | 63 ++++++++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index c48da9d..1031a24 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -27,7 +27,8 @@ class LinkAI(Plugin): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.config = super().load_config() - self.mj_bot = MJBot(self.config.get("midjourney")) + if self.config: + self.mj_bot = MJBot(self.config.get("midjourney")) logger.info("[LinkAI] inited") def on_handle_context(self, e_context: EventContext): @@ -35,6 +36,9 @@ class LinkAI(Plugin): 消息处理逻辑 :param e_context: 消息上下文 """ + if not self.config: + return + context = e_context['context'] if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]: # filter content no need solve diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 3a94b24..9d92ee6 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -58,23 +58,24 @@ class MJBot: self.temp_dict = {} self.tasks_lock = threading.Lock() self.event_loop = asyncio.new_event_loop() - threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start() + # threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start() - def judge_mj_task_type(self, e_context: EventContext) -> TaskType: + def judge_mj_task_type(self, e_context: EventContext): """ 判断MJ任务的类型 :param e_context: 上下文 :return: 任务类型枚举 """ + if not self.config or not self.config.get("enabled"): + return None trigger_prefix = conf().get("plugin_trigger_prefix", "$") context = e_context['context'] if context.type == ContextType.TEXT: - if self.config and self.config.get("enabled"): - cmd_list = context.content.split(maxsplit=1) - if cmd_list[0].lower() == f"{trigger_prefix}mj": - return TaskType.GENERATE - elif cmd_list[0].lower() == f"{trigger_prefix}mju": - return TaskType.UPSCALE + cmd_list = context.content.split(maxsplit=1) + if cmd_list[0].lower() == f"{trigger_prefix}mj": + return TaskType.GENERATE + elif cmd_list[0].lower() == f"{trigger_prefix}mju": + return TaskType.UPSCALE elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): return TaskType.GENERATE @@ -142,7 +143,7 @@ class MJBot: logger.info(f"[MJ] image generate, prompt={prompt}") mode = self._fetch_mode(prompt) body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")} - res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers) + res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15)) if res.status_code == 200: res = res.json() logger.debug(f"[MJ] image generate, res={res}") @@ -152,7 +153,7 @@ class MJBot: if mode == TaskMode.RELAX.value: time_str = "1~10分钟" else: - time_str = "1~2分钟" + time_str = "1分钟" content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n" if real_prompt: content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}" @@ -162,7 +163,8 @@ class MJBot: task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE) # put to memory dict self.tasks[task.id] = task - asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + self._do_check_task(task, e_context) return reply else: res_json = res.json() @@ -173,7 +175,7 @@ class MJBot: def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply: logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index} - res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers) + res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15)) logger.debug(res) if res.status_code == 200: res = res.json() @@ -187,7 +189,8 @@ class MJBot: self.tasks[task.id] = task key = f"{TaskType.UPSCALE.name}_{img_id}_{index}" self.temp_dict[key] = True - asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + self._do_check_task(task, e_context) return reply else: error_msg = "" @@ -198,7 +201,36 @@ class MJBot: reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试") return reply - async def check_task(self, task: MJTask, e_context: EventContext): + def check_task_sync(self, task: MJTask, e_context: EventContext): + logger.debug(f"[MJ] start check task status, {task}") + max_retry_times = 90 + while max_retry_times > 0: + time.sleep(10) + url = f"{self.base_url}/tasks/{task.id}" + try: + res = requests.get(url, headers=self.headers, timeout=5) + if res.status_code == 200: + res_json = res.json() + logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, " + f"data={res_json.get('data')}, thread={threading.current_thread().name}") + if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: + # process success res + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.FINISHED + self._process_success_task(task, res_json.get("data"), e_context) + return + else: + res_json = res.json() + logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}") + max_retry_times -= 20 + except Exception as e: + max_retry_times -= 20 + logger.warn(e) + logger.warn("[MJ] end from poll") + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.EXPIRED + + async def check_task_async(self, task: MJTask, e_context: EventContext): try: logger.debug(f"[MJ] start check task status, {task}") max_retry_times = 90 @@ -232,6 +264,9 @@ class MJBot: except Exception as e: logger.error(e) + def _do_check_task(self, task: MJTask, e_context: EventContext): + threading.Thread(target=self.check_task_sync, args=(task, e_context)).start() + def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext): """ 处理任务成功的结果