@@ -28,4 +28,5 @@ plugins/banwords/__pycache__ | |||||
plugins/banwords/lib/__pycache__ | plugins/banwords/lib/__pycache__ | ||||
!plugins/hello | !plugins/hello | ||||
!plugins/role | !plugins/role | ||||
!plugins/keyword | |||||
!plugins/keyword | |||||
!plugins/linkai |
@@ -111,7 +111,7 @@ pip3 install azure-cognitiveservices-speech | |||||
{ | { | ||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY | "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY | ||||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | "model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | ||||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口 | |||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890" | |||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | ||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | ||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 | "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 | ||||
@@ -123,6 +123,7 @@ pip3 install azure-cognitiveservices-speech | |||||
"group_speech_recognition": false, # 是否开启群组语音识别 | "group_speech_recognition": false, # 是否开启群组语音识别 | ||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/ | "use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/ | ||||
"azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称 | "azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称 | ||||
"azure_api_version": "", # 采用Azure ChatGPT时,API版本 | |||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 | "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 | ||||
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。 | # 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。 | ||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。" | "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。" | ||||
@@ -166,7 +166,7 @@ class AzureChatGPTBot(ChatGPTBot): | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
openai.api_type = "azure" | openai.api_type = "azure" | ||||
openai.api_version = "2023-03-15-preview" | |||||
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview") | |||||
self.args["deployment_id"] = conf().get("azure_deployment_id") | self.args["deployment_id"] = conf().get("azure_deployment_id") | ||||
def create_img(self, query, retry_count=0, api_key=None): | def create_img(self, query, retry_count=0, api_key=None): | ||||
@@ -29,18 +29,24 @@ class LinkAIBot(Bot, OpenAIImage): | |||||
if context.type == ContextType.TEXT: | if context.type == ContextType.TEXT: | ||||
return self._chat(query, context) | return self._chat(query, context) | ||||
elif context.type == ContextType.IMAGE_CREATE: | elif context.type == ContextType.IMAGE_CREATE: | ||||
ok, retstring = self.create_img(query, 0) | |||||
reply = None | |||||
ok, res = self.create_img(query, 0) | |||||
if ok: | if ok: | ||||
reply = Reply(ReplyType.IMAGE_URL, retstring) | |||||
reply = Reply(ReplyType.IMAGE_URL, res) | |||||
else: | else: | ||||
reply = Reply(ReplyType.ERROR, retstring) | |||||
reply = Reply(ReplyType.ERROR, res) | |||||
return reply | return reply | ||||
else: | else: | ||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) | reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) | ||||
return reply | return reply | ||||
def _chat(self, query, context, retry_count=0): | |||||
def _chat(self, query, context, retry_count=0) -> Reply: | |||||
""" | |||||
发起对话请求 | |||||
:param query: 请求提示词 | |||||
:param context: 对话上下文 | |||||
:param retry_count: 当前递归重试次数 | |||||
:return: 回复 | |||||
""" | |||||
if retry_count >= 2: | if retry_count >= 2: | ||||
# exit from retry 2 times | # exit from retry 2 times | ||||
logger.warn("[LINKAI] failed after maximum number of retry times") | logger.warn("[LINKAI] failed after maximum number of retry times") | ||||
@@ -52,7 +58,7 @@ class LinkAIBot(Bot, OpenAIImage): | |||||
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context") | logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context") | ||||
app_code = None | app_code = None | ||||
else: | else: | ||||
app_code = conf().get("linkai_app_code") | |||||
app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code") | |||||
linkai_api_key = conf().get("linkai_api_key") | linkai_api_key = conf().get("linkai_api_key") | ||||
session_id = context["session_id"] | session_id = context["session_id"] | ||||
@@ -63,10 +69,8 @@ class LinkAIBot(Bot, OpenAIImage): | |||||
if app_code and session.messages[0].get("role") == "system": | if app_code and session.messages[0].get("role") == "system": | ||||
session.messages.pop(0) | session.messages.pop(0) | ||||
logger.info(f"[LINKAI] query={query}, app_code={app_code}") | |||||
body = { | body = { | ||||
"appCode": app_code, | |||||
"app_code": app_code, | |||||
"messages": session.messages, | "messages": session.messages, | ||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 | "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 | ||||
"temperature": conf().get("temperature"), | "temperature": conf().get("temperature"), | ||||
@@ -74,31 +78,34 @@ class LinkAIBot(Bot, OpenAIImage): | |||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | ||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | ||||
} | } | ||||
logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}") | |||||
headers = {"Authorization": "Bearer " + linkai_api_key} | headers = {"Authorization": "Bearer " + linkai_api_key} | ||||
# do http request | # do http request | ||||
res = requests.post(url=self.base_url + "/chat/completion", json=body, headers=headers).json() | |||||
if not res or not res["success"]: | |||||
if res.get("code") == self.AUTH_FAILED_CODE: | |||||
logger.exception(f"[LINKAI] please check your linkai_api_key, res={res}") | |||||
return Reply(ReplyType.ERROR, "请再问我一次吧") | |||||
res = requests.post(url=self.base_url + "/chat/completions", json=body, headers=headers, | |||||
timeout=conf().get("request_timeout", 180)) | |||||
if res.status_code == 200: | |||||
# execute success | |||||
response = res.json() | |||||
reply_content = response["choices"][0]["message"]["content"] | |||||
total_tokens = response["usage"]["total_tokens"] | |||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}") | |||||
self.sessions.session_reply(reply_content, session_id, total_tokens) | |||||
return Reply(ReplyType.TEXT, reply_content) | |||||
elif res.get("code") == self.NO_QUOTA_CODE: | |||||
logger.exception(f"[LINKAI] please check your account quota, https://chat.link-ai.tech/console/account") | |||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧") | |||||
else: | |||||
response = res.json() | |||||
error = response.get("error") | |||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, " | |||||
f"msg={error.get('message')}, type={error.get('type')}") | |||||
else: | |||||
# retry | |||||
if res.status_code >= 500: | |||||
# server error, need retry | |||||
time.sleep(2) | time.sleep(2) | ||||
logger.warn(f"[LINKAI] do retry, times={retry_count}") | logger.warn(f"[LINKAI] do retry, times={retry_count}") | ||||
return self._chat(query, context, retry_count + 1) | return self._chat(query, context, retry_count + 1) | ||||
# execute success | |||||
reply_content = res["data"]["content"] | |||||
logger.info(f"[LINKAI] reply={reply_content}") | |||||
self.sessions.session_reply(reply_content, session_id) | |||||
return Reply(ReplyType.TEXT, reply_content) | |||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧") | |||||
except Exception as e: | except Exception as e: | ||||
logger.exception(e) | logger.exception(e) | ||||
@@ -56,3 +56,9 @@ class Bridge(object): | |||||
def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply: | def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply: | ||||
return self.get_bot("translate").translate(text, from_lang, to_lang) | return self.get_bot("translate").translate(text, from_lang, to_lang) | ||||
def reset_bot(self): | |||||
""" | |||||
重置bot路由 | |||||
""" | |||||
self.__init__() |
@@ -108,8 +108,12 @@ class ChatChannel(Channel): | |||||
if not conf().get("group_at_off", False): | if not conf().get("group_at_off", False): | ||||
flag = True | flag = True | ||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)" | pattern = f"@{re.escape(self.name)}(\u2005|\u0020)" | ||||
content = re.sub(pattern, r"", content) | |||||
subtract_res = re.sub(pattern, r"", content) | |||||
if subtract_res == content and context["msg"].self_display_name: | |||||
# 前缀移除后没有变化,使用群昵称再次移除 | |||||
pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)" | |||||
subtract_res = re.sub(pattern, r"", content) | |||||
content = subtract_res | |||||
if not flag: | if not flag: | ||||
if context["origin_ctype"] == ContextType.VOICE: | if context["origin_ctype"] == ContextType.VOICE: | ||||
logger.info("[WX]receive group voice, but checkprefix didn't match") | logger.info("[WX]receive group voice, but checkprefix didn't match") | ||||
@@ -24,9 +24,7 @@ is_at: 是否被at | |||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在) | - (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在) | ||||
actual_user_id: 实际发送者id (群聊必填) | actual_user_id: 实际发送者id (群聊必填) | ||||
actual_user_nickname:实际发送者昵称 | actual_user_nickname:实际发送者昵称 | ||||
self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称 | |||||
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等, | _prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等, | ||||
_prepared: 是否已经调用过准备函数 | _prepared: 是否已经调用过准备函数 | ||||
@@ -48,6 +46,8 @@ class ChatMessage(object): | |||||
to_user_nickname = None | to_user_nickname = None | ||||
other_user_id = None | other_user_id = None | ||||
other_user_nickname = None | other_user_nickname = None | ||||
my_msg = False | |||||
self_display_name = None | |||||
is_group = False | is_group = False | ||||
is_at = False | is_at = False | ||||
@@ -58,6 +58,9 @@ def _check(func): | |||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | ||||
logger.debug("[WX]history message {} skipped".format(msgId)) | logger.debug("[WX]history message {} skipped".format(msgId)) | ||||
return | return | ||||
if cmsg.my_msg and not cmsg.is_group: | |||||
logger.debug("[WX]my message {} skipped".format(msgId)) | |||||
return | |||||
return func(self, cmsg) | return func(self, cmsg) | ||||
return wrapper | return wrapper | ||||
@@ -57,13 +57,19 @@ class WechatMessage(ChatMessage): | |||||
self.from_user_nickname = nickname | self.from_user_nickname = nickname | ||||
if self.to_user_id == user_id: | if self.to_user_id == user_id: | ||||
self.to_user_nickname = nickname | self.to_user_nickname = nickname | ||||
try: # 陌生人时候, 'User'字段可能不存在 | |||||
try: # 陌生人时候, User字段可能不存在 | |||||
# my_msg 为True是表示是自己发送的消息 | |||||
self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \ | |||||
itchat_msg["ToUserName"] != itchat_msg["FromUserName"] | |||||
self.other_user_id = itchat_msg["User"]["UserName"] | self.other_user_id = itchat_msg["User"]["UserName"] | ||||
self.other_user_nickname = itchat_msg["User"]["NickName"] | self.other_user_nickname = itchat_msg["User"]["NickName"] | ||||
if self.other_user_id == self.from_user_id: | if self.other_user_id == self.from_user_id: | ||||
self.from_user_nickname = self.other_user_nickname | self.from_user_nickname = self.other_user_nickname | ||||
if self.other_user_id == self.to_user_id: | if self.other_user_id == self.to_user_id: | ||||
self.to_user_nickname = self.other_user_nickname | self.to_user_nickname = self.other_user_nickname | ||||
if itchat_msg["User"].get("Self"): | |||||
# 自身的展示名,当设置了群昵称时,该字段表示群昵称 | |||||
self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName") | |||||
except KeyError as e: # 处理偶尔没有对方信息的情况 | except KeyError as e: # 处理偶尔没有对方信息的情况 | ||||
logger.warn("[WX]get other_user_id failed: " + str(e)) | logger.warn("[WX]get other_user_id failed: " + str(e)) | ||||
if self.from_user_id == user_id: | if self.from_user_id == user_id: | ||||
@@ -20,6 +20,7 @@ available_setting = { | |||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt | "use_azure_chatgpt": False, # 是否使用azure的chatgpt | ||||
"azure_deployment_id": "", # azure 模型部署名称 | "azure_deployment_id": "", # azure 模型部署名称 | ||||
"use_baidu_wenxin": False, # 是否使用baidu文心一言,优先级次于azure | "use_baidu_wenxin": False, # 是否使用baidu文心一言,优先级次于azure | ||||
"azure_api_version": "", # azure api版本 | |||||
# Bot触发配置 | # Bot触发配置 | ||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | ||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | ||||
@@ -107,6 +108,8 @@ available_setting = { | |||||
"appdata_dir": "", # 数据目录 | "appdata_dir": "", # 数据目录 | ||||
# 插件配置 | # 插件配置 | ||||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 | "plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 | ||||
# 是否使用全局插件配置 | |||||
"use_global_plugin_config": False, | |||||
# 知识库平台配置 | # 知识库平台配置 | ||||
"use_linkai": False, | "use_linkai": False, | ||||
"linkai_api_key": "", | "linkai_api_key": "", | ||||
@@ -257,3 +260,9 @@ def pconf(plugin_name: str) -> dict: | |||||
:return: 该插件的配置项 | :return: 该插件的配置项 | ||||
""" | """ | ||||
return plugin_config.get(plugin_name.lower()) | return plugin_config.get(plugin_name.lower()) | ||||
# 全局配置,用于存放全局生效的状态 | |||||
global_config = { | |||||
"admin_users": [] | |||||
} |
@@ -18,6 +18,7 @@ services: | |||||
SPEECH_RECOGNITION: 'False' | SPEECH_RECOGNITION: 'False' | ||||
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。' | CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。' | ||||
EXPIRES_IN_SECONDS: 3600 | EXPIRES_IN_SECONDS: 3600 | ||||
USE_GLOBAL_PLUGIN_CONFIG: 'True' | |||||
USE_LINKAI: 'False' | USE_LINKAI: 'False' | ||||
LINKAI_API_KEY: '' | LINKAI_API_KEY: '' | ||||
LINKAI_APP_CODE: '' | LINKAI_APP_CODE: '' |
@@ -20,5 +20,19 @@ | |||||
"no_default": false, | "no_default": false, | ||||
"model_name": "gpt-3.5-turbo" | "model_name": "gpt-3.5-turbo" | ||||
} | } | ||||
}, | |||||
"linkai": { | |||||
"group_app_map": { | |||||
"测试群1": "default", | |||||
"测试群2": "Kv2fXJcH" | |||||
}, | |||||
"midjourney": { | |||||
"enabled": true, | |||||
"auto_translate": true, | |||||
"img_proxy": true, | |||||
"max_tasks": 3, | |||||
"max_tasks_per_user": 1, | |||||
"use_image_create_prefix": true | |||||
} | |||||
} | } | ||||
} | } |
@@ -13,7 +13,7 @@ from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common import const | from common import const | ||||
from common.log import logger | from common.log import logger | ||||
from config import conf, load_config | |||||
from config import conf, load_config, global_config | |||||
from plugins import * | from plugins import * | ||||
# 定义指令集 | # 定义指令集 | ||||
@@ -426,9 +426,11 @@ class Godcmd(Plugin): | |||||
password = args[0] | password = args[0] | ||||
if password == self.password: | if password == self.password: | ||||
self.admin_users.append(userid) | self.admin_users.append(userid) | ||||
global_config["admin_users"].append(userid) | |||||
return True, "认证成功" | return True, "认证成功" | ||||
elif password == self.temp_password: | elif password == self.temp_password: | ||||
self.admin_users.append(userid) | self.admin_users.append(userid) | ||||
global_config["admin_users"].append(userid) | |||||
return True, "认证成功,请尽快设置口令" | return True, "认证成功,请尽快设置口令" | ||||
else: | else: | ||||
return False, "认证失败" | return False, "认证失败" | ||||
@@ -54,9 +54,18 @@ class Keyword(Plugin): | |||||
logger.debug(f"[keyword] 匹配到关键字【{content}】") | logger.debug(f"[keyword] 匹配到关键字【{content}】") | ||||
reply_text = self.keyword[content] | reply_text = self.keyword[content] | ||||
reply = Reply() | |||||
reply.type = ReplyType.TEXT | |||||
reply.content = reply_text | |||||
# 判断匹配内容的类型 | |||||
if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".webp"]): | |||||
# 如果是以 http:// 或 https:// 开头,且.jpg/.jpeg/.png/.gif结尾,则认为是图片 URL | |||||
reply = Reply() | |||||
reply.type = ReplyType.IMAGE_URL | |||||
reply.content = reply_text | |||||
else: | |||||
# 否则认为是普通文本 | |||||
reply = Reply() | |||||
reply.type = ReplyType.TEXT | |||||
reply.content = reply_text | |||||
e_context["reply"] = reply | e_context["reply"] = reply | ||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | ||||
@@ -0,0 +1,66 @@ | |||||
## 插件说明 | |||||
基于 LinkAI 提供的知识库、Midjourney绘画等能力对机器人的功能进行增强。平台地址: https://chat.link-ai.tech/console | |||||
## 插件配置 | |||||
将 `plugins/linkai` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`: | |||||
以下是配置项说明: | |||||
```bash | |||||
{ | |||||
"group_app_map": { # 群聊 和 应用编码 的映射关系 | |||||
"测试群1": "default", # 表示在名称为 "测试群1" 的群聊中将使用app_code 为 default 的应用 | |||||
"测试群2": "Kv2fXJcH" | |||||
}, | |||||
"midjourney": { | |||||
"enabled": true, # midjourney 绘画开关 | |||||
"auto_translate": true, # 是否自动将提示词翻译为英文 | |||||
"img_proxy": true, # 是否对生成的图片使用代理,如果你是国外服务器,将这一项设置为false会获得更快的生成速度 | |||||
"max_tasks": 3, # 支持同时提交的总任务个数 | |||||
"max_tasks_per_user": 1, # 支持单个用户同时提交的任务个数 | |||||
"use_image_create_prefix": true # 是否使用全局的绘画触发词,如果开启将同时支持由`config.json`中的 image_create_prefix 配置触发 | |||||
} | |||||
} | |||||
``` | |||||
注意: | |||||
- 配置项中 `group_app_map` 部分是用于映射群聊与LinkAI平台上的应用, `midjourney` 部分是 mj 画图的配置,可根据需要进行填写,未填写配置时默认不开启相应功能 | |||||
- 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释 | |||||
- 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8) | |||||
## 插件使用 | |||||
> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;而midjourney绘画功能则只需填写 `linkai_api_key` 配置,`use_linkai` 无论是否关闭均可使用。具体可参考 [详细文档](https://link-ai.tech/platform/link-app/wechat)。 | |||||
完成配置后运行项目,会自动运行插件,输入 `#help linkai` 可查看插件功能。 | |||||
### 1.知识库管理功能 | |||||
提供在不同群聊使用不同应用的功能。可以在上述 `group_app_map` 配置中固定映射关系,也可以通过指令在群中快速完成切换。 | |||||
应用切换指令需要首先完成管理员 (`godcmd`) 插件的认证,然后按以下格式输入: | |||||
`$linkai app {app_code}` | |||||
例如输入 `$linkai app Kv2fXJcH`,即将当前群聊与 app_code为 Kv2fXJcH 的应用绑定。 | |||||
### 2.Midjourney绘画功能 | |||||
指令格式: | |||||
``` | |||||
- 图片生成: $mj 描述词1, 描述词2.. | |||||
- 图片放大: $mju 图片ID 图片序号 | |||||
``` | |||||
例如: | |||||
``` | |||||
"$mj a little cat, white --ar 9:16" | |||||
"$mju 1105592717188272288 2" | |||||
``` | |||||
注:开启 `use_image_create_prefix` 配置后可直接复用全局画图触发词,以"画"开头便可以生成图片。 |
@@ -0,0 +1 @@ | |||||
from .linkai import * |
@@ -0,0 +1,14 @@ | |||||
{ | |||||
"group_app_map": { | |||||
"测试群1": "default", | |||||
"测试群2": "Kv2fXJcH" | |||||
}, | |||||
"midjourney": { | |||||
"enabled": true, | |||||
"auto_translate": true, | |||||
"img_proxy": true, | |||||
"max_tasks": 3, | |||||
"max_tasks_per_user": 1, | |||||
"use_image_create_prefix": true | |||||
} | |||||
} |
@@ -0,0 +1,161 @@ | |||||
import plugins | |||||
from bridge.context import ContextType | |||||
from bridge.reply import Reply, ReplyType | |||||
from config import global_config | |||||
from plugins import * | |||||
from .midjourney import MJBot | |||||
from bridge import bridge | |||||
@plugins.register( | |||||
name="linkai", | |||||
desc="A plugin that supports knowledge base and midjourney drawing.", | |||||
version="0.1.0", | |||||
author="https://link-ai.tech", | |||||
) | |||||
class LinkAI(Plugin): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||||
self.config = super().load_config() | |||||
if self.config: | |||||
self.mj_bot = MJBot(self.config.get("midjourney")) | |||||
logger.info("[LinkAI] inited") | |||||
def on_handle_context(self, e_context: EventContext): | |||||
""" | |||||
消息处理逻辑 | |||||
:param e_context: 消息上下文 | |||||
""" | |||||
if not self.config: | |||||
return | |||||
context = e_context['context'] | |||||
if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]: | |||||
# filter content no need solve | |||||
return | |||||
mj_type = self.mj_bot.judge_mj_task_type(e_context) | |||||
if mj_type: | |||||
# MJ作图任务处理 | |||||
self.mj_bot.process_mj_task(mj_type, e_context) | |||||
return | |||||
if context.content.startswith(f"{_get_trigger_prefix()}linkai"): | |||||
# 应用管理功能 | |||||
self._process_admin_cmd(e_context) | |||||
return | |||||
if self._is_chat_task(e_context): | |||||
# 文本对话任务处理 | |||||
self._process_chat_task(e_context) | |||||
# 插件管理功能 | |||||
def _process_admin_cmd(self, e_context: EventContext): | |||||
context = e_context['context'] | |||||
cmd = context.content.split() | |||||
if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"): | |||||
_set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO) | |||||
return | |||||
if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"): | |||||
# 知识库开关指令 | |||||
if not _is_admin(e_context): | |||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR) | |||||
return | |||||
is_open = True | |||||
tips_text = "开启" | |||||
if cmd[1] == "close": | |||||
tips_text = "关闭" | |||||
is_open = False | |||||
conf()["use_linkai"] = is_open | |||||
bridge.Bridge().reset_bot() | |||||
_set_reply_text(f"知识库功能已{tips_text}", e_context, level=ReplyType.INFO) | |||||
return | |||||
if len(cmd) == 3 and cmd[1] == "app": | |||||
# 知识库应用切换指令 | |||||
if not context.kwargs.get("isgroup"): | |||||
_set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR) | |||||
return | |||||
if not _is_admin(e_context): | |||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR) | |||||
return | |||||
app_code = cmd[2] | |||||
group_name = context.kwargs.get("msg").from_user_nickname | |||||
group_mapping = self.config.get("group_app_map") | |||||
if group_mapping: | |||||
group_mapping[group_name] = app_code | |||||
else: | |||||
self.config["group_app_map"] = {group_name: app_code} | |||||
# 保存插件配置 | |||||
super().save_config(self.config) | |||||
_set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO) | |||||
else: | |||||
_set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context, | |||||
level=ReplyType.INFO) | |||||
return | |||||
# LinkAI 对话任务处理 | |||||
def _is_chat_task(self, e_context: EventContext): | |||||
context = e_context['context'] | |||||
# 群聊应用管理 | |||||
return self.config.get("group_app_map") and context.kwargs.get("isgroup") | |||||
def _process_chat_task(self, e_context: EventContext): | |||||
""" | |||||
处理LinkAI对话任务 | |||||
:param e_context: 对话上下文 | |||||
""" | |||||
context = e_context['context'] | |||||
# 群聊应用管理 | |||||
group_name = context.kwargs.get("msg").from_user_nickname | |||||
app_code = self._fetch_group_app_code(group_name) | |||||
if app_code: | |||||
context.kwargs['app_code'] = app_code | |||||
def _fetch_group_app_code(self, group_name: str) -> str: | |||||
""" | |||||
根据群聊名称获取对应的应用code | |||||
:param group_name: 群聊名称 | |||||
:return: 应用code | |||||
""" | |||||
group_mapping = self.config.get("group_app_map") | |||||
if group_mapping: | |||||
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP") | |||||
return app_code | |||||
def get_help_text(self, verbose=False, **kwargs): | |||||
trigger_prefix = _get_trigger_prefix() | |||||
help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画等能力。\n\n" | |||||
if not verbose: | |||||
return help_text | |||||
help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n\n例如: \n"$linkai app Kv2fXJcH"\n\n' | |||||
help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID" | |||||
help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\"" | |||||
help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\"" | |||||
return help_text | |||||
# 静态方法 | |||||
def _is_admin(e_context: EventContext) -> bool: | |||||
""" | |||||
判断消息是否由管理员用户发送 | |||||
:param e_context: 消息上下文 | |||||
:return: True: 是, False: 否 | |||||
""" | |||||
context = e_context["context"] | |||||
if context["isgroup"]: | |||||
return context.kwargs.get("msg").actual_user_id in global_config["admin_users"] | |||||
else: | |||||
return context["receiver"] in global_config["admin_users"] | |||||
def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR): | |||||
reply = Reply(level, content) | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | |||||
def _get_trigger_prefix(): | |||||
return conf().get("plugin_trigger_prefix", "$") |
@@ -0,0 +1,415 @@ | |||||
from enum import Enum | |||||
from config import conf | |||||
from common.log import logger | |||||
import requests | |||||
import threading | |||||
import time | |||||
from bridge.reply import Reply, ReplyType | |||||
import aiohttp | |||||
import asyncio | |||||
from bridge.context import ContextType | |||||
from plugins import EventContext, EventAction | |||||
INVALID_REQUEST = 410 | |||||
NOT_FOUND_ORIGIN_IMAGE = 461 | |||||
NOT_FOUND_TASK = 462 | |||||
class TaskType(Enum): | |||||
GENERATE = "generate" | |||||
UPSCALE = "upscale" | |||||
VARIATION = "variation" | |||||
RESET = "reset" | |||||
def __str__(self): | |||||
return self.name | |||||
class Status(Enum): | |||||
PENDING = "pending" | |||||
FINISHED = "finished" | |||||
EXPIRED = "expired" | |||||
ABORTED = "aborted" | |||||
def __str__(self): | |||||
return self.name | |||||
class TaskMode(Enum): | |||||
FAST = "fast" | |||||
RELAX = "relax" | |||||
task_name_mapping = { | |||||
TaskType.GENERATE.name: "生成", | |||||
TaskType.UPSCALE.name: "放大", | |||||
TaskType.VARIATION.name: "变换", | |||||
TaskType.RESET.name: "重新生成", | |||||
} | |||||
class MJTask: | |||||
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30, | |||||
status=Status.PENDING): | |||||
self.id = id | |||||
self.user_id = user_id | |||||
self.task_type = task_type | |||||
self.raw_prompt = raw_prompt | |||||
self.send_func = None # send_func(img_url) | |||||
self.expiry_time = time.time() + expires | |||||
self.status = status | |||||
self.img_url = None # url | |||||
self.img_id = None | |||||
def __str__(self): | |||||
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}" | |||||
# midjourney bot | |||||
class MJBot: | |||||
def __init__(self, config): | |||||
self.base_url = "https://api.link-ai.chat/v1/img/midjourney" | |||||
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} | |||||
self.config = config | |||||
self.tasks = {} | |||||
self.temp_dict = {} | |||||
self.tasks_lock = threading.Lock() | |||||
self.event_loop = asyncio.new_event_loop() | |||||
def judge_mj_task_type(self, e_context: EventContext): | |||||
""" | |||||
判断MJ任务的类型 | |||||
:param e_context: 上下文 | |||||
:return: 任务类型枚举 | |||||
""" | |||||
if not self.config: | |||||
return None | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
context = e_context['context'] | |||||
if context.type == ContextType.TEXT: | |||||
cmd_list = context.content.split(maxsplit=1) | |||||
if cmd_list[0].lower() == f"{trigger_prefix}mj": | |||||
return TaskType.GENERATE | |||||
elif cmd_list[0].lower() == f"{trigger_prefix}mju": | |||||
return TaskType.UPSCALE | |||||
elif cmd_list[0].lower() == f"{trigger_prefix}mjv": | |||||
return TaskType.VARIATION | |||||
elif cmd_list[0].lower() == f"{trigger_prefix}mjr": | |||||
return TaskType.RESET | |||||
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): | |||||
return TaskType.GENERATE | |||||
def process_mj_task(self, mj_type: TaskType, e_context: EventContext): | |||||
""" | |||||
处理mj任务 | |||||
:param mj_type: mj任务类型 | |||||
:param e_context: 对话上下文 | |||||
""" | |||||
context = e_context['context'] | |||||
session_id = context["session_id"] | |||||
cmd = context.content.split(maxsplit=1) | |||||
if len(cmd) == 1 and context.type == ContextType.TEXT: | |||||
# midjourney 帮助指令 | |||||
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO) | |||||
return | |||||
if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"): | |||||
# midjourney 开关指令 | |||||
is_open = True | |||||
tips_text = "开启" | |||||
if cmd[1] == "close": | |||||
tips_text = "关闭" | |||||
is_open = False | |||||
self.config["enabled"] = is_open | |||||
self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO) | |||||
return | |||||
if not self.config.get("enabled"): | |||||
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置") | |||||
self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO) | |||||
return | |||||
if not self._check_rate_limit(session_id, e_context): | |||||
logger.warn("[MJ] midjourney task exceed rate limit") | |||||
return | |||||
if mj_type == TaskType.GENERATE: | |||||
if context.type == ContextType.IMAGE_CREATE: | |||||
raw_prompt = context.content | |||||
else: | |||||
# 图片生成 | |||||
raw_prompt = cmd[1] | |||||
reply = self.generate(raw_prompt, session_id, e_context) | |||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK_PASS | |||||
return | |||||
elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION: | |||||
# 图片放大/变换 | |||||
clist = cmd[1].split() | |||||
if len(clist) < 2: | |||||
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context) | |||||
return | |||||
img_id = clist[0] | |||||
index = int(clist[1]) | |||||
if index < 1 or index > 4: | |||||
self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context) | |||||
return | |||||
key = f"{str(mj_type)}_{img_id}_{index}" | |||||
if self.temp_dict.get(key): | |||||
self._set_reply_text(f"第 {index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context) | |||||
return | |||||
# 执行图片放大/变换操作 | |||||
reply = self.do_operate(mj_type, session_id, img_id, e_context, index) | |||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK_PASS | |||||
return | |||||
elif mj_type == TaskType.RESET: | |||||
# 图片重新生成 | |||||
clist = cmd[1].split() | |||||
if len(clist) < 1: | |||||
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context) | |||||
return | |||||
img_id = clist[0] | |||||
# 图片重新生成 | |||||
reply = self.do_operate(mj_type, session_id, img_id, e_context) | |||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK_PASS | |||||
else: | |||||
self._set_reply_text(f"暂不支持该命令", e_context) | |||||
def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply: | |||||
""" | |||||
图片生成 | |||||
:param prompt: 提示词 | |||||
:param user_id: 用户id | |||||
:param e_context: 对话上下文 | |||||
:return: 任务ID | |||||
""" | |||||
logger.info(f"[MJ] image generate, prompt={prompt}") | |||||
mode = self._fetch_mode(prompt) | |||||
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")} | |||||
if not self.config.get("img_proxy"): | |||||
body["img_proxy"] = False | |||||
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40)) | |||||
if res.status_code == 200: | |||||
res = res.json() | |||||
logger.debug(f"[MJ] image generate, res={res}") | |||||
if res.get("code") == 200: | |||||
task_id = res.get("data").get("task_id") | |||||
real_prompt = res.get("data").get("real_prompt") | |||||
if mode == TaskMode.RELAX.value: | |||||
time_str = "1~10分钟" | |||||
else: | |||||
time_str = "1分钟" | |||||
content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n" | |||||
if real_prompt: | |||||
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}" | |||||
else: | |||||
content += f"prompt: {prompt}" | |||||
reply = Reply(ReplyType.INFO, content) | |||||
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, | |||||
task_type=TaskType.GENERATE) | |||||
# put to memory dict | |||||
self.tasks[task.id] = task | |||||
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) | |||||
self._do_check_task(task, e_context) | |||||
return reply | |||||
else: | |||||
res_json = res.json() | |||||
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}") | |||||
if res.status_code == INVALID_REQUEST: | |||||
reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容") | |||||
else: | |||||
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") | |||||
return reply | |||||
def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext, | |||||
index: int = None) -> Reply: | |||||
logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}") | |||||
body = {"type": task_type.name, "img_id": img_id} | |||||
if index: | |||||
body["index"] = index | |||||
if not self.config.get("img_proxy"): | |||||
body["img_proxy"] = False | |||||
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40)) | |||||
logger.debug(res) | |||||
if res.status_code == 200: | |||||
res = res.json() | |||||
if res.get("code") == 200: | |||||
task_id = res.get("data").get("task_id") | |||||
logger.info(f"[MJ] image operate processing, task_id={task_id}") | |||||
icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"} | |||||
content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待" | |||||
reply = Reply(ReplyType.INFO, content) | |||||
task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type) | |||||
# put to memory dict | |||||
self.tasks[task.id] = task | |||||
key = f"{task_type.name}_{img_id}_{index}" | |||||
self.temp_dict[key] = True | |||||
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) | |||||
self._do_check_task(task, e_context) | |||||
return reply | |||||
else: | |||||
error_msg = "" | |||||
if res.status_code == NOT_FOUND_ORIGIN_IMAGE: | |||||
error_msg = "请输入正确的图片ID" | |||||
res_json = res.json() | |||||
logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}") | |||||
reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试") | |||||
return reply | |||||
def check_task_sync(self, task: MJTask, e_context: EventContext): | |||||
logger.debug(f"[MJ] start check task status, {task}") | |||||
max_retry_times = 90 | |||||
while max_retry_times > 0: | |||||
time.sleep(10) | |||||
url = f"{self.base_url}/tasks/{task.id}" | |||||
try: | |||||
res = requests.get(url, headers=self.headers, timeout=8) | |||||
if res.status_code == 200: | |||||
res_json = res.json() | |||||
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, " | |||||
f"data={res_json.get('data')}, thread={threading.current_thread().name}") | |||||
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: | |||||
# process success res | |||||
if self.tasks.get(task.id): | |||||
self.tasks[task.id].status = Status.FINISHED | |||||
self._process_success_task(task, res_json.get("data"), e_context) | |||||
return | |||||
max_retry_times -= 1 | |||||
else: | |||||
res_json = res.json() | |||||
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}") | |||||
max_retry_times -= 20 | |||||
except Exception as e: | |||||
max_retry_times -= 20 | |||||
logger.warn(e) | |||||
logger.warn("[MJ] end from poll") | |||||
if self.tasks.get(task.id): | |||||
self.tasks[task.id].status = Status.EXPIRED | |||||
def _do_check_task(self, task: MJTask, e_context: EventContext): | |||||
threading.Thread(target=self.check_task_sync, args=(task, e_context)).start() | |||||
def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext): | |||||
""" | |||||
处理任务成功的结果 | |||||
:param task: MJ任务 | |||||
:param res: 请求结果 | |||||
:param e_context: 对话上下文 | |||||
""" | |||||
# channel send img | |||||
task.status = Status.FINISHED | |||||
task.img_id = res.get("img_id") | |||||
task.img_url = res.get("img_url") | |||||
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}") | |||||
# send img | |||||
reply = Reply(ReplyType.IMAGE_URL, task.img_url) | |||||
channel = e_context["channel"] | |||||
channel._send(reply, e_context["context"]) | |||||
# send info | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
text = "" | |||||
if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET: | |||||
text = f"🎨绘画完成!\n" | |||||
if task.raw_prompt: | |||||
text += f"prompt: {task.raw_prompt}\n" | |||||
text += f"- - - - - - - - -\n图片ID: {task.img_id}" | |||||
text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n" | |||||
text += f"例如:\n{trigger_prefix}mju {task.img_id} 1" | |||||
text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n" | |||||
text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1" | |||||
text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n" | |||||
text += f"例如:\n{trigger_prefix}mjr {task.img_id}" | |||||
reply = Reply(ReplyType.INFO, text) | |||||
channel._send(reply, e_context["context"]) | |||||
self._print_tasks() | |||||
return | |||||
def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool: | |||||
""" | |||||
midjourney任务限流控制 | |||||
:param user_id: 用户id | |||||
:param e_context: 对话上下文 | |||||
:return: 任务是否能够生成, True:可以生成, False: 被限流 | |||||
""" | |||||
tasks = self.find_tasks_by_user_id(user_id) | |||||
task_count = len([t for t in tasks if t.status == Status.PENDING]) | |||||
if task_count >= self.config.get("max_tasks_per_user"): | |||||
reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试") | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | |||||
return False | |||||
task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING]) | |||||
if task_count >= self.config.get("max_tasks"): | |||||
reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试") | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | |||||
return False | |||||
return True | |||||
def _fetch_mode(self, prompt) -> str: | |||||
mode = self.config.get("mode") | |||||
if "--relax" in prompt or mode == TaskMode.RELAX.value: | |||||
return TaskMode.RELAX.value | |||||
return mode or TaskMode.FAST.value | |||||
def _run_loop(self, loop: asyncio.BaseEventLoop): | |||||
""" | |||||
运行事件循环,用于轮询任务的线程 | |||||
:param loop: 事件循环 | |||||
""" | |||||
loop.run_forever() | |||||
loop.stop() | |||||
def _print_tasks(self): | |||||
for id in self.tasks: | |||||
logger.debug(f"[MJ] current task: {self.tasks[id]}") | |||||
def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR): | |||||
""" | |||||
设置回复文本 | |||||
:param content: 回复内容 | |||||
:param e_context: 对话上下文 | |||||
:param level: 回复等级 | |||||
""" | |||||
reply = Reply(level, content) | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | |||||
def get_help_text(self, verbose=False, **kwargs): | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
help_text = "🎨利用Midjourney进行画图\n\n" | |||||
if not verbose: | |||||
return help_text | |||||
help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID" | |||||
help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\"" | |||||
help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\"" | |||||
return help_text | |||||
def find_tasks_by_user_id(self, user_id) -> list: | |||||
result = [] | |||||
with self.tasks_lock: | |||||
now = time.time() | |||||
for task in self.tasks.values(): | |||||
if task.status == Status.PENDING and now > task.expiry_time: | |||||
task.status = Status.EXPIRED | |||||
logger.info(f"[MJ] {task} expired") | |||||
if task.user_id == user_id: | |||||
result.append(task) | |||||
return result | |||||
def check_prefix(content, prefix_list): | |||||
if not prefix_list: | |||||
return None | |||||
for prefix in prefix_list: | |||||
if content.startswith(prefix): | |||||
return prefix | |||||
return None |
@@ -1,6 +1,6 @@ | |||||
import os | import os | ||||
import json | import json | ||||
from config import pconf | |||||
from config import pconf, plugin_config, conf | |||||
from common.log import logger | from common.log import logger | ||||
@@ -15,14 +15,31 @@ class Plugin: | |||||
""" | """ | ||||
# 优先获取 plugins/config.json 中的全局配置 | # 优先获取 plugins/config.json 中的全局配置 | ||||
plugin_conf = pconf(self.name) | plugin_conf = pconf(self.name) | ||||
if not plugin_conf: | |||||
# 全局配置不存在,则获取插件目录下的配置 | |||||
if not plugin_conf or not conf().get("use_global_plugin_config"): | |||||
# 全局配置不存在 或者 未开启全局配置开关,则获取插件目录下的配置 | |||||
plugin_config_path = os.path.join(self.path, "config.json") | plugin_config_path = os.path.join(self.path, "config.json") | ||||
if os.path.exists(plugin_config_path): | if os.path.exists(plugin_config_path): | ||||
with open(plugin_config_path, "r") as f: | |||||
with open(plugin_config_path, "r", encoding="utf-8") as f: | |||||
plugin_conf = json.load(f) | plugin_conf = json.load(f) | ||||
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}") | logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}") | ||||
return plugin_conf | return plugin_conf | ||||
def save_config(self, config: dict): | |||||
try: | |||||
plugin_config[self.name] = config | |||||
# 写入全局配置 | |||||
global_config_path = "./plugins/config.json" | |||||
if os.path.exists(global_config_path): | |||||
with open(global_config_path, "w", encoding='utf-8') as f: | |||||
json.dump(plugin_config, f, indent=4, ensure_ascii=False) | |||||
# 写入插件配置 | |||||
plugin_config_path = os.path.join(self.path, "config.json") | |||||
if os.path.exists(plugin_config_path): | |||||
with open(plugin_config_path, "w", encoding='utf-8') as f: | |||||
json.dump(config, f, indent=4, ensure_ascii=False) | |||||
except Exception as e: | |||||
logger.warn("save plugin config failed: {}".format(e)) | |||||
def get_help_text(self, **kwargs): | def get_help_text(self, **kwargs): | ||||
return "暂无帮助信息" | return "暂无帮助信息" |