From 44cb54a9ea499c0c86bbcf83bf343e5c98f1de64 Mon Sep 17 00:00:00 2001 From: SSMario Date: Tue, 16 May 2023 09:38:38 +0800 Subject: [PATCH 01/23] =?UTF-8?q?feat:=20=E6=89=8B=E6=9C=BA=E4=B8=8A?= =?UTF-8?q?=E5=9B=9E=E5=A4=8D=E6=B6=88=E6=81=AF=EF=BC=8C=E4=B8=8D=E8=A7=A6?= =?UTF-8?q?=E5=8F=91=E6=9C=BA=E5=99=A8=E4=BA=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- channel/chat_message.py | 1 + channel/wechat/wechat_channel.py | 3 +++ channel/wechat/wechat_message.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/channel/chat_message.py b/channel/chat_message.py index fdd4d90..0e2f652 100644 --- a/channel/chat_message.py +++ b/channel/chat_message.py @@ -48,6 +48,7 @@ class ChatMessage(object): to_user_nickname = None other_user_id = None other_user_nickname = None + my_msg = False is_group = False is_at = False diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index c888157..10c5003 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -59,6 +59,9 @@ def _check(func): if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 logger.debug("[WX]history message {} skipped".format(msgId)) return + if cmsg.my_msg: + logger.debug("[WX]my message {} skipped".format(msgId)) + return return func(self, cmsg) return wrapper diff --git a/channel/wechat/wechat_message.py b/channel/wechat/wechat_message.py index 5d9bf28..b9824f9 100644 --- a/channel/wechat/wechat_message.py +++ b/channel/wechat/wechat_message.py @@ -58,6 +58,8 @@ class WechatMessage(ChatMessage): if self.to_user_id == user_id: self.to_user_nickname = nickname try: # 陌生人时候, 'User'字段可能不存在 + self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \ + itchat_msg["ToUserName"] != itchat_msg["FromUserName"] self.other_user_id = itchat_msg["User"]["UserName"] self.other_user_nickname = itchat_msg["User"]["NickName"] if self.other_user_id == self.from_user_id: From 1d4ff796d79a49cd173e27121a365d7925305932 Mon Sep 17 00:00:00 2001 From: SSMario Date: Tue, 16 May 2023 11:50:54 +0800 Subject: [PATCH 02/23] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0eleventLabs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 2 +- requirements-optional.txt | 1 + voice/elevent/elevent_voice.py | 32 ++++++++++++++++++++++++++++++++ voice/factory.py | 4 ++++ 4 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 voice/elevent/elevent_voice.py diff --git a/config.py b/config.py index ae1cfd7..782beac 100644 --- a/config.py +++ b/config.py @@ -53,7 +53,7 @@ available_setting = { "voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key "always_reply_voice": False, # 是否一直使用语音回复 "voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure - "text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure + "text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure,eleven # baidu 语音api配置, 使用百度语音识别和语音合成时需要 "baidu_app_id": "", "baidu_api_key": "", diff --git a/requirements-optional.txt b/requirements-optional.txt index c248689..9901de4 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -9,6 +9,7 @@ baidu_aip>=4.16.10 # baidu voice azure-cognitiveservices-speech # azure voice numpy<=1.24.2 langid # language detect +elevenlabs==0.2.15 #install plugin dulwich diff --git a/voice/elevent/elevent_voice.py b/voice/elevent/elevent_voice.py new file mode 100644 index 0000000..72d5bcd --- /dev/null +++ b/voice/elevent/elevent_voice.py @@ -0,0 +1,32 @@ +""" +eleventLabs voice service +""" + +import time + +from elevenlabs import generate + +from bridge.reply import Reply, ReplyType +from common.log import logger +from common.tmp_dir import TmpDir +from voice.voice import Voice + + +class ElevenLabsVoice(Voice): + + def __init__(self): + pass + + def voiceToText(self, voice_file): + pass + + def textToVoice(self, text): + audio = generate( + text=text + ) + fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3" + with open(fileName, "wb") as f: + f.write(audio) + logger.info("[ElevenLabs] textToVoice text={} voice file name={}".format(text, fileName)) + return Reply(ReplyType.VOICE, fileName) + diff --git a/voice/factory.py b/voice/factory.py index 45fe0d1..0cf1a05 100644 --- a/voice/factory.py +++ b/voice/factory.py @@ -29,4 +29,8 @@ def create_voice(voice_type): from voice.azure.azure_voice import AzureVoice return AzureVoice() + elif voice_type == "eleven": + from voice.elevent.elevent_voice import ElevenLabsVoice + + return ElevenLabsVoice() raise RuntimeError From 4dbc54fa15ff77b7acda570cc494f8fef6a87a42 Mon Sep 17 00:00:00 2001 From: SSMario Date: Tue, 16 May 2023 12:00:05 +0800 Subject: [PATCH 03/23] =?UTF-8?q?Revert=20"feat:=20=E5=A2=9E=E5=8A=A0eleve?= =?UTF-8?q?ntLabs"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 1d4ff796d79a49cd173e27121a365d7925305932. --- config.py | 2 +- requirements-optional.txt | 1 - voice/elevent/elevent_voice.py | 32 -------------------------------- voice/factory.py | 4 ---- 4 files changed, 1 insertion(+), 38 deletions(-) delete mode 100644 voice/elevent/elevent_voice.py diff --git a/config.py b/config.py index 782beac..ae1cfd7 100644 --- a/config.py +++ b/config.py @@ -53,7 +53,7 @@ available_setting = { "voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key "always_reply_voice": False, # 是否一直使用语音回复 "voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure - "text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure,eleven + "text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure # baidu 语音api配置, 使用百度语音识别和语音合成时需要 "baidu_app_id": "", "baidu_api_key": "", diff --git a/requirements-optional.txt b/requirements-optional.txt index 9901de4..c248689 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -9,7 +9,6 @@ baidu_aip>=4.16.10 # baidu voice azure-cognitiveservices-speech # azure voice numpy<=1.24.2 langid # language detect -elevenlabs==0.2.15 #install plugin dulwich diff --git a/voice/elevent/elevent_voice.py b/voice/elevent/elevent_voice.py deleted file mode 100644 index 72d5bcd..0000000 --- a/voice/elevent/elevent_voice.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -eleventLabs voice service -""" - -import time - -from elevenlabs import generate - -from bridge.reply import Reply, ReplyType -from common.log import logger -from common.tmp_dir import TmpDir -from voice.voice import Voice - - -class ElevenLabsVoice(Voice): - - def __init__(self): - pass - - def voiceToText(self, voice_file): - pass - - def textToVoice(self, text): - audio = generate( - text=text - ) - fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3" - with open(fileName, "wb") as f: - f.write(audio) - logger.info("[ElevenLabs] textToVoice text={} voice file name={}".format(text, fileName)) - return Reply(ReplyType.VOICE, fileName) - diff --git a/voice/factory.py b/voice/factory.py index 0cf1a05..45fe0d1 100644 --- a/voice/factory.py +++ b/voice/factory.py @@ -29,8 +29,4 @@ def create_voice(voice_type): from voice.azure.azure_voice import AzureVoice return AzureVoice() - elif voice_type == "eleven": - from voice.elevent.elevent_voice import ElevenLabsVoice - - return ElevenLabsVoice() raise RuntimeError From 74a253f52192183b0ef485e4e83f9038b2eb5df5 Mon Sep 17 00:00:00 2001 From: "zyqcn@live.com" Date: Mon, 24 Jul 2023 15:44:03 +0800 Subject: [PATCH 04/23] azure api add api-version:https://learn.microsoft.com/zh-cn/azure/ai-services/openai/reference --- README.md | 1 + bot/chatgpt/chat_gpt_bot.py | 2 +- config.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 341c734..1be2569 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ pip3 install azure-cognitiveservices-speech "group_speech_recognition": false, # 是否开启群组语音识别 "use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/ "azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称 + "azure_api_version": "", # 采用Azure ChatGPT时,API版本 "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 # 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。 "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。" diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index ba357c0..8c9a250 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -166,7 +166,7 @@ class AzureChatGPTBot(ChatGPTBot): def __init__(self): super().__init__() openai.api_type = "azure" - openai.api_version = "2023-03-15-preview" + openai.api_version = conf().get("azure_api_version", "2023-06-01-preview") self.args["deployment_id"] = conf().get("azure_deployment_id") def create_img(self, query, retry_count=0, api_key=None): diff --git a/config.py b/config.py index 85c5436..6bcde09 100644 --- a/config.py +++ b/config.py @@ -19,6 +19,7 @@ available_setting = { "model": "gpt-3.5-turbo", "use_azure_chatgpt": False, # 是否使用azure的chatgpt "azure_deployment_id": "", # azure 模型部署名称 + "azure_api_version": "", # azure api版本 # Bot触发配置 "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 From f81ac31fe15664a366dc36aa7fd6174a9d9490cc Mon Sep 17 00:00:00 2001 From: zhayujie Date: Thu, 27 Jul 2023 21:21:36 +0800 Subject: [PATCH 05/23] feat: add linkai plugin to support midjourney and distinguish app between groups --- .gitignore | 3 +- bot/linkai/link_ai_bot.py | 2 +- plugins/linkai/__init__.py | 1 + plugins/linkai/linkai.py | 93 +++++++++++++ plugins/linkai/midjourney.py | 256 +++++++++++++++++++++++++++++++++++ 5 files changed, 353 insertions(+), 2 deletions(-) create mode 100644 plugins/linkai/__init__.py create mode 100644 plugins/linkai/linkai.py create mode 100644 plugins/linkai/midjourney.py diff --git a/.gitignore b/.gitignore index 4eb71e5..6c6fb46 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,5 @@ plugins/**/ !plugins/banwords/**/ !plugins/hello !plugins/role -!plugins/keyword \ No newline at end of file +!plugins/keyword +!plugins/linkai \ No newline at end of file diff --git a/bot/linkai/link_ai_bot.py b/bot/linkai/link_ai_bot.py index 804b53b..8b8ca8b 100644 --- a/bot/linkai/link_ai_bot.py +++ b/bot/linkai/link_ai_bot.py @@ -52,7 +52,7 @@ class LinkAIBot(Bot, OpenAIImage): logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context") app_code = None else: - app_code = conf().get("linkai_app_code") + app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code") linkai_api_key = conf().get("linkai_api_key") session_id = context["session_id"] diff --git a/plugins/linkai/__init__.py b/plugins/linkai/__init__.py new file mode 100644 index 0000000..e7414be --- /dev/null +++ b/plugins/linkai/__init__.py @@ -0,0 +1 @@ +from .linkai import * diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py new file mode 100644 index 0000000..8698ad5 --- /dev/null +++ b/plugins/linkai/linkai.py @@ -0,0 +1,93 @@ +import asyncio +import json +import threading +from concurrent.futures import ThreadPoolExecutor + +import plugins +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from channel.chat_message import ChatMessage +from common.log import logger +from config import conf +from plugins import * +from .midjourney import MJBot, TaskType + +# 任务线程池 +task_thread_pool = ThreadPoolExecutor(max_workers=4) + + +@plugins.register( + name="linkai", + desc="A plugin that supports knowledge base and midjourney drawing.", + version="0.1.0", + author="https://link-ai.tech", +) +class LinkAI(Plugin): + def __init__(self): + super().__init__() + self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context + self.config = super().load_config() + self.mj_bot = MJBot(self.config.get("midjourney")) + logger.info("[LinkAI] inited") + + def on_handle_context(self, e_context: EventContext): + """ + 消息处理逻辑 + :param e_context: 消息上下文 + """ + context = e_context['context'] + if context.type not in [ContextType.TEXT, ContextType.IMAGE]: + # filter content no need solve + return + + mj_type = self.mj_bot.judge_mj_task_type(e_context) + if mj_type: + # MJ作图任务处理 + self.mj_bot.process_mj_task(mj_type, e_context) + return + + if self._is_chat_task(e_context): + self._process_chat_task(e_context) + + # LinkAI 对话任务处理 + def _is_chat_task(self, e_context: EventContext): + context = e_context['context'] + # 群聊应用管理 + return self.config.get("knowledge_base") and context.kwargs.get("isgroup") + + def _process_chat_task(self, e_context: EventContext): + """ + 处理LinkAI对话任务 + :param e_context: 对话上下文 + """ + context = e_context['context'] + # 群聊应用管理 + group_name = context.kwargs.get("msg").from_user_nickname + app_code = self._fetch_group_app_code(group_name) + if app_code: + context.kwargs['app_code'] = app_code + + def _fetch_group_app_code(self, group_name: str) -> str: + """ + 根据群聊名称获取对应的应用code + :param group_name: 群聊名称 + :return: 应用code + """ + knowledge_base_config = self.config.get("knowledge_base") + if knowledge_base_config and knowledge_base_config.get("group_mapping"): + app_code = knowledge_base_config.get("group_mapping").get(group_name) \ + or knowledge_base_config.get("group_mapping").get("ALL_GROUP") + return app_code + + def get_help_text(self, verbose=False, **kwargs): + trigger_prefix = conf().get("plugin_trigger_prefix", "$") + help_text = "利用midjourney来画图。\n" + if not verbose: + return help_text + help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" + return help_text + + def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): + reply = Reply(level, content) + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py new file mode 100644 index 0000000..c7c6c35 --- /dev/null +++ b/plugins/linkai/midjourney.py @@ -0,0 +1,256 @@ +from enum import Enum +from config import conf +from common.log import logger +import requests +import threading +import time +from bridge.reply import Reply, ReplyType +import aiohttp +import asyncio +from bridge.context import ContextType +from plugins import EventContext, EventAction + + +class TaskType(Enum): + GENERATE = "generate" + UPSCALE = "upscale" + VARIATION = "variation" + RESET = "reset" + + +class Status(Enum): + PENDING = "pending" + FINISHED = "finished" + EXPIRED = "expired" + ABORTED = "aborted" + + def __str__(self): + return self.name + + +class MJTask: + def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING): + self.id = id + self.user_id = user_id + self.task_type = task_type + self.raw_prompt = raw_prompt + self.send_func = None # send_func(img_url) + self.expiry_time = time.time() + expires + self.status = status + self.img_url = None # url + self.img_id = None + + def __str__(self): + return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}" + +# midjourney bot +class MJBot: + def __init__(self, config): + self.base_url = "https://api.link-ai.chat/v1/img/midjourney" + # self.base_url = "http://127.0.0.1:8911/v1/img/midjourney" + self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} + self.config = config + self.tasks = {} + self.temp_dict = {} + self.tasks_lock = threading.Lock() + self.event_loop = asyncio.new_event_loop() + threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start() + + def judge_mj_task_type(self, e_context: EventContext) -> TaskType: + """ + 判断MJ任务的类型 + :param e_context: 上下文 + :return: 任务类型枚举 + """ + trigger_prefix = conf().get("plugin_trigger_prefix", "$") + context = e_context['context'] + if context.type == ContextType.TEXT: + if self.config and self.config.get("enabled"): + cmd_list = context.content.split(maxsplit=1) + if cmd_list[0].lower() == f"{trigger_prefix}mj": + return TaskType.GENERATE + elif cmd_list[0].lower() == f"{trigger_prefix}mju": + return TaskType.UPSCALE + # elif cmd_list[0].lower() == f"{trigger_prefix}mjv": + # return TaskType.VARIATION + # elif cmd_list[0].lower() == f"{trigger_prefix}mjr": + # return TaskType.RESET + + def process_mj_task(self, mj_type: TaskType, e_context: EventContext): + """ + 处理mj任务 + :param mj_type: mj任务类型 + :param e_context: 对话上下文 + """ + context = e_context['context'] + session_id = context["session_id"] + cmd = context.content.split(maxsplit=1) + if len(cmd) == 1: + self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.ERROR) + return + + if mj_type == TaskType.GENERATE: + # 图片生成 + raw_prompt = cmd[1] + reply = self.generate(raw_prompt, session_id, e_context) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + return + + elif mj_type == TaskType.UPSCALE: + # 图片放大 + clist = cmd[1].split() + if len(clist) < 2: + self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context) + return + img_id = clist[0] + index = int(clist[1]) + if index < 1 or index > 4: + self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context) + return + key = f"{TaskType.UPSCALE.name}_{img_id}_{index}" + if self.temp_dict.get(key): + self._set_reply_text(f"第 {index} 张图片已经放大过了", e_context) + return + # 图片放大操作 + reply = self.upscale(session_id, img_id, index, e_context) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS + return + + else: + self._set_reply_text(f"暂不支持该命令", e_context) + + def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply: + """ + 图片生成 + :param prompt: 提示词 + :param user_id: 用户id + :return: 任务ID + """ + logger.info(f"[MJ] image generate, prompt={prompt}") + body = {"prompt": prompt} + res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers) + if res.status_code == 200: + res = res.json() + logger.debug(f"[MJ] image generate, res={res}") + if res.get("code") == 200: + task_id = res.get("data").get("taskId") + real_prompt = res.get("data").get("realPrompt") + content = f"🚀你的作品将在1~2分钟左右完成,请耐心等待\n- - - - - - - - -\n" + if real_prompt: + content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}" + else: + content += f"prompt: {prompt}" + reply = Reply(ReplyType.INFO, content) + task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE) + # put to memory dict + self.tasks[task.id] = task + asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + return reply + else: + res_json = res.json() + logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}") + reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") + return reply + + def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply: + logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") + body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index} + res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers) + if res.status_code == 200: + res = res.json() + logger.info(res) + if res.get("code") == 200: + task_id = res.get("data").get("taskId") + content = f"🔎图片正在放大中,请耐心等待" + reply = Reply(ReplyType.INFO, content) + task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=TaskType.UPSCALE) + # put to memory dict + self.tasks[task.id] = task + key = f"{TaskType.UPSCALE.name}_{img_id}_{index}" + self.temp_dict[key] = True + asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + return reply + else: + error_msg = "" + if res.status_code == 461: + error_msg = "请输入正确的图片ID" + res_json = res.json() + logger.error(f"[MJ] upscale error, msg={res_json.get('message')}, status_code={res.status_code}") + reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试") + return reply + + async def check_task(self, task: MJTask, e_context: EventContext): + max_retry_time = 80 + while max_retry_time > 0: + async with aiohttp.ClientSession() as session: + url = f"{self.base_url}/tasks/{task.id}" + async with session.get(url, headers=self.headers) as res: + if res.status == 200: + res_json = await res.json() + logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, " + f"data={res_json.get('data')}, thread={threading.current_thread().name}") + if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: + # process success res + self._process_success_task(task, res_json.get("data"), e_context) + return + else: + logger.warn(f"[MJ] image check error, status_code={res.status}") + max_retry_time -= 20 + await asyncio.sleep(10) + max_retry_time -= 1 + logger.warn("[MJ] end from poll") + + def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext): + """ + 处理任务成功的结果 + :param task: MJ任务 + :param res: 请求结果 + :param e_context: 对话上下文 + """ + # channel send img + task.status = Status.FINISHED + task.img_id = res.get("imgId") + task.img_url = res.get("imgUrl") + logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}") + + # send img + reply = Reply(ReplyType.IMAGE_URL, task.img_url) + channel = e_context["channel"] + channel._send(reply, e_context["context"]) + + # send info + trigger_prefix = conf().get("plugin_trigger_prefix", "$") + text = "" + if task.task_type == TaskType.GENERATE: + text = f"🎨绘画完成!\nprompt: {task.raw_prompt}\n- - - - - - - - -\n图片ID: {task.img_id}" + text += f"\n\n🔎可使用 {trigger_prefix}mju 命令放大指定图片\n" + text += f"例如:\n{trigger_prefix}mju {task.img_id} 1" + reply = Reply(ReplyType.INFO, text) + channel._send(reply, e_context["context"]) + + self._print_tasks() + return + + def _run_loop(self, loop: asyncio.BaseEventLoop): + loop.run_forever() + loop.stop() + + def _print_tasks(self): + for id in self.tasks: + logger.debug(f"[MJ] current task: {self.tasks[id]}") + + + def get_help_text(self, verbose=False, **kwargs): + trigger_prefix = conf().get("plugin_trigger_prefix", "$") + help_text = "利用midjourney来画图。\n" + if not verbose: + return help_text + help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" + return help_text + + def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): + reply = Reply(level, content) + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS \ No newline at end of file From dd36b8b150e5b1ad3d5e076e22f7a1772e2692e3 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Thu, 27 Jul 2023 21:29:50 +0800 Subject: [PATCH 06/23] config: add config template --- plugins/linkai/config.json.template | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 plugins/linkai/config.json.template diff --git a/plugins/linkai/config.json.template b/plugins/linkai/config.json.template new file mode 100644 index 0000000..bdfed7a --- /dev/null +++ b/plugins/linkai/config.json.template @@ -0,0 +1,13 @@ +{ + "group_app_map": { + "测试群1": "default", + "测试群2": "Kv2fXJcH" + }, + "midjourney": { + "enabled": true, + "mode": "relax", + "auto_translate": true, + "max_tasks": 3, + "max_tasks_per_user": 1 + } +} From 2f9e5b1219aa32f1c7b11a8c36c9eda71b8438d3 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Fri, 28 Jul 2023 12:40:06 +0800 Subject: [PATCH 07/23] feat: check app_code dynamically --- config.py | 6 ++ plugins/godcmd/godcmd.py | 4 +- plugins/linkai/README.md | 0 plugins/linkai/config.json.template | 3 +- plugins/linkai/linkai.py | 66 +++++++++++++--- plugins/linkai/midjourney.py | 118 +++++++++++++++++++++++----- plugins/plugin.py | 19 ++++- 7 files changed, 181 insertions(+), 35 deletions(-) create mode 100644 plugins/linkai/README.md diff --git a/config.py b/config.py index 85c5436..289f872 100644 --- a/config.py +++ b/config.py @@ -252,3 +252,9 @@ def pconf(plugin_name: str) -> dict: :return: 该插件的配置项 """ return plugin_config.get(plugin_name.lower()) + + +# 全局配置,用于存放全局生效的状态 +global_config = { + "admin_users": [] +} diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 08bc09e..0ff204e 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -13,7 +13,7 @@ from bridge.context import ContextType from bridge.reply import Reply, ReplyType from common import const from common.log import logger -from config import conf, load_config +from config import conf, load_config, global_config from plugins import * # 定义指令集 @@ -426,9 +426,11 @@ class Godcmd(Plugin): password = args[0] if password == self.password: self.admin_users.append(userid) + global_config["admin_users"].append(userid) return True, "认证成功" elif password == self.temp_password: self.admin_users.append(userid) + global_config["admin_users"].append(userid) return True, "认证成功,请尽快设置口令" else: return False, "认证失败" diff --git a/plugins/linkai/README.md b/plugins/linkai/README.md new file mode 100644 index 0000000..e69de29 diff --git a/plugins/linkai/config.json.template b/plugins/linkai/config.json.template index bdfed7a..98d114b 100644 --- a/plugins/linkai/config.json.template +++ b/plugins/linkai/config.json.template @@ -8,6 +8,7 @@ "mode": "relax", "auto_translate": true, "max_tasks": 3, - "max_tasks_per_user": 1 + "max_tasks_per_user": 1, + "use_image_create_prefix": true } } diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index 8698ad5..cff6e16 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -8,7 +8,7 @@ from bridge.context import ContextType from bridge.reply import Reply, ReplyType from channel.chat_message import ChatMessage from common.log import logger -from config import conf +from config import conf, global_config from plugins import * from .midjourney import MJBot, TaskType @@ -46,14 +46,48 @@ class LinkAI(Plugin): self.mj_bot.process_mj_task(mj_type, e_context) return + if context.content.startswith(f"{_get_trigger_prefix()}linkai"): + # 应用管理功能 + self._process_admin_cmd(e_context) + return + if self._is_chat_task(e_context): + # 文本对话任务处理 self._process_chat_task(e_context) + # 插件管理功能 + def _process_admin_cmd(self, e_context: EventContext): + context = e_context['context'] + cmd = context.content.split() + if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"): + _set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO) + return + if len(cmd) == 3 and cmd[1] == "app": + if not context.kwargs.get("isgroup"): + _set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR) + return + if e_context["context"]["session_id"] not in global_config["admin_users"]: + _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR) + return + app_code = cmd[2] + group_name = context.kwargs.get("msg").from_user_nickname + group_mapping = self.config.get("group_app_map") + if group_mapping: + group_mapping[group_name] = app_code + else: + self.config["group_app_map"] = {group_name: app_code} + # 保存插件配置 + super().save_config(self.config) + _set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO) + else: + _set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context, level=ReplyType.INFO) + return + # LinkAI 对话任务处理 def _is_chat_task(self, e_context: EventContext): context = e_context['context'] # 群聊应用管理 - return self.config.get("knowledge_base") and context.kwargs.get("isgroup") + return self.config.get("group_app_map") and context.kwargs.get("isgroup") def _process_chat_task(self, e_context: EventContext): """ @@ -73,21 +107,27 @@ class LinkAI(Plugin): :param group_name: 群聊名称 :return: 应用code """ - knowledge_base_config = self.config.get("knowledge_base") - if knowledge_base_config and knowledge_base_config.get("group_mapping"): - app_code = knowledge_base_config.get("group_mapping").get(group_name) \ - or knowledge_base_config.get("group_mapping").get("ALL_GROUP") + group_mapping = self.config.get("group_app_map") + if group_mapping: + app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP") return app_code def get_help_text(self, verbose=False, **kwargs): - trigger_prefix = conf().get("plugin_trigger_prefix", "$") - help_text = "利用midjourney来画图。\n" + trigger_prefix = _get_trigger_prefix() + help_text = "用于集成 LinkAI 提供的文本对话、知识库、绘画等能力。\n" if not verbose: return help_text - help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" + help_text += "" + help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" return help_text - def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): - reply = Reply(level, content) - e_context["reply"] = reply - e_context.action = EventAction.BREAK_PASS + +# 静态方法 +def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR): + reply = Reply(level, content) + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS + + +def _get_trigger_prefix(): + return conf().get("plugin_trigger_prefix", "$") diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index c7c6c35..a2a4662 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -28,6 +28,11 @@ class Status(Enum): return self.name +class TaskMode(Enum): + FAST = "fast" + RELAX = "relax" + + class MJTask: def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING): self.id = id @@ -47,7 +52,6 @@ class MJTask: class MJBot: def __init__(self, config): self.base_url = "https://api.link-ai.chat/v1/img/midjourney" - # self.base_url = "http://127.0.0.1:8911/v1/img/midjourney" self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} self.config = config self.tasks = {} @@ -71,10 +75,10 @@ class MJBot: return TaskType.GENERATE elif cmd_list[0].lower() == f"{trigger_prefix}mju": return TaskType.UPSCALE - # elif cmd_list[0].lower() == f"{trigger_prefix}mjv": - # return TaskType.VARIATION - # elif cmd_list[0].lower() == f"{trigger_prefix}mjr": - # return TaskType.RESET + elif self.config.get("use_image_create_prefix") and \ + check_prefix(context.content, conf().get("image_create_prefix")): + return TaskType.GENERATE + def process_mj_task(self, mj_type: TaskType, e_context: EventContext): """ @@ -86,12 +90,20 @@ class MJBot: session_id = context["session_id"] cmd = context.content.split(maxsplit=1) if len(cmd) == 1: - self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.ERROR) + self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO) + return + + if not self._check_rate_limit(session_id, e_context): + logger.warn("[MJ] midjourney task exceed rate limit") return if mj_type == TaskType.GENERATE: - # 图片生成 - raw_prompt = cmd[1] + image_prefix = check_prefix(context.content, conf().get("image_create_prefix")) + if image_prefix: + raw_prompt = context.content.replace(image_prefix, "", 1) + else: + # 图片生成 + raw_prompt = cmd[1] reply = self.generate(raw_prompt, session_id, e_context) e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS @@ -126,10 +138,12 @@ class MJBot: 图片生成 :param prompt: 提示词 :param user_id: 用户id + :param e_context: 对话上下文 :return: 任务ID """ logger.info(f"[MJ] image generate, prompt={prompt}") - body = {"prompt": prompt} + mode = self._fetch_mode(prompt) + body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")} res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers) if res.status_code == 200: res = res.json() @@ -137,7 +151,11 @@ class MJBot: if res.get("code") == 200: task_id = res.get("data").get("taskId") real_prompt = res.get("data").get("realPrompt") - content = f"🚀你的作品将在1~2分钟左右完成,请耐心等待\n- - - - - - - - -\n" + if mode == TaskMode.RELAX.name: + time_str = "1~10分钟" + else: + time_str = "1~2分钟" + content = f"🚀你的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n" if real_prompt: content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}" else: @@ -182,8 +200,9 @@ class MJBot: return reply async def check_task(self, task: MJTask, e_context: EventContext): - max_retry_time = 80 - while max_retry_time > 0: + max_retry_times = 90 + while max_retry_times > 0: + await asyncio.sleep(10) async with aiohttp.ClientSession() as session: url = f"{self.base_url}/tasks/{task.id}" async with session.get(url, headers=self.headers) as res: @@ -193,14 +212,17 @@ class MJBot: f"data={res_json.get('data')}, thread={threading.current_thread().name}") if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: # process success res + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.FINISHED self._process_success_task(task, res_json.get("data"), e_context) return else: logger.warn(f"[MJ] image check error, status_code={res.status}") - max_retry_time -= 20 - await asyncio.sleep(10) - max_retry_time -= 1 + max_retry_times -= 20 + max_retry_times -= 1 logger.warn("[MJ] end from poll") + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.EXPIRED def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext): """ @@ -233,7 +255,39 @@ class MJBot: self._print_tasks() return + def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool: + """ + midjourney任务限流控制 + :param user_id: 用户id + :param e_context: 对话上下文 + :return: 任务是否能够生成, True:可以生成, False: 被限流 + """ + tasks = self.find_tasks_by_user_id(user_id) + task_count = len([t for t in tasks if t.status == Status.PENDING]) + if task_count >= self.config.get("max_tasks_per_user"): + reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试") + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS + return False + task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING]) + if task_count >= self.config.get("max_tasks"): + reply = Reply(ReplyType.INFO, "Midjourney服务的总任务数已达上限,请稍后再试") + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS + return False + return True + + def _fetch_mode(self, prompt) -> str: + mode = self.config.get("mode") + if "--relax" in prompt or mode == TaskMode.RELAX.name: + return TaskMode.RELAX.name + return TaskMode.FAST.name + def _run_loop(self, loop: asyncio.BaseEventLoop): + """ + 运行事件循环,用于轮询任务的线程 + :param loop: 事件循环 + """ loop.run_forever() loop.stop() @@ -241,6 +295,16 @@ class MJBot: for id in self.tasks: logger.debug(f"[MJ] current task: {self.tasks[id]}") + def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): + """ + 设置回复文本 + :param content: 回复内容 + :param e_context: 对话上下文 + :param level: 回复等级 + """ + reply = Reply(level, content) + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS def get_help_text(self, verbose=False, **kwargs): trigger_prefix = conf().get("plugin_trigger_prefix", "$") @@ -250,7 +314,23 @@ class MJBot: help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" return help_text - def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): - reply = Reply(level, content) - e_context["reply"] = reply - e_context.action = EventAction.BREAK_PASS \ No newline at end of file + def find_tasks_by_user_id(self, user_id) -> list[MJTask]: + result = [] + with self.tasks_lock: + now = time.time() + for task in self.tasks.values(): + if task.status == Status.PENDING and now > task.expiry_time: + task.status = Status.EXPIRED + logger.info(f"[MJ] {task} expired") + if task.user_id == user_id: + result.append(task) + return result + + +def check_prefix(content, prefix_list): + if not prefix_list: + return None + for prefix in prefix_list: + if content.startswith(prefix): + return prefix + return None \ No newline at end of file diff --git a/plugins/plugin.py b/plugins/plugin.py index e7444d2..1d74a7c 100644 --- a/plugins/plugin.py +++ b/plugins/plugin.py @@ -1,6 +1,6 @@ import os import json -from config import pconf +from config import pconf, plugin_config from common.log import logger @@ -24,5 +24,22 @@ class Plugin: logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}") return plugin_conf + def save_config(self, config: dict): + try: + plugin_config[self.name] = config + # 写入全局配置 + global_config_path = "./plugins/config.json" + if os.path.exists(global_config_path): + with open(global_config_path, "w", encoding='utf-8') as f: + json.dump(plugin_config, f, indent=4, ensure_ascii=False) + # 写入插件配置 + plugin_config_path = os.path.join(self.path, "config.json") + if os.path.exists(plugin_config_path): + with open(plugin_config_path, "w", encoding='utf-8') as f: + json.dump(config, f, indent=4, ensure_ascii=False) + + except Exception as e: + logger.warn("save plugin config failed: {}".format(e)) + def get_help_text(self, **kwargs): return "暂无帮助信息" From 233b24ab0fa91db77fcd462235b4334c3d391db1 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Fri, 28 Jul 2023 16:33:41 +0800 Subject: [PATCH 08/23] feat: add global admin config --- config.py | 2 ++ docker/docker-compose.yml | 1 + plugins/linkai/linkai.py | 8 +++--- plugins/linkai/midjourney.py | 52 +++++++++++++++++++----------------- plugins/plugin.py | 6 ++--- 5 files changed, 38 insertions(+), 31 deletions(-) diff --git a/config.py b/config.py index 289f872..0858172 100644 --- a/config.py +++ b/config.py @@ -102,6 +102,8 @@ available_setting = { "appdata_dir": "", # 数据目录 # 插件配置 "plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 + # 是否使用全局插件配置 + "use_global_plugin_config": False, # 知识库平台配置 "use_linkai": False, "linkai_api_key": "", diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index b70f3ef..8dbb1e4 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -18,6 +18,7 @@ services: SPEECH_RECOGNITION: 'False' CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。' EXPIRES_IN_SECONDS: 3600 + USE_GLOBAL_PLUGIN_CONFIG: 'True' USE_LINKAI: 'False' LINKAI_API_KEY: '' LINKAI_APP_CODE: '' diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index cff6e16..7482459 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -36,7 +36,7 @@ class LinkAI(Plugin): :param e_context: 消息上下文 """ context = e_context['context'] - if context.type not in [ContextType.TEXT, ContextType.IMAGE]: + if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]: # filter content no need solve return @@ -114,11 +114,11 @@ class LinkAI(Plugin): def get_help_text(self, verbose=False, **kwargs): trigger_prefix = _get_trigger_prefix() - help_text = "用于集成 LinkAI 提供的文本对话、知识库、绘画等能力。\n" + help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画等能力。\n\n" if not verbose: return help_text - help_text += "" - help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" + help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n\n例如: \n"$linkai app Kv2fXJcH"\n\n' + help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" return help_text diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index a2a4662..9df8f34 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -75,9 +75,8 @@ class MJBot: return TaskType.GENERATE elif cmd_list[0].lower() == f"{trigger_prefix}mju": return TaskType.UPSCALE - elif self.config.get("use_image_create_prefix") and \ - check_prefix(context.content, conf().get("image_create_prefix")): - return TaskType.GENERATE + elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): + return TaskType.GENERATE def process_mj_task(self, mj_type: TaskType, e_context: EventContext): @@ -89,7 +88,7 @@ class MJBot: context = e_context['context'] session_id = context["session_id"] cmd = context.content.split(maxsplit=1) - if len(cmd) == 1: + if len(cmd) == 1 and context.type == ContextType.TEXT: self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO) return @@ -98,9 +97,8 @@ class MJBot: return if mj_type == TaskType.GENERATE: - image_prefix = check_prefix(context.content, conf().get("image_create_prefix")) - if image_prefix: - raw_prompt = context.content.replace(image_prefix, "", 1) + if context.type == ContextType.IMAGE_CREATE: + raw_prompt = context.content else: # 图片生成 raw_prompt = cmd[1] @@ -155,7 +153,7 @@ class MJBot: time_str = "1~10分钟" else: time_str = "1~2分钟" - content = f"🚀你的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n" + content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n" if real_prompt: content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}" else: @@ -205,20 +203,25 @@ class MJBot: await asyncio.sleep(10) async with aiohttp.ClientSession() as session: url = f"{self.base_url}/tasks/{task.id}" - async with session.get(url, headers=self.headers) as res: - if res.status == 200: - res_json = await res.json() - logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, " - f"data={res_json.get('data')}, thread={threading.current_thread().name}") - if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: - # process success res - if self.tasks.get(task.id): - self.tasks[task.id].status = Status.FINISHED - self._process_success_task(task, res_json.get("data"), e_context) - return - else: - logger.warn(f"[MJ] image check error, status_code={res.status}") - max_retry_times -= 20 + try: + async with session.get(url, headers=self.headers) as res: + if res.status == 200: + res_json = await res.json() + logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, " + f"data={res_json.get('data')}, thread={threading.current_thread().name}") + if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: + # process success res + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.FINISHED + self._process_success_task(task, res_json.get("data"), e_context) + return + else: + res_json = await res.json() + logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}") + max_retry_times -= 20 + except Exception as e: + max_retry_times -= 20 + logger.warn(e) max_retry_times -= 1 logger.warn("[MJ] end from poll") if self.tasks.get(task.id): @@ -308,10 +311,11 @@ class MJBot: def get_help_text(self, verbose=False, **kwargs): trigger_prefix = conf().get("plugin_trigger_prefix", "$") - help_text = "利用midjourney来画图。\n" + help_text = "🎨利用Midjourney进行画图\n\n" if not verbose: return help_text - help_text += f"{trigger_prefix}mj 描述词1,描述词2 ... : 利用描述词作画,参数请放在提示词之后。\n{trigger_prefix}mjimage 描述词1,描述词2 ... : 利用描述词进行图生图,参数请放在提示词之后。\n{trigger_prefix}mjr ID: 对指定ID消息重新生成图片。\n{trigger_prefix}mju ID 图片序号: 对指定ID消息中的第x张图片进行放大。\n{trigger_prefix}mjv ID 图片序号: 对指定ID消息中的第x张图片进行变换。\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mjimage a white cat --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" + help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" + return help_text def find_tasks_by_user_id(self, user_id) -> list[MJTask]: diff --git a/plugins/plugin.py b/plugins/plugin.py index 1d74a7c..2bb6c26 100644 --- a/plugins/plugin.py +++ b/plugins/plugin.py @@ -1,6 +1,6 @@ import os import json -from config import pconf, plugin_config +from config import pconf, plugin_config, conf from common.log import logger @@ -15,8 +15,8 @@ class Plugin: """ # 优先获取 plugins/config.json 中的全局配置 plugin_conf = pconf(self.name) - if not plugin_conf: - # 全局配置不存在,则获取插件目录下的配置 + if not plugin_conf or not conf().get("use_global_plugin_config"): + # 全局配置不存在 或者 未开启全局配置开关,则获取插件目录下的配置 plugin_config_path = os.path.join(self.path, "config.json") if os.path.exists(plugin_config_path): with open(plugin_config_path, "r") as f: From de26dc05975d962ce5085dd52a80ec26c47aab9b Mon Sep 17 00:00:00 2001 From: zhayujie Date: Fri, 28 Jul 2023 18:50:21 +0800 Subject: [PATCH 09/23] fix: fast mode and relax mode checkout --- plugins/linkai/linkai.py | 2 +- plugins/linkai/midjourney.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index 7482459..c48da9d 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -66,7 +66,7 @@ class LinkAI(Plugin): if not context.kwargs.get("isgroup"): _set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR) return - if e_context["context"]["session_id"] not in global_config["admin_users"]: + if context.kwargs.get("msg").actual_user_id not in global_config["admin_users"]: _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR) return app_code = cmd[2] diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 9df8f34..93463cc 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -149,7 +149,7 @@ class MJBot: if res.get("code") == 200: task_id = res.get("data").get("taskId") real_prompt = res.get("data").get("realPrompt") - if mode == TaskMode.RELAX.name: + if mode == TaskMode.RELAX.value: time_str = "1~10分钟" else: time_str = "1~2分钟" @@ -174,11 +174,12 @@ class MJBot: logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index} res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers) + logger.debug(res) if res.status_code == 200: res = res.json() - logger.info(res) if res.get("code") == 200: task_id = res.get("data").get("taskId") + logger.info(f"[MJ] image upscale processing, task_id={task_id}") content = f"🔎图片正在放大中,请耐心等待" reply = Reply(ReplyType.INFO, content) task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=TaskType.UPSCALE) @@ -274,7 +275,7 @@ class MJBot: return False task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING]) if task_count >= self.config.get("max_tasks"): - reply = Reply(ReplyType.INFO, "Midjourney服务的总任务数已达上限,请稍后再试") + reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试") e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return False @@ -282,9 +283,9 @@ class MJBot: def _fetch_mode(self, prompt) -> str: mode = self.config.get("mode") - if "--relax" in prompt or mode == TaskMode.RELAX.name: - return TaskMode.RELAX.name - return TaskMode.FAST.name + if "--relax" in prompt or mode == TaskMode.RELAX.value: + return TaskMode.RELAX.value + return mode or TaskMode.RELAX.value def _run_loop(self, loop: asyncio.BaseEventLoop): """ From 782bff3a51346efe83b918cda5775781d10f39c0 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sat, 29 Jul 2023 12:22:45 +0800 Subject: [PATCH 10/23] fix: add debug log --- plugins/linkai/midjourney.py | 60 +++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 93463cc..3a94b24 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -199,34 +199,38 @@ class MJBot: return reply async def check_task(self, task: MJTask, e_context: EventContext): - max_retry_times = 90 - while max_retry_times > 0: - await asyncio.sleep(10) - async with aiohttp.ClientSession() as session: - url = f"{self.base_url}/tasks/{task.id}" - try: - async with session.get(url, headers=self.headers) as res: - if res.status == 200: - res_json = await res.json() - logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, " - f"data={res_json.get('data')}, thread={threading.current_thread().name}") - if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: - # process success res - if self.tasks.get(task.id): - self.tasks[task.id].status = Status.FINISHED - self._process_success_task(task, res_json.get("data"), e_context) - return - else: - res_json = await res.json() - logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}") - max_retry_times -= 20 - except Exception as e: - max_retry_times -= 20 - logger.warn(e) - max_retry_times -= 1 - logger.warn("[MJ] end from poll") - if self.tasks.get(task.id): - self.tasks[task.id].status = Status.EXPIRED + try: + logger.debug(f"[MJ] start check task status, {task}") + max_retry_times = 90 + while max_retry_times > 0: + await asyncio.sleep(10) + async with aiohttp.ClientSession() as session: + url = f"{self.base_url}/tasks/{task.id}" + try: + async with session.get(url, headers=self.headers) as res: + if res.status == 200: + res_json = await res.json() + logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, " + f"data={res_json.get('data')}, thread={threading.current_thread().name}") + if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: + # process success res + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.FINISHED + self._process_success_task(task, res_json.get("data"), e_context) + return + else: + res_json = await res.json() + logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}") + max_retry_times -= 20 + except Exception as e: + max_retry_times -= 20 + logger.warn(e) + max_retry_times -= 1 + logger.warn("[MJ] end from poll") + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.EXPIRED + except Exception as e: + logger.error(e) def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext): """ From d6e16995e08498fda654da4f81dc6428a78ee88e Mon Sep 17 00:00:00 2001 From: befantasy <31535803+befantasy@users.noreply.github.com> Date: Sun, 30 Jul 2023 14:40:07 +0800 Subject: [PATCH 11/23] =?UTF-8?q?Update=20keyword.py=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E5=9B=BE=E7=89=87=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加返回图片的功能。以http/https开头,且以.jpg/.jpeg/.png/.gif结尾的内容,识别为URL,自动以图片发送。 --- plugins/keyword/keyword.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/plugins/keyword/keyword.py b/plugins/keyword/keyword.py index 97ebe26..2dc87ff 100644 --- a/plugins/keyword/keyword.py +++ b/plugins/keyword/keyword.py @@ -54,9 +54,18 @@ class Keyword(Plugin): logger.debug(f"[keyword] 匹配到关键字【{content}】") reply_text = self.keyword[content] - reply = Reply() - reply.type = ReplyType.TEXT - reply.content = reply_text + # 判断匹配内容的类型 + if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".webp"]): + # 如果是以 http:// 或 https:// 开头,且.jpg/.jpeg/.png/.gif结尾,则认为是图片 URL + reply = Reply() + reply.type = ReplyType.IMAGE_URL + reply.content = reply_text + else: + # 否则认为是普通文本 + reply = Reply() + reply.type = ReplyType.TEXT + reply.content = reply_text + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 From e027286b6de50e21dedce5b5663fed8c7183b758 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sun, 30 Jul 2023 15:16:19 +0800 Subject: [PATCH 12/23] fix: midjourney check task thread --- plugins/linkai/linkai.py | 6 +++- plugins/linkai/midjourney.py | 63 ++++++++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index c48da9d..1031a24 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -27,7 +27,8 @@ class LinkAI(Plugin): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.config = super().load_config() - self.mj_bot = MJBot(self.config.get("midjourney")) + if self.config: + self.mj_bot = MJBot(self.config.get("midjourney")) logger.info("[LinkAI] inited") def on_handle_context(self, e_context: EventContext): @@ -35,6 +36,9 @@ class LinkAI(Plugin): 消息处理逻辑 :param e_context: 消息上下文 """ + if not self.config: + return + context = e_context['context'] if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]: # filter content no need solve diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 3a94b24..9d92ee6 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -58,23 +58,24 @@ class MJBot: self.temp_dict = {} self.tasks_lock = threading.Lock() self.event_loop = asyncio.new_event_loop() - threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start() + # threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start() - def judge_mj_task_type(self, e_context: EventContext) -> TaskType: + def judge_mj_task_type(self, e_context: EventContext): """ 判断MJ任务的类型 :param e_context: 上下文 :return: 任务类型枚举 """ + if not self.config or not self.config.get("enabled"): + return None trigger_prefix = conf().get("plugin_trigger_prefix", "$") context = e_context['context'] if context.type == ContextType.TEXT: - if self.config and self.config.get("enabled"): - cmd_list = context.content.split(maxsplit=1) - if cmd_list[0].lower() == f"{trigger_prefix}mj": - return TaskType.GENERATE - elif cmd_list[0].lower() == f"{trigger_prefix}mju": - return TaskType.UPSCALE + cmd_list = context.content.split(maxsplit=1) + if cmd_list[0].lower() == f"{trigger_prefix}mj": + return TaskType.GENERATE + elif cmd_list[0].lower() == f"{trigger_prefix}mju": + return TaskType.UPSCALE elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): return TaskType.GENERATE @@ -142,7 +143,7 @@ class MJBot: logger.info(f"[MJ] image generate, prompt={prompt}") mode = self._fetch_mode(prompt) body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")} - res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers) + res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15)) if res.status_code == 200: res = res.json() logger.debug(f"[MJ] image generate, res={res}") @@ -152,7 +153,7 @@ class MJBot: if mode == TaskMode.RELAX.value: time_str = "1~10分钟" else: - time_str = "1~2分钟" + time_str = "1分钟" content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n" if real_prompt: content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}" @@ -162,7 +163,8 @@ class MJBot: task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE) # put to memory dict self.tasks[task.id] = task - asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + self._do_check_task(task, e_context) return reply else: res_json = res.json() @@ -173,7 +175,7 @@ class MJBot: def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply: logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index} - res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers) + res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15)) logger.debug(res) if res.status_code == 200: res = res.json() @@ -187,7 +189,8 @@ class MJBot: self.tasks[task.id] = task key = f"{TaskType.UPSCALE.name}_{img_id}_{index}" self.temp_dict[key] = True - asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) + self._do_check_task(task, e_context) return reply else: error_msg = "" @@ -198,7 +201,36 @@ class MJBot: reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试") return reply - async def check_task(self, task: MJTask, e_context: EventContext): + def check_task_sync(self, task: MJTask, e_context: EventContext): + logger.debug(f"[MJ] start check task status, {task}") + max_retry_times = 90 + while max_retry_times > 0: + time.sleep(10) + url = f"{self.base_url}/tasks/{task.id}" + try: + res = requests.get(url, headers=self.headers, timeout=5) + if res.status_code == 200: + res_json = res.json() + logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, " + f"data={res_json.get('data')}, thread={threading.current_thread().name}") + if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: + # process success res + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.FINISHED + self._process_success_task(task, res_json.get("data"), e_context) + return + else: + res_json = res.json() + logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}") + max_retry_times -= 20 + except Exception as e: + max_retry_times -= 20 + logger.warn(e) + logger.warn("[MJ] end from poll") + if self.tasks.get(task.id): + self.tasks[task.id].status = Status.EXPIRED + + async def check_task_async(self, task: MJTask, e_context: EventContext): try: logger.debug(f"[MJ] start check task status, {task}") max_retry_times = 90 @@ -232,6 +264,9 @@ class MJBot: except Exception as e: logger.error(e) + def _do_check_task(self, task: MJTask, e_context: EventContext): + threading.Thread(target=self.check_task_sync, args=(task, e_context)).start() + def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext): """ 处理任务成功的结果 From b22994c2d2bf72d16b5cf270dc7623291b3a414c Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sun, 30 Jul 2023 19:55:56 +0800 Subject: [PATCH 13/23] fix: some image bug --- plugins/linkai/config.json.template | 1 + plugins/linkai/midjourney.py | 45 ++++++++++++++++++----------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/plugins/linkai/config.json.template b/plugins/linkai/config.json.template index 98d114b..28fe2ca 100644 --- a/plugins/linkai/config.json.template +++ b/plugins/linkai/config.json.template @@ -7,6 +7,7 @@ "enabled": true, "mode": "relax", "auto_translate": true, + "img_proxy": true, "max_tasks": 3, "max_tasks_per_user": 1, "use_image_create_prefix": true diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 9d92ee6..06fe9b3 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -10,6 +10,7 @@ import asyncio from bridge.context import ContextType from plugins import EventContext, EventAction +INVALID_REQUEST = 410 class TaskType(Enum): GENERATE = "generate" @@ -34,7 +35,8 @@ class TaskMode(Enum): class MJTask: - def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int=60*30, status=Status.PENDING): + def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30, + status=Status.PENDING): self.id = id self.user_id = user_id self.task_type = task_type @@ -48,17 +50,18 @@ class MJTask: def __str__(self): return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}" + # midjourney bot class MJBot: def __init__(self, config): self.base_url = "https://api.link-ai.chat/v1/img/midjourney" + self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} self.config = config self.tasks = {} self.temp_dict = {} self.tasks_lock = threading.Lock() self.event_loop = asyncio.new_event_loop() - # threading.Thread(name="mj-check-thread", target=self._run_loop, args=(self.event_loop,)).start() def judge_mj_task_type(self, e_context: EventContext): """ @@ -79,7 +82,6 @@ class MJBot: elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): return TaskType.GENERATE - def process_mj_task(self, mj_type: TaskType, e_context: EventContext): """ 处理mj任务 @@ -143,13 +145,15 @@ class MJBot: logger.info(f"[MJ] image generate, prompt={prompt}") mode = self._fetch_mode(prompt) body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")} - res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5,15)) + if not self.config.get("img_proxy"): + body["img_proxy"] = False + res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40)) if res.status_code == 200: res = res.json() logger.debug(f"[MJ] image generate, res={res}") if res.get("code") == 200: - task_id = res.get("data").get("taskId") - real_prompt = res.get("data").get("realPrompt") + task_id = res.get("data").get("task_id") + real_prompt = res.get("data").get("real_prompt") if mode == TaskMode.RELAX.value: time_str = "1~10分钟" else: @@ -160,7 +164,8 @@ class MJBot: else: content += f"prompt: {prompt}" reply = Reply(ReplyType.INFO, content) - task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, task_type=TaskType.GENERATE) + task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id, + task_type=TaskType.GENERATE) # put to memory dict self.tasks[task.id] = task # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) @@ -169,18 +174,23 @@ class MJBot: else: res_json = res.json() logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}") - reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") + if res.status_code == INVALID_REQUEST: + reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容") + else: + reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") return reply def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply: logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") - body = {"type": TaskType.UPSCALE.name, "imgId": img_id, "index": index} - res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5,15)) + body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index} + if not self.config.get("img_proxy"): + body["img_proxy"] = False + res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40)) logger.debug(res) if res.status_code == 200: res = res.json() if res.get("code") == 200: - task_id = res.get("data").get("taskId") + task_id = res.get("data").get("task_id") logger.info(f"[MJ] image upscale processing, task_id={task_id}") content = f"🔎图片正在放大中,请耐心等待" reply = Reply(ReplyType.INFO, content) @@ -208,7 +218,7 @@ class MJBot: time.sleep(10) url = f"{self.base_url}/tasks/{task.id}" try: - res = requests.get(url, headers=self.headers, timeout=5) + res = requests.get(url, headers=self.headers, timeout=8) if res.status_code == 200: res_json = res.json() logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, " @@ -219,6 +229,7 @@ class MJBot: self.tasks[task.id].status = Status.FINISHED self._process_success_task(task, res_json.get("data"), e_context) return + max_retry_times -= 1 else: res_json = res.json() logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}") @@ -276,8 +287,8 @@ class MJBot: """ # channel send img task.status = Status.FINISHED - task.img_id = res.get("imgId") - task.img_url = res.get("imgUrl") + task.img_id = res.get("img_id") + task.img_url = res.get("img_url") logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}") # send img @@ -338,7 +349,7 @@ class MJBot: for id in self.tasks: logger.debug(f"[MJ] current task: {self.tasks[id]}") - def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType=ReplyType.ERROR): + def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR): """ 设置回复文本 :param content: 回复内容 @@ -358,7 +369,7 @@ class MJBot: return help_text - def find_tasks_by_user_id(self, user_id) -> list[MJTask]: + def find_tasks_by_user_id(self, user_id) -> list: result = [] with self.tasks_lock: now = time.time() @@ -377,4 +388,4 @@ def check_prefix(content, prefix_list): for prefix in prefix_list: if content.startswith(prefix): return prefix - return None \ No newline at end of file + return None From 9bd7d09f2029dbd2ce554fc5e72f5f5895ca3cc6 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Mon, 31 Jul 2023 14:42:50 +0800 Subject: [PATCH 14/23] fix: remove relax mode temporarily --- plugins/linkai/config.json.template | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/linkai/config.json.template b/plugins/linkai/config.json.template index 28fe2ca..8e6b22c 100644 --- a/plugins/linkai/config.json.template +++ b/plugins/linkai/config.json.template @@ -5,7 +5,6 @@ }, "midjourney": { "enabled": true, - "mode": "relax", "auto_translate": true, "img_proxy": true, "max_tasks": 3, From cda21acb4320864ae228fc0a495d2f444b9905e7 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Mon, 31 Jul 2023 16:11:33 +0800 Subject: [PATCH 15/23] feat: use new linkai completion api --- README.md | 2 +- bot/linkai/link_ai_bot.py | 55 ++++++++++++++++++++++----------------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 341c734..5179827 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ pip3 install azure-cognitiveservices-speech { "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY "model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 - "proxy": "127.0.0.1:7890", # 代理客户端的ip和端口 + "proxy": "", # 代理客户端的ip和端口,国内网络环境需要填该项,如 "127.0.0.1:7890" "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 diff --git a/bot/linkai/link_ai_bot.py b/bot/linkai/link_ai_bot.py index 8b8ca8b..c95113a 100644 --- a/bot/linkai/link_ai_bot.py +++ b/bot/linkai/link_ai_bot.py @@ -29,18 +29,24 @@ class LinkAIBot(Bot, OpenAIImage): if context.type == ContextType.TEXT: return self._chat(query, context) elif context.type == ContextType.IMAGE_CREATE: - ok, retstring = self.create_img(query, 0) - reply = None + ok, res = self.create_img(query, 0) if ok: - reply = Reply(ReplyType.IMAGE_URL, retstring) + reply = Reply(ReplyType.IMAGE_URL, res) else: - reply = Reply(ReplyType.ERROR, retstring) + reply = Reply(ReplyType.ERROR, res) return reply else: reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) return reply - def _chat(self, query, context, retry_count=0): + def _chat(self, query, context, retry_count=0) -> Reply: + """ + 发起对话请求 + :param query: 请求提示词 + :param context: 对话上下文 + :param retry_count: 当前递归重试次数 + :return: 回复 + """ if retry_count >= 2: # exit from retry 2 times logger.warn("[LINKAI] failed after maximum number of retry times") @@ -63,10 +69,8 @@ class LinkAIBot(Bot, OpenAIImage): if app_code and session.messages[0].get("role") == "system": session.messages.pop(0) - logger.info(f"[LINKAI] query={query}, app_code={app_code}") - body = { - "appCode": app_code, + "app_code": app_code, "messages": session.messages, "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 "temperature": conf().get("temperature"), @@ -74,31 +78,34 @@ class LinkAIBot(Bot, OpenAIImage): "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 } + logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}") headers = {"Authorization": "Bearer " + linkai_api_key} # do http request - res = requests.post(url=self.base_url + "/chat/completion", json=body, headers=headers).json() - - if not res or not res["success"]: - if res.get("code") == self.AUTH_FAILED_CODE: - logger.exception(f"[LINKAI] please check your linkai_api_key, res={res}") - return Reply(ReplyType.ERROR, "请再问我一次吧") + res = requests.post(url=self.base_url + "/chat/completions", json=body, headers=headers, + timeout=conf().get("request_timeout", 180)) + if res.status_code == 200: + # execute success + response = res.json() + reply_content = response["choices"][0]["message"]["content"] + total_tokens = response["usage"]["total_tokens"] + logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}") + self.sessions.session_reply(reply_content, session_id, total_tokens) + return Reply(ReplyType.TEXT, reply_content) - elif res.get("code") == self.NO_QUOTA_CODE: - logger.exception(f"[LINKAI] please check your account quota, https://chat.link-ai.tech/console/account") - return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧") + else: + response = res.json() + error = response.get("error") + logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, " + f"msg={error.get('message')}, type={error.get('type')}") - else: - # retry + if res.status_code >= 500: + # server error, need retry time.sleep(2) logger.warn(f"[LINKAI] do retry, times={retry_count}") return self._chat(query, context, retry_count + 1) - # execute success - reply_content = res["data"]["content"] - logger.info(f"[LINKAI] reply={reply_content}") - self.sessions.session_reply(reply_content, session_id) - return Reply(ReplyType.TEXT, reply_content) + return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧") except Exception as e: logger.exception(e) From d689d204822b99befd97426ed580619a95b0e79c Mon Sep 17 00:00:00 2001 From: zhayujie Date: Mon, 31 Jul 2023 17:52:05 +0800 Subject: [PATCH 16/23] docs: update README.md --- README.md | 2 +- plugins/config.json.template | 14 +++++++++ plugins/linkai/README.md | 58 ++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5179827..083c1d2 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ pip3 install azure-cognitiveservices-speech { "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY "model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 - "proxy": "", # 代理客户端的ip和端口,国内网络环境需要填该项,如 "127.0.0.1:7890" + "proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890" "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 diff --git a/plugins/config.json.template b/plugins/config.json.template index 5c2b19b..3334a62 100644 --- a/plugins/config.json.template +++ b/plugins/config.json.template @@ -20,5 +20,19 @@ "no_default": false, "model_name": "gpt-3.5-turbo" } + }, + "linkai": { + "group_app_map": { + "测试群1": "default", + "测试群2": "Kv2fXJcH" + }, + "midjourney": { + "enabled": true, + "auto_translate": true, + "img_proxy": true, + "max_tasks": 3, + "max_tasks_per_user": 1, + "use_image_create_prefix": true + } } } diff --git a/plugins/linkai/README.md b/plugins/linkai/README.md index e69de29..02dfae2 100644 --- a/plugins/linkai/README.md +++ b/plugins/linkai/README.md @@ -0,0 +1,58 @@ +## 插件说明 + +基于 LinkAI 提供的知识库、Midjourney绘画等能力对机器人的功能进行增强。地址: https://chat.link-ai.tech/console + +## 插件配置 + +将 `plugins/linkai` 下的 `config.json.template` 复制为 `config.json`。如果是`docker`部署,可通过映射 plugins/config.json 来完成配置。以下是配置项说明: + +```bash +{ + "group_app_map": { # 群聊 和 应用编码 的映射关系 + "测试群1": "default", # 表示在名称为 "测试群1" 的群聊中将使用app_code 为 default 的应用 + "测试群2": "Kv2fXJcH" + }, + "midjourney": { + "enabled": true, # midjourney 绘画开关 + "auto_translate": true, # 是否自动将提示词翻译为英文 + "img_proxy": true, # 是否对生成的图片使用代理,如果你是国外服务器,将这一项设置为false会获得更快的生成速度 + "max_tasks": 3, # 支持同时提交的总任务个数 + "max_tasks_per_user": 1, # 支持单个用户同时提交的任务个数 + "use_image_create_prefix": true # 是否使用全局的绘画触发词,如果开启将同时支持由`config.json`中的 image_create_prefix 配置触发 + } +} + +``` +注意:实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释。 + +## 插件使用 + +> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖于全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;midjourney绘画功能则只需填写 `linkai_api_key` 配置。 + +完成配置后运行项目,会自动运行插件,输入 `#help linkai` 可查看插件功能。 + +### 1.知识库管理功能 + +提供在不同群聊使用不同应用的功能。可以在上述 `group_app_map` 配置中固定映射关系,也可以通过指令在群中快速完成切换。 + +应用切换指令需要首先完成管理员 (`godcmd`) 插件的认证,然后按以下格式输入: + +`$linkai app {app_code}` + +例如输入 `$linkai app Kv2fXJcH`,即将当前群聊与 app_code为 Kv2fXJcH 的应用绑定。 + +### 2.Midjourney绘画功能 + +指令格式: + +``` + - 图片生成: $mj 描述词1, 描述词2.. + - 图片放大: $mju 图片ID 图片序号 +``` + +例如: + +``` +"$mj a little cat, white --ar 9:16" +"$mju 1105592717188272288 2" +``` From ca916b7ce52fe07d00e40c3af1ffb979b4e1df82 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Mon, 31 Jul 2023 21:40:50 +0800 Subject: [PATCH 17/23] fix: default to fast mode --- plugins/linkai/midjourney.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 06fe9b3..506a8bc 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -335,7 +335,7 @@ class MJBot: mode = self.config.get("mode") if "--relax" in prompt or mode == TaskMode.RELAX.value: return TaskMode.RELAX.value - return mode or TaskMode.RELAX.value + return mode or TaskMode.FAST.value def _run_loop(self, loop: asyncio.BaseEventLoop): """ From 68208f82a0ef94398912cb1da1b407b0c65fb456 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Tue, 1 Aug 2023 00:08:39 +0800 Subject: [PATCH 18/23] docs: update README.md --- plugins/linkai/README.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/plugins/linkai/README.md b/plugins/linkai/README.md index 02dfae2..b2f806d 100644 --- a/plugins/linkai/README.md +++ b/plugins/linkai/README.md @@ -1,10 +1,12 @@ ## 插件说明 -基于 LinkAI 提供的知识库、Midjourney绘画等能力对机器人的功能进行增强。地址: https://chat.link-ai.tech/console +基于 LinkAI 提供的知识库、Midjourney绘画等能力对机器人的功能进行增强。平台地址: https://chat.link-ai.tech/console ## 插件配置 -将 `plugins/linkai` 下的 `config.json.template` 复制为 `config.json`。如果是`docker`部署,可通过映射 plugins/config.json 来完成配置。以下是配置项说明: +将 `plugins/linkai` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`: + +以下是配置项说明: ```bash { @@ -23,11 +25,15 @@ } ``` -注意:实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释。 +注意: + + - 配置项中 `group_app_map` 部分是用于映射群聊与LinkAI平台上的应用, `midjourney` 部分是 mj 画图的配置,可根据需要进行填写,未填写配置时默认不开启相应功能 + - 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释 + - 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8) ## 插件使用 -> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖于全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;midjourney绘画功能则只需填写 `linkai_api_key` 配置。 +> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;而midjourney绘画功能则只需填写 `linkai_api_key` 配置。具体可参考 [详细文档](https://link-ai.tech/platform/link-app/wechat)。 完成配置后运行项目,会自动运行插件,输入 `#help linkai` 可查看插件功能。 @@ -56,3 +62,5 @@ "$mj a little cat, white --ar 9:16" "$mju 1105592717188272288 2" ``` + +注:开启 `use_image_create_prefix` 配置后可直接复用全局画图触发词,以"画"开头便可以生成图片。 \ No newline at end of file From 2386eb8fc2996bd876a0f35d471d669d87460425 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sun, 6 Aug 2023 15:44:48 +0800 Subject: [PATCH 19/23] fix: unable to use plugin when group nickname is set --- channel/chat_channel.py | 8 ++++++-- channel/chat_message.py | 5 ++--- channel/wechat/wechat_message.py | 6 +++++- plugins/linkai/README.md | 2 +- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/channel/chat_channel.py b/channel/chat_channel.py index ffb49c8..911170d 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -108,8 +108,12 @@ class ChatChannel(Channel): if not conf().get("group_at_off", False): flag = True pattern = f"@{re.escape(self.name)}(\u2005|\u0020)" - content = re.sub(pattern, r"", content) - + subtract_res = re.sub(pattern, r"", content) + if subtract_res == content and context["msg"].self_display_name: + # 前缀移除后没有变化,使用群昵称再次移除 + pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)" + subtract_res = re.sub(pattern, r"", content) + content = subtract_res if not flag: if context["origin_ctype"] == ContextType.VOICE: logger.info("[WX]receive group voice, but checkprefix didn't match") diff --git a/channel/chat_message.py b/channel/chat_message.py index 0e2f652..c1b025d 100644 --- a/channel/chat_message.py +++ b/channel/chat_message.py @@ -24,9 +24,7 @@ is_at: 是否被at - (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在) actual_user_id: 实际发送者id (群聊必填) actual_user_nickname:实际发送者昵称 - - - +self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称 _prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等, _prepared: 是否已经调用过准备函数 @@ -49,6 +47,7 @@ class ChatMessage(object): other_user_id = None other_user_nickname = None my_msg = False + self_display_name = None is_group = False is_at = False diff --git a/channel/wechat/wechat_message.py b/channel/wechat/wechat_message.py index b9824f9..7c71a1e 100644 --- a/channel/wechat/wechat_message.py +++ b/channel/wechat/wechat_message.py @@ -57,7 +57,8 @@ class WechatMessage(ChatMessage): self.from_user_nickname = nickname if self.to_user_id == user_id: self.to_user_nickname = nickname - try: # 陌生人时候, 'User'字段可能不存在 + try: # 陌生人时候, User字段可能不存在 + # my_msg 为True是表示是自己发送的消息 self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \ itchat_msg["ToUserName"] != itchat_msg["FromUserName"] self.other_user_id = itchat_msg["User"]["UserName"] @@ -66,6 +67,9 @@ class WechatMessage(ChatMessage): self.from_user_nickname = self.other_user_nickname if self.other_user_id == self.to_user_id: self.to_user_nickname = self.other_user_nickname + if itchat_msg["User"].get("Self"): + # 自身的展示名,当设置了群昵称时,该字段表示群昵称 + self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName") except KeyError as e: # 处理偶尔没有对方信息的情况 logger.warn("[WX]get other_user_id failed: " + str(e)) if self.from_user_id == user_id: diff --git a/plugins/linkai/README.md b/plugins/linkai/README.md index b2f806d..3397e0f 100644 --- a/plugins/linkai/README.md +++ b/plugins/linkai/README.md @@ -33,7 +33,7 @@ ## 插件使用 -> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;而midjourney绘画功能则只需填写 `linkai_api_key` 配置。具体可参考 [详细文档](https://link-ai.tech/platform/link-app/wechat)。 +> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;而midjourney绘画功能则只需填写 `linkai_api_key` 配置,`use_linkai` 无论是否关闭均可使用。具体可参考 [详细文档](https://link-ai.tech/platform/link-app/wechat)。 完成配置后运行项目,会自动运行插件,输入 `#help linkai` 可查看插件功能。 From 395edbd9f47e948543fdeabd95e4c7a6632246b9 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sun, 6 Aug 2023 16:02:02 +0800 Subject: [PATCH 20/23] fix: only filter messages sent by the bot itself in private chat --- channel/wechat/wechat_channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index c1ceacd..34bb9b8 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -58,7 +58,7 @@ def _check(func): if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 logger.debug("[WX]history message {} skipped".format(msgId)) return - if cmsg.my_msg: + if cmsg.my_msg and not cmsg.is_group: logger.debug("[WX]my message {} skipped".format(msgId)) return return func(self, cmsg) From 8abf18ab2561222281f6a0c920557ed91eadf5e7 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Sun, 6 Aug 2023 17:57:07 +0800 Subject: [PATCH 21/23] feat: add knowledge base and midjourney switch instruction --- bridge/bridge.py | 6 +++++ plugins/linkai/linkai.py | 50 ++++++++++++++++++++++++++---------- plugins/linkai/midjourney.py | 19 +++++++++++++- 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/bridge/bridge.py b/bridge/bridge.py index d3fbd95..524cb8c 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -54,3 +54,9 @@ class Bridge(object): def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply: return self.get_bot("translate").translate(text, from_lang, to_lang) + + def reset_bot(self): + """ + 重置bot路由 + """ + self.__init__() diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index 1031a24..9f21e60 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -1,19 +1,10 @@ -import asyncio -import json -import threading -from concurrent.futures import ThreadPoolExecutor - import plugins from bridge.context import ContextType from bridge.reply import Reply, ReplyType -from channel.chat_message import ChatMessage -from common.log import logger -from config import conf, global_config +from config import global_config from plugins import * -from .midjourney import MJBot, TaskType - -# 任务线程池 -task_thread_pool = ThreadPoolExecutor(max_workers=4) +from .midjourney import MJBot +from bridge import bridge @plugins.register( @@ -66,11 +57,28 @@ class LinkAI(Plugin): if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"): _set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO) return + + if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"): + # 知识库开关指令 + if not _is_admin(e_context): + _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR) + return + is_open = True + tips_text = "开启" + if cmd[1] == "close": + tips_text = "关闭" + is_open = False + conf()["use_linkai"] = is_open + bridge.Bridge().reset_bot() + _set_reply_text(f"知识库功能已{tips_text}", e_context, level=ReplyType.INFO) + return + if len(cmd) == 3 and cmd[1] == "app": + # 知识库应用切换指令 if not context.kwargs.get("isgroup"): _set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR) return - if context.kwargs.get("msg").actual_user_id not in global_config["admin_users"]: + if not _is_admin(e_context): _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR) return app_code = cmd[2] @@ -84,7 +92,8 @@ class LinkAI(Plugin): super().save_config(self.config) _set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO) else: - _set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context, level=ReplyType.INFO) + _set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context, + level=ReplyType.INFO) return # LinkAI 对话任务处理 @@ -127,6 +136,19 @@ class LinkAI(Plugin): # 静态方法 +def _is_admin(e_context: EventContext) -> bool: + """ + 判断消息是否由管理员用户发送 + :param e_context: 消息上下文 + :return: True: 是, False: 否 + """ + context = e_context["context"] + if context["isgroup"]: + return context.kwargs.get("msg").actual_user_id in global_config["admin_users"] + else: + return context["receiver"] in global_config["admin_users"] + + def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR): reply = Reply(level, content) e_context["reply"] = reply diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 506a8bc..9512db7 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -69,7 +69,7 @@ class MJBot: :param e_context: 上下文 :return: 任务类型枚举 """ - if not self.config or not self.config.get("enabled"): + if not self.config: return None trigger_prefix = conf().get("plugin_trigger_prefix", "$") context = e_context['context'] @@ -92,9 +92,26 @@ class MJBot: session_id = context["session_id"] cmd = context.content.split(maxsplit=1) if len(cmd) == 1 and context.type == ContextType.TEXT: + # midjourney 帮助指令 self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO) return + if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"): + # midjourney 开关指令 + is_open = True + tips_text = "开启" + if cmd[1] == "close": + tips_text = "关闭" + is_open = False + self.config["enabled"] = is_open + self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO) + return + + if not self.config.get("enabled"): + logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置") + self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO) + return + if not self._check_rate_limit(session_id, e_context): logger.warn("[MJ] midjourney task exceed rate limit") return From 5176b56d3b14ee6eda412c40559cec590c028666 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Mon, 7 Aug 2023 14:42:24 +0800 Subject: [PATCH 22/23] fix: global plugin read encoding --- plugins/plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/plugin.py b/plugins/plugin.py index 2bb6c26..2e3e465 100644 --- a/plugins/plugin.py +++ b/plugins/plugin.py @@ -19,7 +19,7 @@ class Plugin: # 全局配置不存在 或者 未开启全局配置开关,则获取插件目录下的配置 plugin_config_path = os.path.join(self.path, "config.json") if os.path.exists(plugin_config_path): - with open(plugin_config_path, "r") as f: + with open(plugin_config_path, "r", encoding="utf-8") as f: plugin_conf = json.load(f) logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}") return plugin_conf From 6b247ae880ac3e1210243f1e4d58bc9de5e431eb Mon Sep 17 00:00:00 2001 From: zhayujie Date: Mon, 7 Aug 2023 19:14:09 +0800 Subject: [PATCH 23/23] feat: add midjourney variation and reset --- plugins/linkai/linkai.py | 4 +- plugins/linkai/midjourney.py | 115 +++++++++++++++++++---------------- 2 files changed, 64 insertions(+), 55 deletions(-) diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py index 9f21e60..a71b3b1 100644 --- a/plugins/linkai/linkai.py +++ b/plugins/linkai/linkai.py @@ -131,7 +131,9 @@ class LinkAI(Plugin): if not verbose: return help_text help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n\n例如: \n"$linkai app Kv2fXJcH"\n\n' - help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" + help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID" + help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\"" + help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\"" return help_text diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py index 9512db7..735195a 100644 --- a/plugins/linkai/midjourney.py +++ b/plugins/linkai/midjourney.py @@ -11,6 +11,9 @@ from bridge.context import ContextType from plugins import EventContext, EventAction INVALID_REQUEST = 410 +NOT_FOUND_ORIGIN_IMAGE = 461 +NOT_FOUND_TASK = 462 + class TaskType(Enum): GENERATE = "generate" @@ -18,6 +21,9 @@ class TaskType(Enum): VARIATION = "variation" RESET = "reset" + def __str__(self): + return self.name + class Status(Enum): PENDING = "pending" @@ -34,6 +40,14 @@ class TaskMode(Enum): RELAX = "relax" +task_name_mapping = { + TaskType.GENERATE.name: "生成", + TaskType.UPSCALE.name: "放大", + TaskType.VARIATION.name: "变换", + TaskType.RESET.name: "重新生成", +} + + class MJTask: def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30, status=Status.PENDING): @@ -79,6 +93,10 @@ class MJBot: return TaskType.GENERATE elif cmd_list[0].lower() == f"{trigger_prefix}mju": return TaskType.UPSCALE + elif cmd_list[0].lower() == f"{trigger_prefix}mjv": + return TaskType.VARIATION + elif cmd_list[0].lower() == f"{trigger_prefix}mjr": + return TaskType.RESET elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"): return TaskType.GENERATE @@ -127,8 +145,8 @@ class MJBot: e_context.action = EventAction.BREAK_PASS return - elif mj_type == TaskType.UPSCALE: - # 图片放大 + elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION: + # 图片放大/变换 clist = cmd[1].split() if len(clist) < 2: self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context) @@ -138,16 +156,27 @@ class MJBot: if index < 1 or index > 4: self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context) return - key = f"{TaskType.UPSCALE.name}_{img_id}_{index}" + key = f"{str(mj_type)}_{img_id}_{index}" if self.temp_dict.get(key): - self._set_reply_text(f"第 {index} 张图片已经放大过了", e_context) + self._set_reply_text(f"第 {index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context) return - # 图片放大操作 - reply = self.upscale(session_id, img_id, index, e_context) + # 执行图片放大/变换操作 + reply = self.do_operate(mj_type, session_id, img_id, e_context, index) e_context['reply'] = reply e_context.action = EventAction.BREAK_PASS return + elif mj_type == TaskType.RESET: + # 图片重新生成 + clist = cmd[1].split() + if len(clist) < 1: + self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context) + return + img_id = clist[0] + # 图片重新生成 + reply = self.do_operate(mj_type, session_id, img_id, e_context) + e_context['reply'] = reply + e_context.action = EventAction.BREAK_PASS else: self._set_reply_text(f"暂不支持该命令", e_context) @@ -197,9 +226,12 @@ class MJBot: reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试") return reply - def upscale(self, user_id: str, img_id: str, index: int, e_context: EventContext) -> Reply: - logger.info(f"[MJ] image upscale, img_id={img_id}, index={index}") - body = {"type": TaskType.UPSCALE.name, "img_id": img_id, "index": index} + def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext, + index: int = None) -> Reply: + logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}") + body = {"type": task_type.name, "img_id": img_id} + if index: + body["index"] = index if not self.config.get("img_proxy"): body["img_proxy"] = False res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40)) @@ -208,23 +240,24 @@ class MJBot: res = res.json() if res.get("code") == 200: task_id = res.get("data").get("task_id") - logger.info(f"[MJ] image upscale processing, task_id={task_id}") - content = f"🔎图片正在放大中,请耐心等待" + logger.info(f"[MJ] image operate processing, task_id={task_id}") + icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"} + content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待" reply = Reply(ReplyType.INFO, content) - task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=TaskType.UPSCALE) + task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type) # put to memory dict self.tasks[task.id] = task - key = f"{TaskType.UPSCALE.name}_{img_id}_{index}" + key = f"{task_type.name}_{img_id}_{index}" self.temp_dict[key] = True # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop) self._do_check_task(task, e_context) return reply else: error_msg = "" - if res.status_code == 461: + if res.status_code == NOT_FOUND_ORIGIN_IMAGE: error_msg = "请输入正确的图片ID" res_json = res.json() - logger.error(f"[MJ] upscale error, msg={res_json.get('message')}, status_code={res.status_code}") + logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}") reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试") return reply @@ -258,40 +291,6 @@ class MJBot: if self.tasks.get(task.id): self.tasks[task.id].status = Status.EXPIRED - async def check_task_async(self, task: MJTask, e_context: EventContext): - try: - logger.debug(f"[MJ] start check task status, {task}") - max_retry_times = 90 - while max_retry_times > 0: - await asyncio.sleep(10) - async with aiohttp.ClientSession() as session: - url = f"{self.base_url}/tasks/{task.id}" - try: - async with session.get(url, headers=self.headers) as res: - if res.status == 200: - res_json = await res.json() - logger.debug(f"[MJ] task check res, task_id={task.id}, status={res.status}, " - f"data={res_json.get('data')}, thread={threading.current_thread().name}") - if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name: - # process success res - if self.tasks.get(task.id): - self.tasks[task.id].status = Status.FINISHED - self._process_success_task(task, res_json.get("data"), e_context) - return - else: - res_json = await res.json() - logger.warn(f"[MJ] image check error, status_code={res.status}, res={res_json}") - max_retry_times -= 20 - except Exception as e: - max_retry_times -= 20 - logger.warn(e) - max_retry_times -= 1 - logger.warn("[MJ] end from poll") - if self.tasks.get(task.id): - self.tasks[task.id].status = Status.EXPIRED - except Exception as e: - logger.error(e) - def _do_check_task(self, task: MJTask, e_context: EventContext): threading.Thread(target=self.check_task_sync, args=(task, e_context)).start() @@ -316,10 +315,17 @@ class MJBot: # send info trigger_prefix = conf().get("plugin_trigger_prefix", "$") text = "" - if task.task_type == TaskType.GENERATE: - text = f"🎨绘画完成!\nprompt: {task.raw_prompt}\n- - - - - - - - -\n图片ID: {task.img_id}" - text += f"\n\n🔎可使用 {trigger_prefix}mju 命令放大指定图片\n" + if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET: + text = f"🎨绘画完成!\n" + if task.raw_prompt: + text += f"prompt: {task.raw_prompt}\n" + text += f"- - - - - - - - -\n图片ID: {task.img_id}" + text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n" text += f"例如:\n{trigger_prefix}mju {task.img_id} 1" + text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n" + text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1" + text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n" + text += f"例如:\n{trigger_prefix}mjr {task.img_id}" reply = Reply(ReplyType.INFO, text) channel._send(reply, e_context["context"]) @@ -382,8 +388,9 @@ class MJBot: help_text = "🎨利用Midjourney进行画图\n\n" if not verbose: return help_text - help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 1105592717188272288 2\"" - + help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID" + help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\"" + help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\"" return help_text def find_tasks_by_user_id(self, user_id) -> list: