diff --git a/plugins/linkai/config.json.template b/plugins/linkai/config.json.template index 98d114b..28fe2ca 100644 --- a/plugins/linkai/config.json.template +++ b/plugins/linkai/config.json.template @@ -7,6 +7,7 @@ "enabled": true, "mode": "relax", "auto_translate": true, + "img_proxy": true, "max_tasks": 3, "max_tasks_per_user": 1, "use_image_create_prefix": true diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 9d92ee6..06fe9b3 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -10,6 +10,7 @@ import asyncio from bridge.context import ContextType from plugins import EventContext, EventAction +INVALID_REQUEST = 410 class TaskType(Enum): GENERATE = "generate" @@ -34,7 +35,8 @@ class TaskMode(Enum): class MJTask: - def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING): + def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30, + status=Status.PENDING): self.id = id self.user_id = user_id self.task_type = task_type @@ -48,17 +50,18 @@ class MJTask: def __str__(self): return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}" + # midjourney bot class MJBot: def __init__(self, config): self.base_url = "https://api.link-ai.chat/v1/img/midjourney" + self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} self.config = config self.tasks = {} self.temp_dict = {} self.tasks_lock = threading.Lock() self.event_loop = asyncio.new_event_loop() - # threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start() def judge_mj_task_type(self, e_context: EventContext): """ @@ -79,7 +82,6 @@ class MJBot: elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): return TaskType.GENERATE - def process_mj_task(self, mj_type: TaskType, e_context: EventContext): """ 处理mj任务 @@ -143,13 +145,15 @@ class MJBot: logger.info(f"[MJ] image generate, prompt={prompt}") mode = self._fetch_mode(prompt) body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")} - res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15)) + if not self.config.get("img_proxy"): + body["img_proxy"] = False + res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40)) if res.status_code == 200: res = res.json() logger.debug(f"[MJ] image generate, res={res}") if res.get("code") == 200: - task_id = res.get("data").get("taskId") - real_prompt = res.get("data").get("realPrompt") + task_id = res.get("data").get("task_id") + real_prompt = res.get("data").get("real_prompt") if mode == TaskMode.RELAX.value: time_str = "1~10分钟" else: @@ -160,7 +164,8 @@ class MJBot: else: content += f"prompt: {prompt}" reply = Reply(ReplyType.INFO, content) - task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE) + task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, + task_type=TaskType.GENERATE) # put to memory dict self.tasks[task.id] = task # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) @@ -169,18 +174,23 @@ class MJBot: else: res_json = res.json() logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}") - reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") + if res.status_code == INVALID_REQUEST: + reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容") + else: + reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") return reply def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply: logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") - body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index} - res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15)) + body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index} + if not self.config.get("img_proxy"): + body["img_proxy"] = False + res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40)) logger.debug(res) if res.status_code == 200: res = res.json() if res.get("code") == 200: - task_id = res.get("data").get("taskId") + task_id = res.get("data").get("task_id") logger.info(f"[MJ] image upscale processing, task_id={task_id}") content = f"🔎图片正在放大中,请耐心等待" reply = Reply(ReplyType.INFO, content) @@ -208,7 +218,7 @@ class MJBot: time.sleep(10) url = f"{self.base_url}/tasks/{task.id}" try: - res = requests.get(url, headers=self.headers, timeout=5) + res = requests.get(url, headers=self.headers, timeout=8) if res.status_code == 200: res_json = res.json() logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, " @@ -219,6 +229,7 @@ class MJBot: self.tasks[task.id].status = Status.FINISHED self._process_success_task(task, res_json.get("data"), e_context) return + max_retry_times -= 1 else: res_json = res.json() logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}") @@ -276,8 +287,8 @@ class MJBot: """ # channel send img task.status = Status.FINISHED - task.img_id = res.get("imgId") - task.img_url = res.get("imgUrl") + task.img_id = res.get("img_id") + task.img_url = res.get("img_url") logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}") # send img @@ -338,7 +349,7 @@ class MJBot: for id in self.tasks: logger.debug(f"[MJ] current task: {self.tasks[id]}") - def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): + def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR): """ 设置回复文本 :param content: 回复内容 @@ -358,7 +369,7 @@ class MJBot: return help_text - def find_tasks_by_user_id(self, user_id) -> list[MJTask]: + def find_tasks_by_user_id(self, user_id) -> list: result = [] with self.tasks_lock: now = time.time() @@ -377,4 +388,4 @@ def check_prefix(content, prefix_list): for prefix in prefix_list: if content.startswith(prefix): return prefix - return None \ No newline at end of file + return None