浏览代码

fix: some image bug

master
zhayujie 1年前
父节点
当前提交
b22994c2d2
共有 2 个文件被更改,包括 29 次插入17 次删除
  1. +1
    -0
      plugins/linkai/config.json.template
  2. +28
    -17
      plugins/linkai/midjourney.py

+ 1
- 0
plugins/linkai/config.json.template 查看文件

@@ -7,6 +7,7 @@
"enabled": true,
"mode": "relax",
"auto_translate": true,
"img_proxy": true,
"max_tasks": 3,
"max_tasks_per_user": 1,
"use_image_create_prefix": true


+ 28
- 17
plugins/linkai/midjourney.py 查看文件

@@ -10,6 +10,7 @@ import asyncio
from bridge.context import ContextType
from plugins import EventContext, EventAction

INVALID_REQUEST = 410

class TaskType(Enum):
GENERATE = "generate"
@@ -34,7 +35,8 @@ class TaskMode(Enum):


class MJTask:
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30,
status=Status.PENDING):
self.id = id
self.user_id = user_id
self.task_type = task_type
@@ -48,17 +50,18 @@ class MJTask:
def __str__(self):
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"


# midjourney bot
class MJBot:
def __init__(self, config):
self.base_url = "https://api.link-ai.chat/v1/img/midjourney"

self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
self.config = config
self.tasks = {}
self.temp_dict = {}
self.tasks_lock = threading.Lock()
self.event_loop = asyncio.new_event_loop()
# threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start()

def judge_mj_task_type(self, e_context: EventContext):
"""
@@ -79,7 +82,6 @@ class MJBot:
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
return TaskType.GENERATE


def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
"""
处理mj任务
@@ -143,13 +145,15 @@ class MJBot:
logger.info(f"[MJ] image generate, prompt={prompt}")
mode = self._fetch_mode(prompt)
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15))
if not self.config.get("img_proxy"):
body["img_proxy"] = False
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
if res.status_code == 200:
res = res.json()
logger.debug(f"[MJ] image generate, res={res}")
if res.get("code") == 200:
task_id = res.get("data").get("taskId")
real_prompt = res.get("data").get("realPrompt")
task_id = res.get("data").get("task_id")
real_prompt = res.get("data").get("real_prompt")
if mode == TaskMode.RELAX.value:
time_str = "1~10分钟"
else:
@@ -160,7 +164,8 @@ class MJBot:
else:
content += f"prompt: {prompt}"
reply = Reply(ReplyType.INFO, content)
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE)
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
task_type=TaskType.GENERATE)
# put to memory dict
self.tasks[task.id] = task
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
@@ -169,18 +174,23 @@ class MJBot:
else:
res_json = res.json()
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
if res.status_code == INVALID_REQUEST:
reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
else:
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
return reply

def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index}
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15))
body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index}
if not self.config.get("img_proxy"):
body["img_proxy"] = False
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
logger.debug(res)
if res.status_code == 200:
res = res.json()
if res.get("code") == 200:
task_id = res.get("data").get("taskId")
task_id = res.get("data").get("task_id")
logger.info(f"[MJ] image upscale processing, task_id={task_id}")
content = f"🔎图片正在放大中,请耐心等待"
reply = Reply(ReplyType.INFO, content)
@@ -208,7 +218,7 @@ class MJBot:
time.sleep(10)
url = f"{self.base_url}/tasks/{task.id}"
try:
res = requests.get(url, headers=self.headers, timeout=5)
res = requests.get(url, headers=self.headers, timeout=8)
if res.status_code == 200:
res_json = res.json()
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
@@ -219,6 +229,7 @@ class MJBot:
self.tasks[task.id].status = Status.FINISHED
self._process_success_task(task, res_json.get("data"), e_context)
return
max_retry_times -= 1
else:
res_json = res.json()
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
@@ -276,8 +287,8 @@ class MJBot:
"""
# channel send img
task.status = Status.FINISHED
task.img_id = res.get("imgId")
task.img_url = res.get("imgUrl")
task.img_id = res.get("img_id")
task.img_url = res.get("img_url")
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")

# send img
@@ -338,7 +349,7 @@ class MJBot:
for id in self.tasks:
logger.debug(f"[MJ] current task: {self.tasks[id]}")

def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
"""
设置回复文本
:param content: 回复内容
@@ -358,7 +369,7 @@ class MJBot:

return help_text

def find_tasks_by_user_id(self, user_id) -> list[MJTask]:
def find_tasks_by_user_id(self, user_id) -> list:
result = []
with self.tasks_lock:
now = time.time()
@@ -377,4 +388,4 @@ def check_prefix(content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None
return None

正在加载...
取消
保存