Browse Source

feat: check app_code dynamically

master
zhayujie 1 year ago
parent
commit
2f9e5b1219
7 changed files with 181 additions and 35 deletions
  1. +6
    -0
      config.py
  2. +3
    -1
      plugins/godcmd/godcmd.py
  3. +0
    -0
      plugins/linkai/README.md
  4. +2
    -1
      plugins/linkai/config.json.template
  5. +53
    -13
      plugins/linkai/linkai.py
  6. +99
    -19
      plugins/linkai/midjourney.py
  7. +18
    -1
      plugins/plugin.py

+ 6
- 0
config.py View File

@@ -252,3 +252,9 @@ def pconf(plugin_name: str) -> dict:
:return: 该插件的配置项
"""
return plugin_config.get(plugin_name.lower())


# 全局配置,用于存放全局生效的状态
global_config = {
"admin_users": []
}

+ 3
- 1
plugins/godcmd/godcmd.py View File

@@ -13,7 +13,7 @@ from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common import const
from common.log import logger
from config import conf, load_config
from config import conf, load_config, global_config
from plugins import *

# 定义指令集
@@ -426,9 +426,11 @@ class Godcmd(Plugin):
password = args[0]
if password == self.password:
self.admin_users.append(userid)
global_config["admin_users"].append(userid)
return True, "认证成功"
elif password == self.temp_password:
self.admin_users.append(userid)
global_config["admin_users"].append(userid)
return True, "认证成功,请尽快设置口令"
else:
return False, "认证失败"


+ 0
- 0
plugins/linkai/README.md View File


+ 2
- 1
plugins/linkai/config.json.template View File

@@ -8,6 +8,7 @@
"mode": "relax",
"auto_translate": true,
"max_tasks": 3,
"max_tasks_per_user": 1
"max_tasks_per_user": 1,
"use_image_create_prefix": true
}
}

+ 53
- 13
plugins/linkai/linkai.py View File

@@ -8,7 +8,7 @@ from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_message import ChatMessage
from common.log import logger
from config import conf
from config import conf, global_config
from plugins import *
from .midjourney import MJBot, TaskType

@@ -46,14 +46,48 @@ class LinkAI(Plugin):
self.mj_bot.process_mj_task(mj_type, e_context)
return

if context.content.startswith(f"{_get_trigger_prefix()}linkai"):
# 应用管理功能
self._process_admin_cmd(e_context)
return

if self._is_chat_task(e_context):
# 文本对话任务处理
self._process_chat_task(e_context)

# 插件管理功能
def _process_admin_cmd(self, e_context: EventContext):
context = e_context['context']
cmd = context.content.split()
if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"):
_set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
return
if len(cmd) == 3 and cmd[1] == "app":
if not context.kwargs.get("isgroup"):
_set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR)
return
if e_context["context"]["session_id"] not in global_config["admin_users"]:
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
return
app_code = cmd[2]
group_name = context.kwargs.get("msg").from_user_nickname
group_mapping = self.config.get("group_app_map")
if group_mapping:
group_mapping[group_name] = app_code
else:
self.config["group_app_map"] = {group_name: app_code}
# 保存插件配置
super().save_config(self.config)
_set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO)
else:
_set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context, level=ReplyType.INFO)
return

# LinkAI 对话任务处理
def _is_chat_task(self, e_context: EventContext):
context = e_context['context']
# 群聊应用管理
return self.config.get("knowledge_base") and context.kwargs.get("isgroup")
return self.config.get("group_app_map") and context.kwargs.get("isgroup")

def _process_chat_task(self, e_context: EventContext):
"""
@@ -73,21 +107,27 @@ class LinkAI(Plugin):
:param group_name: 群聊名称
:return: 应用code
"""
knowledge_base_config = self.config.get("knowledge_base")
if knowledge_base_config and knowledge_base_config.get("group_mapping"):
app_code = knowledge_base_config.get("group_mapping").get(group_name) \
or knowledge_base_config.get("group_mapping").get("ALL_GROUP")
group_mapping = self.config.get("group_app_map")
if group_mapping:
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
return app_code

def get_help_text(self, verbose=False, **kwargs):
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = "利用midjourney来画图。\n"
trigger_prefix = _get_trigger_prefix()
help_text = "用于集成 LinkAI 提供的文本对话、知识库、绘画等能力。\n"
if not verbose:
return help_text
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
help_text += ""
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
return help_text

def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS

# 静态方法
def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS


def _get_trigger_prefix():
return conf().get("plugin_trigger_prefix", "$")

+ 99
- 19
plugins/linkai/midjourney.py View File

@@ -28,6 +28,11 @@ class Status(Enum):
return self.name


class TaskMode(Enum):
FAST = "fast"
RELAX = "relax"


class MJTask:
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING):
self.id = id
@@ -47,7 +52,6 @@ class MJTask:
class MJBot:
def __init__(self, config):
self.base_url = "https://api.link-ai.chat/v1/img/midjourney"
# self.base_url = "http://127.0.0.1:8911/v1/img/midjourney"
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
self.config = config
self.tasks = {}
@@ -71,10 +75,10 @@ class MJBot:
return TaskType.GENERATE
elif cmd_list[0].lower() == f"{trigger_prefix}mju":
return TaskType.UPSCALE
# elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
# return TaskType.VARIATION
# elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
# return TaskType.RESET
elif self.config.get("use_image_create_prefix") and \
check_prefix(context.content, conf().get("image_create_prefix")):
return TaskType.GENERATE

def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
"""
@@ -86,12 +90,20 @@ class MJBot:
session_id = context["session_id"]
cmd = context.content.split(maxsplit=1)
if len(cmd) == 1:
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.ERROR)
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
return

if not self._check_rate_limit(session_id, e_context):
logger.warn("[MJ] midjourney task exceed rate limit")
return

if mj_type == TaskType.GENERATE:
# 图片生成
raw_prompt = cmd[1]
image_prefix = check_prefix(context.content, conf().get("image_create_prefix"))
if image_prefix:
raw_prompt = context.content.replace(image_prefix, "", 1)
else:
# 图片生成
raw_prompt = cmd[1]
reply = self.generate(raw_prompt, session_id, e_context)
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS
@@ -126,10 +138,12 @@ class MJBot:
图片生成
:param prompt: 提示词
:param user_id: 用户id
:param e_context: 对话上下文
:return: 任务ID
"""
logger.info(f"[MJ] image generate, prompt={prompt}")
body = {"prompt": prompt}
mode = self._fetch_mode(prompt)
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers)
if res.status_code == 200:
res = res.json()
@@ -137,7 +151,11 @@ class MJBot:
if res.get("code") == 200:
task_id = res.get("data").get("taskId")
real_prompt = res.get("data").get("realPrompt")
content = f"🚀你的作品将在1~2分钟左右完成,请耐心等待\n- - - - - - - - -\n"
if mode == TaskMode.RELAX.name:
time_str = "1~10分钟"
else:
time_str = "1~2分钟"
content = f"🚀你的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
if real_prompt:
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
else:
@@ -182,8 +200,9 @@ class MJBot:
return reply

async def check_task(self, task: MJTask, e_context: EventContext):
max_retry_time = 80
while max_retry_time > 0:
max_retry_times = 90
while max_retry_times > 0:
await asyncio.sleep(10)
async with aiohttp.ClientSession() as session:
url = f"{self.base_url}/tasks/{task.id}"
async with session.get(url, headers=self.headers) as res:
@@ -193,14 +212,17 @@ class MJBot:
f"data={res_json.get('data')}, thread={threading.current_thread().name}")
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
# process success res
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.FINISHED
self._process_success_task(task, res_json.get("data"), e_context)
return
else:
logger.warn(f"[MJ] image check error, status_code={res.status}")
max_retry_time -= 20
await asyncio.sleep(10)
max_retry_time -= 1
max_retry_times -= 20
max_retry_times -= 1
logger.warn("[MJ] end from poll")
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.EXPIRED

def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
"""
@@ -233,7 +255,39 @@ class MJBot:
self._print_tasks()
return

def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
"""
midjourney任务限流控制
:param user_id: 用户id
:param e_context: 对话上下文
:return: 任务是否能够生成, True:可以生成, False: 被限流
"""
tasks = self.find_tasks_by_user_id(user_id)
task_count = len([t for t in tasks if t.status == Status.PENDING])
if task_count >= self.config.get("max_tasks_per_user"):
reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return False
task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
if task_count >= self.config.get("max_tasks"):
reply = Reply(ReplyType.INFO, "Midjourney服务的总任务数已达上限,请稍后再试")
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return False
return True

def _fetch_mode(self, prompt) -> str:
mode = self.config.get("mode")
if "--relax" in prompt or mode == TaskMode.RELAX.name:
return TaskMode.RELAX.name
return TaskMode.FAST.name

def _run_loop(self, loop: asyncio.BaseEventLoop):
"""
运行事件循环,用于轮询任务的线程
:param loop: 事件循环
"""
loop.run_forever()
loop.stop()

@@ -241,6 +295,16 @@ class MJBot:
for id in self.tasks:
logger.debug(f"[MJ] current task: {self.tasks[id]}")

def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
"""
设置回复文本
:param content: 回复内容
:param e_context: 对话上下文
:param level: 回复等级
"""
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS

def get_help_text(self, verbose=False, **kwargs):
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
@@ -250,7 +314,23 @@ class MJBot:
help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
return help_text

def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR):
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
def find_tasks_by_user_id(self, user_id) -> list[MJTask]:
result = []
with self.tasks_lock:
now = time.time()
for task in self.tasks.values():
if task.status == Status.PENDING and now > task.expiry_time:
task.status = Status.EXPIRED
logger.info(f"[MJ] {task} expired")
if task.user_id == user_id:
result.append(task)
return result


def check_prefix(content, prefix_list):
if not prefix_list:
return None
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None

+ 18
- 1
plugins/plugin.py View File

@@ -1,6 +1,6 @@
import os
import json
from config import pconf
from config import pconf, plugin_config
from common.log import logger


@@ -24,5 +24,22 @@ class Plugin:
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
return plugin_conf

def save_config(self, config: dict):
try:
plugin_config[self.name] = config
# 写入全局配置
global_config_path = "./plugins/config.json"
if os.path.exists(global_config_path):
with open(global_config_path, "w", encoding='utf-8') as f:
json.dump(plugin_config, f, indent=4, ensure_ascii=False)
# 写入插件配置
plugin_config_path = os.path.join(self.path, "config.json")
if os.path.exists(plugin_config_path):
with open(plugin_config_path, "w", encoding='utf-8') as f:
json.dump(config, f, indent=4, ensure_ascii=False)

except Exception as e:
logger.warn("save plugin config failed: {}".format(e))

def get_help_text(self, **kwargs):
return "暂无帮助信息"

Loading…
Cancel
Save