diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 93463cc..3a94b24 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -199,34 +199,38 @@ class MJBot: return reply async def check_task(self, task: MJTask, e_context: EventContext): - max_retry_times = 90 - while max_retry_times > 0: - await asyncio.sleep(10) - async with aiohttp.ClientSession() as session: - url = f"{self.base_url}/tasks/{task.id}" - try: - async with session.get(url, headers=self.headers) as res: - if res.status == 200: - res_json = await res.json() - logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, " - f"data={res_json.get('data')}, thread={threading.current_thread().name}") - if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: - # process success res - if self.tasks.get(task.id): - self.tasks[task.id].status = Status.FINISHED - self._process_success_task(task, res_json.get("data"), e_context) - return - else: - res_json = await res.json() - logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}") - max_retry_times -= 20 - except Exception as e: - max_retry_times -= 20 - logger.warn(e) - max_retry_times -= 1 - logger.warn("[MJ] end from poll") - if self.tasks.get(task.id): - self.tasks[task.id].status = Status.EXPIRED + try: + logger.debug(f"[MJ] start check task status, {task}") + max_retry_times = 90 + while max_retry_times > 0: + await asyncio.sleep(10) + async with aiohttp.ClientSession() as session: + url = f"{self.base_url}/tasks/{task.id}" + try: + async with session.get(url, headers=self.headers) as res: + if res.status == 200: + res_json = await res.json() + logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, " + f"data={res_json.get('data')}, thread={threading.current_thread().name}") + if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: + # process success res + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.FINISHED + self._process_success_task(task, res_json.get("data"), e_context) + return + else: + res_json = await res.json() + logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}") + max_retry_times -= 20 + except Exception as e: + max_retry_times -= 20 + logger.warn(e) + max_retry_times -= 1 + logger.warn("[MJ] end from poll") + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.EXPIRED + except Exception as e: + logger.error(e) def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext): """