|
@@ -10,6 +10,7 @@ import asyncio |
|
|
from bridge.context import ContextType |
|
|
from bridge.context import ContextType |
|
|
from plugins import EventContext, EventAction |
|
|
from plugins import EventContext, EventAction |
|
|
|
|
|
|
|
|
|
|
|
INVALID_REQUEST = 410 |
|
|
|
|
|
|
|
|
class TaskType(Enum): |
|
|
class TaskType(Enum): |
|
|
GENERATE = "generate" |
|
|
GENERATE = "generate" |
|
@@ -34,7 +35,8 @@ class TaskMode(Enum): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MJTask: |
|
|
class MJTask: |
|
|
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING): |
|
|
|
|
|
|
|
|
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30, |
|
|
|
|
|
status=Status.PENDING): |
|
|
self.id = id |
|
|
self.id = id |
|
|
self.user_id = user_id |
|
|
self.user_id = user_id |
|
|
self.task_type = task_type |
|
|
self.task_type = task_type |
|
@@ -48,17 +50,18 @@ class MJTask: |
|
|
def __str__(self): |
|
|
def __str__(self): |
|
|
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}" |
|
|
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# midjourney bot |
|
|
# midjourney bot |
|
|
class MJBot: |
|
|
class MJBot: |
|
|
def __init__(self, config): |
|
|
def __init__(self, config): |
|
|
self.base_url = "https://api.link-ai.chat/v1/img/midjourney" |
|
|
self.base_url = "https://api.link-ai.chat/v1/img/midjourney" |
|
|
|
|
|
|
|
|
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} |
|
|
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} |
|
|
self.config = config |
|
|
self.config = config |
|
|
self.tasks = {} |
|
|
self.tasks = {} |
|
|
self.temp_dict = {} |
|
|
self.temp_dict = {} |
|
|
self.tasks_lock = threading.Lock() |
|
|
self.tasks_lock = threading.Lock() |
|
|
self.event_loop = asyncio.new_event_loop() |
|
|
self.event_loop = asyncio.new_event_loop() |
|
|
# threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start() |
|
|
|
|
|
|
|
|
|
|
|
def judge_mj_task_type(self, e_context: EventContext): |
|
|
def judge_mj_task_type(self, e_context: EventContext): |
|
|
""" |
|
|
""" |
|
@@ -79,7 +82,6 @@ class MJBot: |
|
|
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): |
|
|
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): |
|
|
return TaskType.GENERATE |
|
|
return TaskType.GENERATE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_mj_task(self, mj_type: TaskType, e_context: EventContext): |
|
|
def process_mj_task(self, mj_type: TaskType, e_context: EventContext): |
|
|
""" |
|
|
""" |
|
|
处理mj任务 |
|
|
处理mj任务 |
|
@@ -143,13 +145,15 @@ class MJBot: |
|
|
logger.info(f"[MJ] image generate, prompt={prompt}") |
|
|
logger.info(f"[MJ] image generate, prompt={prompt}") |
|
|
mode = self._fetch_mode(prompt) |
|
|
mode = self._fetch_mode(prompt) |
|
|
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")} |
|
|
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")} |
|
|
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15)) |
|
|
|
|
|
|
|
|
if not self.config.get("img_proxy"): |
|
|
|
|
|
body["img_proxy"] = False |
|
|
|
|
|
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40)) |
|
|
if res.status_code == 200: |
|
|
if res.status_code == 200: |
|
|
res = res.json() |
|
|
res = res.json() |
|
|
logger.debug(f"[MJ] image generate, res={res}") |
|
|
logger.debug(f"[MJ] image generate, res={res}") |
|
|
if res.get("code") == 200: |
|
|
if res.get("code") == 200: |
|
|
task_id = res.get("data").get("taskId") |
|
|
|
|
|
real_prompt = res.get("data").get("realPrompt") |
|
|
|
|
|
|
|
|
task_id = res.get("data").get("task_id") |
|
|
|
|
|
real_prompt = res.get("data").get("real_prompt") |
|
|
if mode == TaskMode.RELAX.value: |
|
|
if mode == TaskMode.RELAX.value: |
|
|
time_str = "1~10分钟" |
|
|
time_str = "1~10分钟" |
|
|
else: |
|
|
else: |
|
@@ -160,7 +164,8 @@ class MJBot: |
|
|
else: |
|
|
else: |
|
|
content += f"prompt: {prompt}" |
|
|
content += f"prompt: {prompt}" |
|
|
reply = Reply(ReplyType.INFO, content) |
|
|
reply = Reply(ReplyType.INFO, content) |
|
|
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE) |
|
|
|
|
|
|
|
|
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, |
|
|
|
|
|
task_type=TaskType.GENERATE) |
|
|
# put to memory dict |
|
|
# put to memory dict |
|
|
self.tasks[task.id] = task |
|
|
self.tasks[task.id] = task |
|
|
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) |
|
|
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) |
|
@@ -169,18 +174,23 @@ class MJBot: |
|
|
else: |
|
|
else: |
|
|
res_json = res.json() |
|
|
res_json = res.json() |
|
|
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}") |
|
|
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}") |
|
|
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") |
|
|
|
|
|
|
|
|
if res.status_code == INVALID_REQUEST: |
|
|
|
|
|
reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容") |
|
|
|
|
|
else: |
|
|
|
|
|
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") |
|
|
return reply |
|
|
return reply |
|
|
|
|
|
|
|
|
def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply: |
|
|
def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply: |
|
|
logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") |
|
|
logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") |
|
|
body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index} |
|
|
|
|
|
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15)) |
|
|
|
|
|
|
|
|
body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index} |
|
|
|
|
|
if not self.config.get("img_proxy"): |
|
|
|
|
|
body["img_proxy"] = False |
|
|
|
|
|
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40)) |
|
|
logger.debug(res) |
|
|
logger.debug(res) |
|
|
if res.status_code == 200: |
|
|
if res.status_code == 200: |
|
|
res = res.json() |
|
|
res = res.json() |
|
|
if res.get("code") == 200: |
|
|
if res.get("code") == 200: |
|
|
task_id = res.get("data").get("taskId") |
|
|
|
|
|
|
|
|
task_id = res.get("data").get("task_id") |
|
|
logger.info(f"[MJ] image upscale processing, task_id={task_id}") |
|
|
logger.info(f"[MJ] image upscale processing, task_id={task_id}") |
|
|
content = f"🔎图片正在放大中,请耐心等待" |
|
|
content = f"🔎图片正在放大中,请耐心等待" |
|
|
reply = Reply(ReplyType.INFO, content) |
|
|
reply = Reply(ReplyType.INFO, content) |
|
@@ -208,7 +218,7 @@ class MJBot: |
|
|
time.sleep(10) |
|
|
time.sleep(10) |
|
|
url = f"{self.base_url}/tasks/{task.id}" |
|
|
url = f"{self.base_url}/tasks/{task.id}" |
|
|
try: |
|
|
try: |
|
|
res = requests.get(url, headers=self.headers, timeout=5) |
|
|
|
|
|
|
|
|
res = requests.get(url, headers=self.headers, timeout=8) |
|
|
if res.status_code == 200: |
|
|
if res.status_code == 200: |
|
|
res_json = res.json() |
|
|
res_json = res.json() |
|
|
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, " |
|
|
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, " |
|
@@ -219,6 +229,7 @@ class MJBot: |
|
|
self.tasks[task.id].status = Status.FINISHED |
|
|
self.tasks[task.id].status = Status.FINISHED |
|
|
self._process_success_task(task, res_json.get("data"), e_context) |
|
|
self._process_success_task(task, res_json.get("data"), e_context) |
|
|
return |
|
|
return |
|
|
|
|
|
max_retry_times -= 1 |
|
|
else: |
|
|
else: |
|
|
res_json = res.json() |
|
|
res_json = res.json() |
|
|
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}") |
|
|
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}") |
|
@@ -276,8 +287,8 @@ class MJBot: |
|
|
""" |
|
|
""" |
|
|
# channel send img |
|
|
# channel send img |
|
|
task.status = Status.FINISHED |
|
|
task.status = Status.FINISHED |
|
|
task.img_id = res.get("imgId") |
|
|
|
|
|
task.img_url = res.get("imgUrl") |
|
|
|
|
|
|
|
|
task.img_id = res.get("img_id") |
|
|
|
|
|
task.img_url = res.get("img_url") |
|
|
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}") |
|
|
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}") |
|
|
|
|
|
|
|
|
# send img |
|
|
# send img |
|
@@ -338,7 +349,7 @@ class MJBot: |
|
|
for id in self.tasks: |
|
|
for id in self.tasks: |
|
|
logger.debug(f"[MJ] current task: {self.tasks[id]}") |
|
|
logger.debug(f"[MJ] current task: {self.tasks[id]}") |
|
|
|
|
|
|
|
|
def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): |
|
|
|
|
|
|
|
|
def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR): |
|
|
""" |
|
|
""" |
|
|
设置回复文本 |
|
|
设置回复文本 |
|
|
:param content: 回复内容 |
|
|
:param content: 回复内容 |
|
@@ -358,7 +369,7 @@ class MJBot: |
|
|
|
|
|
|
|
|
return help_text |
|
|
return help_text |
|
|
|
|
|
|
|
|
def find_tasks_by_user_id(self, user_id) -> list[MJTask]: |
|
|
|
|
|
|
|
|
def find_tasks_by_user_id(self, user_id) -> list: |
|
|
result = [] |
|
|
result = [] |
|
|
with self.tasks_lock: |
|
|
with self.tasks_lock: |
|
|
now = time.time() |
|
|
now = time.time() |
|
@@ -377,4 +388,4 @@ def check_prefix(content, prefix_list): |
|
|
for prefix in prefix_list: |
|
|
for prefix in prefix_list: |
|
|
if content.startswith(prefix): |
|
|
if content.startswith(prefix): |
|
|
return prefix |
|
|
return prefix |
|
|
return None |
|
|
|
|
|
|
|
|
return None |