浏览代码

feat: check app_code dynamically

master
zhayujie 1年前
父节点
当前提交
2f9e5b1219
共有 7 个文件被更改,包括 181 次插入35 次删除
  1. +6
    -0
      config.py
  2. +3
    -1
      plugins/godcmd/godcmd.py
  3. +0
    -0
      plugins/linkai/README.md
  4. +2
    -1
      plugins/linkai/config.json.template
  5. +53
    -13
      plugins/linkai/linkai.py
  6. +99
    -19
      plugins/linkai/midjourney.py
  7. +18
    -1
      plugins/plugin.py

+ 6
- 0
config.py 查看文件

@@ -252,3 +252,9 @@ def pconf(plugin_name: str) -> dict:
:return: 该插件的配置项 :return: 该插件的配置项
""" """
return plugin_config.get(plugin_name.lower()) return plugin_config.get(plugin_name.lower())


# 全局配置,用于存放全局生效的状态
global_config = {
"admin_users": []
}

+ 3
- 1
plugins/godcmd/godcmd.py 查看文件

@@ -13,7 +13,7 @@ from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from common import const from common import const
from common.log import logger from common.log import logger
from config import conf, load_config
from config import conf, load_config, global_config
from plugins import * from plugins import *


# 定义指令集 # 定义指令集
@@ -426,9 +426,11 @@ class Godcmd(Plugin):
password = args[0] password = args[0]
if password == self.password: if password == self.password:
self.admin_users.append(userid) self.admin_users.append(userid)
global_config["admin_users"].append(userid)
return True, "认证成功" return True, "认证成功"
elif password == self.temp_password: elif password == self.temp_password:
self.admin_users.append(userid) self.admin_users.append(userid)
global_config["admin_users"].append(userid)
return True, "认证成功,请尽快设置口令" return True, "认证成功,请尽快设置口令"
else: else:
return False, "认证失败" return False, "认证失败"


+ 0
- 0
plugins/linkai/README.md 查看文件


+ 2
- 1
plugins/linkai/config.json.template 查看文件

@@ -8,6 +8,7 @@
"mode": "relax", "mode": "relax",
"auto_translate": true, "auto_translate": true,
"max_tasks": 3, "max_tasks": 3,
"max_tasks_per_user": 1
"max_tasks_per_user": 1,
"use_image_create_prefix": true
} }
} }

+ 53
- 13
plugins/linkai/linkai.py 查看文件

@@ -8,7 +8,7 @@ from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from channel.chat_message import ChatMessage from channel.chat_message import ChatMessage
from common.log import logger from common.log import logger
from config import conf
from config import conf, global_config
from plugins import * from plugins import *
from .midjourney import MJBot, TaskType from .midjourney import MJBot, TaskType


@@ -46,14 +46,48 @@ class LinkAI(Plugin):
self.mj_bot.process_mj_task(mj_type, e_context) self.mj_bot.process_mj_task(mj_type, e_context)
return return


if context.content.startswith(f"{_get_trigger_prefix()}linkai"):
# 应用管理功能
self._process_admin_cmd(e_context)
return

if self._is_chat_task(e_context): if self._is_chat_task(e_context):
# 文本对话任务处理
self._process_chat_task(e_context) self._process_chat_task(e_context)


# 插件管理功能
def _process_admin_cmd(self, e_context: EventContext):
context = e_context['context']
cmd = context.content.split()
if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"):
_set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
return
if len(cmd) == 3 and cmd[1] == "app":
if not context.kwargs.get("isgroup"):
_set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR)
return
if e_context["context"]["session_id"] not in global_config["admin_users"]:
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
return
app_code = cmd[2]
group_name = context.kwargs.get("msg").from_user_nickname
group_mapping = self.config.get("group_app_map")
if group_mapping:
group_mapping[group_name] = app_code
else:
self.config["group_app_map"] = {group_name: app_code}
# 保存插件配置
super().save_config(self.config)
_set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO)
else:
_set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context, level=ReplyType.INFO)
return

# LinkAI 对话任务处理 # LinkAI 对话任务处理
def _is_chat_task(self, e_context: EventContext): def _is_chat_task(self, e_context: EventContext):
context = e_context['context'] context = e_context['context']
# 群聊应用管理 # 群聊应用管理
return self.config.get("knowledge_base") and context.kwargs.get("isgroup")
return self.config.get("group_app_map") and context.kwargs.get("isgroup")


def _process_chat_task(self, e_context: EventContext): def _process_chat_task(self, e_context: EventContext):
""" """
@@ -73,21 +107,27 @@ class LinkAI(Plugin):
:param group_name: 群聊名称 :param group_name: 群聊名称
:return: 应用code :return: 应用code
""" """
knowledge_base_config = self.config.get("knowledge_base")
if knowledge_base_config and knowledge_base_config.get("group_mapping"):
app_code = knowledge_base_config.get("group_mapping").get(group_name) \
or knowledge_base_config.get("group_mapping").get("ALL_GROUP")
group_mapping = self.config.get("group_app_map")
if group_mapping:
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
return app_code return app_code


def get_help_text(self, verbose=False, **kwargs): def get_help_text(self, verbose=False, **kwargs):
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = "利用midjourney来画图。\n"
trigger_prefix = _get_trigger_prefix()
help_text = "用于集成 LinkAI 提供的文本对话、知识库、绘画等能力。\n"
if not verbose: if not verbose:
return help_text return help_text
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
help_text += ""
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
return help_text return help_text


def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS

# 静态方法
def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS


def _get_trigger_prefix():
return conf().get("plugin_trigger_prefix", "$")

+ 99
- 19
plugins/linkai/midjourney.py 查看文件

@@ -28,6 +28,11 @@ class Status(Enum):
return self.name return self.name




class TaskMode(Enum):
FAST = "fast"
RELAX = "relax"


class MJTask: class MJTask:
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING): def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
self.id = id self.id = id
@@ -47,7 +52,6 @@ class MJTask:
class MJBot: class MJBot:
def __init__(self, config): def __init__(self, config):
self.base_url = "https://api.link-ai.chat/v1/img/midjourney" self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
# self.base_url = "http://127.0.0.1:8911/v1/img/midjourney"
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
self.config = config self.config = config
self.tasks = {} self.tasks = {}
@@ -71,10 +75,10 @@ class MJBot:
return TaskType.GENERATE return TaskType.GENERATE
elif cmd_list[0].lower() == f"{trigger_prefix}mju": elif cmd_list[0].lower() == f"{trigger_prefix}mju":
return TaskType.UPSCALE return TaskType.UPSCALE
# elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
# return TaskType.VARIATION
# elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
# return TaskType.RESET
elif self.config.get("use_image_create_prefix") and \
check_prefix(context.content, conf().get("image_create_prefix")):
return TaskType.GENERATE


def process_mj_task(self, mj_type: TaskType, e_context: EventContext): def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
""" """
@@ -86,12 +90,20 @@ class MJBot:
session_id = context["session_id"] session_id = context["session_id"]
cmd = context.content.split(maxsplit=1) cmd = context.content.split(maxsplit=1)
if len(cmd) == 1: if len(cmd) == 1:
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.ERROR)
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
return

if not self._check_rate_limit(session_id, e_context):
logger.warn("[MJ] midjourney task exceed rate limit")
return return


if mj_type == TaskType.GENERATE: if mj_type == TaskType.GENERATE:
# 图片生成
raw_prompt = cmd[1]
image_prefix = check_prefix(context.content, conf().get("image_create_prefix"))
if image_prefix:
raw_prompt = context.content.replace(image_prefix, "", 1)
else:
# 图片生成
raw_prompt = cmd[1]
reply = self.generate(raw_prompt, session_id, e_context) reply = self.generate(raw_prompt, session_id, e_context)
e_context['reply'] = reply e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
@@ -126,10 +138,12 @@ class MJBot:
图片生成 图片生成
:param prompt: 提示词 :param prompt: 提示词
:param user_id: 用户id :param user_id: 用户id
:param e_context: 对话上下文
:return: 任务ID :return: 任务ID
""" """
logger.info(f"[MJ] image generate, prompt={prompt}") logger.info(f"[MJ] image generate, prompt={prompt}")
body = {"prompt": prompt}
mode = self._fetch_mode(prompt)
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers) res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers)
if res.status_code == 200: if res.status_code == 200:
res = res.json() res = res.json()
@@ -137,7 +151,11 @@ class MJBot:
if res.get("code") == 200: if res.get("code") == 200:
task_id = res.get("data").get("taskId") task_id = res.get("data").get("taskId")
real_prompt = res.get("data").get("realPrompt") real_prompt = res.get("data").get("realPrompt")
content = f"🚀你的作品将在1~2分钟左右完成,请耐心等待\n- - - - - - - - -\n"
if mode == TaskMode.RELAX.name:
time_str = "1~10分钟"
else:
time_str = "1~2分钟"
content = f"🚀你的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
if real_prompt: if real_prompt:
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}" content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
else: else:
@@ -182,8 +200,9 @@ class MJBot:
return reply return reply


async def check_task(self, task: MJTask, e_context: EventContext): async def check_task(self, task: MJTask, e_context: EventContext):
max_retry_time = 80
while max_retry_time > 0:
max_retry_times = 90
while max_retry_times > 0:
await asyncio.sleep(10)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
url = f"{self.base_url}/tasks/{task.id}" url = f"{self.base_url}/tasks/{task.id}"
async with session.get(url, headers=self.headers) as res: async with session.get(url, headers=self.headers) as res:
@@ -193,14 +212,17 @@ class MJBot:
f"data={res_json.get('data')}, thread={threading.current_thread().name}") f"data={res_json.get('data')}, thread={threading.current_thread().name}")
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
# process success res # process success res
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.FINISHED
self._process_success_task(task, res_json.get("data"), e_context) self._process_success_task(task, res_json.get("data"), e_context)
return return
else: else:
logger.warn(f"[MJ] image check error, status_code={res.status}") logger.warn(f"[MJ] image check error, status_code={res.status}")
max_retry_time -= 20
await asyncio.sleep(10)
max_retry_time -= 1
max_retry_times -= 20
max_retry_times -= 1
logger.warn("[MJ] end from poll") logger.warn("[MJ] end from poll")
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.EXPIRED


def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext): def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
""" """
@@ -233,7 +255,39 @@ class MJBot:
self._print_tasks() self._print_tasks()
return return


def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
"""
midjourney任务限流控制
:param user_id: 用户id
:param e_context: 对话上下文
:return: 任务是否能够生成, True:可以生成, False: 被限流
"""
tasks = self.find_tasks_by_user_id(user_id)
task_count = len([t for t in tasks if t.status == Status.PENDING])
if task_count >= self.config.get("max_tasks_per_user"):
reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return False
task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
if task_count >= self.config.get("max_tasks"):
reply = Reply(ReplyType.INFO, "Midjourney服务的总任务数已达上限,请稍后再试")
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return False
return True

def _fetch_mode(self, prompt) -> str:
mode = self.config.get("mode")
if "--relax" in prompt or mode == TaskMode.RELAX.name:
return TaskMode.RELAX.name
return TaskMode.FAST.name

def _run_loop(self, loop: asyncio.BaseEventLoop): def _run_loop(self, loop: asyncio.BaseEventLoop):
"""
运行事件循环,用于轮询任务的线程
:param loop: 事件循环
"""
loop.run_forever() loop.run_forever()
loop.stop() loop.stop()


@@ -241,6 +295,16 @@ class MJBot:
for id in self.tasks: for id in self.tasks:
logger.debug(f"[MJ] current task: {self.tasks[id]}") logger.debug(f"[MJ] current task: {self.tasks[id]}")


def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
"""
设置回复文本
:param content: 回复内容
:param e_context: 对话上下文
:param level: 回复等级
"""
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS


def get_help_text(self, verbose=False, **kwargs): def get_help_text(self, verbose=False, **kwargs):
trigger_prefix = conf().get("plugin_trigger_prefix", "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
@@ -250,7 +314,23 @@ class MJBot:
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
return help_text return help_text


def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
def find_tasks_by_user_id(self, user_id) -> list[MJTask]:
result = []
with self.tasks_lock:
now = time.time()
for task in self.tasks.values():
if task.status == Status.PENDING and now > task.expiry_time:
task.status = Status.EXPIRED
logger.info(f"[MJ] {task} expired")
if task.user_id == user_id:
result.append(task)
return result


def check_prefix(content, prefix_list):
if not prefix_list:
return None
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None

+ 18
- 1
plugins/plugin.py 查看文件

@@ -1,6 +1,6 @@
import os import os
import json import json
from config import pconf
from config import pconf, plugin_config
from common.log import logger from common.log import logger




@@ -24,5 +24,22 @@ class Plugin:
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}") logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
return plugin_conf return plugin_conf


def save_config(self, config: dict):
try:
plugin_config[self.name] = config
# 写入全局配置
global_config_path = "./plugins/config.json"
if os.path.exists(global_config_path):
with open(global_config_path, "w", encoding='utf-8') as f:
json.dump(plugin_config, f, indent=4, ensure_ascii=False)
# 写入插件配置
plugin_config_path = os.path.join(self.path, "config.json")
if os.path.exists(plugin_config_path):
with open(plugin_config_path, "w", encoding='utf-8') as f:
json.dump(config, f, indent=4, ensure_ascii=False)

except Exception as e:
logger.warn("save plugin config failed: {}".format(e))

def get_help_text(self, **kwargs): def get_help_text(self, **kwargs):
return "暂无帮助信息" return "暂无帮助信息"

正在加载...
取消
保存