You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 satır
11KB

  1. from enum import Enum
  2. from config import conf
  3. from common.log import logger
  4. import requests
  5. import threading
  6. import time
  7. from bridge.reply import Reply, ReplyType
  8. import aiohttp
  9. import asyncio
  10. from bridge.context import ContextType
  11. from plugins import EventContext, EventAction
  12. class TaskType(Enum):
  13. GENERATE = "generate"
  14. UPSCALE = "upscale"
  15. VARIATION = "variation"
  16. RESET = "reset"
  17. class Status(Enum):
  18. PENDING = "pending"
  19. FINISHED = "finished"
  20. EXPIRED = "expired"
  21. ABORTED = "aborted"
  22. def __str__(self):
  23. return self.name
  24. class MJTask:
  25. def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
  26. self.id = id
  27. self.user_id = user_id
  28. self.task_type = task_type
  29. self.raw_prompt = raw_prompt
  30. self.send_func = None # send_func(img_url)
  31. self.expiry_time = time.time() + expires
  32. self.status = status
  33. self.img_url = None # url
  34. self.img_id = None
  35. def __str__(self):
  36. return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
  37. # midjourney bot
  38. class MJBot:
  39. def __init__(self, config):
  40. self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
  41. # self.base_url = "http://127.0.0.1:8911/v1/img/midjourney"
  42. self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
  43. self.config = config
  44. self.tasks = {}
  45. self.temp_dict = {}
  46. self.tasks_lock = threading.Lock()
  47. self.event_loop = asyncio.new_event_loop()
  48. threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start()
  49. def judge_mj_task_type(self, e_context: EventContext) -> TaskType:
  50. """
  51. 判断MJ任务的类型
  52. :param e_context: 上下文
  53. :return: 任务类型枚举
  54. """
  55. trigger_prefix = conf().get("plugin_trigger_prefix", "$")
  56. context = e_context['context']
  57. if context.type == ContextType.TEXT:
  58. if self.config and self.config.get("enabled"):
  59. cmd_list = context.content.split(maxsplit=1)
  60. if cmd_list[0].lower() == f"{trigger_prefix}mj":
  61. return TaskType.GENERATE
  62. elif cmd_list[0].lower() == f"{trigger_prefix}mju":
  63. return TaskType.UPSCALE
  64. # elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
  65. # return TaskType.VARIATION
  66. # elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
  67. # return TaskType.RESET
  68. def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
  69. """
  70. 处理mj任务
  71. :param mj_type: mj任务类型
  72. :param e_context: 对话上下文
  73. """
  74. context = e_context['context']
  75. session_id = context["session_id"]
  76. cmd = context.content.split(maxsplit=1)
  77. if len(cmd) == 1:
  78. self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.ERROR)
  79. return
  80. if mj_type == TaskType.GENERATE:
  81. # 图片生成
  82. raw_prompt = cmd[1]
  83. reply = self.generate(raw_prompt, session_id, e_context)
  84. e_context['reply'] = reply
  85. e_context.action = EventAction.BREAK_PASS
  86. return
  87. elif mj_type == TaskType.UPSCALE:
  88. # 图片放大
  89. clist = cmd[1].split()
  90. if len(clist) < 2:
  91. self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
  92. return
  93. img_id = clist[0]
  94. index = int(clist[1])
  95. if index < 1 or index > 4:
  96. self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
  97. return
  98. key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
  99. if self.temp_dict.get(key):
  100. self._set_reply_text(f"第 {index} 张图片已经放大过了", e_context)
  101. return
  102. # 图片放大操作
  103. reply = self.upscale(session_id, img_id, index, e_context)
  104. e_context['reply'] = reply
  105. e_context.action = EventAction.BREAK_PASS
  106. return
  107. else:
  108. self._set_reply_text(f"暂不支持该命令", e_context)
  109. def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
  110. """
  111. 图片生成
  112. :param prompt: 提示词
  113. :param user_id: 用户id
  114. :return: 任务ID
  115. """
  116. logger.info(f"[MJ] image generate, prompt={prompt}")
  117. body = {"prompt": prompt}
  118. res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers)
  119. if res.status_code == 200:
  120. res = res.json()
  121. logger.debug(f"[MJ] image generate, res={res}")
  122. if res.get("code") == 200:
  123. task_id = res.get("data").get("taskId")
  124. real_prompt = res.get("data").get("realPrompt")
  125. content = f"🚀你的作品将在1~2分钟左右完成,请耐心等待\n- - - - - - - - -\n"
  126. if real_prompt:
  127. content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
  128. else:
  129. content += f"prompt: {prompt}"
  130. reply = Reply(ReplyType.INFO, content)
  131. task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE)
  132. # put to memory dict
  133. self.tasks[task.id] = task
  134. asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
  135. return reply
  136. else:
  137. res_json = res.json()
  138. logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
  139. reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
  140. return reply
  141. def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
  142. logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
  143. body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index}
  144. res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers)
  145. if res.status_code == 200:
  146. res = res.json()
  147. logger.info(res)
  148. if res.get("code") == 200:
  149. task_id = res.get("data").get("taskId")
  150. content = f"🔎图片正在放大中,请耐心等待"
  151. reply = Reply(ReplyType.INFO, content)
  152. task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=TaskType.UPSCALE)
  153. # put to memory dict
  154. self.tasks[task.id] = task
  155. key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
  156. self.temp_dict[key] = True
  157. asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
  158. return reply
  159. else:
  160. error_msg = ""
  161. if res.status_code == 461:
  162. error_msg = "请输入正确的图片ID"
  163. res_json = res.json()
  164. logger.error(f"[MJ] upscale error, msg={res_json.get('message')}, status_code={res.status_code}")
  165. reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
  166. return reply
  167. async def check_task(self, task: MJTask, e_context: EventContext):
  168. max_retry_time = 80
  169. while max_retry_time > 0:
  170. async with aiohttp.ClientSession() as session:
  171. url = f"{self.base_url}/tasks/{task.id}"
  172. async with session.get(url, headers=self.headers) as res:
  173. if res.status == 200:
  174. res_json = await res.json()
  175. logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, "
  176. f"data={res_json.get('data')}, thread={threading.current_thread().name}")
  177. if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
  178. # process success res
  179. self._process_success_task(task, res_json.get("data"), e_context)
  180. return
  181. else:
  182. logger.warn(f"[MJ] image check error, status_code={res.status}")
  183. max_retry_time -= 20
  184. await asyncio.sleep(10)
  185. max_retry_time -= 1
  186. logger.warn("[MJ] end from poll")
  187. def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
  188. """
  189. 处理任务成功的结果
  190. :param task: MJ任务
  191. :param res: 请求结果
  192. :param e_context: 对话上下文
  193. """
  194. # channel send img
  195. task.status = Status.FINISHED
  196. task.img_id = res.get("imgId")
  197. task.img_url = res.get("imgUrl")
  198. logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
  199. # send img
  200. reply = Reply(ReplyType.IMAGE_URL, task.img_url)
  201. channel = e_context["channel"]
  202. channel._send(reply, e_context["context"])
  203. # send info
  204. trigger_prefix = conf().get("plugin_trigger_prefix", "$")
  205. text = ""
  206. if task.task_type == TaskType.GENERATE:
  207. text = f"🎨绘画完成!\nprompt: {task.raw_prompt}\n- - - - - - - - -\n图片ID: {task.img_id}"
  208. text += f"\n\n🔎可使用 {trigger_prefix}mju 命令放大指定图片\n"
  209. text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
  210. reply = Reply(ReplyType.INFO, text)
  211. channel._send(reply, e_context["context"])
  212. self._print_tasks()
  213. return
  214. def _run_loop(self, loop: asyncio.BaseEventLoop):
  215. loop.run_forever()
  216. loop.stop()
  217. def _print_tasks(self):
  218. for id in self.tasks:
  219. logger.debug(f"[MJ] current task: {self.tasks[id]}")
  220. def get_help_text(self, verbose=False, **kwargs):
  221. trigger_prefix = conf().get("plugin_trigger_prefix", "$")
  222. help_text = "利用midjourney来画图。\n"
  223. if not verbose:
  224. return help_text
  225. help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
  226. return help_text
  227. def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
  228. reply = Reply(level, content)
  229. e_context["reply"] = reply
  230. e_context.action = EventAction.BREAK_PASS