Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

midjourney.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. from enum import Enum
  2. from config import conf
  3. from common.log import logger
  4. import requests
  5. import threading
  6. import time
  7. from bridge.reply import Reply, ReplyType
  8. import aiohttp
  9. import asyncio
  10. from bridge.context import ContextType
  11. from plugins import EventContext, EventAction
  12. INVALID_REQUEST = 410
  13. class TaskType(Enum):
  14. GENERATE = "generate"
  15. UPSCALE = "upscale"
  16. VARIATION = "variation"
  17. RESET = "reset"
  18. class Status(Enum):
  19. PENDING = "pending"
  20. FINISHED = "finished"
  21. EXPIRED = "expired"
  22. ABORTED = "aborted"
  23. def __str__(self):
  24. return self.name
  25. class TaskMode(Enum):
  26. FAST = "fast"
  27. RELAX = "relax"
  28. class MJTask:
  29. def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30,
  30. status=Status.PENDING):
  31. self.id = id
  32. self.user_id = user_id
  33. self.task_type = task_type
  34. self.raw_prompt = raw_prompt
  35. self.send_func = None # send_func(img_url)
  36. self.expiry_time = time.time() + expires
  37. self.status = status
  38. self.img_url = None # url
  39. self.img_id = None
  40. def __str__(self):
  41. return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
  42. # midjourney bot
  43. class MJBot:
  44. def __init__(self, config):
  45. self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
  46. self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
  47. self.config = config
  48. self.tasks = {}
  49. self.temp_dict = {}
  50. self.tasks_lock = threading.Lock()
  51. self.event_loop = asyncio.new_event_loop()
  52. def judge_mj_task_type(self, e_context: EventContext):
  53. """
  54. 判断MJ任务的类型
  55. :param e_context: 上下文
  56. :return: 任务类型枚举
  57. """
  58. if not self.config:
  59. return None
  60. trigger_prefix = conf().get("plugin_trigger_prefix", "$")
  61. context = e_context['context']
  62. if context.type == ContextType.TEXT:
  63. cmd_list = context.content.split(maxsplit=1)
  64. if cmd_list[0].lower() == f"{trigger_prefix}mj":
  65. return TaskType.GENERATE
  66. elif cmd_list[0].lower() == f"{trigger_prefix}mju":
  67. return TaskType.UPSCALE
  68. elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
  69. return TaskType.GENERATE
  70. def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
  71. """
  72. 处理mj任务
  73. :param mj_type: mj任务类型
  74. :param e_context: 对话上下文
  75. """
  76. context = e_context['context']
  77. session_id = context["session_id"]
  78. cmd = context.content.split(maxsplit=1)
  79. if len(cmd) == 1 and context.type == ContextType.TEXT:
  80. # midjourney 帮助指令
  81. self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
  82. return
  83. if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
  84. # midjourney 开关指令
  85. is_open = True
  86. tips_text = "开启"
  87. if cmd[1] == "close":
  88. tips_text = "关闭"
  89. is_open = False
  90. self.config["enabled"] = is_open
  91. self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
  92. return
  93. if not self.config.get("enabled"):
  94. logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
  95. self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
  96. return
  97. if not self._check_rate_limit(session_id, e_context):
  98. logger.warn("[MJ] midjourney task exceed rate limit")
  99. return
  100. if mj_type == TaskType.GENERATE:
  101. if context.type == ContextType.IMAGE_CREATE:
  102. raw_prompt = context.content
  103. else:
  104. # 图片生成
  105. raw_prompt = cmd[1]
  106. reply = self.generate(raw_prompt, session_id, e_context)
  107. e_context['reply'] = reply
  108. e_context.action = EventAction.BREAK_PASS
  109. return
  110. elif mj_type == TaskType.UPSCALE:
  111. # 图片放大
  112. clist = cmd[1].split()
  113. if len(clist) < 2:
  114. self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
  115. return
  116. img_id = clist[0]
  117. index = int(clist[1])
  118. if index < 1 or index > 4:
  119. self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
  120. return
  121. key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
  122. if self.temp_dict.get(key):
  123. self._set_reply_text(f"第 {index} 张图片已经放大过了", e_context)
  124. return
  125. # 图片放大操作
  126. reply = self.upscale(session_id, img_id, index, e_context)
  127. e_context['reply'] = reply
  128. e_context.action = EventAction.BREAK_PASS
  129. return
  130. else:
  131. self._set_reply_text(f"暂不支持该命令", e_context)
  132. def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
  133. """
  134. 图片生成
  135. :param prompt: 提示词
  136. :param user_id: 用户id
  137. :param e_context: 对话上下文
  138. :return: 任务ID
  139. """
  140. logger.info(f"[MJ] image generate, prompt={prompt}")
  141. mode = self._fetch_mode(prompt)
  142. body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
  143. if not self.config.get("img_proxy"):
  144. body["img_proxy"] = False
  145. res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
  146. if res.status_code == 200:
  147. res = res.json()
  148. logger.debug(f"[MJ] image generate, res={res}")
  149. if res.get("code") == 200:
  150. task_id = res.get("data").get("task_id")
  151. real_prompt = res.get("data").get("real_prompt")
  152. if mode == TaskMode.RELAX.value:
  153. time_str = "1~10分钟"
  154. else:
  155. time_str = "1分钟"
  156. content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
  157. if real_prompt:
  158. content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
  159. else:
  160. content += f"prompt: {prompt}"
  161. reply = Reply(ReplyType.INFO, content)
  162. task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
  163. task_type=TaskType.GENERATE)
  164. # put to memory dict
  165. self.tasks[task.id] = task
  166. # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
  167. self._do_check_task(task, e_context)
  168. return reply
  169. else:
  170. res_json = res.json()
  171. logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
  172. if res.status_code == INVALID_REQUEST:
  173. reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
  174. else:
  175. reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
  176. return reply
  177. def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
  178. logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
  179. body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index}
  180. if not self.config.get("img_proxy"):
  181. body["img_proxy"] = False
  182. res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
  183. logger.debug(res)
  184. if res.status_code == 200:
  185. res = res.json()
  186. if res.get("code") == 200:
  187. task_id = res.get("data").get("task_id")
  188. logger.info(f"[MJ] image upscale processing, task_id={task_id}")
  189. content = f"🔎图片正在放大中,请耐心等待"
  190. reply = Reply(ReplyType.INFO, content)
  191. task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=TaskType.UPSCALE)
  192. # put to memory dict
  193. self.tasks[task.id] = task
  194. key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
  195. self.temp_dict[key] = True
  196. # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
  197. self._do_check_task(task, e_context)
  198. return reply
  199. else:
  200. error_msg = ""
  201. if res.status_code == 461:
  202. error_msg = "请输入正确的图片ID"
  203. res_json = res.json()
  204. logger.error(f"[MJ] upscale error, msg={res_json.get('message')}, status_code={res.status_code}")
  205. reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
  206. return reply
  207. def check_task_sync(self, task: MJTask, e_context: EventContext):
  208. logger.debug(f"[MJ] start check task status, {task}")
  209. max_retry_times = 90
  210. while max_retry_times > 0:
  211. time.sleep(10)
  212. url = f"{self.base_url}/tasks/{task.id}"
  213. try:
  214. res = requests.get(url, headers=self.headers, timeout=8)
  215. if res.status_code == 200:
  216. res_json = res.json()
  217. logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
  218. f"data={res_json.get('data')}, thread={threading.current_thread().name}")
  219. if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
  220. # process success res
  221. if self.tasks.get(task.id):
  222. self.tasks[task.id].status = Status.FINISHED
  223. self._process_success_task(task, res_json.get("data"), e_context)
  224. return
  225. max_retry_times -= 1
  226. else:
  227. res_json = res.json()
  228. logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
  229. max_retry_times -= 20
  230. except Exception as e:
  231. max_retry_times -= 20
  232. logger.warn(e)
  233. logger.warn("[MJ] end from poll")
  234. if self.tasks.get(task.id):
  235. self.tasks[task.id].status = Status.EXPIRED
  236. async def check_task_async(self, task: MJTask, e_context: EventContext):
  237. try:
  238. logger.debug(f"[MJ] start check task status, {task}")
  239. max_retry_times = 90
  240. while max_retry_times > 0:
  241. await asyncio.sleep(10)
  242. async with aiohttp.ClientSession() as session:
  243. url = f"{self.base_url}/tasks/{task.id}"
  244. try:
  245. async with session.get(url, headers=self.headers) as res:
  246. if res.status == 200:
  247. res_json = await res.json()
  248. logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, "
  249. f"data={res_json.get('data')}, thread={threading.current_thread().name}")
  250. if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
  251. # process success res
  252. if self.tasks.get(task.id):
  253. self.tasks[task.id].status = Status.FINISHED
  254. self._process_success_task(task, res_json.get("data"), e_context)
  255. return
  256. else:
  257. res_json = await res.json()
  258. logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}")
  259. max_retry_times -= 20
  260. except Exception as e:
  261. max_retry_times -= 20
  262. logger.warn(e)
  263. max_retry_times -= 1
  264. logger.warn("[MJ] end from poll")
  265. if self.tasks.get(task.id):
  266. self.tasks[task.id].status = Status.EXPIRED
  267. except Exception as e:
  268. logger.error(e)
  269. def _do_check_task(self, task: MJTask, e_context: EventContext):
  270. threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
  271. def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
  272. """
  273. 处理任务成功的结果
  274. :param task: MJ任务
  275. :param res: 请求结果
  276. :param e_context: 对话上下文
  277. """
  278. # channel send img
  279. task.status = Status.FINISHED
  280. task.img_id = res.get("img_id")
  281. task.img_url = res.get("img_url")
  282. logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
  283. # send img
  284. reply = Reply(ReplyType.IMAGE_URL, task.img_url)
  285. channel = e_context["channel"]
  286. channel._send(reply, e_context["context"])
  287. # send info
  288. trigger_prefix = conf().get("plugin_trigger_prefix", "$")
  289. text = ""
  290. if task.task_type == TaskType.GENERATE:
  291. text = f"🎨绘画完成!\nprompt: {task.raw_prompt}\n- - - - - - - - -\n图片ID: {task.img_id}"
  292. text += f"\n\n🔎可使用 {trigger_prefix}mju 命令放大指定图片\n"
  293. text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
  294. reply = Reply(ReplyType.INFO, text)
  295. channel._send(reply, e_context["context"])
  296. self._print_tasks()
  297. return
  298. def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
  299. """
  300. midjourney任务限流控制
  301. :param user_id: 用户id
  302. :param e_context: 对话上下文
  303. :return: 任务是否能够生成, True:可以生成, False: 被限流
  304. """
  305. tasks = self.find_tasks_by_user_id(user_id)
  306. task_count = len([t for t in tasks if t.status == Status.PENDING])
  307. if task_count >= self.config.get("max_tasks_per_user"):
  308. reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
  309. e_context["reply"] = reply
  310. e_context.action = EventAction.BREAK_PASS
  311. return False
  312. task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
  313. if task_count >= self.config.get("max_tasks"):
  314. reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试")
  315. e_context["reply"] = reply
  316. e_context.action = EventAction.BREAK_PASS
  317. return False
  318. return True
  319. def _fetch_mode(self, prompt) -> str:
  320. mode = self.config.get("mode")
  321. if "--relax" in prompt or mode == TaskMode.RELAX.value:
  322. return TaskMode.RELAX.value
  323. return mode or TaskMode.FAST.value
  324. def _run_loop(self, loop: asyncio.BaseEventLoop):
  325. """
  326. 运行事件循环,用于轮询任务的线程
  327. :param loop: 事件循环
  328. """
  329. loop.run_forever()
  330. loop.stop()
  331. def _print_tasks(self):
  332. for id in self.tasks:
  333. logger.debug(f"[MJ] current task: {self.tasks[id]}")
  334. def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
  335. """
  336. 设置回复文本
  337. :param content: 回复内容
  338. :param e_context: 对话上下文
  339. :param level: 回复等级
  340. """
  341. reply = Reply(level, content)
  342. e_context["reply"] = reply
  343. e_context.action = EventAction.BREAK_PASS
  344. def get_help_text(self, verbose=False, **kwargs):
  345. trigger_prefix = conf().get("plugin_trigger_prefix", "$")
  346. help_text = "🎨利用Midjourney进行画图\n\n"
  347. if not verbose:
  348. return help_text
  349. help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
  350. return help_text
  351. def find_tasks_by_user_id(self, user_id) -> list:
  352. result = []
  353. with self.tasks_lock:
  354. now = time.time()
  355. for task in self.tasks.values():
  356. if task.status == Status.PENDING and now > task.expiry_time:
  357. task.status = Status.EXPIRED
  358. logger.info(f"[MJ] {task} expired")
  359. if task.user_id == user_id:
  360. result.append(task)
  361. return result
  362. def check_prefix(content, prefix_list):
  363. if not prefix_list:
  364. return None
  365. for prefix in prefix_list:
  366. if content.startswith(prefix):
  367. return prefix
  368. return None