From 2f9e5b1219aa32f1c7b11a8c36c9eda71b8438d3 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Fri, 28 Jul 2023 12:40:06 +0800 Subject: [PATCH] feat: check app_code dynamically --- config.py | 6 ++ plugins/godcmd/godcmd.py | 4 +- plugins/linkai/README.md | 0 plugins/linkai/config.json.template | 3 +- plugins/linkai/linkai.py | 66 +++++++++++++--- plugins/linkai/midjourney.py | 118 +++++++++++++++++++++++----- plugins/plugin.py | 19 ++++- 7 files changed, 181 insertions(+), 35 deletions(-) create mode 100644 plugins/linkai/README.md diff --git a/config.py b/config.py index 85c5436..289f872 100644 --- a/config.py +++ b/config.py @@ -252,3 +252,9 @@ def pconf(plugin_name: str) -> dict: :return: 该插件的配置项 """ return plugin_config.get(plugin_name.lower()) + + +# 全局配置,用于存放全局生效的状态 +global_config = { + "admin_users": [] +} diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 08bc09e..0ff204e 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -13,7 +13,7 @@ from bridge.context import ContextType from bridge.reply import Reply, ReplyType from common import const from common.log import logger -from config import conf, load_config +from config import conf, load_config, global_config from plugins import * # 定义指令集 @@ -426,9 +426,11 @@ class Godcmd(Plugin): password = args[0] if password == self.password: self.admin_users.append(userid) + global_config["admin_users"].append(userid) return True, "认证成功" elif password == self.temp_password: self.admin_users.append(userid) + global_config["admin_users"].append(userid) return True, "认证成功,请尽快设置口令" else: return False, "认证失败" diff --git a/plugins/linkai/README.md b/plugins/linkai/README.md new file mode 100644 index 0000000..e69de29 diff --git a/plugins/linkai/config.json.template b/plugins/linkai/config.json.template index bdfed7a..98d114b 100644 --- a/plugins/linkai/config.json.template +++ b/plugins/linkai/config.json.template @@ -8,6 +8,7 @@ "mode": "relax", "auto_translate": true, "max_tasks": 3, - "max_tasks_per_user": 1 + "max_tasks_per_user": 1, + "use_image_create_prefix": true } } diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index 8698ad5..cff6e16 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -8,7 +8,7 @@ from bridge.context import ContextType from bridge.reply import Reply, ReplyType from channel.chat_message import ChatMessage from common.log import logger -from config import conf +from config import conf, global_config from plugins import * from .midjourney import MJBot, TaskType @@ -46,14 +46,48 @@ class LinkAI(Plugin): self.mj_bot.process_mj_task(mj_type, e_context) return + if context.content.startswith(f"{_get_trigger_prefix()}linkai"): + # 应用管理功能 + self._process_admin_cmd(e_context) + return + if self._is_chat_task(e_context): + # 文本对话任务处理 self._process_chat_task(e_context) + # 插件管理功能 + def _process_admin_cmd(self, e_context: EventContext): + context = e_context['context'] + cmd = context.content.split() + if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"): + _set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO) + return + if len(cmd) == 3 and cmd[1] == "app": + if not context.kwargs.get("isgroup"): + _set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR) + return + if e_context["context"]["session_id"] not in global_config["admin_users"]: + _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR) + return + app_code = cmd[2] + group_name = context.kwargs.get("msg").from_user_nickname + group_mapping = self.config.get("group_app_map") + if group_mapping: + group_mapping[group_name] = app_code + else: + self.config["group_app_map"] = {group_name: app_code} + # 保存插件配置 + super().save_config(self.config) + _set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO) + else: + _set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context, level=ReplyType.INFO) + return + # LinkAI 对话任务处理 def _is_chat_task(self, e_context: EventContext): context = e_context['context'] # 群聊应用管理 - return self.config.get("knowledge_base") and context.kwargs.get("isgroup") + return self.config.get("group_app_map") and context.kwargs.get("isgroup") def _process_chat_task(self, e_context: EventContext): """ @@ -73,21 +107,27 @@ class LinkAI(Plugin): :param group_name: 群聊名称 :return: 应用code """ - knowledge_base_config = self.config.get("knowledge_base") - if knowledge_base_config and knowledge_base_config.get("group_mapping"): - app_code = knowledge_base_config.get("group_mapping").get(group_name) \ - or knowledge_base_config.get("group_mapping").get("ALL_GROUP") + group_mapping = self.config.get("group_app_map") + if group_mapping: + app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP") return app_code def get_help_text(self, verbose=False, **kwargs): - trigger_prefix = conf().get("plugin_trigger_prefix", "$") - help_text = "利用midjourney来画图。\n" + trigger_prefix = _get_trigger_prefix() + help_text = "用于集成 LinkAI 提供的文本对话、知识库、绘画等能力。\n" if not verbose: return help_text - help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" + help_text += "" + help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" return help_text - def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): - reply = Reply(level, content) - e_context["reply"] = reply - e_context.action = EventAction.BREAK_PASS + +# 静态方法 +def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR): + reply = Reply(level, content) + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS + + +def _get_trigger_prefix(): + return conf().get("plugin_trigger_prefix", "$") diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index c7c6c35..a2a4662 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -28,6 +28,11 @@ class Status(Enum): return self.name +class TaskMode(Enum): + FAST = "fast" + RELAX = "relax" + + class MJTask: def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING): self.id = id @@ -47,7 +52,6 @@ class MJTask: class MJBot: def __init__(self, config): self.base_url = "https://api.link-ai.chat/v1/img/midjourney" - # self.base_url = "http://127.0.0.1:8911/v1/img/midjourney" self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} self.config = config self.tasks = {} @@ -71,10 +75,10 @@ class MJBot: return TaskType.GENERATE elif cmd_list[0].lower() == f"{trigger_prefix}mju": return TaskType.UPSCALE - # elif cmd_list[0].lower() == f"{trigger_prefix}mjv": - # return TaskType.VARIATION - # elif cmd_list[0].lower() == f"{trigger_prefix}mjr": - # return TaskType.RESET + elif self.config.get("use_image_create_prefix") and \ + check_prefix(context.content, conf().get("image_create_prefix")): + return TaskType.GENERATE + def process_mj_task(self, mj_type: TaskType, e_context: EventContext): """ @@ -86,12 +90,20 @@ class MJBot: session_id = context["session_id"] cmd = context.content.split(maxsplit=1) if len(cmd) == 1: - self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.ERROR) + self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO) + return + + if not self._check_rate_limit(session_id, e_context): + logger.warn("[MJ] midjourney task exceed rate limit") return if mj_type == TaskType.GENERATE: - # 图片生成 - raw_prompt = cmd[1] + image_prefix = check_prefix(context.content, conf().get("image_create_prefix")) + if image_prefix: + raw_prompt = context.content.replace(image_prefix, "", 1) + else: + # 图片生成 + raw_prompt = cmd[1] reply = self.generate(raw_prompt, session_id, e_context) e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS @@ -126,10 +138,12 @@ class MJBot: 图片生成 :param prompt: 提示词 :param user_id: 用户id + :param e_context: 对话上下文 :return: 任务ID """ logger.info(f"[MJ] image generate, prompt={prompt}") - body = {"prompt": prompt} + mode = self._fetch_mode(prompt) + body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")} res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers) if res.status_code == 200: res = res.json() @@ -137,7 +151,11 @@ class MJBot: if res.get("code") == 200: task_id = res.get("data").get("taskId") real_prompt = res.get("data").get("realPrompt") - content = f"🚀你的作品将在1~2分钟左右完成,请耐心等待\n- - - - - - - - -\n" + if mode == TaskMode.RELAX.name: + time_str = "1~10分钟" + else: + time_str = "1~2分钟" + content = f"🚀你的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n" if real_prompt: content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}" else: @@ -182,8 +200,9 @@ class MJBot: return reply async def check_task(self, task: MJTask, e_context: EventContext): - max_retry_time = 80 - while max_retry_time > 0: + max_retry_times = 90 + while max_retry_times > 0: + await asyncio.sleep(10) async with aiohttp.ClientSession() as session: url = f"{self.base_url}/tasks/{task.id}" async with session.get(url, headers=self.headers) as res: @@ -193,14 +212,17 @@ class MJBot: f"data={res_json.get('data')}, thread={threading.current_thread().name}") if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: # process success res + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.FINISHED self._process_success_task(task, res_json.get("data"), e_context) return else: logger.warn(f"[MJ] image check error, status_code={res.status}") - max_retry_time -= 20 - await asyncio.sleep(10) - max_retry_time -= 1 + max_retry_times -= 20 + max_retry_times -= 1 logger.warn("[MJ] end from poll") + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.EXPIRED def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext): """ @@ -233,7 +255,39 @@ class MJBot: self._print_tasks() return + def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool: + """ + midjourney任务限流控制 + :param user_id: 用户id + :param e_context: 对话上下文 + :return: 任务是否能够生成, True:可以生成, False: 被限流 + """ + tasks = self.find_tasks_by_user_id(user_id) + task_count = len([t for t in tasks if t.status == Status.PENDING]) + if task_count >= self.config.get("max_tasks_per_user"): + reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试") + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS + return False + task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING]) + if task_count >= self.config.get("max_tasks"): + reply = Reply(ReplyType.INFO, "Midjourney服务的总任务数已达上限,请稍后再试") + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS + return False + return True + + def _fetch_mode(self, prompt) -> str: + mode = self.config.get("mode") + if "--relax" in prompt or mode == TaskMode.RELAX.name: + return TaskMode.RELAX.name + return TaskMode.FAST.name + def _run_loop(self, loop: asyncio.BaseEventLoop): + """ + 运行事件循环,用于轮询任务的线程 + :param loop: 事件循环 + """ loop.run_forever() loop.stop() @@ -241,6 +295,16 @@ class MJBot: for id in self.tasks: logger.debug(f"[MJ] current task: {self.tasks[id]}") + def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): + """ + 设置回复文本 + :param content: 回复内容 + :param e_context: 对话上下文 + :param level: 回复等级 + """ + reply = Reply(level, content) + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS def get_help_text(self, verbose=False, **kwargs): trigger_prefix = conf().get("plugin_trigger_prefix", "$") @@ -250,7 +314,23 @@ class MJBot: help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" return help_text - def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): - reply = Reply(level, content) - e_context["reply"] = reply - e_context.action = EventAction.BREAK_PASS \ No newline at end of file + def find_tasks_by_user_id(self, user_id) -> list[MJTask]: + result = [] + with self.tasks_lock: + now = time.time() + for task in self.tasks.values(): + if task.status == Status.PENDING and now > task.expiry_time: + task.status = Status.EXPIRED + logger.info(f"[MJ] {task} expired") + if task.user_id == user_id: + result.append(task) + return result + + +def check_prefix(content, prefix_list): + if not prefix_list: + return None + for prefix in prefix_list: + if content.startswith(prefix): + return prefix + return None \ No newline at end of file diff --git a/plugins/plugin.py b/plugins/plugin.py index e7444d2..1d74a7c 100644 --- a/plugins/plugin.py +++ b/plugins/plugin.py @@ -1,6 +1,6 @@ import os import json -from config import pconf +from config import pconf, plugin_config from common.log import logger @@ -24,5 +24,22 @@ class Plugin: logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}") return plugin_conf + def save_config(self, config: dict): + try: + plugin_config[self.name] = config + # 写入全局配置 + global_config_path = "./plugins/config.json" + if os.path.exists(global_config_path): + with open(global_config_path, "w", encoding='utf-8') as f: + json.dump(plugin_config, f, indent=4, ensure_ascii=False) + # 写入插件配置 + plugin_config_path = os.path.join(self.path, "config.json") + if os.path.exists(plugin_config_path): + with open(plugin_config_path, "w", encoding='utf-8') as f: + json.dump(config, f, indent=4, ensure_ascii=False) + + except Exception as e: + logger.warn("save plugin config failed: {}".format(e)) + def get_help_text(self, **kwargs): return "暂无帮助信息"