@@ -86,14 +86,14 @@ class ChatChannel(Channel): | |||
if e_context.is_pass() or context is None: | |||
return context | |||
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True): | |||
logger.debug("[WX]self message skipped") | |||
logger.debug("[chat_channel]self message skipped") | |||
return None | |||
# 消息内容匹配过程,并处理content | |||
if ctype == ContextType.TEXT: | |||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息 | |||
logger.debug(content) | |||
logger.debug("[WX]reference query skipped") | |||
logger.debug("[chat_channel]reference query skipped") | |||
return None | |||
nick_name_black_list = conf().get("nick_name_black_list", []) | |||
@@ -111,10 +111,10 @@ class ChatChannel(Channel): | |||
nick_name = context["msg"].actual_user_nickname | |||
if nick_name and nick_name in nick_name_black_list: | |||
# 黑名单过滤 | |||
logger.warning(f"[WX] Nickname {nick_name} in In BlackList, ignore") | |||
logger.warning(f"[chat_channel] Nickname {nick_name} in In BlackList, ignore") | |||
return None | |||
logger.info("[WX]receive group at") | |||
logger.info("[chat_channel]receive group at") | |||
if not conf().get("group_at_off", False): | |||
flag = True | |||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)" | |||
@@ -130,13 +130,13 @@ class ChatChannel(Channel): | |||
content = subtract_res | |||
if not flag: | |||
if context["origin_ctype"] == ContextType.VOICE: | |||
logger.info("[WX]receive group voice, but checkprefix didn't match") | |||
logger.info("[chat_channel]receive group voice, but checkprefix didn't match") | |||
return None | |||
else: # 单聊 | |||
nick_name = context["msg"].from_user_nickname | |||
if nick_name and nick_name in nick_name_black_list: | |||
# 黑名单过滤 | |||
logger.warning(f"[WX] Nickname '{nick_name}' in In BlackList, ignore") | |||
logger.warning(f"[chat_channel] Nickname '{nick_name}' in In BlackList, ignore") | |||
return None | |||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""])) | |||
@@ -147,7 +147,7 @@ class ChatChannel(Channel): | |||
else: | |||
return None | |||
content = content.strip() | |||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix")) | |||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""])) | |||
if img_match_prefix: | |||
content = content.replace(img_match_prefix, "", 1) | |||
context.type = ContextType.IMAGE_CREATE | |||
@@ -159,17 +159,16 @@ class ChatChannel(Channel): | |||
elif context.type == ContextType.VOICE: | |||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||
context["desire_rtype"] = ReplyType.VOICE | |||
return context | |||
def _handle(self, context: Context): | |||
if context is None or not context.content: | |||
return | |||
logger.debug("[WX] ready to handle context: {}".format(context)) | |||
logger.debug("[chat_channel] ready to handle context: {}".format(context)) | |||
# reply的构建步骤 | |||
reply = self._generate_reply(context) | |||
logger.debug("[WX] ready to decorate reply: {}".format(reply)) | |||
logger.debug("[chat_channel] ready to decorate reply: {}".format(reply)) | |||
# reply的包装步骤 | |||
if reply and reply.content: | |||
@@ -187,7 +186,7 @@ class ChatChannel(Channel): | |||
) | |||
reply = e_context["reply"] | |||
if not e_context.is_pass(): | |||
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content)) | |||
logger.debug("[chat_channel] ready to handle context: type={}, content={}".format(context.type, context.content)) | |||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 | |||
context["channel"] = e_context["channel"] | |||
reply = super().build_reply_content(context.content, context) | |||
@@ -199,7 +198,7 @@ class ChatChannel(Channel): | |||
try: | |||
any_to_wav(file_path, wav_path) | |||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别 | |||
logger.warning("[WX]any to wav error, use raw path. " + str(e)) | |||
logger.warning("[chat_channel]any to wav error, use raw path. " + str(e)) | |||
wav_path = file_path | |||
# 语音识别 | |||
reply = super().build_voice_to_text(wav_path) | |||
@@ -210,7 +209,7 @@ class ChatChannel(Channel): | |||
os.remove(wav_path) | |||
except Exception as e: | |||
pass | |||
# logger.warning("[WX]delete temp file error: " + str(e)) | |||
# logger.warning("[chat_channel]delete temp file error: " + str(e)) | |||
if reply.type == ReplyType.TEXT: | |||
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs) | |||
@@ -228,7 +227,7 @@ class ChatChannel(Channel): | |||
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑 | |||
pass | |||
else: | |||
logger.warning("[WX] unknown context type: {}".format(context.type)) | |||
logger.warning("[chat_channel] unknown context type: {}".format(context.type)) | |||
return | |||
return reply | |||
@@ -244,7 +243,7 @@ class ChatChannel(Channel): | |||
desire_rtype = context.get("desire_rtype") | |||
if not e_context.is_pass() and reply and reply.type: | |||
if reply.type in self.NOT_SUPPORT_REPLYTYPE: | |||
logger.error("[WX]reply type not support: " + str(reply.type)) | |||
logger.error("[chat_channel]reply type not support: " + str(reply.type)) | |||
reply.type = ReplyType.ERROR | |||
reply.content = "不支持发送的消息类型: " + str(reply.type) | |||
@@ -265,10 +264,10 @@ class ChatChannel(Channel): | |||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL: | |||
pass | |||
else: | |||
logger.error("[WX] unknown reply type: {}".format(reply.type)) | |||
logger.error("[chat_channel] unknown reply type: {}".format(reply.type)) | |||
return | |||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: | |||
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type)) | |||
logger.warning("[chat_channel] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type)) | |||
return reply | |||
def _send_reply(self, context: Context, reply: Reply): | |||
@@ -281,14 +280,14 @@ class ChatChannel(Channel): | |||
) | |||
reply = e_context["reply"] | |||
if not e_context.is_pass() and reply and reply.type: | |||
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context)) | |||
logger.debug("[chat_channel] ready to send reply: {}, context: {}".format(reply, context)) | |||
self._send(reply, context) | |||
def _send(self, reply: Reply, context: Context, retry_cnt=0): | |||
try: | |||
self.send(reply, context) | |||
except Exception as e: | |||
logger.error("[WX] sendMsg error: {}".format(str(e))) | |||
logger.error("[chat_channel] sendMsg error: {}".format(str(e))) | |||
if isinstance(e, NotImplementedError): | |||
return | |||
logger.exception(e) | |||
@@ -342,7 +341,7 @@ class ChatChannel(Channel): | |||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除 | |||
if not context_queue.empty(): | |||
context = context_queue.get() | |||
logger.debug("[WX] consume context: {}".format(context)) | |||
logger.debug("[chat_channel] consume context: {}".format(context)) | |||
future: Future = handler_pool.submit(self._handle, context) | |||
future.add_done_callback(self._thread_pool_callback(session_id, context=context)) | |||
if session_id not in self.futures: | |||
@@ -4,20 +4,81 @@ | |||
@author huiwen | |||
@Date 2023/11/28 | |||
""" | |||
import copy | |||
import json | |||
# -*- coding=utf-8 -*- | |||
import logging | |||
import time | |||
import dingtalk_stream | |||
from dingtalk_stream import AckMessage | |||
from dingtalk_stream.card_replier import AICardReplier | |||
from dingtalk_stream.card_replier import AICardStatus | |||
from dingtalk_stream.card_replier import CardReplier | |||
from bridge.context import Context, ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from channel.chat_channel import ChatChannel | |||
from channel.dingtalk.dingtalk_message import DingTalkMessage | |||
from bridge.context import Context | |||
from bridge.reply import Reply | |||
from common.expired_dict import ExpiredDict | |||
from common.log import logger | |||
from common.singleton import singleton | |||
from common.time_check import time_checker | |||
from config import conf | |||
from common.expired_dict import ExpiredDict | |||
from bridge.context import ContextType | |||
from channel.chat_channel import ChatChannel | |||
import logging | |||
from dingtalk_stream import AckMessage | |||
import dingtalk_stream | |||
class CustomAICardReplier(CardReplier): | |||
def __init__(self, dingtalk_client, incoming_message): | |||
super(AICardReplier, self).__init__(dingtalk_client, incoming_message) | |||
def start( | |||
self, | |||
card_template_id: str, | |||
card_data: dict, | |||
recipients: list = None, | |||
support_forward: bool = True, | |||
) -> str: | |||
""" | |||
AI卡片的创建接口 | |||
:param support_forward: | |||
:param recipients: | |||
:param card_template_id: | |||
:param card_data: | |||
:return: | |||
""" | |||
card_data_with_status = copy.deepcopy(card_data) | |||
card_data_with_status["flowStatus"] = AICardStatus.PROCESSING | |||
return self.create_and_send_card( | |||
card_template_id, | |||
card_data_with_status, | |||
at_sender=True, | |||
at_all=False, | |||
recipients=recipients, | |||
support_forward=support_forward, | |||
) | |||
# 对 AICardReplier 进行猴子补丁 | |||
AICardReplier.start = CustomAICardReplier.start | |||
def _check(func): | |||
def wrapper(self, cmsg: DingTalkMessage): | |||
msgId = cmsg.msg_id | |||
if msgId in self.receivedMsgs: | |||
logger.info("DingTalk message {} already received, ignore".format(msgId)) | |||
return | |||
self.receivedMsgs[msgId] = True | |||
create_time = cmsg.create_time # 消息时间戳 | |||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | |||
logger.debug("[DingTalk] History message {} skipped".format(msgId)) | |||
return | |||
if cmsg.my_msg and not cmsg.is_group: | |||
logger.debug("[DingTalk] My message {} skipped".format(msgId)) | |||
return | |||
return func(self, cmsg) | |||
return wrapper | |||
@singleton | |||
@@ -39,11 +100,13 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler): | |||
super(dingtalk_stream.ChatbotHandler, self).__init__() | |||
self.logger = self.setup_logger() | |||
# 历史消息id暂存,用于幂等控制 | |||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1) | |||
logger.info("[dingtalk] client_id={}, client_secret={} ".format( | |||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds")) | |||
logger.info("[DingTalk] client_id={}, client_secret={} ".format( | |||
self.dingtalk_client_id, self.dingtalk_client_secret)) | |||
# 无需群校验和前缀 | |||
conf()["group_name_white_list"] = ["ALL_GROUP"] | |||
# 单聊无需前缀 | |||
conf()["single_chat_prefix"] = [""] | |||
def startup(self): | |||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret) | |||
@@ -51,50 +114,107 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler): | |||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self) | |||
client.start_forever() | |||
async def process(self, callback: dingtalk_stream.CallbackMessage): | |||
try: | |||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data) | |||
image_download_handler = self # 传入方法所在的类实例 | |||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler) | |||
if dingtalk_msg.is_group: | |||
self.handle_group(dingtalk_msg) | |||
else: | |||
self.handle_single(dingtalk_msg) | |||
return AckMessage.STATUS_OK, 'OK' | |||
except Exception as e: | |||
logger.error(f"dingtalk process error={e}") | |||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR' | |||
@time_checker | |||
@_check | |||
def handle_single(self, cmsg: DingTalkMessage): | |||
# 处理单聊消息 | |||
if cmsg.ctype == ContextType.VOICE: | |||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content)) | |||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.IMAGE: | |||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content)) | |||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.IMAGE_CREATE: | |||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.PATPAT: | |||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content)) | |||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.TEXT: | |||
expression = cmsg.my_msg | |||
cmsg.content = conf()["single_chat_prefix"][0] + cmsg.content | |||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content)) | |||
else: | |||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content)) | |||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg) | |||
if context: | |||
self.produce(context) | |||
@time_checker | |||
@_check | |||
def handle_group(self, cmsg: DingTalkMessage): | |||
# 处理群聊消息 | |||
if cmsg.ctype == ContextType.VOICE: | |||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content)) | |||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.IMAGE: | |||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content)) | |||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.IMAGE_CREATE: | |||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.PATPAT: | |||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content)) | |||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.TEXT: | |||
expression = cmsg.my_msg | |||
cmsg.content = conf()["group_chat_prefix"][0] + cmsg.content | |||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content)) | |||
else: | |||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content)) | |||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) | |||
context['no_need_at'] = True | |||
if context: | |||
self.produce(context) | |||
async def process(self, callback: dingtalk_stream.CallbackMessage): | |||
try: | |||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data) | |||
dingtalk_msg = DingTalkMessage(incoming_message) | |||
if incoming_message.conversation_type == '1': | |||
self.handle_single(dingtalk_msg) | |||
else: | |||
self.handle_group(dingtalk_msg) | |||
return AckMessage.STATUS_OK, 'OK' | |||
except Exception as e: | |||
logger.error(e) | |||
return self.FAILED_MSG | |||
def send(self, reply: Reply, context: Context): | |||
receiver = context["receiver"] | |||
isgroup = context.kwargs['msg'].is_group | |||
incoming_message = context.kwargs['msg'].incoming_message | |||
self.reply_text(reply.content, incoming_message) | |||
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver)) | |||
def reply_with_text(): | |||
self.reply_text(reply.content, incoming_message) | |||
def reply_with_at_text(): | |||
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message) | |||
def reply_with_ai_markdown(): | |||
button_list, markdown_content = self.generate_button_markdown_content(context, reply) | |||
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI-Bot生成", "",[incoming_message.sender_staff_id]) | |||
if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]: | |||
if isgroup: | |||
reply_with_ai_markdown() | |||
reply_with_at_text() | |||
else: | |||
reply_with_ai_markdown() | |||
else: | |||
# 暂不支持其它类型消息回复 | |||
reply_with_text() | |||
def generate_button_markdown_content(self, context, reply): | |||
image_url = context.kwargs.get("image_url") | |||
promptEn = context.kwargs.get("promptEn") | |||
reply_text = reply.content | |||
button_list = [] | |||
markdown_content = f""" | |||
{reply.content} | |||
""" | |||
if image_url is not None and promptEn is not None: | |||
button_list = [ | |||
{"text": "查看原图", "url": image_url, "iosUrl": image_url, "color": "blue"} | |||
] | |||
markdown_content = f""" | |||
{promptEn} | |||
!["图片"]({image_url}) | |||
{reply_text} | |||
""" | |||
logger.debug(f"[Dingtalk] generate_button_markdown_content, button_list={button_list} , markdown_content={markdown_content}") | |||
return button_list, markdown_content |
@@ -1,44 +1,83 @@ | |||
import os | |||
import requests | |||
from dingtalk_stream import ChatbotMessage | |||
from bridge.context import ContextType | |||
from channel.chat_message import ChatMessage | |||
import json | |||
import requests | |||
# -*- coding=utf-8 -*- | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
from common import utils | |||
from dingtalk_stream import ChatbotMessage | |||
class DingTalkMessage(ChatMessage): | |||
def __init__(self, event: ChatbotMessage): | |||
def __init__(self, event: ChatbotMessage, image_download_handler): | |||
super().__init__(event) | |||
self.image_download_handler = image_download_handler | |||
self.msg_id = event.message_id | |||
msg_type = event.message_type | |||
self.incoming_message =event | |||
self.message_type = event.message_type | |||
self.incoming_message = event | |||
self.sender_staff_id = event.sender_staff_id | |||
self.other_user_id = event.conversation_id | |||
self.create_time = event.create_at | |||
if event.conversation_type=="1": | |||
self.image_content = event.image_content | |||
self.rich_text_content = event.rich_text_content | |||
if event.conversation_type == "1": | |||
self.is_group = False | |||
else: | |||
self.is_group = True | |||
if msg_type == "text": | |||
if self.message_type == "text": | |||
self.ctype = ContextType.TEXT | |||
self.content = event.text.content.strip() | |||
elif msg_type == "audio": | |||
elif self.message_type == "audio": | |||
# 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理 | |||
self.content = event.extensions['content']['recognition'].strip() | |||
self.ctype = ContextType.TEXT | |||
elif (self.message_type == 'picture') or (self.message_type == 'richText'): | |||
self.ctype = ContextType.IMAGE | |||
# 钉钉图片类型或富文本类型消息处理 | |||
image_list = event.get_image_list() | |||
if len(image_list) > 0: | |||
download_code = image_list[0] | |||
download_url = image_download_handler.get_image_download_url(download_code) | |||
self.content = download_image_file(download_url, TmpDir().path()) | |||
else: | |||
logger.debug(f"[Dingtalk] messageType :{self.message_type} , imageList isEmpty") | |||
if self.is_group: | |||
self.from_user_id = event.conversation_id | |||
self.actual_user_id = event.sender_id | |||
else: | |||
self.from_user_id = event.sender_id | |||
self.actual_user_id = event.sender_id | |||
self.to_user_id = event.chatbot_user_id | |||
self.other_user_nickname = event.conversation_title | |||
user_id = event.sender_id | |||
nickname =event.sender_nick | |||
def download_image_file(image_url, temp_dir): | |||
headers = { | |||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36' | |||
} | |||
# 设置代理 | |||
# self.proxies | |||
# , proxies=self.proxies | |||
response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5) | |||
if response.status_code == 200: | |||
# 生成文件名 | |||
file_name = image_url.split("/")[-1].split("?")[0] | |||
# 检查临时目录是否存在,如果不存在则创建 | |||
if not os.path.exists(temp_dir): | |||
os.makedirs(temp_dir) | |||
# 将文件保存到临时目录 | |||
file_path = os.path.join(temp_dir, file_name) | |||
with open(file_path, 'wb') as file: | |||
file.write(response.content) | |||
return file_path | |||
else: | |||
logger.info(f"[Dingtalk] Failed to download image file, {response.content}") | |||
return None |
@@ -40,7 +40,7 @@ class FeiShuChanel(ChatChannel): | |||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token)) | |||
# 无需群校验和前缀 | |||
conf()["group_name_white_list"] = ["ALL_GROUP"] | |||
conf()["single_chat_prefix"] = [] | |||
conf()["single_chat_prefix"] = [""] | |||
def startup(self): | |||
urls = ( | |||
@@ -109,7 +109,7 @@ class WechatChannel(ChatChannel): | |||
def __init__(self): | |||
super().__init__() | |||
self.receivedMsgs = ExpiredDict(60 * 60) | |||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds")) | |||
self.auto_login_times = 0 | |||
def startup(self): | |||
@@ -151,7 +151,7 @@ available_setting = { | |||
# chatgpt指令自定义触发词 | |||
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头 | |||
# channel配置 | |||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app} | |||
"channel_type": "", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app,dingtalk} | |||
"subscribe_msg": "", # 订阅消息, 支持: wechatmp, wechatmp_service, wechatcom_app | |||
"debug": False, # 是否开启debug模式,开启后会打印更多日志 | |||
"appdata_dir": "", # 数据目录 | |||
@@ -1,5 +1,9 @@ | |||
{ | |||
"repo": { | |||
"midjourney": { | |||
"url": "https://github.com/baojingyu/midjourney.git", | |||
"desc": "利用midjourney实现ai绘图的的插件" | |||
}, | |||
"sdwebui": { | |||
"url": "https://github.com/lanvent/plugin_sdwebui.git", | |||
"desc": "利用stable-diffusion画图的插件" | |||