Преглед на файлове

Merge pull request #1343 from zhayujie/feat-1.3.4

feat: add midjourney and app manager plugin
develop
zhayujie GitHub преди 1 година
родител
ревизия
39dd99b272
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
променени са 12 файла, в които са добавени 610 реда и са изтрити 31 реда
  1. +2
    -1
      .gitignore
  2. +1
    -1
      README.md
  3. +32
    -25
      bot/linkai/link_ai_bot.py
  4. +8
    -0
      config.py
  5. +1
    -0
      docker/docker-compose.yml
  6. +3
    -1
      plugins/godcmd/godcmd.py
  7. +0
    -0
      plugins/linkai/README.md
  8. +1
    -0
      plugins/linkai/__init__.py
  9. +14
    -0
      plugins/linkai/config.json.template
  10. +137
    -0
      plugins/linkai/linkai.py
  11. +391
    -0
      plugins/linkai/midjourney.py
  12. +20
    -3
      plugins/plugin.py

+ 2
- 1
.gitignore Целия файл

@@ -24,4 +24,5 @@ plugins/**/
!plugins/banwords/**/
!plugins/hello
!plugins/role
!plugins/keyword
!plugins/keyword
!plugins/linkai

+ 1
- 1
README.md Целия файл

@@ -111,7 +111,7 @@ pip3 install azure-cognitiveservices-speech
{
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
"proxy": "", # 代理客户端的ip和端口,国内网络环境需要填该项,如 "127.0.0.1:7890"
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复


+ 32
- 25
bot/linkai/link_ai_bot.py Целия файл

@@ -29,18 +29,24 @@ class LinkAIBot(Bot, OpenAIImage):
if context.type == ContextType.TEXT:
return self._chat(query, context)
elif context.type == ContextType.IMAGE_CREATE:
ok, retstring = self.create_img(query, 0)
reply = None
ok, res = self.create_img(query, 0)
if ok:
reply = Reply(ReplyType.IMAGE_URL, retstring)
reply = Reply(ReplyType.IMAGE_URL, res)
else:
reply = Reply(ReplyType.ERROR, retstring)
reply = Reply(ReplyType.ERROR, res)
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply

def _chat(self, query, context, retry_count=0):
def _chat(self, query, context, retry_count=0) -> Reply:
"""
发起对话请求
:param query: 请求提示词
:param context: 对话上下文
:param retry_count: 当前递归重试次数
:return: 回复
"""
if retry_count >= 2:
# exit from retry 2 times
logger.warn("[LINKAI] failed after maximum number of retry times")
@@ -52,7 +58,7 @@ class LinkAIBot(Bot, OpenAIImage):
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
app_code = None
else:
app_code = conf().get("linkai_app_code")
app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code")
linkai_api_key = conf().get("linkai_api_key")

session_id = context["session_id"]
@@ -63,10 +69,8 @@ class LinkAIBot(Bot, OpenAIImage):
if app_code and session.messages[0].get("role") == "system":
session.messages.pop(0)

logger.info(f"[LINKAI] query={query}, app_code={app_code}")

body = {
"appCode": app_code,
"app_code": app_code,
"messages": session.messages,
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature": conf().get("temperature"),
@@ -74,31 +78,34 @@ class LinkAIBot(Bot, OpenAIImage):
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
}
logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}")
headers = {"Authorization": "Bearer " + linkai_api_key}

# do http request
res = requests.post(url=self.base_url + "/chat/completion", json=body, headers=headers).json()

if not res or not res["success"]:
if res.get("code") == self.AUTH_FAILED_CODE:
logger.exception(f"[LINKAI] please check your linkai_api_key, res={res}")
return Reply(ReplyType.ERROR, "请再问我一次吧")
res = requests.post(url=self.base_url + "/chat/completions", json=body, headers=headers,
timeout=conf().get("request_timeout", 180))
if res.status_code == 200:
# execute success
response = res.json()
reply_content = response["choices"][0]["message"]["content"]
total_tokens = response["usage"]["total_tokens"]
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
self.sessions.session_reply(reply_content, session_id, total_tokens)
return Reply(ReplyType.TEXT, reply_content)

elif res.get("code") == self.NO_QUOTA_CODE:
logger.exception(f"[LINKAI] please check your account quota, https://chat.link-ai.tech/console/account")
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
else:
response = res.json()
error = response.get("error")
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
f"msg={error.get('message')}, type={error.get('type')}")

else:
# retry
if res.status_code >= 500:
# server error, need retry
time.sleep(2)
logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)

# execute success
reply_content = res["data"]["content"]
logger.info(f"[LINKAI] reply={reply_content}")
self.sessions.session_reply(reply_content, session_id)
return Reply(ReplyType.TEXT, reply_content)
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")

except Exception as e:
logger.exception(e)


+ 8
- 0
config.py Целия файл

@@ -102,6 +102,8 @@ available_setting = {
"appdata_dir": "", # 数据目录
# 插件配置
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
# 是否使用全局插件配置
"use_global_plugin_config": False,
# 知识库平台配置
"use_linkai": False,
"linkai_api_key": "",
@@ -252,3 +254,9 @@ def pconf(plugin_name: str) -> dict:
:return: 该插件的配置项
"""
return plugin_config.get(plugin_name.lower())


# 全局配置,用于存放全局生效的状态
global_config = {
"admin_users": []
}

+ 1
- 0
docker/docker-compose.yml Целия файл

@@ -18,6 +18,7 @@ services:
SPEECH_RECOGNITION: 'False'
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
EXPIRES_IN_SECONDS: 3600
USE_GLOBAL_PLUGIN_CONFIG: 'True'
USE_LINKAI: 'False'
LINKAI_API_KEY: ''
LINKAI_APP_CODE: ''

+ 3
- 1
plugins/godcmd/godcmd.py Целия файл

@@ -13,7 +13,7 @@ from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common import const
from common.log import logger
from config import conf, load_config
from config import conf, load_config, global_config
from plugins import *

# 定义指令集
@@ -426,9 +426,11 @@ class Godcmd(Plugin):
password = args[0]
if password == self.password:
self.admin_users.append(userid)
global_config["admin_users"].append(userid)
return True, "认证成功"
elif password == self.temp_password:
self.admin_users.append(userid)
global_config["admin_users"].append(userid)
return True, "认证成功,请尽快设置口令"
else:
return False, "认证失败"


+ 0
- 0
plugins/linkai/README.md Целия файл


+ 1
- 0
plugins/linkai/__init__.py Целия файл

@@ -0,0 +1 @@
from .linkai import *

+ 14
- 0
plugins/linkai/config.json.template Целия файл

@@ -0,0 +1,14 @@
{
"group_app_map": {
"测试群1": "default",
"测试群2": "Kv2fXJcH"
},
"midjourney": {
"enabled": true,
"auto_translate": true,
"img_proxy": true,
"max_tasks": 3,
"max_tasks_per_user": 1,
"use_image_create_prefix": true
}
}

+ 137
- 0
plugins/linkai/linkai.py Целия файл

@@ -0,0 +1,137 @@
import asyncio
import json
import threading
from concurrent.futures import ThreadPoolExecutor

import plugins
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_message import ChatMessage
from common.log import logger
from config import conf, global_config
from plugins import *
from .midjourney import MJBot, TaskType

# 任务线程池
task_thread_pool = ThreadPoolExecutor(max_workers=4)


@plugins.register(
name="linkai",
desc="A plugin that supports knowledge base and midjourney drawing.",
version="0.1.0",
author="https://link-ai.tech",
)
class LinkAI(Plugin):
def __init__(self):
super().__init__()
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
self.config = super().load_config()
if self.config:
self.mj_bot = MJBot(self.config.get("midjourney"))
logger.info("[LinkAI] inited")

def on_handle_context(self, e_context: EventContext):
"""
消息处理逻辑
:param e_context: 消息上下文
"""
if not self.config:
return

context = e_context['context']
if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]:
# filter content no need solve
return

mj_type = self.mj_bot.judge_mj_task_type(e_context)
if mj_type:
# MJ作图任务处理
self.mj_bot.process_mj_task(mj_type, e_context)
return

if context.content.startswith(f"{_get_trigger_prefix()}linkai"):
# 应用管理功能
self._process_admin_cmd(e_context)
return

if self._is_chat_task(e_context):
# 文本对话任务处理
self._process_chat_task(e_context)

# 插件管理功能
def _process_admin_cmd(self, e_context: EventContext):
context = e_context['context']
cmd = context.content.split()
if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"):
_set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
return
if len(cmd) == 3 and cmd[1] == "app":
if not context.kwargs.get("isgroup"):
_set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR)
return
if context.kwargs.get("msg").actual_user_id not in global_config["admin_users"]:
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
return
app_code = cmd[2]
group_name = context.kwargs.get("msg").from_user_nickname
group_mapping = self.config.get("group_app_map")
if group_mapping:
group_mapping[group_name] = app_code
else:
self.config["group_app_map"] = {group_name: app_code}
# 保存插件配置
super().save_config(self.config)
_set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO)
else:
_set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context, level=ReplyType.INFO)
return

# LinkAI 对话任务处理
def _is_chat_task(self, e_context: EventContext):
context = e_context['context']
# 群聊应用管理
return self.config.get("group_app_map") and context.kwargs.get("isgroup")

def _process_chat_task(self, e_context: EventContext):
"""
处理LinkAI对话任务
:param e_context: 对话上下文
"""
context = e_context['context']
# 群聊应用管理
group_name = context.kwargs.get("msg").from_user_nickname
app_code = self._fetch_group_app_code(group_name)
if app_code:
context.kwargs['app_code'] = app_code

def _fetch_group_app_code(self, group_name: str) -> str:
"""
根据群聊名称获取对应的应用code
:param group_name: 群聊名称
:return: 应用code
"""
group_mapping = self.config.get("group_app_map")
if group_mapping:
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
return app_code

def get_help_text(self, verbose=False, **kwargs):
trigger_prefix = _get_trigger_prefix()
help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画等能力。\n\n"
if not verbose:
return help_text
help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n\n例如: \n"$linkai app Kv2fXJcH"\n\n'
help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""
return help_text


# 静态方法
def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS


def _get_trigger_prefix():
return conf().get("plugin_trigger_prefix", "$")

+ 391
- 0
plugins/linkai/midjourney.py Целия файл

@@ -0,0 +1,391 @@
from enum import Enum
from config import conf
from common.log import logger
import requests
import threading
import time
from bridge.reply import Reply, ReplyType
import aiohttp
import asyncio
from bridge.context import ContextType
from plugins import EventContext, EventAction

INVALID_REQUEST = 410

class TaskType(Enum):
GENERATE = "generate"
UPSCALE = "upscale"
VARIATION = "variation"
RESET = "reset"


class Status(Enum):
PENDING = "pending"
FINISHED = "finished"
EXPIRED = "expired"
ABORTED = "aborted"

def __str__(self):
return self.name


class TaskMode(Enum):
FAST = "fast"
RELAX = "relax"


class MJTask:
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30,
status=Status.PENDING):
self.id = id
self.user_id = user_id
self.task_type = task_type
self.raw_prompt = raw_prompt
self.send_func = None # send_func(img_url)
self.expiry_time = time.time() + expires
self.status = status
self.img_url = None # url
self.img_id = None

def __str__(self):
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"


# midjourney bot
class MJBot:
def __init__(self, config):
self.base_url = "https://api.link-ai.chat/v1/img/midjourney"

self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
self.config = config
self.tasks = {}
self.temp_dict = {}
self.tasks_lock = threading.Lock()
self.event_loop = asyncio.new_event_loop()

def judge_mj_task_type(self, e_context: EventContext):
"""
判断MJ任务的类型
:param e_context: 上下文
:return: 任务类型枚举
"""
if not self.config or not self.config.get("enabled"):
return None
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
context = e_context['context']
if context.type == ContextType.TEXT:
cmd_list = context.content.split(maxsplit=1)
if cmd_list[0].lower() == f"{trigger_prefix}mj":
return TaskType.GENERATE
elif cmd_list[0].lower() == f"{trigger_prefix}mju":
return TaskType.UPSCALE
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
return TaskType.GENERATE

def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
"""
处理mj任务
:param mj_type: mj任务类型
:param e_context: 对话上下文
"""
context = e_context['context']
session_id = context["session_id"]
cmd = context.content.split(maxsplit=1)
if len(cmd) == 1 and context.type == ContextType.TEXT:
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
return

if not self._check_rate_limit(session_id, e_context):
logger.warn("[MJ] midjourney task exceed rate limit")
return

if mj_type == TaskType.GENERATE:
if context.type == ContextType.IMAGE_CREATE:
raw_prompt = context.content
else:
# 图片生成
raw_prompt = cmd[1]
reply = self.generate(raw_prompt, session_id, e_context)
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS
return

elif mj_type == TaskType.UPSCALE:
# 图片放大
clist = cmd[1].split()
if len(clist) < 2:
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
return
img_id = clist[0]
index = int(clist[1])
if index < 1 or index > 4:
self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
return
key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
if self.temp_dict.get(key):
self._set_reply_text(f"第 {index} 张图片已经放大过了", e_context)
return
# 图片放大操作
reply = self.upscale(session_id, img_id, index, e_context)
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS
return

else:
self._set_reply_text(f"暂不支持该命令", e_context)

def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
"""
图片生成
:param prompt: 提示词
:param user_id: 用户id
:param e_context: 对话上下文
:return: 任务ID
"""
logger.info(f"[MJ] image generate, prompt={prompt}")
mode = self._fetch_mode(prompt)
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
if not self.config.get("img_proxy"):
body["img_proxy"] = False
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
if res.status_code == 200:
res = res.json()
logger.debug(f"[MJ] image generate, res={res}")
if res.get("code") == 200:
task_id = res.get("data").get("task_id")
real_prompt = res.get("data").get("real_prompt")
if mode == TaskMode.RELAX.value:
time_str = "1~10分钟"
else:
time_str = "1分钟"
content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
if real_prompt:
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
else:
content += f"prompt: {prompt}"
reply = Reply(ReplyType.INFO, content)
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
task_type=TaskType.GENERATE)
# put to memory dict
self.tasks[task.id] = task
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
self._do_check_task(task, e_context)
return reply
else:
res_json = res.json()
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
if res.status_code == INVALID_REQUEST:
reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
else:
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
return reply

def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply:
logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}")
body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index}
if not self.config.get("img_proxy"):
body["img_proxy"] = False
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
logger.debug(res)
if res.status_code == 200:
res = res.json()
if res.get("code") == 200:
task_id = res.get("data").get("task_id")
logger.info(f"[MJ] image upscale processing, task_id={task_id}")
content = f"🔎图片正在放大中,请耐心等待"
reply = Reply(ReplyType.INFO, content)
task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=TaskType.UPSCALE)
# put to memory dict
self.tasks[task.id] = task
key = f"{TaskType.UPSCALE.name}_{img_id}_{index}"
self.temp_dict[key] = True
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
self._do_check_task(task, e_context)
return reply
else:
error_msg = ""
if res.status_code == 461:
error_msg = "请输入正确的图片ID"
res_json = res.json()
logger.error(f"[MJ] upscale error, msg={res_json.get('message')}, status_code={res.status_code}")
reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
return reply

def check_task_sync(self, task: MJTask, e_context: EventContext):
logger.debug(f"[MJ] start check task status, {task}")
max_retry_times = 90
while max_retry_times > 0:
time.sleep(10)
url = f"{self.base_url}/tasks/{task.id}"
try:
res = requests.get(url, headers=self.headers, timeout=8)
if res.status_code == 200:
res_json = res.json()
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
f"data={res_json.get('data')}, thread={threading.current_thread().name}")
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
# process success res
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.FINISHED
self._process_success_task(task, res_json.get("data"), e_context)
return
max_retry_times -= 1
else:
res_json = res.json()
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
max_retry_times -= 20
except Exception as e:
max_retry_times -= 20
logger.warn(e)
logger.warn("[MJ] end from poll")
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.EXPIRED

async def check_task_async(self, task: MJTask, e_context: EventContext):
try:
logger.debug(f"[MJ] start check task status, {task}")
max_retry_times = 90
while max_retry_times > 0:
await asyncio.sleep(10)
async with aiohttp.ClientSession() as session:
url = f"{self.base_url}/tasks/{task.id}"
try:
async with session.get(url, headers=self.headers) as res:
if res.status == 200:
res_json = await res.json()
logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, "
f"data={res_json.get('data')}, thread={threading.current_thread().name}")
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
# process success res
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.FINISHED
self._process_success_task(task, res_json.get("data"), e_context)
return
else:
res_json = await res.json()
logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}")
max_retry_times -= 20
except Exception as e:
max_retry_times -= 20
logger.warn(e)
max_retry_times -= 1
logger.warn("[MJ] end from poll")
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.EXPIRED
except Exception as e:
logger.error(e)

def _do_check_task(self, task: MJTask, e_context: EventContext):
threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()

def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
"""
处理任务成功的结果
:param task: MJ任务
:param res: 请求结果
:param e_context: 对话上下文
"""
# channel send img
task.status = Status.FINISHED
task.img_id = res.get("img_id")
task.img_url = res.get("img_url")
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")

# send img
reply = Reply(ReplyType.IMAGE_URL, task.img_url)
channel = e_context["channel"]
channel._send(reply, e_context["context"])

# send info
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
text = ""
if task.task_type == TaskType.GENERATE:
text = f"🎨绘画完成!\nprompt: {task.raw_prompt}\n- - - - - - - - -\n图片ID: {task.img_id}"
text += f"\n\n🔎可使用 {trigger_prefix}mju 命令放大指定图片\n"
text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
reply = Reply(ReplyType.INFO, text)
channel._send(reply, e_context["context"])

self._print_tasks()
return

def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
"""
midjourney任务限流控制
:param user_id: 用户id
:param e_context: 对话上下文
:return: 任务是否能够生成, True:可以生成, False: 被限流
"""
tasks = self.find_tasks_by_user_id(user_id)
task_count = len([t for t in tasks if t.status == Status.PENDING])
if task_count >= self.config.get("max_tasks_per_user"):
reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return False
task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
if task_count >= self.config.get("max_tasks"):
reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试")
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return False
return True

def _fetch_mode(self, prompt) -> str:
mode = self.config.get("mode")
if "--relax" in prompt or mode == TaskMode.RELAX.value:
return TaskMode.RELAX.value
return mode or TaskMode.RELAX.value

def _run_loop(self, loop: asyncio.BaseEventLoop):
"""
运行事件循环,用于轮询任务的线程
:param loop: 事件循环
"""
loop.run_forever()
loop.stop()

def _print_tasks(self):
for id in self.tasks:
logger.debug(f"[MJ] current task: {self.tasks[id]}")

def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
"""
设置回复文本
:param content: 回复内容
:param e_context: 对话上下文
:param level: 回复等级
"""
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS

def get_help_text(self, verbose=False, **kwargs):
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = "🎨利用Midjourney进行画图\n\n"
if not verbose:
return help_text
help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\""

return help_text

def find_tasks_by_user_id(self, user_id) -> list:
result = []
with self.tasks_lock:
now = time.time()
for task in self.tasks.values():
if task.status == Status.PENDING and now > task.expiry_time:
task.status = Status.EXPIRED
logger.info(f"[MJ] {task} expired")
if task.user_id == user_id:
result.append(task)
return result


def check_prefix(content, prefix_list):
if not prefix_list:
return None
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None

+ 20
- 3
plugins/plugin.py Целия файл

@@ -1,6 +1,6 @@
import os
import json
from config import pconf
from config import pconf, plugin_config, conf
from common.log import logger


@@ -15,8 +15,8 @@ class Plugin:
"""
# 优先获取 plugins/config.json 中的全局配置
plugin_conf = pconf(self.name)
if not plugin_conf:
# 全局配置不存在,则获取插件目录下的配置
if not plugin_conf or not conf().get("use_global_plugin_config"):
# 全局配置不存在 或者 未开启全局配置开关,则获取插件目录下的配置
plugin_config_path = os.path.join(self.path, "config.json")
if os.path.exists(plugin_config_path):
with open(plugin_config_path, "r") as f:
@@ -24,5 +24,22 @@ class Plugin:
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
return plugin_conf

def save_config(self, config: dict):
try:
plugin_config[self.name] = config
# 写入全局配置
global_config_path = "./plugins/config.json"
if os.path.exists(global_config_path):
with open(global_config_path, "w", encoding='utf-8') as f:
json.dump(plugin_config, f, indent=4, ensure_ascii=False)
# 写入插件配置
plugin_config_path = os.path.join(self.path, "config.json")
if os.path.exists(plugin_config_path):
with open(plugin_config_path, "w", encoding='utf-8') as f:
json.dump(config, f, indent=4, ensure_ascii=False)

except Exception as e:
logger.warn("save plugin config failed: {}".format(e))

def get_help_text(self, **kwargs):
return "暂无帮助信息"

Loading…
Отказ
Запис