@@ -0,0 +1,39 @@ | |||||
.DS_Store | |||||
.idea | |||||
.vscode | |||||
.venv | |||||
.vs | |||||
.wechaty/ | |||||
__pycache__/ | |||||
venv* | |||||
*.pyc | |||||
config.json | |||||
QR.png | |||||
nohup.out | |||||
tmp | |||||
plugins.json | |||||
itchat.pkl | |||||
*.log | |||||
user_datas.pkl | |||||
chatgpt_tool_hub/ | |||||
plugins/**/ | |||||
!plugins/bdunit | |||||
!plugins/dungeon | |||||
!plugins/finish | |||||
!plugins/godcmd | |||||
!plugins/tool | |||||
!plugins/banwords | |||||
!plugins/banwords/**/ | |||||
plugins/banwords/__pycache__ | |||||
plugins/banwords/lib/__pycache__ | |||||
!plugins/hello | |||||
!plugins/role | |||||
!plugins/keyword | |||||
!plugins/linkai | |||||
!plugins/healthai | |||||
client_config.json | |||||
!config.json | |||||
!plugins.json | |||||
tmp/ | |||||
logs/ | |||||
cmd.txt |
@@ -0,0 +1,29 @@ | |||||
#!/bin/bash | |||||
environment=$1 | |||||
version=$2 | |||||
echo "环境变量为${environment},版本为$version!" | |||||
if [[ ${environment} == 'production' ]]; then | |||||
echo "开始远程构建容器" | |||||
docker stop ai-ops-wx || true | |||||
docker rm ai-ops-wx || true | |||||
docker rmi -f $(docker images | grep registry.cn-shanghai.aliyuncs.com/gps_card/ai-ops-wechat | awk '{print $3}') | |||||
#docker login --username=telpo_linwl@1111649216405698 --password=telpo#1234 registry.cn-shanghai.aliyuncs.com | |||||
docker login --username=rzl_wangjx@1111649216405698 --password=telpo.123 registry.cn-shanghai.aliyuncs.com | |||||
docker pull registry.cn-shanghai.aliyuncs.com/gps_card/ai-ops-wechat:$version | |||||
docker run -p 6560:5000 -d -e environment=production -v /home/data/ai-ops-wx/logs:/app/logs -v /home/data/ai-ops-wx/tmp:/app/tmp --restart=always --name ai-ops-wx registry.cn-shanghai.aliyuncs.com/gps_card/ai-ops-wechat:$version; | |||||
#删除产生的None镜像 | |||||
docker rmi -f $(docker images | grep none | awk '{print $3}') | |||||
docker ps -a | |||||
elif [[ ${environment} == 'test' || ${environment} == 'presure' ]]; then | |||||
echo "开始在测试环境远程构建容器" | |||||
docker stop ai-ops-wx || true | |||||
docker rm ai-ops-wx || true | |||||
docker rmi -f $(docker images | grep 139.224.254.18:5000/ssjl/ai-ops-wechat | awk '{print $3}') | |||||
docker pull 139.224.254.18:5000/ssjl/ai-ops-wechat:$version | |||||
docker run -p 6560:5000 -d -e environment=test -v /home/data/ai-ops-wx/logs:/app/logs -v /home/data/ai-ops-wx/tmp:/app/tmp --restart=always --name ai-ops-wx 139.224.254.18:5000/ssjl/ai-ops-wechat:$version; | |||||
#删除产生的None镜像 | |||||
docker rmi -f $(docker images | grep none | awk '{print $3}') | |||||
docker ps -a | |||||
fi |
@@ -0,0 +1,54 @@ | |||||
# from celery import Celery | |||||
# # 创建 Celery 应用 | |||||
# celery_app = Celery( | |||||
# 'ai_ops_wechat_app', | |||||
# broker='redis://:telpo%231234@192.168.2.121:8090/3', | |||||
# backend='redis://:telpo%231234@192.168.2.121:8090/3', | |||||
# ) | |||||
# # 配置 Celery | |||||
# celery_app.conf.update( | |||||
# task_serializer='json', | |||||
# accept_content=['json'], | |||||
# result_serializer='json', | |||||
# timezone='Asia/Shanghai', | |||||
# enable_utc=True, | |||||
# ) | |||||
# #celery_app.autodiscover_tasks(['app.tasks']) | |||||
# from celery import Celery | |||||
# def make_celery(app): | |||||
# celery = Celery( | |||||
# app.import_name, | |||||
# backend=app.config['CELERY_RESULT_BACKEND'], | |||||
# broker=app.config['CELERY_BROKER_URL'] | |||||
# ) | |||||
# celery.conf.update(app.config) | |||||
# # 自动发现任务 | |||||
# celery.autodiscover_tasks(['app.tasks']) | |||||
# return celery | |||||
# # 初始化 Flask | |||||
# app = Flask(__name__) | |||||
# app.config.update( | |||||
# CELERY_BROKER_URL='redis://:telpo%231234@192.168.2.121:8090/3', | |||||
# CELERY_RESULT_BACKEND='redis://:telpo%231234@192.168.2.121:8090/3' | |||||
# ) | |||||
# celery = make_celery(app) | |||||
from celery import Celery | |||||
# 配置 Celery | |||||
celery = Celery( | |||||
"worker", | |||||
broker="redis://:telpo%231234@192.168.2.121:8090/3", | |||||
backend="redis://:telpo%231234@192.168.2.121:8090/3", | |||||
include=['app.tasks'] | |||||
) | |||||
# 自动发现任务 | |||||
celery.autodiscover_tasks(['app.tasks']) |
@@ -0,0 +1,200 @@ | |||||
from fastapi import APIRouter,Request,FastAPI | |||||
from pydantic import BaseModel | |||||
from fastapi import APIRouter, Depends | |||||
from pydantic import BaseModel, ValidationError | |||||
from common.log import logger | |||||
from model.models import AgentConfig,validate_wxid | |||||
from services.gewe_service import GeWeService,get_gewe_service | |||||
from services.redis_service import RedisService | |||||
from model.models import AgentConfig,validate_wxid | |||||
from common.utils import * | |||||
import time,asyncio | |||||
agent_router = APIRouter(prefix="/api/agent") | |||||
class GetAgentLoginRequest(BaseModel): | |||||
tel: str | |||||
class GetWxQRCodeRequest(BaseModel): | |||||
tel: str | |||||
tokenId:str | |||||
regionId:str | |||||
agentTokenId:str | |||||
class LogincCaptchCode(BaseModel): | |||||
tokenId: str | |||||
captchCode:str | |||||
@agent_router.post("/getlogin", response_model=None) | |||||
async def get_login(request: Request, body: GetAgentLoginRequest, ): | |||||
tel = body.tel | |||||
return await request.app.state.gewe_service.get_login_info_from_cache_async(tel) | |||||
@agent_router.post("/getwxqrcode", response_model=None) | |||||
async def get_wx_qrcode(request: Request, body: GetWxQRCodeRequest, ): | |||||
tel = body.tel | |||||
token_id =body.tokenId | |||||
region_id= body.regionId | |||||
agent_token_id= body.agentTokenId | |||||
loginfo=await request.app.state.gewe_service.get_login_info_from_cache_async(tel) | |||||
status=loginfo.get('status','0') | |||||
if status=='1': | |||||
msg=f'手机号{tel},wx_token{token_id} 已经微信登录,终止登录流程' | |||||
logger.info(msg) | |||||
return {'code': 501, 'message': msg} | |||||
now=time.time() | |||||
expried_time=int(now)+150 | |||||
flag=await request.app.state.gewe_service.acquire_login_lock_async(token_id,150) | |||||
if not flag: | |||||
msg=f'手机号{tel}, wx_token{token_id} 登录进行中,稍后再试' | |||||
logger.info(msg) | |||||
return {'code': 501, 'message': msg} | |||||
app_id=loginfo.get('appId','') | |||||
qr_code = await request.app.state.gewe_service.get_login_qr_code_async(token_id, app_id,region_id) | |||||
if not qr_code: | |||||
msg=f"获取二维码失败,qr_code: {qr_code}" | |||||
await request.app.state.gewe_service.release_login_lock_async(token_id) | |||||
logger.info(msg) | |||||
return {'code': 501, 'message': msg} | |||||
uuid = qr_code.get('uuid',None) | |||||
if not uuid: | |||||
msg=f"uuid获取二维码失败,uuid: {uuid}" | |||||
await request.app.state.gewe_service.release_login_lock_async(token_id) | |||||
logger.info(msg) | |||||
return {'code': 501, 'message': msg} | |||||
app_id = app_id or qr_code.get('appId') | |||||
base64_string = qr_code.get('qrImgBase64',None) | |||||
await request.app.state.gewe_service.qrCallback(uuid,base64_string) | |||||
hash_key = f"__AI_OPS_WX__:LOGININFO:{tel}" | |||||
print(hash_key) | |||||
# thread = threading.Thread(target=waitting_login_result, args=(gewe_chat.wxchat,token_id, app_id,region_id, agent_token_id,hash_key, uuid,now)) | |||||
# thread.daemon = True | |||||
# thread.start() | |||||
loop = asyncio.get_event_loop() | |||||
future = asyncio.run_coroutine_threadsafe( | |||||
waitting_login_result(request,token_id, app_id,region_id, agent_token_id,hash_key, uuid,now), | |||||
loop | |||||
) | |||||
return { | |||||
"tokenId": token_id, | |||||
"tel": tel, | |||||
"base64Img": base64_string, | |||||
"expiredTime": expried_time, | |||||
} | |||||
async def waitting_login_result(request: Request, token_id, app_id,region_id, agent_token_id,hash_key, uuid,start_time): | |||||
agent_tel=hash_key.split(":")[-1] | |||||
try: | |||||
while True: | |||||
now = time.time() | |||||
if now - start_time > 150: | |||||
logger.info(f'{token_id} 使用 {app_id} 扫二维码登录超时') | |||||
break | |||||
logger.info(f"{token_id} 使用 {app_id},等待扫码登录,二维码有效时间 {150 - int(now - start_time)} 秒") | |||||
captch_code = await request.app.state.gewe_service.get_login_wx_captch_code_from_cache_async(token_id) | |||||
captch_code= captch_code if captch_code else '' | |||||
logger.info(f"{token_id} 使用 {app_id} 的验证码 {captch_code}") | |||||
ret,msg,res = await request.app.state.gewe_service.check_login_async(token_id, app_id, uuid,captch_code) | |||||
if ret == 200: | |||||
flag = res.get('status') | |||||
if flag == 2: | |||||
logger.info(f"登录成功: {res}") | |||||
head_img_url=res.get('headImgUrl','') | |||||
login_info = res.get('loginInfo', {}) | |||||
wxid=login_info.get('wxid',agent_tel) | |||||
cache_login_info=await request.app.state.gewe_service.get_login_info_from_cache_async(agent_tel) | |||||
cache_wxid=cache_login_info.get('wxid','') | |||||
if not cache_login_info and cache_wxid!=wxid and cache_wxid!='': | |||||
logger.warning(f"agent_tel {agent_tel} , wxid {wxid} 与 cache_wxid {cache_wxid} 不匹配,登录失败") | |||||
await request.app.state.gewe_service.logout_async(token_id,app_id) | |||||
# k_message=utils.login_result_message(token_id,agent_tel,region_id,agent_token_id,'') | |||||
# kafka_helper.kafka_client.produce_message(k_message) | |||||
break | |||||
login_info.update({'appId': app_id, 'uuid': uuid, 'tokenId': token_id,'status': 1,'headImgUrl':head_img_url,'regionId':region_id}) | |||||
cache_login_info=await request.app.state.redis_service.get_hash(hash_key) | |||||
if 'appId' not in cache_login_info: | |||||
login_info.update({"create_at":int(time.time()),"modify_at":int(time.time())}) | |||||
# 默认配置 | |||||
config=AgentConfig.model_validate({ | |||||
"chatroomIdWhiteList": [], | |||||
"agentTokenId": agent_token_id, | |||||
"agentEnabled": False, | |||||
"addContactsFromChatroomIdWhiteList": [], | |||||
"chatWaitingMsgEnabled": True | |||||
}) | |||||
else: | |||||
login_info.update({"modify_at":int(time.time())}) | |||||
# 已有配置 | |||||
config_cache=await request.app.state.gewe_service.get_wxchat_config_from_cache_async(wxid) | |||||
config=AgentConfig.model_validate(config_cache) | |||||
cleaned_login_info = {k: (v if v is not None else '') for k, v in login_info.items()} | |||||
#wxid=cleaned_login_info.get('wxid',agent_tel) | |||||
# 保存配置信息 | |||||
config_dict=config.model_dump() | |||||
await request.app.state.gewe_service.save_wxchat_config_async(wxid,config_dict) | |||||
# 保存登录信息 | |||||
await request.app.state.redis_service.set_hash(hash_key, cleaned_login_info) | |||||
await request.app.state.gewe_service.release_login_lock_async(token_id) | |||||
# 登录结果推送到kafka | |||||
k_message=login_result_message(token_id,agent_tel,region_id,agent_token_id,wxid) | |||||
await request.app.state.kafka_service.send_message_async(k_message) | |||||
# 同步联系人列表 | |||||
ret,msg,contacts_list=await request.app.state.gewe_service.fetch_contacts_list_async(token_id,app_id) | |||||
if ret!=200: | |||||
logger.warning(f"同步联系人列表失败: {ret}-{msg}") | |||||
break | |||||
friend_wxids = [c for c in contacts_list['friends'] if c not in ['fmessage', 'medianote','weixin','weixingongzhong']] # 可以调整截取范围 | |||||
data=await request.app.state.gewe_service.save_contacts_brief_to_cache_async(token_id, app_id, wxid, friend_wxids) | |||||
chatrooms=contacts_list['chatrooms'] | |||||
# 同步群列表 | |||||
logger.info(f'{wxid} 的群数量 {len(chatrooms)}') | |||||
logger.info(f'{wxid} 同步群列表') | |||||
await request.app.state.gewe_service.save_groups_info_to_cache_async(token_id, app_id, wxid, chatrooms) | |||||
logger.info(f'{wxid} 同步群成员') | |||||
# 同步群成员 | |||||
await request.app.state.gewe_service.save_groups_members_to_cache_async(token_id, app_id, wxid, chatrooms) | |||||
logger.info(f'{wxid} 好友信息推送到kafka') | |||||
# 联系人推送到kafka | |||||
k_message=wx_all_contacts(wxid,data) | |||||
await request.app.state.kafka_service.send_message_async(k_message) | |||||
break | |||||
else: | |||||
logger.info(f"登录检查中: {ret}-{msg}-{res}") | |||||
await asyncio.sleep(5) | |||||
finally: | |||||
await request.app.state.gewe_service.release_login_lock_async(token_id) | |||||
@agent_router.post("/logincaptchcode", response_model=None) | |||||
async def login_captch_code(request: Request, body: LogincCaptchCode, ): | |||||
token_id = body.tokenId | |||||
captch_code=body.captchCode | |||||
res=await request.app.state.gewe_service.save_login_wx_captch_code_to_cache_async(token_id,captch_code) | |||||
return {'message': '操作成功'} |
@@ -0,0 +1,53 @@ | |||||
from fastapi import APIRouter,Request | |||||
from pydantic import BaseModel | |||||
from fastapi import APIRouter, Depends | |||||
from pydantic import BaseModel, ValidationError | |||||
from services.gewe_service import GeWeService,get_gewe_service | |||||
from services.redis_service import RedisService | |||||
from model.models import AgentConfig,validate_wxid | |||||
config_router = APIRouter(prefix="/api/wxchat") | |||||
# 定义请求体的 Pydantic 模型 | |||||
class GetConfigRequest(BaseModel): | |||||
wxid: str | |||||
class SaveConfigRequest(BaseModel): | |||||
wxid: str | |||||
@config_router.post("/getconfig",response_model=None) | |||||
@validate_wxid | |||||
async def get_config(request: Request, body: GetConfigRequest): | |||||
wxid = body.wxid | |||||
# k,loginfo=await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
# if not k: | |||||
# return {"code":404,"message":f"{wxid} 没有对应的登录信息"} | |||||
config=await request.app.state.gewe_service.get_wxchat_config_from_cache_async(wxid) | |||||
return config | |||||
@config_router.post("/saveconfig",response_model=None) | |||||
@validate_wxid | |||||
async def save_config(request: Request, body: SaveConfigRequest): | |||||
wxid = body.get("wxid") | |||||
data = body.get("data") | |||||
# k,loginfo=await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
# if not k: | |||||
# return {"code":404,"message":f"{wxid} 没有对应的登录信息"} | |||||
try: | |||||
# 使用 Pydantic 严格校验数据类型和结构 | |||||
validated_config = AgentConfig.model_validate(data) | |||||
except ValidationError as e: | |||||
return {'code': 407, 'message': e.errors().__str__()} | |||||
await request.app.state.gewe_service.save_wxchat_config_async(wxid, data) | |||||
return {'wxid': wxid, 'config': data} | |||||
@@ -0,0 +1,72 @@ | |||||
from fastapi import APIRouter,Request | |||||
from pydantic import BaseModel | |||||
from fastapi import APIRouter, Depends | |||||
from pydantic import BaseModel, ValidationError | |||||
from model.models import AgentConfig,validate_wxid | |||||
contacts_router = APIRouter(prefix="/api/contacts") | |||||
class GetContactsRequest(BaseModel): | |||||
wxid: str | |||||
cache: bool = True # 默认为 True | |||||
class DeleteContactsRequest(BaseModel): | |||||
wxid: str | |||||
friendWxid:str | |||||
@contacts_router.post("/getfriends",response_model=None) | |||||
async def get_friends(request: Request, body: GetContactsRequest): | |||||
wxid = body.wxid | |||||
cache=body.cache | |||||
k,loginfo=await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
print(k,loginfo) | |||||
if not k: | |||||
return {"code":404,"message":f"{wxid} 没有对应的登录信息"} | |||||
if cache: | |||||
return await request.app.state.gewe_service.get_contacts_brief_from_cache_async(wxid) | |||||
else: | |||||
token_id=loginfo.get('tokenId','') | |||||
app_id=loginfo.get('appId','') | |||||
ret,msg,contacts_list=await request.app.state.gewe_service.fetch_contacts_list_async(token_id,app_id) | |||||
if ret!=200: | |||||
return {'code':ret,'message':msg} | |||||
friend_wxids = [c for c in contacts_list['friends'] if c not in ['fmessage', 'medianote','weixin','weixingongzhong']] # 可以调整截取范围 | |||||
data=await request.app.state.gewe_service.save_contacts_brief_to_cache_async(token_id, app_id, wxid, friend_wxids) | |||||
print(f'{wxid}获取实时好友信息') | |||||
return data | |||||
@contacts_router.post("/deletefriend",response_model=None) | |||||
async def delete_friends(request: Request, body: DeleteContactsRequest): | |||||
wxid = body.wxid | |||||
friend_wxid = body.friendWxid | |||||
k,loginfo=await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
if not k: | |||||
return {"code":404,"message":f"{wxid} 没有用户信息"} | |||||
token_id=loginfo.get('tokenId','') | |||||
app_id=loginfo.get('appId','') | |||||
ret, msg, data = await request.app.state.gewe_service.wxchat.delete_friend_async(token_id, app_id, friend_wxid) | |||||
if ret !=200: | |||||
return { | |||||
'code': ret, | |||||
'message': "删除好友失败" | |||||
} | |||||
ret, msg, contacts_list = await request.app.state.gewe_service.fetch_contacts_list_async(token_id, app_id) | |||||
if ret !=200: | |||||
return { | |||||
'code': ret, | |||||
'message': '获取联系人列表失败' | |||||
} | |||||
friend_wxids = [c for c in contacts_list['friends'] if c not in ['fmessage', 'medianote','weixin','weixingongzhong']] # 可以调整截取范围 | |||||
data=await request.app.state.gewe_service.save_contacts_brief_to_cache_async(token_id, app_id, wxid, friend_wxids) | |||||
return data |
@@ -0,0 +1,48 @@ | |||||
from fastapi import APIRouter,Request | |||||
from pydantic import BaseModel | |||||
from fastapi import APIRouter, Depends | |||||
from pydantic import BaseModel, ValidationError | |||||
from model.models import AgentConfig,validate_wxid,auth_required_time | |||||
import threading | |||||
import asyncio | |||||
groups_router = APIRouter(prefix="/api/groups") | |||||
class GetChatroomInfoRequest(BaseModel): | |||||
wxid: str | |||||
class GetChatroomMenberListRequest(BaseModel): | |||||
wxid: str | |||||
chatroomId:str | |||||
@groups_router.post("/getchatroominfo",response_model=None) | |||||
async def get_chatroominfo(request: Request, body: GetChatroomInfoRequest): | |||||
wxid = body.wxid | |||||
groups= await request.app.state.gewe_service.get_groups_info_from_cache_async(wxid) | |||||
return groups | |||||
@groups_router.post("/getmenberlist",response_model=None) | |||||
async def get_menberlist(request: Request, body: GetChatroomMenberListRequest): | |||||
wxid = body.wxid | |||||
chatroom_id=body.chatroomId | |||||
_,loginfo=await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
token_id=loginfo.get('tokenId') | |||||
app_id=loginfo.get('appId') | |||||
ret, msg, data = await request.app.state.gewe_service.get_group_memberlist_async(token_id, app_id, chatroom_id) | |||||
if ret != 200: | |||||
return { | |||||
'code': ret, | |||||
'message': msg | |||||
} | |||||
loop = asyncio.get_event_loop() | |||||
future = asyncio.run_coroutine_threadsafe( | |||||
request.app.state.gewe_service.save_groups_members_to_cache_async(token_id, app_id, wxid, [chatroom_id]), | |||||
loop | |||||
) | |||||
# Optionally, you can wait for the future to complete if needed | |||||
#future.result() | |||||
return data |
@@ -0,0 +1,905 @@ | |||||
from voice.ali.ali_voice import AliVoice | |||||
from common.log import logger | |||||
import xml.etree.ElementTree as ET | |||||
import os,json,asyncio,aiohttp | |||||
from voice import audio_convert | |||||
from fastapi import APIRouter,Request | |||||
from pydantic import BaseModel | |||||
from fastapi import APIRouter, Depends | |||||
from typing import Dict, Any | |||||
from model.models import AgentConfig,OperationType | |||||
from common.utils import * | |||||
from common.memory import * | |||||
timeout_duration = 2.0 | |||||
messages_router = APIRouter() | |||||
WX_BACKLIST=['fmessage', 'medianote','weixin','weixingongzhong','tmessage'] | |||||
@messages_router.post("/messages",response_model=None) | |||||
async def get_chatroominfo(request: Request, body: Dict[str, Any]): | |||||
msg = body | |||||
logger.info(f"收到微信回调消息: {json.dumps(msg, separators=(',', ':'),ensure_ascii=False)}") | |||||
type_name =msg.get("TypeName") | |||||
app_id = msg.get("Appid") | |||||
k, loginfo = await request.app.state.gewe_service.get_login_info_by_app_id_async(app_id) | |||||
if not k: | |||||
logger.warning('找不到登录信息,不处理') | |||||
return {"message": "收到微信回调消息"} | |||||
token_id=loginfo.get('tokenId','') | |||||
wxid = msg.get("Wxid",'') | |||||
if type_name == 'AddMsg': | |||||
await handle_self_cmd_async(request,wxid,msg) | |||||
msg_data = msg.get("Data") | |||||
from_wxid = msg_data["FromUserName"]["string"] | |||||
config=await request.app.state.redis_service.get_hash(f"__AI_OPS_WX__:WXCHAT_CONFIG") | |||||
wxids=config.keys() | |||||
WX_BACKLIST.extend(wxids) | |||||
if from_wxid in WX_BACKLIST: | |||||
logger.warning(f'微信ID {wxid} 在黑名单,不处理') | |||||
return {"message": "收到微信回调消息"} | |||||
await handle_messages_async(request,token_id,msg) | |||||
return {"message": "收到微信回调消息"} | |||||
async def handle_self_cmd_async(request: Request,wxid,msg): | |||||
''' | |||||
个人微信命令处理 | |||||
如果是个人微信的指令,wxid == from_wxid | |||||
commands = { | |||||
'启用托管': True, | |||||
'关闭托管': False | |||||
} | |||||
''' | |||||
msg_data=msg.get("Data") | |||||
from_wxid=msg_data["FromUserName"]["string"] | |||||
msg_content=msg_data["Content"]["string"] | |||||
if wxid == from_wxid: | |||||
commands = { | |||||
'启用托管': True, | |||||
'关闭托管': False | |||||
} | |||||
if msg_content in commands: | |||||
c_dict = await request.app.state.gewe_service.get_wxchat_config_from_cache_async(wxid) | |||||
if c_dict: | |||||
agent_config=AgentConfig.model_validate(c_dict) | |||||
agent_config.agentEnabled=commands[msg_content] | |||||
await request.app.state.gewe_service.wxchat.save_wxchat_config_async(wxid, agent_config.model_dump()) | |||||
logger.info(f'{wxid} {"启动" if commands[msg_content] else "关闭"}托管') | |||||
async def handle_messages_async(request: Request,token_id,msg): | |||||
msg_data=msg.get("Data") | |||||
type_name =msg.get("TypeName") | |||||
app_id = msg.get("Appid") | |||||
from_wxid=msg_data["FromUserName"]["string"] | |||||
msg_content=msg_data["Content"]["string"] | |||||
wxid = msg.get("Wxid",'') | |||||
match type_name: | |||||
case 'AddMsg': | |||||
await handle_add_messages_async(request,token_id,msg,wxid) | |||||
case 'ModContacts': | |||||
await handle_mod_contacts_async(request,token_id,msg,wxid) | |||||
case 'DelContacts': | |||||
await handle_del_contacts_async(request,token_id,msg,wxid) | |||||
case 'Offline': | |||||
await handle_offline_async(request,token_id,msg,wxid) | |||||
case _: | |||||
logger.warning(f'未知消息类型:{type_name}') | |||||
async def gpt_client_async(request,messages: list, wixd: str, friend_wxid: str): | |||||
c = await request.app.state.gewe_service.get_wxchat_config_from_cache_async(wixd) | |||||
api_key = c.get('agentTokenId', "sk-jr69ONIehfGKe9JFphuNk4DU5Y5wooHKHhQv7oSnFzVbwCnW65fXO9kvH") | |||||
print(f'流程key:{api_key}\n') | |||||
#api_key="sk-jr69ONIehfGKe9JFphuNk4DU5Y5wooHKHhQv7oSnFzVbwCnW65fXO9kvH" #测试 | |||||
#api_key="sk-uJDBdKmJVb2cmfldGOvlIY6Qx0AzqWMPD3lS1IzgQYzHNOXv9SKNI" #开发2 | |||||
api_url = "http://106.15.182.218:3000/api/v1/chat/completions" | |||||
headers = { | |||||
"Content-Type": "application/json", | |||||
"Authorization": f"Bearer {api_key}" | |||||
} | |||||
session_id = f'{wixd}-{friend_wxid}' | |||||
data = { | |||||
"model": "", | |||||
"messages": messages, | |||||
"chatId": session_id, | |||||
"detail": True | |||||
} | |||||
logger.info("[CHATGPT] 请求={}".format(json.dumps(data, ensure_ascii=False))) | |||||
async with aiohttp.ClientSession() as session: | |||||
try: | |||||
async with session.post(url=api_url, headers=headers, data=json.dumps(data), timeout=600) as response: | |||||
response.raise_for_status() | |||||
response_data = await response.json() | |||||
logger.info("[CHATGPT] 响应={}".format(json.dumps(response_data, separators=(',', ':'), ensure_ascii=False))) | |||||
return response_data | |||||
except aiohttp.ClientError as e: | |||||
logger.error(f"请求失败: {e}") | |||||
raise | |||||
async def handle_add_messages_async(request: Request,token_id,msg,wxid): | |||||
wx_config =await request.app.state.gewe_service.get_wxchat_config_from_cache_async(wxid) | |||||
if not bool(wx_config.get("agentEnabled",False)): | |||||
logger.info(f'微信ID {wxid} 未托管,不处理') | |||||
return | |||||
app_id = msg.get("Appid") | |||||
msg_data = msg.get("Data") | |||||
msg_type = msg_data.get("MsgType",None) | |||||
from_wxid = msg_data["FromUserName"]["string"] | |||||
to_wxid = msg_data["ToUserName"]["string"] | |||||
msg_push_content=msg_data.get("PushContent") | |||||
handlers = { | |||||
1: handle_text_async, | |||||
3: handle_image_async, | |||||
34: handle_voice_async, | |||||
42: handle_name_card_async, | |||||
49: handle_xml_async, | |||||
37: handle_add_friend_notice_async, | |||||
10002: handle_10002_msg | |||||
} | |||||
# (扫码进群情况)判断受否是群聊,并添加到通信录 | |||||
if check_chatroom(from_wxid) or check_chatroom(to_wxid): | |||||
logger.info('群信息') | |||||
chatroom_id=from_wxid | |||||
ret,msg,data=await request.app.state.gewe_service.save_contract_list_async(token_id,app_id,chatroom_id,3) | |||||
logger.info(f'保存到通讯录 chatroom_id {chatroom_id} {msg}') | |||||
await request.app.state.gewe_service.update_group_info_to_cache_async(token_id,app_id,wxid,chatroom_id) | |||||
await request.app.state.gewe_service.update_group_members_to_cache_async(token_id,app_id,wxid,chatroom_id) | |||||
handlers[1]=handle_text_group_async | |||||
handlers[3]=handle_image_group_async | |||||
handlers[34]=handle_voice_group_async | |||||
handler = handlers.get(msg_type) | |||||
if handler: | |||||
return await handler(request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid) | |||||
else: | |||||
logger.warning(f"微信回调消息类型 {msg_type} 未处理") | |||||
async def handle_mod_contacts_async(request: Request,token_id,msg,wxid): | |||||
''' | |||||
好友通过验证及好友资料变更的通知消息 | |||||
''' | |||||
wxid = msg.get("Wxid") | |||||
msg_data = msg.get("Data") | |||||
app_id = msg.get("Appid") | |||||
# | |||||
#handle_mod_contacts(token_id,app_id,wxid,msg_data) | |||||
# | |||||
loop = asyncio.get_event_loop() | |||||
future = asyncio.run_coroutine_threadsafe( | |||||
handle_mod_contacts_worker_async(request,token_id,app_id,wxid,msg_data), | |||||
loop | |||||
) | |||||
contact_wxid = msg_data["UserName"]["string"] | |||||
nickname=msg_data["NickName"]["string"] | |||||
city=msg_data.get("City","") | |||||
signature=msg_data.get("Signature","") | |||||
province=msg_data.get("Province","") | |||||
bigHeadImgUrl=msg_data["SnsUserInfo"]["SnsBgimgId"] | |||||
country=msg_data.get("Country","") | |||||
sex=msg_data.get("Sex",None) | |||||
pyInitial=msg_data["PyInitial"]["string"] | |||||
quanPin=msg_data["QuanPin"]["string"] | |||||
remark=msg_data.get("Remark").get("string","") | |||||
remarkPyInitial=msg_data.get("RemarkPyInitial").get("string","") | |||||
remarkQuanPin=msg_data.get("RemarkQuanPin").get("string","") | |||||
smallHeadImgUrl=msg_data.get("smallHeadImgUrl","") | |||||
# data=gewe_chat.wxchat.get_brief_info(token_id,app_id,[contact_wxid]) | |||||
# contact=data[0] | |||||
# alias=contact.get("alias") | |||||
#推动到kafka | |||||
contact_data = { | |||||
"alias": None, | |||||
"bigHeadImgUrl": bigHeadImgUrl, | |||||
"cardImgUrl": None, | |||||
"city": city, | |||||
"country": country, | |||||
"description": None, | |||||
"labelList": None, | |||||
"nickName": nickname, | |||||
"phoneNumList": None, | |||||
"province": province, | |||||
"pyInitial": pyInitial, | |||||
"quanPin": quanPin, | |||||
"remark": remark, | |||||
"remarkPyInitial": remarkPyInitial, | |||||
"remarkQuanPin": remarkQuanPin, | |||||
"sex": sex, | |||||
"signature": signature, | |||||
"smallHeadImgUrl": smallHeadImgUrl, | |||||
"snsBgImg": None, | |||||
"userName": contact_wxid | |||||
} | |||||
input_message=wx_mod_contact_message(wxid,contact_data) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
async def handle_del_contacts_async(request: Request,token_id,msg,wxid): | |||||
''' | |||||
删除好友通知/退出群聊 | |||||
''' | |||||
msg_data = msg.get("Data") | |||||
username=msg_data["UserName"]["string"] | |||||
if check_chatroom(username): | |||||
logger.info('退出群聊') | |||||
wxid = msg.get("Wxid") | |||||
chatroom_id=username | |||||
await request.app.state.redis_service.delete_hash_field(f'__AI_OPS_WX__:GROUPS_INFO:{wxid}',chatroom_id) | |||||
logger.info(f'清除 chatroom_id{chatroom_id} 数据') | |||||
else: | |||||
logger.info('删除好友通知') | |||||
# 推送到kafka | |||||
input_message=wx_del_contact_message(wxid,username) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
async def handle_offline_async(request: Request,token_id,msg,wxid): | |||||
''' | |||||
已经离线 | |||||
''' | |||||
wxid = msg.get("Wxid") | |||||
app_id = msg.get("Appid") | |||||
logger.warning(f'微信ID {wxid}在设备{app_id}已经离线') | |||||
k,r=await request.app.state.gewe_service.get_login_info_by_app_id_async(app_id) | |||||
print(k) | |||||
await request.app.state.redis_service.update_hash_field(k,'status',0) | |||||
await request.app.state.redis_service.update_hash_field(k,'modify_at',int(time.time())) | |||||
# 推送到kafka | |||||
input_message=wx_offline_message(app_id,wxid) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
async def handle_mod_contacts_worker_async(request:Request,token_id,app_id,wxid,msg_data): | |||||
''' | |||||
好友通过验证及好友资料变更的通知消息 | |||||
''' | |||||
logger.info('好友通过验证及好友资料变更的通知消息') | |||||
if not check_chatroom(msg_data["UserName"]["string"]): | |||||
contact_wxid = msg_data["UserName"]["string"] | |||||
# 更新好友信息 | |||||
# 检查好友关系,不是好友则删除 | |||||
# ret,msg,check_relation=gewe_chat.wxchat.check_relation(token_id, app_id,[contact_wxid]) | |||||
# first_item = check_relation[0] | |||||
# check_relation_status=first_item.get('relation') | |||||
# logger.info(f'{wxid} 好友 {contact_wxid} 关系检查:{check_relation_status}') | |||||
# if check_relation_status != 0: | |||||
# gewe_chat.wxchat.delete_contacts_brief_from_cache(wxid, [contact_wxid]) | |||||
# logger.info(f'好友关系异常:{check_relation_status},删除好友 {contact_wxid} 信息') | |||||
# else: | |||||
# gewe_chat.wxchat.save_contacts_brief_to_cache(token_id, app_id, wxid, [contact_wxid]) | |||||
ret,msg,contacts_list = await request.app.state.gewe_service.fetch_contacts_list_async(token_id, app_id) | |||||
# friend_wxids = contacts_list['friends'][3:] # 可以调整截取范围 | |||||
# print(friend_wxids) | |||||
#friend_wxids.remove('fmessage') | |||||
#friend_wxids.remove('weixin') | |||||
friend_wxids = [c for c in contacts_list['friends'] if c not in ['fmessage', 'medianote','weixin','weixingongzhong','tmessage']] # 可以调整截取范围 | |||||
print(f'{wxid}的好友数量 {len(friend_wxids)}') | |||||
await request.app.state.gewe_service.save_contacts_brief_to_cache_async(token_id, app_id, wxid, friend_wxids) | |||||
else: | |||||
logger.info('群聊好友通过验证及好友资料变更的通知消息') | |||||
async def handle_text_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
''' | |||||
私聊文本消息 | |||||
''' | |||||
msg_content=msg_data["Content"]["string"] | |||||
if wxid == from_wxid: #手动发送消息 | |||||
logger.info("Active message sending detected") | |||||
await request.app.state.gewe_service.save_contacts_brief_to_cache_async(token_id,app_id,wxid,[to_wxid]) | |||||
callback_to_user=msg_data["ToUserName"]["string"] | |||||
input_wx_content_dialogue_message=[{"type": "text", "text": msg_content}] | |||||
input_message=dialogue_message(from_wxid,to_wxid,input_wx_content_dialogue_message) | |||||
await request.app.state.kafaka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
else: | |||||
callback_to_user=msg_data["FromUserName"]["string"] | |||||
# 创建并启动任务协程,将参数传递给 ai_chat_text 函数 | |||||
task = asyncio.create_task( | |||||
ai_chat_text_async(request,token_id, app_id, wxid, msg_data, msg_content) | |||||
) | |||||
# 设置定时器,1秒后检查任务是否超时。这里需要使用 lambda 来传递参数 | |||||
timeout_timer = asyncio.create_task( | |||||
check_timeout_async(task, request,token_id, wxid, app_id, callback_to_user) | |||||
) | |||||
# 等待任务协程完成 | |||||
await task | |||||
# 取消定时器 | |||||
timeout_timer.cancel() | |||||
async def check_timeout_async(task: asyncio.Task, request: Request,token_id, wxid, app_id, callback_to_user): | |||||
await asyncio.sleep(timeout_duration) # 等待超时时间 | |||||
if not task.done(): | |||||
print(f"任务运行时间超过{timeout_duration}秒,token_id={token_id}, app_id={app_id}, callback_to_user={callback_to_user}") | |||||
wx_config = await request.app.state.gewe_service.get_wxchat_config_from_cache_async(wxid) | |||||
if bool(wx_config.get("chatWaitingMsgEnabled", True)): | |||||
await request.app.state.gewe_service.post_text_async(token_id, app_id, callback_to_user, "亲,我正在组织回复的信息,请稍等一会") | |||||
async def ai_chat_text_async(request: Request,token_id, app_id, wxid, msg_data, msg_content): | |||||
start_time = time.time() # 记录任务开始时间 | |||||
callback_to_user = msg_data["FromUserName"]["string"] | |||||
hash_key = f'__AI_OPS_WX__:MESSAGES:{wxid}:{callback_to_user}' | |||||
prompt = {"role": "user", "content": [{ | |||||
"type": "text", | |||||
"text": msg_content | |||||
}]} | |||||
messages_to_send = await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, prompt) | |||||
# 收到的对话 | |||||
input_wx_content_dialogue_message = [{"type": "text", "text": msg_content}] | |||||
input_message = dialogue_message(callback_to_user, wxid, input_wx_content_dialogue_message) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s", input_message) | |||||
cache_data = USER_INTERACTIVE_CACHE.get(wxid) | |||||
if cache_data and cache_data.get('interactive'): | |||||
o = get_first_char_if_digit(msg_content) | |||||
if o is not None: | |||||
userSelectOptions = cache_data.get('options') | |||||
if o < len(userSelectOptions): | |||||
o = o - 1 | |||||
msg_content = userSelectOptions[o].get("value") | |||||
messages_to_send = [{"role": "user", "content": msg_content}] | |||||
else: | |||||
messages_to_send = [{"role": "user", "content": msg_content}] | |||||
else: | |||||
messages_to_send = [{"role": "user", "content": msg_content}] | |||||
res = await gpt_client_async(request,messages_to_send, wxid, callback_to_user) | |||||
reply_content = remove_markdown_symbol(res["choices"][0]["message"]["content"]) | |||||
description = '' | |||||
userSelectOptions = [] | |||||
if isinstance(reply_content, list) and any(item.get("type") == "interactive" for item in reply_content): | |||||
for item in reply_content: | |||||
if item["type"] == "interactive" and item["interactive"]["type"] == "userSelect": | |||||
params = item["interactive"]["params"] | |||||
description = params.get("description") | |||||
userSelectOptions = params.get("userSelectOptions", []) | |||||
values_string = "\n".join(option["value"] for option in userSelectOptions) | |||||
if description is not None: | |||||
USER_INTERACTIVE_CACHE[wxid] = { | |||||
"interactive": True, | |||||
"options": userSelectOptions, | |||||
} | |||||
reply_content = description + '------------------------------\n' + values_string | |||||
elif isinstance(reply_content, list) and any(item.get("type") == "text" for item in reply_content): | |||||
USER_INTERACTIVE_CACHE[wxid] = { | |||||
"interactive": False | |||||
} | |||||
text = '' | |||||
for item in reply_content: | |||||
if item["type"] == "text": | |||||
text = item["text"]["content"] | |||||
if text == '': | |||||
# 去除上次上一轮对话再次请求 | |||||
cache_messages_str = await request.app.state.redis_service.get_hash_field(hash_key, "data") | |||||
cache_messages = json.loads(cache_messages_str) if cache_messages_str else [] | |||||
if len(cache_messages) >= 3: | |||||
cache_messages = cache_messages[:-3] | |||||
await request.app.state.redis_service.update_hash_field(hash_key, "data", json.dumps(cache_messages, ensure_ascii=False)) | |||||
messages_to_send = await request.app.state.redis_service.save_session_messages_to_cache_async(hash_key, prompt) | |||||
res = await gpt_client_async(request,messages_to_send, wxid, callback_to_user) | |||||
reply_content = remove_markdown_symbol(res["choices"][0]["message"]["content"]) | |||||
if isinstance(reply_content, list): | |||||
reply_content = remove_markdown_symbol(reply_content[0].get('text').get("content")) | |||||
else: | |||||
reply_content = text | |||||
else: | |||||
USER_INTERACTIVE_CACHE[wxid] = { | |||||
"interactive": False | |||||
} | |||||
reply_content = remove_markdown_symbol(res["choices"][0]["message"]["content"]) | |||||
await request.app.state.gewe_service.post_text_async(token_id, app_id, callback_to_user, reply_content) | |||||
await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, {"role": "assistant", "content": reply_content}) | |||||
# 回复的对话 | |||||
input_wx_content_dialogue_message = [{"type": "text", "text": reply_content}] | |||||
input_message = dialogue_message(wxid, callback_to_user, input_wx_content_dialogue_message, True) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s", input_message) | |||||
end_time = time.time() # 记录任务结束时间 | |||||
execution_time = end_time - start_time # 计算执行时间 | |||||
logger.info(f"AI回答任务完成,耗时 {execution_time:.2f} 秒") | |||||
async def handle_text_group_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
''' | |||||
群聊文本消息 | |||||
''' | |||||
msg_content=msg_data["Content"]["string"] | |||||
msg_push_content=msg_data.get("PushContent") | |||||
k,login_info=await request.app.state.gewe_service.get_login_info_by_app_id_async(app_id) | |||||
nickname=login_info.get("nickName") | |||||
if wxid == from_wxid: #手动发送消息 | |||||
logger.info("Active message sending detected") | |||||
await request.app.state.gewe_service.save_contacts_brief_to_cache_async(token_id,app_id,wxid,[to_wxid]) | |||||
callback_to_user=msg_data["ToUserName"]["string"] | |||||
input_wx_content_dialogue_message=[{"type": "text", "text": msg_content}] | |||||
input_message=dialogue_message(from_wxid,to_wxid,input_wx_content_dialogue_message) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
else: | |||||
c = await request.app.state.gewe_service.get_wxchat_config_from_cache_async(wxid) | |||||
chatroom_id_white_list = c.get("chatroomIdWhiteList", []) | |||||
if not chatroom_id_white_list: | |||||
logger.info('白名单为空或未定义,不处理') | |||||
return | |||||
if from_wxid not in chatroom_id_white_list: | |||||
logger.info(f'群ID {from_wxid} 不在白名单中,不处理') | |||||
return | |||||
if '在群聊中@了你' in msg_push_content or '@'+nickname in msg_push_content: | |||||
callback_to_user=msg_data["FromUserName"]["string"] | |||||
hash_key = f'__AI_OPS_WX__:MESSAGES:{wxid}:{callback_to_user}' | |||||
prompt={"role": "user", "content": [{ | |||||
"type": "text", | |||||
"text": msg_content | |||||
}]} | |||||
messages_to_send=await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, prompt) | |||||
# 收到的对话 | |||||
input_wx_content_dialogue_message=[{"type": "text", "text": msg_content}] | |||||
input_message=dialogue_message(callback_to_user,wxid,input_wx_content_dialogue_message) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
cache_data = USER_INTERACTIVE_CACHE.get(wxid) | |||||
if cache_data and cache_data.get('interactive') : | |||||
o=get_first_char_if_digit(msg_content) | |||||
if o is not None: | |||||
userSelectOptions=cache_data.get('options') | |||||
if o < len(userSelectOptions): | |||||
o=o-1 | |||||
msg_content=userSelectOptions[o].get("value") | |||||
messages_to_send=[{"role": "user", "content": msg_content}] | |||||
else: | |||||
messages_to_send=[{"role": "user", "content": msg_content}] | |||||
else: | |||||
messages_to_send=[{"role": "user", "content": msg_content}] | |||||
res=await gpt_client_async(request,messages_to_send,wxid,callback_to_user) | |||||
reply_content=remove_markdown_symbol(res["choices"][0]["message"]["content"]) | |||||
description = '' | |||||
userSelectOptions = [] | |||||
if isinstance(reply_content, list) and any(item.get("type") == "interactive" for item in reply_content): | |||||
for item in reply_content: | |||||
if item["type"] == "interactive" and item["interactive"]["type"] == "userSelect": | |||||
params = item["interactive"]["params"] | |||||
description = params.get("description") | |||||
userSelectOptions = params.get("userSelectOptions", []) | |||||
values_string = "\n".join(option["value"] for option in userSelectOptions) | |||||
if description is not None: | |||||
USER_INTERACTIVE_CACHE[wxid] = { | |||||
"interactive":True, | |||||
"options": userSelectOptions, | |||||
} | |||||
reply_content=description + '------------------------------\n'+values_string | |||||
elif isinstance(reply_content, list) and any(item.get("type") == "text" for item in reply_content): | |||||
USER_INTERACTIVE_CACHE[wxid] = { | |||||
"interactive":False | |||||
} | |||||
text='' | |||||
for item in reply_content: | |||||
if item["type"] == "text": | |||||
text=item["text"]["content"] | |||||
if text=='': | |||||
# 去除上次上一轮对话再次请求 | |||||
cache_messages_str=await request.app.state.redis_service.get_hash_field(hash_key,"data") | |||||
cache_messages = json.loads(cache_messages_str) if cache_messages_str else [] | |||||
if len(cache_messages) >= 3: | |||||
cache_messages = cache_messages[:-3] | |||||
await request.app.state.redis_service.update_hash_field(hash_key,"data",json.dumps(cache_messages,ensure_ascii=False)) | |||||
messages_to_send=await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, prompt) | |||||
res=await gpt_client_async(request,messages_to_send,wxid,callback_to_user) | |||||
reply_content=remove_markdown_symbol(res["choices"][0]["message"]["content"]) | |||||
else: | |||||
reply_content=text | |||||
else: | |||||
USER_INTERACTIVE_CACHE[wxid] = { | |||||
"interactive":False | |||||
} | |||||
reply_content=res["choices"][0]["message"]["content"] | |||||
reply_content='@'+extract_nickname(msg_push_content) + reply_content | |||||
await request.app.state.gewe_service.post_text_async(token_id,app_id,callback_to_user,reply_content) | |||||
await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, {"role": "assistant", "content": reply_content}) | |||||
# 回复的对话 | |||||
input_wx_content_dialogue_message=[{"type": "text", "text": reply_content}] | |||||
input_message=dialogue_message(wxid,callback_to_user,input_wx_content_dialogue_message,True) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
else: | |||||
logger.info('群聊公开消息') | |||||
callback_to_user=msg_data["FromUserName"]["string"] | |||||
group_dialogue_message=[{"type": "text", "text": msg_content}] | |||||
input_message=dialogue_message(callback_to_user,wxid,group_dialogue_message) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
return | |||||
async def handle_image_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
''' | |||||
私聊图片消息 | |||||
''' | |||||
msg_content=msg_data["Content"]["string"] | |||||
callback_to_user=from_wxid | |||||
hash_key = f'__AI_OPS_WX__:MESSAGES:{wxid}:{callback_to_user}' | |||||
wx_img_url= await request.app.state.gewe_service.download_image_msg_async(token_id,app_id,msg_content) | |||||
oss_access_key_id="LTAI5tRTG6pLhTpKACJYoPR5" | |||||
oss_access_key_secret="E7dMzeeMxq4VQvLg7Tq7uKf3XWpYfN" | |||||
oss_endpoint="http://oss-cn-shanghai.aliyuncs.com" | |||||
oss_bucket_name="cow-agent" | |||||
oss_prefix="cow" | |||||
img_url=upload_oss(oss_access_key_id, oss_access_key_secret, oss_endpoint, oss_bucket_name, wx_img_url, oss_prefix) | |||||
prompt={ | |||||
"role": "user", | |||||
"content": [{ | |||||
"type": "image_url", | |||||
"image_url": {"url": img_url} | |||||
}] | |||||
} | |||||
await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, prompt) | |||||
await request.app.state.gewe_service.post_text_async(token_id,app_id,callback_to_user,'已经上传了图片,有什么可以为您服务') | |||||
logger.info(f"上传图片 URL: {img_url}") | |||||
wx_content_dialogue_message=[{"type": "image_url", "image_url": {"url": img_url}}] | |||||
input_message=dialogue_message(wxid,callback_to_user,wx_content_dialogue_message) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
async def handle_image_group_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
logger.info('群聊图片消息') | |||||
async def handle_voice_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
''' | |||||
私聊语音消息 | |||||
''' | |||||
callback_to_user=from_wxid | |||||
msg_content=msg_data["Content"]["string"] | |||||
msg_id=msg_data["MsgId"] | |||||
file_url=await request.app.state.gewe_service.download_audio_msg_async(token_id,app_id,msg_id,msg_content) | |||||
react_silk_path=await save_to_local_from_url_async(file_url) | |||||
react_wav_path = os.path.splitext(react_silk_path)[0] + ".wav" | |||||
audio_convert.any_to_wav(react_silk_path,react_wav_path) | |||||
react_voice_text=AliVoice().voiceToText(react_wav_path) | |||||
os.remove(react_silk_path) | |||||
os.remove(react_wav_path) | |||||
hash_key = f'__AI_OPS_WX__:MESSAGES:{wxid}:{callback_to_user}' | |||||
messages=await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, {"role": "user", "content": react_voice_text}) | |||||
ai_res=await gpt_client_async(request,messages,wxid,callback_to_user) | |||||
ai_res_content=remove_markdown_symbol(ai_res["choices"][0]["message"]["content"]) | |||||
has_url=contains_url(ai_res_content) | |||||
if not has_url: | |||||
voice_during,voice_url=wx_voice(ai_res_content) | |||||
if voice_during < 60 * 1000: | |||||
ret,ret_msg,res=await request.app.state.gewe_service.post_voice_async(token_id,app_id,callback_to_user,voice_url,voice_during) | |||||
else: | |||||
ret,ret_msg,res=await request.app.state.gewe_service.post_text_async(token_id,app_id,callback_to_user,ai_res_content) | |||||
logger.warning(f'回应语音消息长度 {voice_during/1000}秒,超过60秒,转为文本回复') | |||||
if ret==200: | |||||
logger.info((f'{wxid} 向 {callback_to_user} 发送语音文本【{ai_res_content}】{ret_msg}')) | |||||
else: | |||||
logger.warning((f'{wxid} 向 {callback_to_user} 发送语音文本【{ai_res_content}】{ret_msg}')) | |||||
ret,ret_msg,res==await request.app.state.gewe_service.post_text_async(token_id,app_id,callback_to_user,ai_res_content) | |||||
logger.info((f'{wxid} 向 {callback_to_user} 发送文本【{ai_res_content}】{ret_msg}')) | |||||
else: | |||||
logger.info(f"回复内容包含网址,不发送语音,回复文字内容:{ai_res_content}") | |||||
ret,ret_msg,res=await request.app.state.gewe_service.post_text_async(token_id,app_id,callback_to_user,ai_res_content) | |||||
await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, {"role": "assistant", "content": ai_res_content}) | |||||
# 构造对话消息并发送到 Kafka | |||||
input_wx_content_dialogue_message = [{"type": "text", "text": ai_res_content}] | |||||
input_message = dialogue_message(wxid, callback_to_user, input_wx_content_dialogue_message,True) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s", input_message) | |||||
async def handle_voice_group_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
logger.info('群聊语音消息') | |||||
async def handle_name_card_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
logger.info('名片消息') | |||||
try: | |||||
msg_content_xml=msg_data["Content"]["string"] | |||||
# 解析XML字符串 | |||||
root = ET.fromstring(msg_content_xml) | |||||
# 提取alias属性 | |||||
alias_value = root.get("alias") | |||||
# 加好友资料 | |||||
scene = int(root.get("scene")) | |||||
v3 = root.get("username") | |||||
v4 = root.get("antispamticket") | |||||
logger.info(f"alias_value: {alias_value}, scene: {scene}, v3: {v3}, v4: {v4}") | |||||
# 判断appid 是否已经创建3天 | |||||
k, login_info = await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
creation_timestamp=int(login_info.get('create_at',time.time())) | |||||
current_timestamp = time.time() | |||||
three_days_seconds = 3 * 24 * 60 * 60 # 三天的秒数 | |||||
diff_flag=(current_timestamp - creation_timestamp) >= three_days_seconds | |||||
if not diff_flag: | |||||
log_content=f'名片添加好友功能,{wxid} 用户创建不够三天,不能使用该功能' | |||||
logger.warning(log_content) | |||||
return | |||||
# 将加好友资料添加到待加好友队列 | |||||
#gewe_chat.wxchat.enqueue_to_add_contacts(wxid,scene,v3,v4) | |||||
_,loginfo=await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
nickname=loginfo.get('nickName') | |||||
add_contact_content=f'您好,我是{nickname}' | |||||
#gewe_chat.wxchat.add_contacts(token_id,app_id,scene,Models.OperationType.ADD_FRIEND,v3,v4,add_contact_content) | |||||
await request.app.state.gewe_service.add_contacts_async(token_id,app_id,scene,OperationType.ADD_FRIEND.value,v3,v4,add_contact_content) | |||||
except ET.ParseError as e: | |||||
logger.error(f"XML解析错误: {e}") | |||||
except KeyError as e: | |||||
logger.error(f"字典键错误: {e}") | |||||
except Exception as e: | |||||
logger.error(f"未知错误: {e}") | |||||
async def handle_xml_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
''' | |||||
处理xml | |||||
''' | |||||
try: | |||||
msg_content_xml=msg_data["Content"]["string"] | |||||
root = ET.fromstring(msg_content_xml) | |||||
type_value = int(root.find(".//appmsg/type").text) | |||||
handlers = { | |||||
57: handle_xml_reference_async, | |||||
5: handle_xml_invite_group_async | |||||
} | |||||
handler = handlers.get(type_value) | |||||
if handler: | |||||
return await handler(request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid) | |||||
# elif "邀请你加入了群聊" in msg_content_xml: # 邀请加入群聊 | |||||
# logger.warning(f"xml消息 {type_value} 邀请你加入了群聊.todo") | |||||
else: | |||||
print(f"xml消息 {type_value} 未解析") | |||||
except ET.ParseError as e: | |||||
logger.error(f"解析XML失败: {e}") | |||||
except Exception as e: | |||||
logger.error(f"未知错误: {e}") | |||||
return | |||||
async def handle_xml_reference_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
''' | |||||
引用消息 | |||||
判断此类消息的逻辑:$.Data.MsgType=49 并且 解析$.Data.Content.string中的xml msg.appmsg.type=57 | |||||
''' | |||||
callback_to_user=from_wxid | |||||
hash_key = f'__AI_OPS_WX__:MESSAGES:{wxid}:{callback_to_user}' | |||||
msg_content= msg_data["PushContent"] | |||||
prompt={"role": "user", "content": [{ | |||||
"type": "text", | |||||
"text": msg_content | |||||
}]} | |||||
# 收到的对话 | |||||
messages_to_send=await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, prompt) | |||||
input_wx_content_dialogue_message=[{"type": "text", "text": msg_content}] | |||||
input_message=dialogue_message(callback_to_user,wxid,input_wx_content_dialogue_message) | |||||
await request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
# 回复的对话 | |||||
res=await gpt_client_async(request,messages_to_send,wxid,callback_to_user) | |||||
reply_content=remove_markdown_symbol(res["choices"][0]["message"]["content"]) | |||||
input_wx_content_dialogue_message=[{"type": "text", "text": reply_content}] | |||||
input_message=dialogue_message(wxid,callback_to_user,input_wx_content_dialogue_message,True) | |||||
await request.app.state.kafka_service.kafka_client.produce_message(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
await request.app.state.kafka_service.save_session_messages_to_cache_async(hash_key, {"role": "assistant", "content": reply_content}) | |||||
await request.app.state.kafka_service.post_text_async(token_id,app_id,callback_to_user,reply_content) | |||||
async def handle_xml_invite_group_async (request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
''' | |||||
群聊邀请 | |||||
判断此类消息的逻辑:$.Data.MsgType=49 | |||||
并且 解析$.Data.Content.string中的xml msg.appmsg.title=邀请你加入群聊(根据手机设置的系统语言title会有调整,不同语言关键字不同) | |||||
''' | |||||
logger.info(f'{wxid} 群聊邀请') | |||||
msg_content_xml=msg_data["Content"]["string"] | |||||
root = ET.fromstring(msg_content_xml) | |||||
title_value = root.find(".//appmsg/title").text | |||||
if '邀请你加入群聊' in title_value: | |||||
invite_url = root.find('.//url').text | |||||
ret,msg,data=await request.app.state.gewe_service.agree_join_room_async(token_id,app_id,invite_url) | |||||
if ret==200: | |||||
logger.info(f'群聊邀请,同意加入群聊 {msg} {data}') | |||||
chatroom_id=data.get('chatroomId','') | |||||
# if not chatroom_id: | |||||
# logger.warning(f'群聊邀请,同意加入群聊失败 {msg} {data}') | |||||
# return | |||||
ret,msg,data=await request.app.state.gewe_service.save_contract_list_async(token_id,app_id,chatroom_id,3) | |||||
logger.info(f'群聊邀请,保存到通讯录 chatroom_id {chatroom_id} {msg}') | |||||
await request.app.state.gewe_service.update_group_info_to_cache_async(token_id,app_id,wxid,chatroom_id) | |||||
await request.app.state.gewe_service.update_group_members_to_cache_async(token_id,app_id,wxid,chatroom_id) | |||||
else: | |||||
logger.warning(f'群聊邀请,同意加入群聊失败 {msg} {data}') | |||||
async def handle_add_friend_notice_async(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
''' | |||||
好友添加请求通知 | |||||
''' | |||||
logger.info('好友添加请求通知') | |||||
try: | |||||
msg_content_xml=msg_data["Content"]["string"] | |||||
root = ET.fromstring(msg_content_xml) | |||||
msg_content = root.attrib.get('content', None) | |||||
v3= root.attrib.get('encryptusername', None) | |||||
v4= root.attrib.get('ticket', None) | |||||
scene=root.attrib.get('scene', None) | |||||
to_contact_wxid=root.attrib.get('fromusername', None) | |||||
wxid=msg_data["ToUserName"]["string"] | |||||
# 自动同意好友 | |||||
# print(v3) | |||||
# print(v4) | |||||
# print(scene) | |||||
# print(msg_content) | |||||
# 操作类型,2添加好友 3同意好友 4拒绝好友 | |||||
#option=2 | |||||
option=3 | |||||
reply_add_contact_contact="亲,我是你的好友" | |||||
ret,ret_msg=await request.app.state.gewe_service.add_contacts_async(token_id,app_id,scene,option,v3,v4,reply_add_contact_contact) | |||||
if ret==200: | |||||
logger.info('自动添加好友成功') | |||||
# 好友发送的文字 | |||||
hash_key = f'__AI_OPS_WX__:MESSAGES:{wxid}:{to_contact_wxid}' | |||||
prompt={"role": "user", "content": [{"type": "text","text": msg_content}]} | |||||
messages_to_send=await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, prompt) | |||||
input_wx_content_dialogue_message=[{"type": "text", "text": msg_content}] | |||||
input_message=dialogue_message(to_contact_wxid,wxid,input_wx_content_dialogue_message) | |||||
await request.app.state.gewe_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
callback_to_user=to_contact_wxid | |||||
res=await gpt_client_async(messages_to_send,wxid,callback_to_user) | |||||
reply_content=remove_markdown_symbol(res["choices"][0]["message"]["content"]) | |||||
#保存好友信息 | |||||
await request.app.state.gewe_service.save_contacts_brief_to_cache_async(token_id,app_id, wxid,[to_contact_wxid]) | |||||
# 保存到缓存 | |||||
await request.app.state.gewe_service.save_session_messages_to_cache_async(hash_key, {"role": "assistant", "content": reply_content}) | |||||
# 发送信息 | |||||
await request.app.state.gewe_service.post_text_async(token_id,app_id, to_contact_wxid,reply_content) | |||||
# 发送到kafka | |||||
input_wx_content_dialogue_message=[{"type": "text", "text": reply_content}] | |||||
input_message=dialogue_message(wxid,to_contact_wxid,input_wx_content_dialogue_message,True) | |||||
request.app.state.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s",input_message) | |||||
else: | |||||
logger.warning("添加好友失败") | |||||
except ET.ParseError as e: | |||||
logger.error(f"解析XML失败: {e}") | |||||
except Exception as e: | |||||
logger.error(f"未知错误: {e}") | |||||
return | |||||
async def handle_10002_msg(request: Request,token_id,app_id, wxid,msg_data,from_wxid, to_wxid): | |||||
''' | |||||
群聊邀请 | |||||
撤回消息 | |||||
拍一拍消息 | |||||
地理位置 | |||||
踢出群聊通知 | |||||
解散群聊通知 | |||||
发布群公告 | |||||
''' | |||||
try: | |||||
msg_content_xml=msg_data["Content"]["string"] | |||||
# 群聊邀请 | |||||
if '邀请你加入了群聊' in msg_content_xml and check_chatroom(msg_data["FromUserName"]["string"]): | |||||
chatroom_id=msg_data["FromUserName"]["string"] | |||||
ret,msg,data=await request.app.state.gewe_service.save_contract_list_async(token_id,app_id,chatroom_id,3) | |||||
logger.info(f'群聊邀请,保存到通讯录 chatroom_id {chatroom_id} {msg}') | |||||
await request.app.state.gewe_service.update_group_info_to_cache_async(token_id,app_id,wxid,chatroom_id) | |||||
await request.app.state.gewe_service.update_group_members_to_cache_async(token_id,app_id,wxid,chatroom_id) | |||||
if '移出了群聊' in msg_content_xml and 'sysmsgtemplate' in msg_content_xml : | |||||
chatroom_id=msg_data["FromUserName"]["string"] | |||||
ret,msg,data=await request.app.state.gewe_service.save_contract_list_async(token_id,app_id,chatroom_id,2) | |||||
logger.info(f'踢出群聊,移除从通讯录 chatroom_id {chatroom_id} {msg}') | |||||
await request.app.state.redis_service.delete_hash_field(f'__AI_OPS_WX__:GROUPS_INFO:{wxid}',chatroom_id) | |||||
logger.info(f'清除 chatroom_id{chatroom_id} 数据') | |||||
if '已解散该群聊' in msg_content_xml and 'sysmsgtemplate' in msg_content_xml : | |||||
chatroom_id=msg_data["FromUserName"]["string"] | |||||
ret,msg,data=await request.app.state.gewe_service.save_contract_list_async(token_id,app_id,chatroom_id,2) | |||||
logger.info(f'解散群聊,移除从通讯录 chatroom_id {chatroom_id} {msg}') | |||||
await request.app.state.redis_service.delete_hash_field(f'__AI_OPS_WX__:GROUPS_INFO:{wxid}',chatroom_id) | |||||
logger.info(f'清除 chatroom_id{chatroom_id} 数据') | |||||
print('撤回消息,拍一拍消息,地理位置') | |||||
except ET.ParseError as e: | |||||
logger.error(f"解析XML失败: {e}") | |||||
except Exception as e: | |||||
logger.error(f"未知错误: {e}") | |||||
return | |||||
@@ -0,0 +1,126 @@ | |||||
from fastapi import APIRouter,Request,HTTPException | |||||
from pydantic import BaseModel | |||||
from fastapi import APIRouter, Depends | |||||
from pydantic import BaseModel, ValidationError | |||||
from model.models import AgentConfig,validate_wxid,auth_required_time | |||||
import threading | |||||
import asyncio | |||||
import time | |||||
from typing import Dict, Tuple, Any | |||||
sns_router = APIRouter(prefix="/api/sns") | |||||
class SendTextRequest(BaseModel): | |||||
wxid: str | |||||
content:str | |||||
class SendImagesRequest(BaseModel): | |||||
wxid: str | |||||
content:str | |||||
imageUrls:list | |||||
class SendVideoRequest(BaseModel): | |||||
wxid: str | |||||
content:str | |||||
videoUrl:str | |||||
videoThumbUrl:str | |||||
async def auth_required_time(request: Request,wxid:str): | |||||
if not wxid: | |||||
return {"code": 400, "message": "wxid 不能为空"} | |||||
# 模拟获取登录信息 | |||||
k, loginfo = await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
if not loginfo: | |||||
return {"code": 404, "message": f"{wxid} 微信信息不存在"} | |||||
login_status = loginfo.get('status', '0') | |||||
if login_status != '1': | |||||
return {"code": 401, "message": f"{wxid} 已经离线"} | |||||
creation_timestamp = int(loginfo.get('create_at', time.time())) | |||||
current_timestamp = time.time() | |||||
three_days_seconds = 3 * 24 * 60 * 60 # 三天的秒数 | |||||
diff_flag = (current_timestamp - creation_timestamp) >= three_days_seconds | |||||
if not diff_flag: | |||||
return {'code': 401, 'message': '用户创建不够三天,不能使用该功能'} | |||||
return k, loginfo | |||||
@sns_router.post("/sendtext", response_model=None) | |||||
async def send_text(request: Request, body: SendTextRequest, ): | |||||
wxid = body.wxid | |||||
content = body.content | |||||
auth = await auth_required_time(request, wxid) | |||||
if isinstance(auth, Dict): | |||||
return auth | |||||
k, loginfo = auth | |||||
token_id = loginfo.get('tokenId') | |||||
app_id = loginfo.get('appId') | |||||
ret, msg, data = await request.app.state.gewe_service.send_text_sns_async(token_id, app_id, content) | |||||
if ret != 200: | |||||
return { | |||||
'code': ret, | |||||
'message': msg | |||||
} | |||||
return data | |||||
@sns_router.post("/sendimages", response_model=None) | |||||
async def send_text(request: Request, body: SendImagesRequest, ): | |||||
wxid = body.wxid | |||||
content = body.content | |||||
image_urls=body.imageUrls | |||||
auth = await auth_required_time(request, wxid) | |||||
if isinstance(auth, Dict): | |||||
return auth | |||||
k, loginfo = auth | |||||
token_id = loginfo.get('tokenId') | |||||
app_id = loginfo.get('appId') | |||||
ret, msg, data =await request.app.state.gewe_service.upload_sns_image_async(token_id, app_id, image_urls) | |||||
if ret != 200: | |||||
return { | |||||
'code': ret, | |||||
'message': msg | |||||
} | |||||
ret, msg, data = await request.app.state.gewe_service.send_image_sns_async(token_id, app_id, content,data) | |||||
if ret != 200: | |||||
return { | |||||
'code': ret, | |||||
'message': msg | |||||
} | |||||
return data | |||||
@sns_router.post("/sendvideo", response_model=None) | |||||
async def send_video(request: Request, body: SendVideoRequest, ): | |||||
wxid = body.wxid | |||||
content = body.content | |||||
video_url=body.videoUrl | |||||
video_thumb_url=body.videoThumbUrl | |||||
auth = await auth_required_time(request, wxid) | |||||
if isinstance(auth, Dict): | |||||
return auth | |||||
k, loginfo = auth | |||||
token_id = loginfo.get('tokenId') | |||||
app_id = loginfo.get('appId') | |||||
ret, msg, data = await request.app.state.gewe_service.upload_sns_video_async(token_id, app_id, video_url,video_thumb_url) | |||||
if ret != 200: | |||||
return { | |||||
'code': ret, | |||||
'message': msg | |||||
} | |||||
ret, msg, data = await request.app.state.gewe_service.send_video_sns_async(token_id, app_id, content,data) | |||||
if ret != 200: | |||||
return { | |||||
'code': ret, | |||||
'message': msg | |||||
} | |||||
return data |
@@ -0,0 +1,233 @@ | |||||
from fastapi import FastAPI,Request | |||||
from pydantic import BaseModel | |||||
from contextlib import asynccontextmanager | |||||
from celery import Celery | |||||
# from aiokafka import AIOKafkaConsumer | |||||
import asyncio | |||||
import json | |||||
import time | |||||
import uvicorn | |||||
import logging | |||||
from logging.handlers import TimedRotatingFileHandler,RotatingFileHandler | |||||
from starlette.middleware.base import BaseHTTPMiddleware | |||||
from services.gewe_service import GeWeService # 导入 GeWeChatCom | |||||
from common.log import logger | |||||
from app.endpoints.config_endpoint import config_router | |||||
from app.endpoints.contacts_endpoint import contacts_router | |||||
from app.endpoints.groups_endpoint import groups_router | |||||
from app.endpoints.sns_endpoint import sns_router | |||||
from app.endpoints.agent_endpoint import agent_router | |||||
from app.endpoints.pipeline_endpoint import messages_router | |||||
from services.redis_service import RedisService | |||||
from services.kafka_service import KafkaService | |||||
from services.biz_service import BizService | |||||
from app.middleware import http_context | |||||
from celery.result import AsyncResult | |||||
from app.tasks import add_task,sync_contacts_task | |||||
from config import load_config | |||||
from config import conf | |||||
from common.utils import * | |||||
load_config() | |||||
# Kafka 配置 | |||||
#KAFKA_BOOTSTRAP_SERVERS = '192.168.2.121:9092' | |||||
KAFKA_BOOTSTRAP_SERVERS = conf().get("kafka_bootstrap_servers") | |||||
KAFKA_TOPIC = 'topic.ai.ops.wx' | |||||
KAFKA_GROUP_ID = 'ai-ops-wx' | |||||
# 用于存储后台任务的全局变量 | |||||
background_tasks = set() | |||||
async def kafka_consumer(): | |||||
while True: | |||||
# 这里模拟 Kafka 消费者的逻辑 | |||||
# print("Kafka consumer is running...") | |||||
await asyncio.sleep(1) | |||||
async def background_worker(redis_service:RedisService,kafka_service:KafkaService,gewe_service:GeWeService): | |||||
lock_name = "background_wxchat_thread_lock" | |||||
lock_identifier = str(time.time()) # 使用时间戳作为唯一标识 | |||||
while True: | |||||
# 尝试获取分布式锁 | |||||
if await redis_service.acquire_lock(lock_name, timeout=10): | |||||
try: | |||||
logger.info("分布式锁已成功获取") | |||||
# 启动任务 | |||||
print('启动任务') | |||||
# 启动后台任务 | |||||
await startup_sync_data_task_async(redis_service, kafka_service, gewe_service) # 确保传递了正确的参数 | |||||
print('启动任务完成') | |||||
# 保持锁的续期 | |||||
while True: | |||||
await asyncio.sleep(30) # 每30秒检查一次锁的状态 | |||||
if not await redis_service.renew_lock(lock_name, lock_identifier, timeout=60): | |||||
break # 如果无法续期锁,退出循环 | |||||
finally: | |||||
# 释放锁 | |||||
await redis_service.release_lock(lock_name, lock_identifier) | |||||
break # 任务完成后退出循环 | |||||
else: | |||||
# 如果获取锁失败,等待一段时间后重试 | |||||
logger.info("获取分布式锁失败,等待10秒后重试...") | |||||
await asyncio.sleep(10) | |||||
async def startup_sync_data_task_async(redis_service: RedisService, kafka_service: KafkaService, gewe_service: GeWeService): | |||||
try: | |||||
login_keys = [] | |||||
async for key in redis_service.client.scan_iter(match='__AI_OPS_WX__:LOGININFO:*'): # 使用 async for 遍历异步生成器 | |||||
login_keys.append(key) | |||||
for k in login_keys: | |||||
r = await redis_service.get_hash(k) | |||||
app_id = r.get("appId") | |||||
token_id = r.get("tokenId") | |||||
wxid = r.get("wxid") | |||||
status = r.get('status') | |||||
if status == '0': | |||||
continue | |||||
# 同步联系人列表 | |||||
ret, msg, contacts_list = await gewe_service.fetch_contacts_list_async(token_id, app_id) | |||||
if ret != 200: | |||||
logger.warning(f"同步联系人列表失败: {ret}-{msg}") | |||||
continue | |||||
friend_wxids = [c for c in contacts_list['friends'] if c not in ['fmessage', 'medianote', 'weixin', 'weixingongzhong']] # 可以调整截取范围 | |||||
data = await gewe_service.save_contacts_brief_to_cache_async(token_id, app_id, wxid, friend_wxids) | |||||
chatrooms = contacts_list['chatrooms'] | |||||
# 同步群列表 | |||||
logger.info(f'{wxid} 的群数量 {len(chatrooms)}') | |||||
logger.info(f'{wxid} 同步群列表') | |||||
await gewe_service.save_groups_info_to_cache_async(token_id, app_id, wxid, chatrooms) | |||||
logger.info(f'{wxid} 同步群成员') | |||||
# 同步群成员 | |||||
await gewe_service.save_groups_members_to_cache_async(token_id, app_id, wxid, chatrooms) | |||||
logger.info(f'{wxid} 好友信息推送到kafka') | |||||
# 联系人推送到kafka | |||||
k_message = wx_all_contacts(wxid, data) | |||||
await kafka_service.send_message_async(k_message) | |||||
except Exception as e: | |||||
logger.error(f"任务执行过程中发生异常: {e}") | |||||
@asynccontextmanager | |||||
async def lifespan(app: FastAPI): | |||||
#app.state.redis_helper = RedisHelper(host='192.168.2.121',password='telpo#1234', port=8090, db=3) | |||||
# 初始化 RedisHelper | |||||
redis_service = RedisService() | |||||
redis_host=conf().get("redis_host") | |||||
redis_port=conf().get("redis_port") | |||||
redis_password=conf().get("redis_password") | |||||
redis_db=conf().get("redis_db") | |||||
await redis_service.init(host=redis_host,port=redis_port, password=redis_password, db=redis_db) | |||||
app.state.redis_service = redis_service | |||||
# 初始化 KafkaService | |||||
kafka_service= KafkaService(KAFKA_BOOTSTRAP_SERVERS, KAFKA_TOPIC, KAFKA_TOPIC,KAFKA_GROUP_ID) | |||||
await kafka_service.start() | |||||
app.state.kafka_service = kafka_service | |||||
# redis_service_instance=app.state.redis_service | |||||
# 初始化 GeWeChatCom | |||||
app.state.gewe_service = await GeWeService.get_instance(app,"http://api.geweapi.com/gewe") | |||||
gewe_service=app.state.gewe_service | |||||
# # 初始化 GeWeChatCom | |||||
#app.state.gwechat_service = GeWeService(app) | |||||
# 初始化业务服务 | |||||
biz_service = BizService(app) | |||||
app.state.biz_service = biz_service | |||||
biz_service.setup_handlers() | |||||
# 在应用程序启动时启动 Kafka 消费者任务 | |||||
# try: | |||||
# yield # 应用程序运行期间 | |||||
# finally: | |||||
# # 在应用程序关闭时取消所有后台任务 | |||||
# await kafka_service.stop() | |||||
#task = asyncio.create_task(kafka_consumer()) | |||||
task=asyncio.create_task(background_worker(redis_service,kafka_service,gewe_service)) | |||||
background_tasks.add(task) | |||||
try: | |||||
yield # 应用程序运行期间 | |||||
finally: | |||||
# # 在应用程序关闭时取消所有后台任务 | |||||
task.cancel() | |||||
try: | |||||
await task | |||||
except asyncio.CancelledError: | |||||
pass | |||||
background_tasks.clear() | |||||
# 关闭 KafkaService | |||||
print('应用关闭') | |||||
await kafka_service.stop() | |||||
app = FastAPI(lifespan=lifespan) | |||||
# 配置日志:输出到文件,文件最大 10MB,保留 5 个备份文件 | |||||
# log_handler = RotatingFileHandler( | |||||
# "app.log", # 日志文件名 | |||||
# maxBytes=10 * 1024 * 1024, # 文件大小限制:10MB | |||||
# backupCount=5, # 保留 5 个备份文件 | |||||
# encoding="utf-8" # 日志文件的编码 | |||||
# ) | |||||
# # 设置日志格式 | |||||
# log_handler.setFormatter( | |||||
# logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |||||
# ) | |||||
# 获取根日志记录器,并设置日志级别为 INFO | |||||
# logging.basicConfig( | |||||
# level=logging.INFO, # 设置日志记录级别 | |||||
# handlers=[log_handler] # 配置文件日志处理器 | |||||
# ) | |||||
app.add_middleware(BaseHTTPMiddleware, dispatch=http_context) | |||||
app.include_router(config_router) | |||||
app.include_router(contacts_router) | |||||
app.include_router(groups_router) | |||||
app.include_router(sns_router) | |||||
app.include_router(agent_router) | |||||
app.include_router(messages_router) | |||||
@app.get("/") | |||||
async def root(): | |||||
logger.info("Root route is called") | |||||
return {"message": "Kafka consumer is running in the background"} | |||||
class AddRequest(BaseModel): | |||||
x: int | |||||
y: int | |||||
@app.post("/add") | |||||
async def add_numbers(request: AddRequest): | |||||
task = add_task.delay(request.x, request.y) | |||||
return {"task_id": task.id} | |||||
@app.get("/task/{task_id}") | |||||
async def get_task_status(task_id: str): | |||||
task_result = AsyncResult(task_id) | |||||
return { | |||||
"task_id": task_id, | |||||
"task_status": task_result.status, | |||||
"task_result": task_result.result | |||||
} | |||||
@@ -0,0 +1,180 @@ | |||||
import json | |||||
import time | |||||
from fastapi import FastAPI, Request,HTTPException | |||||
from starlette.middleware.base import BaseHTTPMiddleware | |||||
from starlette.responses import JSONResponse | |||||
from pydantic import BaseModel | |||||
from datetime import datetime | |||||
import logging | |||||
from common.log import logger | |||||
class Result(BaseModel): | |||||
code: int | |||||
message: str | |||||
status: str | |||||
class ResponseData(BaseModel): | |||||
data: dict|list|None | |||||
result: Result | |||||
timestamp: str | |||||
async def http_context(request: Request, call_next): | |||||
# 记录请求信息 | |||||
request_info = { | |||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |||||
"method": request.method, | |||||
"url": str(request.url), | |||||
"body": await request.body() if request.method in ["POST", "PUT", "PATCH"] else None, | |||||
} | |||||
logger.info(f"请求: {json.dumps(request_info, separators=(',', ':'), default=str, ensure_ascii=False)}") | |||||
# 调用下一个中间件或路由处理程序 | |||||
response = await call_next(request) | |||||
# 如果响应状态码为 200,则格式化响应为统一格式 | |||||
if response.status_code == 200: | |||||
try: | |||||
response_body = b"" | |||||
async for chunk in response.body_iterator: | |||||
response_body += chunk | |||||
response_body_str = response_body.decode("utf-8") | |||||
business_data = json.loads(response_body_str) | |||||
except Exception as e: | |||||
business_data = {"error": f"Unable to decode response body: {str(e)}"} | |||||
if "code" in business_data: | |||||
message=business_data.get("message","请求失败!") | |||||
result = ResponseData( | |||||
data=None, | |||||
result=Result(code=business_data.get("code",500), message=message, status="failed"), | |||||
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") | |||||
) | |||||
else: | |||||
# 构造统一格式的响应 | |||||
result = ResponseData( | |||||
data=business_data, | |||||
result=Result(code=200, message="请求成功!", status="succeed"), | |||||
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") | |||||
) | |||||
response_info = { | |||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |||||
"status_code": response.status_code, | |||||
"headers": dict(response.headers), | |||||
"body": result.dict(), | |||||
} | |||||
logger.info(f"响应: {json.dumps(response_info, separators=(',', ':'), default=str, ensure_ascii=False)}") | |||||
# 返回修改后的响应 | |||||
return JSONResponse(content=result.model_dump()) | |||||
else: | |||||
print(response) | |||||
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') | |||||
message = "请求失败!" | |||||
# 如果响应状态码不为 200,则记录响应信息 | |||||
try: | |||||
response_body = b"" | |||||
async for chunk in response.body_iterator: | |||||
response_body += chunk | |||||
response_body_str = response_body.decode("utf-8") | |||||
business_data = json.loads(response_body_str) | |||||
except Exception as e: | |||||
business_data = {"error": f"Unable to decode response body: {str(e)}"} | |||||
# 根据不同状态码定制 message 字段 | |||||
if response.status_code == 404: | |||||
message = e.detail | |||||
elif response.status_code == 400: | |||||
message = "请求参数错误" | |||||
elif response.status_code == 500: | |||||
message = "服务器内部错误" | |||||
# 你可以根据不同的状态码设置更详细的错误消息 | |||||
# 构造统一格式的响应 | |||||
result = ResponseData( | |||||
data={}, # 返回空数据 | |||||
result=Result( | |||||
code=response.status_code, | |||||
message=message, # 根据状态码返回详细信息 | |||||
status="failed" | |||||
), | |||||
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") | |||||
) | |||||
response_info = { | |||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |||||
"status_code": response.status_code, | |||||
"headers": dict(response.headers), | |||||
"body": result.dict(), | |||||
} | |||||
logger.info(f"响应: {json.dumps(response_info, separators=(',', ':'), default=str, ensure_ascii=False)}") | |||||
# 返回修改后的响应 | |||||
return JSONResponse(content=result.model_dump(), status_code=response.status_code) | |||||
async def http_context_2(request: Request, call_next): | |||||
# 记录请求信息 | |||||
request_body = None | |||||
if request.method in ["POST", "PUT", "PATCH"]: | |||||
try: | |||||
request_body = await request.json() # 使用 .json(),避免影响 FastAPI 解析 | |||||
except Exception: | |||||
request_body = "无法解析 JSON" | |||||
request_info = { | |||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |||||
"method": request.method, | |||||
"url": str(request.url), | |||||
"body": request_body, | |||||
} | |||||
logger.info(f"请求: {json.dumps(request_info, separators=(',', ':'), ensure_ascii=False)}") | |||||
# 继续处理请求 | |||||
response = await call_next(request) | |||||
# 如果是 422 错误,直接返回,避免 Pydantic 解析错误 | |||||
if response.status_code == 422: | |||||
return response | |||||
# 处理正常请求 | |||||
try: | |||||
response_body = b"" | |||||
async for chunk in response.body_iterator: | |||||
response_body += chunk | |||||
response_body_str = response_body.decode("utf-8") | |||||
business_data = json.loads(response_body_str) | |||||
except Exception as e: | |||||
business_data = {"error": f"无法解析响应体: {str(e)}"} | |||||
if "code" in business_data: | |||||
message = business_data.get("message", "请求失败!") | |||||
result = ResponseData( | |||||
data=None, | |||||
result=Result(code=business_data.get("code", 500), message=message, status="failed"), | |||||
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") | |||||
) | |||||
else: | |||||
result = ResponseData( | |||||
data=business_data, | |||||
result=Result(code=200, message="请求成功!", status="succeed"), | |||||
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") | |||||
) | |||||
response_info = { | |||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |||||
"status_code": response.status_code, | |||||
"headers": dict(response.headers), | |||||
"body": result.dict(), | |||||
} | |||||
logger.info(f"响应: {json.dumps(response_info, separators=(',', ':'), ensure_ascii=False)}") | |||||
return JSONResponse(content=result.model_dump()) |
@@ -0,0 +1,28 @@ | |||||
from app.celery_app import celery | |||||
from fastapi import Request,FastAPI | |||||
import time | |||||
@celery.task(name='app.tasks.add_task', bind=True, acks_late=True) | |||||
def add_task(self, x, y): | |||||
time.sleep(5) # 模拟长时间计算 | |||||
return x + y | |||||
@celery.task(name='app.tasks.mul_task', bind=True, acks_late=True) | |||||
def mul_task(self, x, y): | |||||
time.sleep(5) # 模拟长时间计算 | |||||
return x * y | |||||
# @celery.task(name='app.tasks.sync_contacts', bind=True, acks_late=True) | |||||
# async def sync_contacts_task(self,app): | |||||
# login_keys = list(await app.state.redis_service.client.scan_iter(match='__AI_OPS_WX__:LOGININFO:*')) | |||||
# return login_keys | |||||
# # for k in login_keys: | |||||
# # print(k) | |||||
@celery.task(name='app.tasks.sync_contacts', bind=True, acks_late=True) | |||||
async def sync_contacts_task(self, redis_service): | |||||
# Use the redis_service passed as an argument | |||||
login_keys = list(await redis_service.client.scan_iter(match='__AI_OPS_WX__:LOGININFO:*')) | |||||
return login_keys |
@@ -0,0 +1,43 @@ | |||||
from datetime import datetime, timedelta | |||||
class ExpiredDict(dict): | |||||
def __init__(self, expires_in_seconds): | |||||
super().__init__() | |||||
self.expires_in_seconds = expires_in_seconds | |||||
def __getitem__(self, key): | |||||
value, expiry_time = super().__getitem__(key) | |||||
if datetime.now() > expiry_time: | |||||
del self[key] | |||||
raise KeyError("expired {}".format(key)) | |||||
self.__setitem__(key, value) | |||||
return value | |||||
def __setitem__(self, key, value): | |||||
expiry_time = datetime.now() + timedelta(seconds=self.expires_in_seconds) | |||||
# print(f'{key} 缓存过期时间:{expiry_time}') | |||||
super().__setitem__(key, (value, expiry_time)) | |||||
def get(self, key, default=None): | |||||
try: | |||||
return self[key] | |||||
except KeyError: | |||||
return default | |||||
def __contains__(self, key): | |||||
try: | |||||
self[key] | |||||
return True | |||||
except KeyError: | |||||
return False | |||||
def keys(self): | |||||
keys = list(super().keys()) | |||||
return [key for key in keys if key in self] | |||||
def items(self): | |||||
return [(key, self[key]) for key in self.keys()] | |||||
def __iter__(self): | |||||
return self.keys().__iter__() |
@@ -0,0 +1,176 @@ | |||||
# import logging | |||||
# import sys | |||||
# import os | |||||
# from datetime import datetime, timedelta | |||||
# LOG_DIR = "./logs" # 日志文件目录 | |||||
# LOG_RETENTION_DAYS = 7 # 日志保留天数 | |||||
# def _remove_old_logs(log_dir, retention_days): | |||||
# """删除超过保留天数的日志文件""" | |||||
# if not os.path.exists(log_dir): | |||||
# os.makedirs(log_dir) | |||||
# now = datetime.now() | |||||
# for filename in os.listdir(log_dir): | |||||
# file_path = os.path.join(log_dir, filename) | |||||
# if os.path.isfile(file_path) and filename.startswith("run_") and filename.endswith(".log"): | |||||
# # 提取文件日期 | |||||
# try: | |||||
# log_date_str = filename[4:14] | |||||
# log_date = datetime.strptime(log_date_str, "%Y-%m-%d") | |||||
# if now - log_date > timedelta(days=retention_days): | |||||
# os.remove(file_path) | |||||
# print(f"删除旧日志: {filename}") | |||||
# except ValueError: | |||||
# continue | |||||
# def _reset_logger(log): | |||||
# """重置日志配置,移除旧的 Handler 并添加新的 Handler""" | |||||
# for handler in log.handlers: | |||||
# handler.close() | |||||
# log.removeHandler(handler) | |||||
# del handler | |||||
# log.handlers.clear() | |||||
# log.propagate = False | |||||
# # 控制台输出的日志处理器 | |||||
# console_handle = logging.StreamHandler(sys.stdout) | |||||
# console_handle.setFormatter( | |||||
# logging.Formatter( | |||||
# "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", | |||||
# datefmt="%Y-%m-%d %H:%M:%S", | |||||
# ) | |||||
# ) | |||||
# # 生成带有当前日期的日志文件路径 | |||||
# date_str = datetime.now().strftime("%Y-%m-%d") | |||||
# log_file_path = os.path.join(LOG_DIR, f"run_{date_str}.log") | |||||
# # 文件日志处理器 | |||||
# file_handle = logging.FileHandler(log_file_path, encoding="utf-8") | |||||
# file_handle.setFormatter( | |||||
# logging.Formatter( | |||||
# "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", | |||||
# datefmt="%Y-%m-%d %H:%M:%S", | |||||
# ) | |||||
# ) | |||||
# # 将处理器添加到日志 | |||||
# log.addHandler(file_handle) | |||||
# log.addHandler(console_handle) | |||||
# # 删除旧的日志文件 | |||||
# _remove_old_logs(LOG_DIR, LOG_RETENTION_DAYS) | |||||
# def setup_logging(): | |||||
# """设置日志配置""" | |||||
# log = logging.getLogger("log") | |||||
# _reset_logger(log) | |||||
# log.setLevel(logging.INFO) # 日志级别 | |||||
# return log | |||||
# def setup_logging(): | |||||
# """设置日志配置""" | |||||
# log = logging.getLogger() # 获取 Flask 默认的日志记录器 | |||||
# _reset_logger(log) | |||||
# log.setLevel(logging.INFO) # 设置日志级别为 INFO | |||||
# return log | |||||
# # 创建日志实例 | |||||
# logger = setup_logging() | |||||
# def log_exception(sender, exception, **extra): | |||||
# """记录异常日志""" | |||||
# sender.logger.debug('处理过程发生异常: %s', exception) | |||||
import logging | |||||
import sys | |||||
import os | |||||
from datetime import datetime, timedelta | |||||
# 日志文件目录 | |||||
LOG_DIR = "./logs" | |||||
# 日志保留天数 | |||||
LOG_RETENTION_DAYS = 7 | |||||
def _remove_old_logs(log_dir, retention_days): | |||||
"""删除超过保留天数的日志文件""" | |||||
if not os.path.exists(log_dir): | |||||
os.makedirs(log_dir) | |||||
now = datetime.now() | |||||
for filename in os.listdir(log_dir): | |||||
file_path = os.path.join(log_dir, filename) | |||||
if os.path.isfile(file_path) and filename.startswith("run_") and filename.endswith(".log"): | |||||
# 提取文件日期 | |||||
try: | |||||
log_date_str = filename[4:14] | |||||
log_date = datetime.strptime(log_date_str, "%Y-%m-%d") | |||||
if now - log_date > timedelta(days=retention_days): | |||||
os.remove(file_path) | |||||
print(f"删除旧日志: {filename}") | |||||
except ValueError: | |||||
continue | |||||
def _reset_logger(log): | |||||
"""重置日志配置,移除旧的 Handler 并添加新的 Handler""" | |||||
for handler in log.handlers: | |||||
handler.close() | |||||
log.removeHandler(handler) | |||||
del handler | |||||
log.handlers.clear() | |||||
log.propagate = False | |||||
# 控制台输出的日志处理器 | |||||
console_handle = logging.StreamHandler(sys.stdout) | |||||
console_handle.setFormatter( | |||||
logging.Formatter( | |||||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", | |||||
datefmt="%Y-%m-%d %H:%M:%S", | |||||
) | |||||
) | |||||
# 生成带有当前日期的日志文件路径 | |||||
date_str = datetime.now().strftime("%Y-%m-%d") | |||||
log_file_path = os.path.join(LOG_DIR, f"run_{date_str}.log") | |||||
# 文件日志处理器 | |||||
file_handle = logging.FileHandler(log_file_path, encoding="utf-8") | |||||
file_handle.setFormatter( | |||||
logging.Formatter( | |||||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", | |||||
datefmt="%Y-%m-%d %H:%M:%S", | |||||
) | |||||
) | |||||
# 将处理器添加到日志 | |||||
log.addHandler(file_handle) | |||||
log.addHandler(console_handle) | |||||
# 删除旧的日志文件 | |||||
_remove_old_logs(LOG_DIR, LOG_RETENTION_DAYS) | |||||
def setup_logging(): | |||||
"""设置日志配置""" | |||||
log = logging.getLogger() # 获取默认的日志记录器 | |||||
_reset_logger(log) | |||||
log.setLevel(logging.INFO) # 设置日志级别为 INFO | |||||
return log | |||||
# 创建日志实例 | |||||
logger = setup_logging() | |||||
def log_exception(sender, exception, **extra): | |||||
"""记录异常日志""" | |||||
logger.error(f"处理过程发生异常: {exception}", exc_info=True) | |||||
# FastAPI 日志配置示例 | |||||
def configure_fastapi_logging(): | |||||
"""配置 FastAPI 的日志记录""" | |||||
fastapi_logger = logging.getLogger("uvicorn") | |||||
_reset_logger(fastapi_logger) | |||||
fastapi_logger.setLevel(logging.INFO) | |||||
# 配置 FastAPI 日志 | |||||
configure_fastapi_logging() |
@@ -0,0 +1,5 @@ | |||||
from common.expired_dict import ExpiredDict | |||||
USER_IMAGE_CACHE = ExpiredDict(60 * 3) | |||||
USER_INTERACTIVE_CACHE=ExpiredDict(60 * 1) | |||||
USER_LOGIN_QRCODE=ExpiredDict(80) |
@@ -0,0 +1,9 @@ | |||||
def singleton(cls): | |||||
instances = {} | |||||
def get_instance(*args, **kwargs): | |||||
if cls not in instances: | |||||
instances[cls] = cls(*args, **kwargs) | |||||
return instances[cls] | |||||
return get_instance |
@@ -0,0 +1,17 @@ | |||||
import os | |||||
import pathlib | |||||
class TmpDir(object): | |||||
"""A temporary directory that is deleted when the object is destroyed.""" | |||||
tmpFilePath = pathlib.Path("./tmp/") | |||||
def __init__(self): | |||||
pathExists = os.path.exists(self.tmpFilePath) | |||||
if not pathExists: | |||||
os.makedirs(self.tmpFilePath) | |||||
def path(self): | |||||
return str(self.tmpFilePath) + "/" |
@@ -0,0 +1,409 @@ | |||||
import io | |||||
import os | |||||
import uuid | |||||
import requests | |||||
from urllib.parse import urlparse | |||||
from PIL import Image | |||||
from common.log import logger | |||||
import oss2,time,json | |||||
from urllib.parse import urlparse, unquote | |||||
from voice.ali.ali_voice import AliVoice | |||||
from voice import audio_convert | |||||
import aiohttp,aiofiles | |||||
import cv2,re | |||||
import os | |||||
import tempfile | |||||
from moviepy.editor import VideoFileClip | |||||
from datetime import datetime | |||||
def clean_json_string(json_str): | |||||
# 删除所有控制字符(非打印字符),包括换行符、回车符等 | |||||
return re.sub(r'[\x00-\x1f\x7f]', '', json_str) | |||||
def dialogue_message(wxid_from:str,wxid_to:str,wx_content:list,is_ai:bool=False): | |||||
""" | |||||
构造消息的 JSON 数据 | |||||
:param contents: list,包含多个消息内容,每个内容为字典,如: | |||||
[{"type": "text", "text": "AAAAAAA"}, | |||||
{"type": "image_url", "image_url": {"url": "https://AAAAA.jpg"}}, | |||||
{"type":"file","file_url":{"url":"https://AAAAA.pdf"}} | |||||
] | |||||
:return: JSON 字符串 | |||||
""" | |||||
# 获取当前时间戳,精确到毫秒 | |||||
current_timestamp = int(time.time() * 1000) | |||||
# 获取当前时间,格式化为 "YYYY-MM-DD HH:MM:SS" | |||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||||
# 构造 JSON 数据 | |||||
data = { | |||||
"message_id": str(current_timestamp), | |||||
"topic": "topic.ai.ops.wx", | |||||
"time": current_time, | |||||
"data": { | |||||
"msg_type": "dialogue", | |||||
"is_ai":is_ai, | |||||
"content": { | |||||
"wxid_from": wxid_from, | |||||
"wxid_to": wxid_to, | |||||
"wx_content":wx_content | |||||
} | |||||
} | |||||
} | |||||
return json.dumps(data, separators=(',', ':'), ensure_ascii=False) | |||||
def kafka_base_message(msg_type:str,content: dict|list)->dict: | |||||
""" | |||||
构造消息的 JSON 数据 | |||||
:param wxid: 微信ID | |||||
:param data: 一个包含了所有联系人的数据,格式为list, | |||||
每个元素为字典,包含wxid、alias、remark、sex、city、province、country, | |||||
headimgurl、signature、skey、uin、nickname这10个字段 | |||||
:return: JSON 字符串 | |||||
""" | |||||
# 获取当前时间戳,精确到毫秒 | |||||
current_timestamp = int(time.time() * 1000) | |||||
# 获取当前时间,格式化为 "YYYY-MM-DD HH:MM:SS" | |||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||||
# 构造 JSON 数据 | |||||
data = { | |||||
"message_id": str(current_timestamp), | |||||
"topic": "topic.ai.ops.wx", | |||||
"time": current_time, | |||||
"data": { | |||||
#"msg_type": "login-qrcode", | |||||
"msg_type": msg_type, | |||||
"content": content | |||||
} | |||||
} | |||||
return data | |||||
def wx_offline_message(appid:str,wxid:str)->str: | |||||
content = {"appid": appid,"wxid":wxid} | |||||
data=kafka_base_message("wx-offline",content) | |||||
return json.dumps(data, separators=(',', ':'), ensure_ascii=False) | |||||
def wx_del_contact_message(wxid:str,contact_wixd:str)->str: | |||||
content = {"wxid": wxid,"contact_wixd":contact_wixd} | |||||
data=kafka_base_message("del-contact",content) | |||||
return json.dumps(data, separators=(',', ':'), ensure_ascii=False) | |||||
def wx_mod_contact_message(wxid:str,contact_data:dict)->str: | |||||
content = {"wxid": wxid,"contact_data":contact_data} | |||||
data=kafka_base_message("mod-contact",content) | |||||
return json.dumps(data, separators=(',', ':'), ensure_ascii=False) | |||||
def wx_all_contacts(wxid:str,data:dict|list)->str: | |||||
content = {"wxid": wxid,"contacts_data":data} | |||||
data=kafka_base_message("all-contacts",content) | |||||
return json.dumps(data, separators=(',', ':'), ensure_ascii=False) | |||||
def login_qrcode_message(token_id: str,agent_tel:str,qr_code_img_base64:str,qr_code_url:list)->str: | |||||
""" | |||||
构造消息的 JSON 数据 | |||||
:param contents: list,包含多个消息内容,每个内容为字典,如: | |||||
{ | |||||
"tel":"18029274615", | |||||
"token_id":"f828cb3c-1039-489f-b9ae-7494d1778a15", | |||||
"qr_code_urls":["url1","url2","url3","url4",], | |||||
"qr_code_img_base64":"aaaaaaaaaaaaaa" | |||||
} | |||||
:return: JSON 字符串 | |||||
""" | |||||
content = { | |||||
"tel":agent_tel, | |||||
"token_id":token_id, | |||||
"qr_code_urls":qr_code_url, | |||||
"qr_code_img_base64":qr_code_img_base64 | |||||
} | |||||
data=kafka_base_message("login-qrcode",content) | |||||
return json.dumps(data, separators=(',', ':'), ensure_ascii=False) | |||||
def login_result_message(token_id: str,agent_tel:str,region_id:str,agent_token_id:str,wxid:str)->str: | |||||
content = { | |||||
"tel":agent_tel, | |||||
"token_id":token_id, | |||||
"region_id":region_id, | |||||
"agent_token_id":agent_token_id, | |||||
"wxid":wxid | |||||
} | |||||
data=kafka_base_message("login-result",content) | |||||
return json.dumps(data, separators=(',', ':'), ensure_ascii=False) | |||||
def wx_voice(text: str): | |||||
try: | |||||
# 将文本转换为语音 | |||||
reply_text_voice = AliVoice().textToVoice(text) | |||||
reply_text_voice_path = os.path.join(os.getcwd(), reply_text_voice) | |||||
# 转换为 Silk 格式 | |||||
reply_silk_path = os.path.splitext(reply_text_voice_path)[0] + ".silk" | |||||
reply_silk_during = audio_convert.any_to_sil(reply_text_voice_path, reply_silk_path) | |||||
# OSS 配置(建议将凭证存储在安全的地方) | |||||
oss_access_key_id="LTAI5tRTG6pLhTpKACJYoPR5" | |||||
oss_access_key_secret="E7dMzeeMxq4VQvLg7Tq7uKf3XWpYfN" | |||||
oss_endpoint="http://oss-cn-shanghai.aliyuncs.com" | |||||
oss_bucket_name="cow-agent" | |||||
oss_prefix="cow" | |||||
# 上传文件到 OSS | |||||
file_path = reply_silk_path | |||||
file_url = upload_oss(oss_access_key_id, oss_access_key_secret, oss_endpoint, oss_bucket_name, file_path, oss_prefix) | |||||
# 删除临时文件 | |||||
try: | |||||
os.remove(reply_text_voice_path) | |||||
except FileNotFoundError: | |||||
pass # 如果文件未找到,跳过删除 | |||||
try: | |||||
os.remove(reply_silk_path) | |||||
except FileNotFoundError: | |||||
pass # 如果文件未找到,跳过删除 | |||||
return int(reply_silk_during), file_url | |||||
except Exception as e: | |||||
print(f"发生错误:{e}") | |||||
return None, None # 发生错误时返回 None | |||||
def upload_oss( | |||||
access_key_id, | |||||
access_key_secret, | |||||
endpoint, | |||||
bucket_name, | |||||
file_source, | |||||
prefix, | |||||
expiration_days=7 | |||||
): | |||||
""" | |||||
上传文件到阿里云OSS并设置生命周期规则,同时返回文件的公共访问地址。 | |||||
:param access_key_id: 阿里云AccessKey ID | |||||
:param access_key_secret: 阿里云AccessKey Secret | |||||
:param endpoint: OSS区域对应的Endpoint | |||||
:param bucket_name: OSS中的Bucket名称 | |||||
:param file_source: 本地文件路径或HTTP链接 | |||||
:param prefix: 设置规则应用的前缀为文件所在目录 | |||||
:param expiration_days: 文件保存天数,默认7天后删除 | |||||
:return: 文件的公共访问地址 | |||||
""" | |||||
# 创建Bucket实例 | |||||
auth = oss2.Auth(access_key_id, access_key_secret) | |||||
bucket = oss2.Bucket(auth, endpoint, bucket_name) | |||||
### 1. 设置生命周期规则 ### | |||||
rule_id = f'delete_after_{expiration_days}_days' # 规则ID | |||||
# prefix = oss_file_name.split('/')[0] + '/' # 设置规则应用的前缀为文件所在目录 | |||||
# 定义生命周期规则 | |||||
rule = oss2.models.LifecycleRule(rule_id, prefix, status=oss2.models.LifecycleRule.ENABLED, | |||||
expiration=oss2.models.LifecycleExpiration(days=expiration_days)) | |||||
# 设置Bucket的生命周期 | |||||
lifecycle = oss2.models.BucketLifecycle([rule]) | |||||
bucket.put_bucket_lifecycle(lifecycle) | |||||
print(f"已设置生命周期规则:文件将在{expiration_days}天后自动删除") | |||||
### 2. 判断文件来源并上传到OSS ### | |||||
if file_source.startswith('http://') or file_source.startswith('https://'): | |||||
# HTTP 链接,先下载文件 | |||||
try: | |||||
response = requests.get(file_source, stream=True) | |||||
response.raise_for_status() | |||||
parsed_url = urlparse(file_source) | |||||
# 提取路径部分并解码 | |||||
path = unquote(parsed_url.path) | |||||
# 获取路径的最后一部分作为文件名 | |||||
filename = path.split('/')[-1] | |||||
oss_file_name=prefix+'/'+ filename | |||||
bucket.put_object(oss_file_name, response.content) | |||||
print(f"文件从 HTTP 链接上传成功:{file_source}") | |||||
except requests.exceptions.RequestException as e: | |||||
print(f"从 HTTP 链接下载文件失败: {e}") | |||||
return None | |||||
else: | |||||
# 本地文件路径 | |||||
try: | |||||
filename=os.path.basename(file_source) | |||||
oss_file_name=prefix+'/'+ filename | |||||
bucket.put_object_from_file(oss_file_name, file_source) | |||||
print(f"文件从本地路径上传成功:{file_source}") | |||||
except oss2.exceptions.OssError as e: | |||||
print(f"从本地路径上传文件失败: {e}") | |||||
return None | |||||
### 3. 构建公共访问URL ### | |||||
file_url = f"http://{bucket_name}.{endpoint.replace('http://', '')}/{oss_file_name}" | |||||
print(f"文件上传成功,公共访问地址:{file_url}") | |||||
return file_url | |||||
def download_video_and_get_thumbnail(url, thumbnail_path): | |||||
""" | |||||
从指定URL下载MP4视频,提取首帧作为缩略图,并返回缩略图路径及视频时长。 | |||||
参数: | |||||
url (str): 视频的URL地址。 | |||||
thumbnail_path (str): 缩略图的保存路径。 | |||||
返回: | |||||
tuple: (缩略图路径, 视频时长(秒)) | |||||
异常: | |||||
可能抛出requests.exceptions.RequestException,cv2.error,IOError等异常。 | |||||
""" | |||||
logger.info("处理视频开始") | |||||
# 创建临时目录以下载视频 | |||||
with tempfile.TemporaryDirectory() as tmp_dir: | |||||
# 下载视频到临时文件 | |||||
video_path = os.path.join(tmp_dir, 'temp_video.mp4') | |||||
response = requests.get(url, stream=True) | |||||
response.raise_for_status() # 确保请求成功 | |||||
with open(video_path, 'wb') as f: | |||||
for chunk in response.iter_content(chunk_size=8192): | |||||
if chunk: # 过滤掉保持连接的空白块 | |||||
f.write(chunk) | |||||
# 提取视频首帧作为缩略图 | |||||
vidcap = cv2.VideoCapture(video_path) | |||||
success, image = vidcap.read() | |||||
vidcap.release() | |||||
if not success: | |||||
raise RuntimeError("无法读取视频的首帧,请检查视频文件是否有效。") | |||||
# 确保缩略图的目录存在 | |||||
thumbnail_dir = os.path.dirname(thumbnail_path) | |||||
if thumbnail_dir: | |||||
os.makedirs(thumbnail_dir, exist_ok=True) | |||||
# 保存缩略图 | |||||
cv2.imwrite(thumbnail_path, image) | |||||
# 使用moviepy计算视频时长 | |||||
clip = VideoFileClip(video_path) | |||||
duration = clip.duration | |||||
clip.close() | |||||
logger.info("处理视频完成") | |||||
# OSS 配置(建议将凭证存储在安全的地方) | |||||
oss_access_key_id="LTAI5tRTG6pLhTpKACJYoPR5" | |||||
oss_access_key_secret="E7dMzeeMxq4VQvLg7Tq7uKf3XWpYfN" | |||||
oss_endpoint="http://oss-cn-shanghai.aliyuncs.com" | |||||
oss_bucket_name="cow-agent" | |||||
oss_prefix="cow" | |||||
# 上传文件到 OSS | |||||
file_path = thumbnail_path | |||||
file_url = upload_oss(oss_access_key_id, oss_access_key_secret, oss_endpoint, oss_bucket_name, file_path, oss_prefix) | |||||
logger.info("上传缩略图") | |||||
# 删除临时文件 | |||||
try: | |||||
os.remove(thumbnail_path) | |||||
except FileNotFoundError: | |||||
pass # 如果文件未找到,跳过删除 | |||||
return file_url, duration | |||||
def contains_url(text): | |||||
# 定义检测网址的正则表达式 | |||||
url_pattern = re.compile( | |||||
r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' | |||||
) | |||||
# 检查字符串是否包含网址 | |||||
return bool(url_pattern.search(text)) | |||||
def get_first_char_if_digit(s): | |||||
if s and s[0].isdigit(): # 判断字符串是否非空且首字符为数字 | |||||
return int(s[0]) # 返回数字形式 | |||||
return None # 如果不是数字则返回 None | |||||
def remove_at_mention_regex(text): | |||||
# 使用正则表达式去掉“在群聊中@了你” | |||||
return re.sub(r"在群聊中@了你", "", text) | |||||
def extract_nickname(text)->str: | |||||
if "在群聊中@了你" in text: | |||||
# 如果包含 "在群聊中@了你",提取其前面的名字 | |||||
match = re.search(r"^(.*?)在群聊中@了你", text) | |||||
if match: | |||||
return match.group(1).strip() | |||||
elif ": @" in text: | |||||
# 如果包含 ": @",提取其前面的名字 | |||||
return text.split(": @")[0].strip() | |||||
return '' | |||||
def check_chatroom(userName): | |||||
pattern = r'^\d+@chatroom$' | |||||
if re.match(pattern, userName): | |||||
return True | |||||
return False | |||||
def remove_markdown_symbol(text: str): | |||||
# 移除markdown格式,目前先移除** | |||||
if not text or not isinstance(text, str): | |||||
return text | |||||
# 去除加粗、斜体等格式 | |||||
#text = re.sub(r'\*\*([^*]+)\*\*', r'\1', text) # 去除加粗 | |||||
text=re.sub(r'\*\*(.*?)\*\*', r'\1', text) | |||||
text = re.sub(r'\*([^*]+)\*', r'\1', text) # 去除斜体 | |||||
text = re.sub(r'__([^_]+)__', r'\1', text) # 去除加粗(下划线) | |||||
text = re.sub(r'_(.*?)_', r'\1', text) # 去除斜体(下划线) | |||||
# 去除行内代码块 | |||||
text = re.sub(r'`([^`]+)`', r'\1', text) | |||||
# 去除换行符\n,或者多余的空格 | |||||
#text = re.sub(r'\n+', ' ', text) | |||||
# 去除列表编号等 | |||||
#text = re.sub(r'^\d+\.\s*', '', text, flags=re.MULTILINE) | |||||
#text = re.sub('[\\\`\*\_\[\]\#\+\-\!\>]', '', text) | |||||
text = re.sub('[\\\`\*\_\[\]\#\+\!\>]', '', text) | |||||
print(text) | |||||
return text | |||||
async def save_to_local_from_url_async(url): | |||||
''' | |||||
从url保存到本地tmp目录 | |||||
''' | |||||
parsed_url = urlparse(url) | |||||
# 从 URL 提取文件名 | |||||
filename = os.path.basename(parsed_url.path) | |||||
# 拼接完整路径 | |||||
tmp_file_path = os.path.join(os.getcwd(), 'tmp', filename) | |||||
# 检查是否存在同名文件 | |||||
if os.path.exists(tmp_file_path): | |||||
logger.info(f"文件已存在,将覆盖:{tmp_file_path}") | |||||
# 异步下载文件并保存到临时目录 | |||||
async with aiohttp.ClientSession() as session: | |||||
async with session.get(url) as response: | |||||
if response.status == 200: | |||||
async with aiofiles.open(tmp_file_path, 'wb') as f: | |||||
async for chunk in response.content.iter_chunked(1024): | |||||
await f.write(chunk) | |||||
else: | |||||
logger.error(f"无法下载文件,HTTP状态码:{response.status}") | |||||
return None | |||||
return tmp_file_path |
@@ -0,0 +1,9 @@ | |||||
{ | |||||
"debug": false, | |||||
"redis_host":"192.168.2.121", | |||||
"redis_port":8090, | |||||
"redis_password":"telpo#1234", | |||||
"redis_db":3, | |||||
"kafka_bootstrap_servers":"192.168.2.121:9092", | |||||
"aiops_api":"https://id.ssjlai.com/aiopsadmin" | |||||
} |
@@ -0,0 +1,9 @@ | |||||
{ | |||||
"debug": false, | |||||
"redis_host":"172.19.42.40", | |||||
"redis_port":8090, | |||||
"redis_password":"telpo#1234", | |||||
"redis_db":3, | |||||
"kafka_bootstrap_servers":"172.19.42.40:9092,172.19.42.41:9092,172.19.42.48:9092", | |||||
"aiops_api":"https://ai.ssjlai.com/aiopsadmin" | |||||
} |
@@ -0,0 +1,37 @@ | |||||
{ | |||||
"channel_type": "wx", | |||||
"model": "", | |||||
"open_ai_api_key": "YOUR API KEY", | |||||
"claude_api_key": "YOUR API KEY", | |||||
"text_to_image": "dall-e-2", | |||||
"voice_to_text": "openai", | |||||
"text_to_voice": "openai", | |||||
"proxy": "", | |||||
"hot_reload": false, | |||||
"single_chat_prefix": [ | |||||
"bot", | |||||
"@bot" | |||||
], | |||||
"single_chat_reply_prefix": "[bot] ", | |||||
"group_chat_prefix": [ | |||||
"@bot" | |||||
], | |||||
"group_name_white_list": [ | |||||
"ChatGPT测试群", | |||||
"ChatGPT测试群2" | |||||
], | |||||
"image_create_prefix": [ | |||||
"画" | |||||
], | |||||
"speech_recognition": true, | |||||
"group_speech_recognition": false, | |||||
"voice_reply_voice": false, | |||||
"conversation_max_tokens": 2500, | |||||
"expires_in_seconds": 3600, | |||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", | |||||
"temperature": 0.7, | |||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。", | |||||
"use_linkai": false, | |||||
"linkai_api_key": "", | |||||
"linkai_app_code": "" | |||||
} |
@@ -0,0 +1,9 @@ | |||||
{ | |||||
"debug": false, | |||||
"redis_host":"47.116.142.20", | |||||
"redis_port":8090, | |||||
"redis_password":"telpo#1234", | |||||
"redis_db":3, | |||||
"kafka_bootstrap_servers":"172.19.42.53:9092", | |||||
"aiops_api":"https://id.ssjlai.com/aiopsadmin" | |||||
} |
@@ -0,0 +1,9 @@ | |||||
{ | |||||
"debug": false, | |||||
"redis_host":"192.168.2.121", | |||||
"redis_port":8090, | |||||
"redis_password":"telpo#1234", | |||||
"redis_db":3, | |||||
"kafka_bootstrap_servers":"192.168.2.121:9092", | |||||
"aiops_api":"https://id.ssjlai.com/aiopsadmin" | |||||
} |
@@ -0,0 +1,179 @@ | |||||
import json | |||||
import logging | |||||
import os | |||||
import pickle | |||||
import copy | |||||
from common.log import logger | |||||
# from common.log import logger | |||||
# 示例配置文件 | |||||
DEBUG = True | |||||
available_setting = { | |||||
"qwen_access_key_id": "", | |||||
"qwen_access_key_secret": "", | |||||
"debug": False, | |||||
#redis 配置 | |||||
"redis_host":"", | |||||
"redis_port":0, | |||||
"redis_password":"", | |||||
"redis_db":0, | |||||
# kafka配置 | |||||
"kafka_bootstrap_servers":"", | |||||
# aiops平台 | |||||
"aiops_api":"" | |||||
} | |||||
class Config(dict): | |||||
def __init__(self, d=None): | |||||
super().__init__() | |||||
if d is None: | |||||
d = {} | |||||
for k, v in d.items(): | |||||
self[k] = v | |||||
# user_datas: 用户数据,key为用户名,value为用户数据,也是dict | |||||
self.user_datas = {} | |||||
def __getitem__(self, key): | |||||
if key not in available_setting: | |||||
raise Exception("key {} not in available_setting".format(key)) | |||||
return super().__getitem__(key) | |||||
def __setitem__(self, key, value): | |||||
if key not in available_setting: | |||||
raise Exception("key {} not in available_setting".format(key)) | |||||
return super().__setitem__(key, value) | |||||
def get(self, key, default=None): | |||||
try: | |||||
return self[key] | |||||
except KeyError as e: | |||||
return default | |||||
except Exception as e: | |||||
raise e | |||||
config = Config() | |||||
def drag_sensitive(config): | |||||
try: | |||||
if isinstance(config, str): | |||||
conf_dict: dict = json.loads(config) | |||||
conf_dict_copy = copy.deepcopy(conf_dict) | |||||
for key in conf_dict_copy: | |||||
if "key" in key or "secret" in key: | |||||
if isinstance(conf_dict_copy[key], str): | |||||
conf_dict_copy[key] = conf_dict_copy[key][0:3] + "*" * 5 + conf_dict_copy[key][-3:] | |||||
return json.dumps(conf_dict_copy, indent=4) | |||||
elif isinstance(config, dict): | |||||
config_copy = copy.deepcopy(config) | |||||
for key in config: | |||||
if "key" in key or "secret" in key: | |||||
if isinstance(config_copy[key], str): | |||||
config_copy[key] = config_copy[key][0:3] + "*" * 5 + config_copy[key][-3:] | |||||
return config_copy | |||||
except Exception as e: | |||||
logger.exception(e) | |||||
return config | |||||
return config | |||||
def load_config(): | |||||
global config | |||||
# config_path = "./config.json" | |||||
# if not os.path.exists(config_path): | |||||
# logger.info("配置文件不存在,将使用config-template.json模板") | |||||
# config_path = "./config-template.json" | |||||
# 默认加载 config.json 或者 config-template.json | |||||
environment = os.environ.get('environment', 'default') # 默认是生产环境 | |||||
logger.info(f"当前环境: {environment}") | |||||
if environment == "test": | |||||
config_path = "./config-test.json" | |||||
elif environment == "production": | |||||
config_path = "./config-production.json" | |||||
elif environment == "dev": | |||||
config_path = "./config-dev.json" | |||||
elif environment == "default": | |||||
config_path = "./config.json" | |||||
else: | |||||
logger.error("无效的环境配置,使用默认的 config-template.json") | |||||
config_path = "./config-template.json" | |||||
# 加载配置文件 | |||||
if not os.path.exists(config_path): | |||||
logger.info(f"配置文件 {config_path} 不存在,将使用 config-template.json 模板") | |||||
config_path = "./config-template.json" | |||||
config_str = read_file(config_path) | |||||
logger.debug("[INIT] config str: {}".format(drag_sensitive(config_str))) | |||||
# 将json字符串反序列化为dict类型 | |||||
config = Config(json.loads(config_str)) | |||||
# override config with environment variables. | |||||
# Some online deployment platforms (e.g. Railway) deploy project from github directly. So you shouldn't put your secrets like api key in a config file, instead use environment variables to override the default config. | |||||
for name, value in os.environ.items(): | |||||
name = name.lower() | |||||
if name in available_setting: | |||||
logger.info("[INIT] override config by environ args: {}={}".format(name, value)) | |||||
try: | |||||
config[name] = eval(value) | |||||
except: | |||||
if value == "false": | |||||
config[name] = False | |||||
elif value == "true": | |||||
config[name] = True | |||||
else: | |||||
config[name] = value | |||||
if config.get("debug", False): | |||||
logger.setLevel(logging.DEBUG) | |||||
logger.debug("[INIT] set log level to DEBUG") | |||||
logger.info("[INIT] load config: {}".format(drag_sensitive(config))) | |||||
def get_root(): | |||||
return os.path.dirname(os.path.abspath(__file__)) | |||||
def read_file(path): | |||||
with open(path, mode="r", encoding="utf-8") as f: | |||||
return f.read() | |||||
def conf(): | |||||
return config | |||||
# # global plugin config | |||||
# plugin_config = {} | |||||
# def write_plugin_config(pconf: dict): | |||||
# """ | |||||
# 写入插件全局配置 | |||||
# :param pconf: 全量插件配置 | |||||
# """ | |||||
# global plugin_config | |||||
# for k in pconf: | |||||
# plugin_config[k.lower()] = pconf[k] | |||||
# def pconf(plugin_name: str) -> dict: | |||||
# """ | |||||
# 根据插件名称获取配置 | |||||
# :param plugin_name: 插件名称 | |||||
# :return: 该插件的配置项 | |||||
# """ | |||||
# return plugin_config.get(plugin_name.lower()) | |||||
# # 全局配置,用于存放全局生效的状态 | |||||
# global_config = {"admin_users": []} |
@@ -0,0 +1,42 @@ | |||||
FROM python:3.10-slim-bullseye | |||||
LABEL maintainer="foo@bar.com" | |||||
ARG TZ='Asia/Shanghai' | |||||
# RUN echo /etc/apt/sources.list | |||||
# RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list | |||||
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list | |||||
# Set the timezone and configure tzdata | |||||
RUN apt-get update \ | |||||
&& apt-get install -y --no-install-recommends tzdata \ | |||||
&& ln -sf /usr/share/zoneinfo/$TZ /etc/localtime \ | |||||
&& dpkg-reconfigure --frontend noninteractive tzdata \ | |||||
&& apt-get clean | |||||
ENV BUILD_PREFIX=/app | |||||
ADD . ${BUILD_PREFIX} | |||||
RUN apt-get update \ | |||||
&&apt-get install -y --no-install-recommends bash ffmpeg espeak libavcodec-extra\ | |||||
&& cd ${BUILD_PREFIX} \ | |||||
&& cp config-template.json config.json \ | |||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \ | |||||
&& pip install --no-cache -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ | |||||
WORKDIR ${BUILD_PREFIX} | |||||
ADD docker/entrypoint.sh /entrypoint.sh | |||||
RUN chmod +x /entrypoint.sh \ | |||||
&& mkdir -p /home/noroot \ | |||||
&& groupadd -r noroot \ | |||||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \ | |||||
&& chown -R noroot:noroot /home/noroot ${BUILD_PREFIX} /usr/local/lib | |||||
USER noroot | |||||
ENTRYPOINT ["/entrypoint.sh"] |
@@ -0,0 +1,8 @@ | |||||
#!/bin/bash | |||||
unset KUBECONFIG | |||||
cd .. && docker build -f docker/Dockerfile.latest \ | |||||
-t 139.224.254.18:5000/ssjl/ai-ops-wechat . | |||||
docker tag 139.224.254.18:5000/ssjl/ai-ops-wechat 139.224.254.18:5000/ssjl/ai-ops-wechat:$(date +%y%m%d) |
@@ -0,0 +1,45 @@ | |||||
#!/bin/bash | |||||
set -e | |||||
# build prefix | |||||
AI_OPS_WECHAT_PREFIX=${AI_OPS_WECHAT_PREFIX:-""} | |||||
# path to config.json | |||||
AI_OPS_WECHAT_CONFIG_PATH=${AI_OPS_WECHAT_CONFIG_PATH:-""} | |||||
# execution command line | |||||
AI_OPS_WECHAT_EXEC=${AI_OPS_WECHAT_EXEC:-""} | |||||
# Determine the environment and set the config file accordingly | |||||
if [ "$environment" == "test" ]; then | |||||
AI_OPS_WECHAT_CONFIG_PATH=${AI_OPS_WECHAT_CONFIG_PATH:-$AI_OPS_WECHAT_PREFIX/config-test.json} | |||||
elif [ "$environment" == "production" ]; then | |||||
AI_OPS_WECHAT_CONFIG_PATH=${AI_OPS_WECHAT_CONFIG_PATH:-$AI_OPS_WECHAT_PREFIX/config-production.json} | |||||
elif [ "$environment" == "dev" ]; then | |||||
AI_OPS_WECHAT_CONFIG_PATH=${AI_OPS_WECHAT_CONFIG_PATH:-$AI_OPS_WECHAT_PREFIX/config-dev.json} | |||||
else | |||||
echo "Invalid environment specified. Please set environment to 'test' or 'prod' or 'dev'." | |||||
exit 1 | |||||
fi | |||||
# AI_OPS_WECHAT_PREFIX is empty, use /app | |||||
if [ "$AI_OPS_WECHAT_PREFIX" == "" ]; then | |||||
AI_OPS_WECHAT_PREFIX=/app | |||||
fi | |||||
# AI_OPS_WECHAT_EXEC is empty, use ‘python app.py’ | |||||
if [ "$AI_OPS_WECHAT_EXEC" == "" ]; then | |||||
AI_OPS_WECHAT_EXEC="python run.py" | |||||
fi | |||||
# go to prefix dir | |||||
cd $AI_OPS_WECHAT_PREFIX | |||||
# # execute | |||||
# $AI_OPS_WECHAT_EXEC | |||||
if [ "$environment" == "default" ]; then | |||||
$AI_OPS_WECHAT_EXEC | |||||
else | |||||
#uvicorn app.main:app --host 0.0.0.0 --port 5000 | |||||
#$AI_OPS_WECHAT_EXEC | |||||
$AI_OPS_WECHAT_EXEC | |||||
fi |
@@ -0,0 +1,171 @@ | |||||
from pydantic import BaseModel, ValidationError | |||||
from dataclasses import dataclass, asdict | |||||
from typing import List | |||||
from enum import Enum, unique | |||||
from fastapi import HTTPException | |||||
from functools import wraps | |||||
from fastapi import Request | |||||
import time | |||||
@dataclass | |||||
class AgentConfig(BaseModel): | |||||
chatroomIdWhiteList: List[str] = [] | |||||
agentTokenId: str | |||||
agentEnabled: bool | |||||
addContactsFromChatroomIdWhiteList: List[str] = [] | |||||
chatWaitingMsgEnabled: bool | |||||
@dataclass | |||||
class AddGroupContactsHistory(BaseModel): | |||||
chatroomId:str | |||||
wxid:str | |||||
contactWixd:str | |||||
addTime:int | |||||
@unique | |||||
class OperationType(Enum): | |||||
ADD_FRIEND = 2 | |||||
ACCEPT_FRIEND = 3 | |||||
REJECT_FRIEND = 4 | |||||
def validate_wxid(func): | |||||
@wraps(func) | |||||
async def wrapper(request: Request, *args, **kwargs): | |||||
# 从 kwargs 中获取 wxid,如果不存在,则从请求体中获取 | |||||
wxid = kwargs.get("wxid") | |||||
if wxid is None: | |||||
# 异步获取请求体 | |||||
body = await request.json() | |||||
wxid = body.get("wxid") | |||||
# 如果 wxid 仍然为空,返回错误 | |||||
if not wxid: | |||||
return {"code": 400, "message": "wxid 不能为空"} | |||||
# 验证 wxid 是否存在 | |||||
k, loginfo = await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
if not k: | |||||
return {"code": 404, "message": f"{wxid} 没有对应的登录信息"} | |||||
login_status=loginfo.get('status','0') | |||||
if login_status != '1': | |||||
return {"code": 401, "message": f"{wxid} 已经离线"} | |||||
# 将 k 和 loginfo 注入到路由函数的参数中 | |||||
# kwargs["loginfo_key"] = k | |||||
# kwargs["loginfo"] = loginfo | |||||
# 如果验证通过,继续执行原始函数 | |||||
return await func(request, *args, **kwargs) | |||||
return wrapper | |||||
def auth_required_time(f): | |||||
@wraps(f) | |||||
async def decorated_function(request: Request, *args, **kwargs): | |||||
try: | |||||
body = await request.json() | |||||
wxid = body.get("wxid") | |||||
if not wxid: | |||||
return {"code": 400, "message": "wxid 不能为空"} | |||||
# 模拟获取登录信息 | |||||
loginfo = {"status": "1", "create_at": time.time() - 2 * 24 * 60 * 60, "tokenId": "token123", "appId": "app123"} | |||||
login_status = loginfo.get('status', '0') | |||||
if login_status != '1': | |||||
return {"code": 401, "message": f"{wxid} 已经离线"} | |||||
creation_timestamp = int(loginfo.get('create_at', time.time())) | |||||
current_timestamp = time.time() | |||||
three_days_seconds = 3 * 24 * 60 * 60 # 三天的秒数 | |||||
diff_flag = (current_timestamp - creation_timestamp) >= three_days_seconds | |||||
if not diff_flag: | |||||
return {'code': 401, 'message': '用户创建不够三天,不能使用该功能'} | |||||
kwargs['token_id'] = loginfo.get('tokenId') | |||||
kwargs['app_id'] = loginfo.get('appId') | |||||
return await f(*args, **kwargs) | |||||
except ValidationError as e: | |||||
raise HTTPException(status_code=422, detail=str(e)) | |||||
return decorated_function | |||||
# def auth_required_time(f): | |||||
# @wraps(f) | |||||
# async def decorated_function(request: Request, *args, **kwargs): | |||||
# try: | |||||
# body = await request.json() | |||||
# print("Received body:", body) # 打印请求体 | |||||
# wxid = body.get("wxid") | |||||
# if not wxid: | |||||
# return {"code": 400, "message": "wxid 不能为空"} | |||||
# # 模拟获取登录信息 | |||||
# loginfo = {"status": "1", "create_at": time.time() - 2 * 24 * 60 * 60, "tokenId": "token123", "appId": "app123"} | |||||
# login_status = loginfo.get('status', '0') | |||||
# if login_status != '1': | |||||
# return {"code": 401, "message": f"{wxid} 已经离线"} | |||||
# creation_timestamp = int(loginfo.get('create_at', time.time())) | |||||
# current_timestamp = time.time() | |||||
# three_days_seconds = 3 * 24 * 60 * 60 # 三天的秒数 | |||||
# diff_flag = (current_timestamp - creation_timestamp) >= three_days_seconds | |||||
# if not diff_flag: | |||||
# return {'code': 401, 'message': '用户创建不够三天,不能使用该功能'} | |||||
# kwargs['token_id'] = loginfo.get('tokenId') | |||||
# kwargs['app_id'] = loginfo.get('appId') | |||||
# return await f(*args, **kwargs) | |||||
# except Exception as e: | |||||
# raise HTTPException(status_code=422, detail=f"请求体解析失败: {str(e)}") | |||||
# return decorated_function | |||||
from functools import wraps | |||||
import time | |||||
from fastapi import Request, HTTPException | |||||
def auth_required_time(f): | |||||
@wraps(f) | |||||
async def decorated_function(request: Request, *args, **kwargs): | |||||
try: | |||||
# 解析 JSON 只调用一次 | |||||
body = await request.json() | |||||
wxid = body.get("wxid") | |||||
if not wxid: | |||||
raise HTTPException(status_code=400, detail="wxid 不能为空") | |||||
# 调用异步方法获取登录信息 | |||||
k, loginfo = await request.app.state.gewe_service.get_login_info_by_wxid_async(wxid) | |||||
if not k: | |||||
raise HTTPException(status_code=404, detail=f"{wxid} 没有对应的登录信息") | |||||
login_status = loginfo.get('status', '0') | |||||
if login_status != '1': | |||||
raise HTTPException(status_code=401, detail=f"{wxid} 已经离线") | |||||
# 计算创建时间差 | |||||
creation_timestamp = int(loginfo.get('create_at', time.time())) | |||||
current_timestamp = time.time() | |||||
three_days_seconds = 3 * 24 * 60 * 60 # 三天的秒数 | |||||
if (current_timestamp - creation_timestamp) < three_days_seconds: | |||||
raise HTTPException(status_code=401, detail="用户创建不够三天,不能使用该功能") | |||||
# 注入 token_id 和 app_id | |||||
kwargs['token_id'] = loginfo.get('tokenId') | |||||
kwargs['app_id'] = loginfo.get('appId') | |||||
# 需要 `await` 调用被装饰的异步函数 | |||||
return await f(request, *args, **kwargs) | |||||
except HTTPException as e: | |||||
return {"code": e.status_code, "message": e.detail} | |||||
return decorated_function |
@@ -0,0 +1,58 @@ | |||||
#voice | |||||
pydub>=0.25.1 # need ffmpeg | |||||
SpeechRecognition # google speech to text | |||||
gTTS>=2.3.1 # google text to speech | |||||
pyttsx3>=2.90 # pytsx text to speech | |||||
baidu_aip>=4.16.10 # baidu voice | |||||
azure-cognitiveservices-speech # azure voice | |||||
edge-tts # edge-tts | |||||
numpy<=1.24.2 | |||||
langid # language detect | |||||
elevenlabs==1.0.3 # elevenlabs TTS | |||||
tiktoken>=0.3.2 # openai calculate token | |||||
openai==0.27.8 | |||||
HTMLParser>=0.0.2 | |||||
PyQRCode>=1.2.1 | |||||
qrcode>=7.4.2 | |||||
requests>=2.28.2 | |||||
chardet>=5.1.0 | |||||
Pillow | |||||
pre-commit | |||||
web.py | |||||
linkai>=0.0.6.0 | |||||
pypng | |||||
pypinyin | |||||
redis | |||||
flask | |||||
flask_restful | |||||
confluent_kafka | |||||
av | |||||
#pilk | |||||
# silk-python | |||||
# pysilk | |||||
pysilk-mod | |||||
#pip3 install pysilk-mod | |||||
oss2 | |||||
gunicorn | |||||
opencv-python | |||||
moviepy | |||||
fastapi | |||||
uvicorn | |||||
celery | |||||
pydantic | |||||
aioredis>=2.0.0 | |||||
requests | |||||
aiokafka | |||||
aiofiles |
@@ -0,0 +1,27 @@ | |||||
import subprocess | |||||
import sys | |||||
import os | |||||
def start_fastapi(): | |||||
environment = os.environ.get('environment', 'default') | |||||
if environment == 'default': | |||||
process = subprocess.Popen(["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]) | |||||
else: | |||||
process = subprocess.Popen(["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "5000"]) | |||||
return process | |||||
def start_celery(): | |||||
if sys.platform == "win32": | |||||
process = subprocess.Popen(["celery", "-A", "app.celery_app", "worker", "--loglevel=info", "-P", "solo"]) | |||||
else: | |||||
process = subprocess.Popen(["celery", "-A", "app.celery_app", "worker", "--loglevel=info"]) | |||||
return process | |||||
if __name__ == "__main__": | |||||
# 启动 FastAPI 和 Celery | |||||
fastapi_process = start_fastapi() | |||||
# celery_process = start_celery() # 如果需要启动 Celery,取消注释 | |||||
# 等待子进程完成 | |||||
fastapi_process.wait() | |||||
# celery_process.wait() # 如果需要等待 Celery,取消注释 |
@@ -0,0 +1,263 @@ | |||||
import aiohttp | |||||
import asyncio | |||||
import json | |||||
import base64 | |||||
import io | |||||
import json | |||||
import os | |||||
import threading | |||||
import time | |||||
import uuid,random | |||||
from fastapi import FastAPI, Depends | |||||
from common.log import logger | |||||
from common.singleton import singleton | |||||
from services.kafka_service import KafkaService | |||||
from fastapi import Request | |||||
from common.utils import * | |||||
@singleton | |||||
class BizService(): | |||||
def __init__(self,app:FastAPI): | |||||
if not hasattr(self, 'initialized'): | |||||
#self.kafka_service =kafka_service # 获取 KafkaService 单例 | |||||
self.kafka_service =app.state.kafka_service | |||||
self.wxchat=app.state.gewe_service | |||||
self.redis_service=app.state.redis_service | |||||
self.initialized = True | |||||
def setup_handlers(self): | |||||
"""设置默认的消息处理器""" | |||||
# 这里可以添加业务逻辑 | |||||
# 注册默认处理器 | |||||
self.kafka_service.add_handler( | |||||
self.kafka_service.consumer_topic, | |||||
self.ops_messages_process_handler | |||||
) | |||||
async def ops_messages_process_handler(self, message: str): | |||||
"""消息处理器""" | |||||
#print(f"BizService handling message: {message}") | |||||
try: | |||||
msg_content = message | |||||
cleaned_content = clean_json_string(msg_content) | |||||
content = json.loads(cleaned_content) | |||||
data = content.get("data", {}) | |||||
msg_type_data = data.get("msg_type", None) | |||||
content_data = data.get("content", {}) | |||||
if msg_type_data=="login": | |||||
await self.login_handler_async(content_data) | |||||
elif msg_type_data == 'group-sending': | |||||
print(f'处理消息类型group-sending') | |||||
await self.group_sending_handler_async(content_data) | |||||
elif msg_type_data == 'login_wx_captch_code': | |||||
pass | |||||
else: | |||||
print(f'kakfa 未处理息类型 {msg_type_data}') | |||||
except Exception as e: | |||||
print(f"处理消息时发生错误: {e}, 消息内容: {message}") | |||||
async def login_handler_async(self, content_data: dict): | |||||
tel=content_data.get('tel', '18733438393') | |||||
token_id=content_data.get('token_id', 'c50b7d57-2efa-4a53-8c11-104a06d1e1fa') | |||||
region_id=content_data.get('region_id', '440000') | |||||
agent_token_id=content_data.get('agent_token_id', 'sk-fAOIdANeGXjWKW5mFybnsNZZGYU2lFLmqVY9rVFaFmjiOaWt3tcWMi') | |||||
loginfo= await self.wxchat.get_login_info_from_cache_async(tel,token_id,region_id,agent_token_id) | |||||
print(loginfo) | |||||
status=loginfo.get('status','0') | |||||
if status=='1': | |||||
logger.info(f'手机号{tel},wx_token{token_id} 已经微信登录,终止登录流程') | |||||
return | |||||
async def group_sending_handler_async(self,content_data: dict): | |||||
agent_tel=content_data.get('agent_tel', '18733438393') | |||||
hash_key = f"__AI_OPS_WX__:LOGININFO:{agent_tel}" | |||||
logininfo = await self.redis_service.get_hash(hash_key) | |||||
if not logininfo: | |||||
logger.warning(f"未找到 {agent_tel} 的登录信息") | |||||
return | |||||
token_id = logininfo.get('tokenId') | |||||
app_id = logininfo.get('appId') | |||||
agent_wxid = logininfo.get('wxid') | |||||
# 获取联系人列表并计算交集 | |||||
hash_key = f"__AI_OPS_WX__:CONTACTS_BRIEF:{agent_wxid}" | |||||
cache_friend_wxids_str=await self.redis_service.get_hash_field(hash_key,"data") | |||||
cache_friend_wxids_list=json.loads(cache_friend_wxids_str) if cache_friend_wxids_str else [] | |||||
cache_friend_wxids=[f["userName"] for f in cache_friend_wxids_list] | |||||
# 获取群交集 | |||||
hash_key = f"__AI_OPS_WX__:GROUPS_INFO:{agent_wxid}" | |||||
cache_chatrooms = await self.redis_service.get_hash(hash_key) | |||||
cache_chatroom_ids=cache_chatrooms.keys() | |||||
wxid_contact_list_content_data = [c['wxid'] for c in content_data.get("contact_list", [])] | |||||
intersection_friend_wxids = list(set(cache_friend_wxids) & set(wxid_contact_list_content_data)) | |||||
intersection_chatroom_ids = list(set(cache_chatroom_ids) & set(wxid_contact_list_content_data)) | |||||
intersection_wxids=intersection_friend_wxids+intersection_chatroom_ids | |||||
# 发送消息 | |||||
wx_content_list = content_data.get("wx_content", []) | |||||
self.wxchat.forward_video_aeskey = '' | |||||
self.wxchat.forward_video_cdnvideourl = '' | |||||
self.wxchat.forward_video_length = 0 | |||||
for intersection_wxid in intersection_wxids: | |||||
for wx_content in wx_content_list: | |||||
if wx_content["type"] == "text": | |||||
await self.send_text_message_async(token_id, app_id, agent_wxid, [intersection_wxid], wx_content["text"]) | |||||
elif wx_content["type"] == "image_url": | |||||
await self.send_image_messagae_sync(token_id, app_id, agent_wxid, [intersection_wxid], wx_content.get("image_url", {}).get("url")) | |||||
elif wx_content["type"] == "tts": | |||||
await self.send_tts_message(token_id, app_id, agent_wxid, [intersection_wxid], wx_content["text"]) | |||||
elif wx_content["type"] == "file": | |||||
await self.send_file_message(token_id, app_id, agent_wxid, [intersection_wxid], wx_content.get("file_url", {}).get("url")) | |||||
async def send_text_message_async(self, token_id, app_id, agent_wxid, intersection_wxids, text): | |||||
for t in intersection_wxids: | |||||
# 发送文本消息 | |||||
ret,ret_msg,res = await self.wxchat.post_text(token_id, app_id, t, text) | |||||
logger.info(f'{agent_wxid} 向 {t} 发送文字【{text}】') | |||||
# 构造对话消息并发送到 Kafka | |||||
input_wx_content_dialogue_message = [{"type": "text", "text": text}] | |||||
input_message = dialogue_message(agent_wxid, t, input_wx_content_dialogue_message) | |||||
await self.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s", input_message) | |||||
# 等待随机时间 | |||||
await asyncio.sleep(random.uniform(5, 15)) | |||||
async def send_image_messagae_sync(self,token_id, app_id, agent_wxid, intersection_wxids, image_url): | |||||
aeskey, cdnthumburl, cdnthumblength, cdnthumbheight, cdnthumbwidth, length, md5 = "", "", 0, 0, 0, 0, "" | |||||
for t in intersection_wxids: | |||||
if t == intersection_wxids[0]: | |||||
# 发送图片 | |||||
ret,ret_msg,res = await self.wxchat.post_image(token_id, app_id, t, image_url) | |||||
if ret==200: | |||||
aeskey = res["aesKey"] | |||||
cdnthumburl = res["fileId"] | |||||
cdnthumblength = res["cdnThumbLength"] | |||||
cdnthumbheight = res["height"] | |||||
cdnthumbwidth = res["width"] | |||||
length = res["length"] | |||||
md5 = res["md5"] | |||||
logger.info(f'{agent_wxid} 向 {t} 发送图片【{image_url}】{ret_msg}') | |||||
else: | |||||
logger.warning(f'{agent_wxid} 向 {t} 发送图片【{image_url}】{ret_msg}') | |||||
else: | |||||
if aeskey !="": | |||||
# 转发图片 | |||||
ret,ret_msg,res = await self.wxchat.forward_image(token_id, app_id, t, aeskey, cdnthumburl, cdnthumblength, cdnthumbheight, cdnthumbwidth, length, md5) | |||||
logger.info(f'{agent_wxid} 向 {t} 转发图片【{image_url}】{ret_msg}') | |||||
else: | |||||
# 发送图片 | |||||
ret,ret_msg,res = await self.wxchat.post_image(token_id, app_id, t, image_url) | |||||
if ret==200: | |||||
aeskey = res["aesKey"] | |||||
cdnthumburl = res["fileId"] | |||||
cdnthumblength = res["cdnThumbLength"] | |||||
cdnthumbheight = res["height"] | |||||
cdnthumbwidth = res["width"] | |||||
length = res["length"] | |||||
md5 = res["md5"] | |||||
logger.info(f'{agent_wxid} 向 {t} 发送图片【{image_url}】{ret_msg}') | |||||
else: | |||||
logger.warning(f'{agent_wxid} 向 {t} 发送图片【{image_url}】{ret_msg}') | |||||
# 构造对话消息并发送到 Kafka | |||||
wx_content_dialogue_message = [{"type": "image_url", "image_url": {"url": image_url}}] | |||||
input_message = dialogue_message(agent_wxid, t, wx_content_dialogue_message) | |||||
await self.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s", input_message) | |||||
# 等待随机时间 | |||||
await asyncio.sleep(random.uniform(5, 15)) | |||||
async def send_tts_message(self, token_id, app_id, agent_wxid, intersection_wxids, text): | |||||
voice_during,voice_url=wx_voice(text) | |||||
for t in intersection_wxids: | |||||
# 发送送语音消息 | |||||
if voice_url: | |||||
ret,ret_msg,res = await self.wxchat.post_voice(token_id, app_id, t, voice_url,voice_during) | |||||
if ret==200: | |||||
logger.info(f'{agent_wxid} 向 {t} 发送语音文本【{text}】{ret_msg}') | |||||
# 构造对话消息并发送到 Kafka | |||||
input_wx_content_dialogue_message = [{"type": "text", "text": text}] | |||||
input_message = dialogue_message(agent_wxid, t, input_wx_content_dialogue_message) | |||||
await self.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s", input_message) | |||||
else: | |||||
logger.warning((f'{agent_wxid} 向 {t} 发送语音文本【{text}】{ret_msg}')) | |||||
else: | |||||
logger.warning((f'{agent_wxid} 向 {t} 发送语音文本【{text}】出错')) | |||||
# 等待随机时间 | |||||
await asyncio.sleep(random.uniform(5, 15)) | |||||
async def send_file_message(self,token_id, app_id, agent_wxid, intersection_wxids, file_url): | |||||
parsed_url = urlparse(file_url) | |||||
path = parsed_url.path | |||||
# 从路径中提取文件名 | |||||
filename = path.split('/')[-1] | |||||
# 获取扩展名 | |||||
_, ext = os.path.splitext(filename) | |||||
if ext == '.mp4': | |||||
await self.send_video_message(self, token_id, app_id, agent_wxid, intersection_wxids, file_url) | |||||
else: | |||||
await self.send_other_file_message(self, token_id, app_id, agent_wxid, intersection_wxids, file_url) | |||||
#time.sleep(random.uniform(5, 15)) | |||||
async def send_video_message(self, token_id, app_id, agent_wxid, intersection_wxids, file_url): | |||||
for t in intersection_wxids: | |||||
# 发送视频消息 | |||||
parsed_url = urlparse(file_url) | |||||
filename = os.path.basename(parsed_url.path) | |||||
tmp_file_path = os.path.join(os.getcwd(),'tmp', filename) # 拼接完整路径 | |||||
thumbnail_path=tmp_file_path.replace('.mp4','.jpg') | |||||
if self.wxchat.forward_video_aeskey == '': | |||||
video_thumb_url,video_duration =download_video_and_get_thumbnail(file_url,thumbnail_path) | |||||
print(f'视频缩略图 {video_thumb_url} 时长 {video_duration}') | |||||
ret,ret_msg,res = await self.wxchat.post_video(token_id, app_id, t, file_url,video_thumb_url,video_duration) | |||||
if ret==200: | |||||
self.wxchat.forward_video_aeskey = res["aesKey"] | |||||
self.wxchat.forward_video_cdnvideourl = res["cdnThumbUrl"] | |||||
self.wxchat.forward_video_length = res["length"] | |||||
else: | |||||
ret,ret_msg,res = await self.wxchat.forward_video(token_id, app_id, t, self.wxchat.forward_video_aeskey, self.wxchat.forward_video_cdnvideourl, self.wxchat.forward_video_length) | |||||
print('转发视频') | |||||
if ret==200: | |||||
logger.info(f'{agent_wxid} 向 {t} 发送视频【{file_url}】{ret_msg}') | |||||
# 构造对话消息并发送到 Kafka | |||||
input_wx_content_dialogue_message = [{"type": "file", "file_url": {"url": file_url}}] | |||||
input_message = dialogue_message(agent_wxid, t, input_wx_content_dialogue_message) | |||||
await self.kafka_service.send_message_async(input_message) | |||||
logger.info("发送对话 %s", input_message) | |||||
else: | |||||
logger.warning((f'{agent_wxid} 向 {t} 发送视频【{file_url}】{ret_msg}')) | |||||
# 等待随机时间 | |||||
await asyncio.sleep(random.uniform(5, 15)) | |||||
async def send_other_file_message(self, token_id, app_id, agent_wxid, intersection_wxids, file_url): | |||||
print('send_otherfile_message') |
@@ -0,0 +1,166 @@ | |||||
import asyncio | |||||
from typing import Dict, Callable, Optional | |||||
from aiokafka import AIOKafkaProducer, AIOKafkaConsumer | |||||
from aiokafka.errors import KafkaError | |||||
from fastapi import FastAPI | |||||
import json | |||||
from common.log import logger | |||||
class KafkaService: | |||||
_instance = None | |||||
def __new__(cls, *args, **kwargs): | |||||
if not cls._instance: | |||||
cls._instance = super().__new__(cls) | |||||
return cls._instance | |||||
def __init__( | |||||
self, | |||||
bootstrap_servers: str = "localhost:9092", | |||||
producer_topic: str = "default_topic", | |||||
consumer_topic: str = "default_topic", | |||||
group_id: str = "fastapi-group" | |||||
): | |||||
if not hasattr(self, 'initialized'): | |||||
self.bootstrap_servers = bootstrap_servers | |||||
self.producer_topic = producer_topic | |||||
self.consumer_topic = consumer_topic | |||||
self.group_id = group_id | |||||
self.producer: Optional[AIOKafkaProducer] = None | |||||
self.consumer: Optional[AIOKafkaConsumer] = None | |||||
self.consumer_task: Optional[asyncio.Task] = None | |||||
self.message_handlers: Dict[str, Callable] = {} | |||||
self.initialized = True | |||||
async def connect_producer(self): | |||||
"""Initialize Kafka producer""" | |||||
try: | |||||
self.producer = AIOKafkaProducer( | |||||
bootstrap_servers=self.bootstrap_servers, | |||||
compression_type="gzip" | |||||
) | |||||
await self.producer.start() | |||||
except KafkaError as e: | |||||
print(f"Producer connection failed: {e}") | |||||
raise | |||||
async def connect_consumer(self): | |||||
"""Initialize Kafka consumer""" | |||||
try: | |||||
self.consumer = AIOKafkaConsumer( | |||||
self.consumer_topic, | |||||
bootstrap_servers=self.bootstrap_servers, | |||||
group_id=self.group_id, | |||||
auto_offset_reset="earliest", | |||||
session_timeout_ms=30000, # 增加会话超时时间 | |||||
heartbeat_interval_ms=10000 # 增加心跳间隔时间 | |||||
) | |||||
await self.consumer.start() | |||||
except KafkaError as e: | |||||
print(f"Consumer connection failed: {e}") | |||||
raise | |||||
async def send_message_async(self, message: str, topic: str = None): | |||||
"""Send message to Kafka topic""" | |||||
if not self.producer: | |||||
raise RuntimeError("Producer not initialized") | |||||
target_topic = topic or self.producer_topic | |||||
print(f'生产者topic:{target_topic}') | |||||
logger.info(f"生产者topic:{target_topic}\n生产者消息:{json.dumps(json.loads(message), separators=(',', ':'), default=str, ensure_ascii=False)}") | |||||
try: | |||||
await self.producer.send_and_wait( | |||||
target_topic, | |||||
message.encode('utf-8') | |||||
) | |||||
except KafkaError as e: | |||||
print(f"Error sending message: {e}") | |||||
raise | |||||
# async def consume_messages(self): | |||||
# """Start consuming messages from Kafka""" | |||||
# if not self.consumer: | |||||
# raise RuntimeError("Consumer not initialized") | |||||
# try: | |||||
# async for msg in self.consumer: | |||||
# #print(f"Received message: {msg.value.decode()}") | |||||
# logger.info(f"接收到kafka消息: {json.dumps(json.loads(msg.value.decode()), ensure_ascii=False)}") | |||||
# topic = msg.topic | |||||
# if topic in self.message_handlers: | |||||
# handler = self.message_handlers[topic] | |||||
# await handler(msg.value.decode()) | |||||
# except Exception as e: | |||||
# print(f"Consuming error: {e}") | |||||
# raise | |||||
# finally: | |||||
# await self.consumer.stop() | |||||
# async def consume_messages(self): | |||||
# """Start consuming messages from Kafka""" | |||||
# if not self.consumer: | |||||
# raise RuntimeError("Consumer not initialized") | |||||
# try: | |||||
# async for msg in self.consumer: | |||||
# try: | |||||
# logger.info(f"接收到kafka消息: {json.dumps(json.loads(msg.value.decode()), ensure_ascii=False)}") | |||||
# topic = msg.topic | |||||
# if topic in self.message_handlers: | |||||
# handler = self.message_handlers[topic] | |||||
# await handler(msg.value.decode()) | |||||
# else: | |||||
# logger.warning(f"未处理消息类型: {topic}") | |||||
# except Exception as e: | |||||
# logger.error(f"处理消息失败: {e}") | |||||
# except Exception as e: | |||||
# logger.error(f"消费消息异常: {e}") | |||||
# raise | |||||
# finally: | |||||
# await self.consumer.stop() | |||||
async def consume_messages(self): | |||||
"""Start consuming messages from Kafka""" | |||||
if not self.consumer: | |||||
raise RuntimeError("Consumer not initialized") | |||||
while True: | |||||
try: | |||||
async for msg in self.consumer: | |||||
try: | |||||
logger.info(f"接收到kafka消息: {json.dumps(json.loads(msg.value.decode()), ensure_ascii=False)}") | |||||
topic = msg.topic | |||||
if topic in self.message_handlers: | |||||
handler = self.message_handlers[topic] | |||||
await handler(msg.value.decode()) | |||||
else: | |||||
logger.warning(f"未处理消息类型: {topic}") | |||||
except Exception as e: | |||||
logger.error(f"处理消息失败: {e}") | |||||
except Exception as e: | |||||
logger.error(f"消费消息异常: {e}") | |||||
await asyncio.sleep(5) # 等待一段时间后重试 | |||||
def add_handler(self, topic: str, handler: Callable): | |||||
"""Add message handler for specific topic""" | |||||
self.message_handlers[topic] = handler | |||||
async def start(self): | |||||
"""Start both producer and consumer""" | |||||
await self.connect_producer() | |||||
await self.connect_consumer() | |||||
self.consumer_task = asyncio.create_task(self.consume_messages()) | |||||
async def stop(self): | |||||
"""Graceful shutdown""" | |||||
if self.producer: | |||||
await self.producer.stop() | |||||
if self.consumer: | |||||
await self.consumer.stop() | |||||
if self.consumer_task: | |||||
self.consumer_task.cancel() | |||||
try: | |||||
await self.consumer_task | |||||
except asyncio.CancelledError: | |||||
pass |
@@ -0,0 +1,198 @@ | |||||
import aioredis | |||||
import os | |||||
import uuid | |||||
import asyncio | |||||
import threading | |||||
from fastapi import FastAPI, Depends | |||||
from fastapi import Request | |||||
import time | |||||
# 定义全局 redis_helper 作为单例 | |||||
class RedisService: | |||||
_instance = None | |||||
def __new__(cls, host='localhost', port=6379, password=None, db=0): | |||||
if not cls._instance: | |||||
cls._instance = super(RedisService, cls).__new__(cls) | |||||
cls._instance.client = None | |||||
cls._instance.lock_renewal_thread = None | |||||
return cls._instance | |||||
async def init(self, host='localhost', port=6379, password=None, db=0): | |||||
"""初始化 Redis 连接""" | |||||
#self.client = await aioredis.Redis(f'redis://{host}:{port}', password=password, db=db) | |||||
self.client = await aioredis.Redis(host=host, port=port, password=password, db=db) | |||||
# async def set_hash(self, hash_key, data, timeout=None): | |||||
# """添加或更新哈希,并设置有效期""" | |||||
# await self.client.hmset_dict(hash_key, data) | |||||
# if timeout: | |||||
# # 设置有效期(单位:秒) | |||||
# await self.client.expire(hash_key, timeout) | |||||
async def set_hash(self, hash_key, data, timeout=None): | |||||
"""添加或更新哈希,并设置有效期""" | |||||
# 使用 hmset 方法设置哈希表数据 | |||||
await self.client.hmset(hash_key, data) | |||||
if timeout: | |||||
# 设置有效期(单位:秒) | |||||
await self.client.expire(hash_key, timeout) | |||||
async def get_hash(self, hash_key): | |||||
"""获取整个哈希表数据""" | |||||
result = await self.client.hgetall(hash_key) | |||||
# 将字节数据解码成字符串格式返回 | |||||
return {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} | |||||
async def get_hash_field(self, hash_key, field): | |||||
"""获取哈希表中的单个字段值""" | |||||
result = await self.client.hget(hash_key, field) | |||||
return result.decode('utf-8') if result else None | |||||
async def delete_hash(self, hash_key): | |||||
"""删除整个哈希表""" | |||||
await self.client.delete(hash_key) | |||||
async def delete_hash_field(self, hash_key, field): | |||||
"""删除哈希表中的某个字段""" | |||||
await self.client.hdel(hash_key, field) | |||||
async def update_hash_field(self, hash_key, field, value): | |||||
"""更新哈希表中的某个字段""" | |||||
await self.client.hset(hash_key, field, value) | |||||
# async def acquire_lock(self, lock_name, timeout=60): | |||||
# """ | |||||
# 尝试获取分布式锁,成功返回 True,失败返回 False | |||||
# :param lock_name: 锁的名称 | |||||
# :param timeout: 锁的超时时间(秒) | |||||
# :return: bool | |||||
# """ | |||||
# # identifier = str(time.time()) # 使用时间戳作为唯一标识 | |||||
# # if await self.client.setnx(lock_name, identifier): | |||||
# # await self.client.expire(lock_name, timeout) | |||||
# # self.lock_renewal_thread = threading.Thread(target=self.renew_lock, args=(lock_name, identifier, timeout)) | |||||
# # self.lock_renewal_thread.start() | |||||
# # return True | |||||
# # return False | |||||
# identifier = str(time.time()) # 使用 UUID 作为唯一标识 | |||||
# if await self.client.setnx(lock_name, identifier): | |||||
# await self.client.expire(lock_name, timeout) | |||||
# self.lock_renewal_thread = threading.Thread( | |||||
# target=self.run_renew_lock, | |||||
# args=(lock_name, identifier, timeout) | |||||
# ) | |||||
# self.lock_renewal_thread.start() | |||||
# return True | |||||
# return False | |||||
# def run_renew_lock(self, lock_name, identifier, timeout): | |||||
# """ | |||||
# 在单独的线程中运行异步的 renew_lock 方法 | |||||
# """ | |||||
# loop = asyncio.new_event_loop() | |||||
# asyncio.set_event_loop(loop) | |||||
# loop.run_until_complete(self.renew_lock(lock_name, identifier, timeout)) | |||||
# loop.close() | |||||
async def acquire_lock(self, lock_name, timeout=60): | |||||
""" | |||||
尝试获取分布式锁,成功返回 True,失败返回 False | |||||
:param lock_name: 锁的名称 | |||||
:param timeout: 锁的超时时间(秒) | |||||
:return: bool | |||||
""" | |||||
identifier = str(uuid.uuid4()) # 使用 UUID 作为唯一标识 | |||||
if await self.client.setnx(lock_name, identifier): | |||||
await self.client.expire(lock_name, timeout) | |||||
# 启动异步任务来续期锁 | |||||
self.lock_renewal_task = asyncio.create_task( | |||||
self.renew_lock(lock_name, identifier, timeout) | |||||
) | |||||
return True | |||||
return False | |||||
async def renew_lock(self, lock_name, identifier, timeout): | |||||
""" | |||||
锁的自动续期 | |||||
:param lock_name: 锁的名称 | |||||
:param identifier: 锁的唯一标识 | |||||
:param timeout: 锁的超时时间(秒) | |||||
""" | |||||
while True: | |||||
await asyncio.sleep(timeout / 2) | |||||
if await self.client.get(lock_name) == identifier.encode(): | |||||
await self.client.expire(lock_name, timeout) | |||||
else: | |||||
break | |||||
async def release_lock(self, lock_name, identifier): | |||||
""" | |||||
释放分布式锁 | |||||
:param lock_name: 锁的名称 | |||||
:param identifier: 锁的唯一标识 | |||||
""" | |||||
if await self.client.get(lock_name) == identifier.encode(): | |||||
await self.client.delete(lock_name) | |||||
if self.lock_renewal_thread: | |||||
self.lock_renewal_thread.join() | |||||
async def enqueue(self, queue_name, item): | |||||
""" | |||||
将元素添加到队列的尾部(右侧) | |||||
:param queue_name: 队列名称 | |||||
:param item: 要添加到队列的元素 | |||||
""" | |||||
await self.client.rpush(queue_name, item) | |||||
print(f"Enqueued: {item} to queue {queue_name}") | |||||
async def dequeue(self, queue_name): | |||||
""" | |||||
从队列的头部(左侧)移除并返回元素 | |||||
:param queue_name: 队列名称 | |||||
:return: 移除的元素,如果队列为空则返回 None | |||||
""" | |||||
item = await self.client.lpop(queue_name) | |||||
if item: | |||||
print(f"Dequeued: {item.decode('utf-8')} from queue {queue_name}") | |||||
return item.decode('utf-8') | |||||
print(f"Queue {queue_name} is empty") | |||||
return None | |||||
async def get_queue_length(self, queue_name): | |||||
""" | |||||
获取队列的长度 | |||||
:param queue_name: 队列名称 | |||||
:return: 队列的长度 | |||||
""" | |||||
length = await self.client.llen(queue_name) | |||||
print(f"Queue {queue_name} length: {length}") | |||||
return length | |||||
async def peek_queue(self, queue_name): | |||||
""" | |||||
查看队列的头部元素,但不移除 | |||||
:param queue_name: 队列名称 | |||||
:return: 队列的头部元素,如果队列为空则返回 None | |||||
""" | |||||
item = await self.client.lrange(queue_name, 0, 0) | |||||
if item: | |||||
print(f"Peeked: {item[0].decode('utf-8')} from queue {queue_name}") | |||||
return item[0].decode('utf-8') | |||||
print(f"Queue {queue_name} is empty") | |||||
return None | |||||
async def clear_queue(self, queue_name): | |||||
""" | |||||
清空队列 | |||||
:param queue_name: 队列名称 | |||||
""" | |||||
await self.client.delete(queue_name) | |||||
print(f"Cleared queue {queue_name}") | |||||
# Dependency injection helper function | |||||
async def get_redis_service(request: Request) -> RedisService: | |||||
return request.app.state.redis_serive |
@@ -0,0 +1,20 @@ | |||||
#!/usr/bin/env bash | |||||
image_version=$version | |||||
# 删除镜像 | |||||
docker rmi -f $( | |||||
docker images | grep registry.cn-shanghai.aliyuncs.com/gps_card/ai-ops-wechat | awk '{print $3}' | |||||
) | |||||
# 构建telpo/mrp:$image_version镜像 | |||||
docker build -f ./docker/Dockerfile.latest . -t ssjl/ai-ops-wechat:$image_version | |||||
#TODO:推送镜像到阿里仓库 | |||||
echo '=================开始推送镜像=======================' | |||||
#docker login --username=telpo_linwl@1111649216405698 --password=telpo#1234 registry.cn-shanghai.aliyuncs.com | |||||
docker login --username=rzl_wangjx@1111649216405698 --password=telpo.123 registry.cn-shanghai.aliyuncs.com | |||||
docker tag ssjl/ai-ops-wechat:$image_version registry.cn-shanghai.aliyuncs.com/gps_card/ai-ops-wechat:$image_version | |||||
docker push registry.cn-shanghai.aliyuncs.com/gps_card/ai-ops-wechat:$image_version | |||||
echo '=================推送镜像完成=======================' | |||||
#删除产生的None镜像 | |||||
docker rmi -f $(docker images | grep none | awk '{print $3}') | |||||
# 查看镜像列表 | |||||
docker images |
@@ -0,0 +1,17 @@ | |||||
#!/usr/bin/env bash | |||||
image_version=$version | |||||
# 删除镜像 | |||||
docker rmi -f $( | |||||
docker images | grep 139.224.254.18:5000/ssjl/ai-ops-wechat | awk '{print $3}' | |||||
) | |||||
# 构建telpo/mrp:$image_version镜像 | |||||
docker build -f ./docker/Dockerfile.latest . -t ssjl/ai-ops-wechat:$image_version | |||||
#TODO:推送镜像到私有仓库 | |||||
echo '=================开始推送镜像=======================' | |||||
docker tag ssjl/ai-ops-wechat:$image_version 139.224.254.18:5000/ssjl/ai-ops-wechat:$image_version | |||||
docker push 139.224.254.18:5000/ssjl/ai-ops-wechat:$image_version | |||||
echo '=================推送镜像完成=======================' | |||||
#删除产生的None镜像 | |||||
docker rmi -f $(docker images | grep none | awk '{print $3}') | |||||
# 查看镜像列表 | |||||
docker images |
@@ -0,0 +1,216 @@ | |||||
# coding=utf-8 | |||||
""" | |||||
Author: chazzjimel | |||||
Email: chazzjimel@gmail.com | |||||
wechat:cheung-z-x | |||||
Description: | |||||
""" | |||||
import http.client | |||||
import json | |||||
import time | |||||
import requests | |||||
import datetime | |||||
import hashlib | |||||
import hmac | |||||
import base64 | |||||
import urllib.parse | |||||
import uuid | |||||
from common.log import logger | |||||
from common.tmp_dir import TmpDir | |||||
def text_to_speech_aliyun(url, text, appkey, token): | |||||
""" | |||||
使用阿里云的文本转语音服务将文本转换为语音。 | |||||
参数: | |||||
- url (str): 阿里云文本转语音服务的端点URL。 | |||||
- text (str): 要转换为语音的文本。 | |||||
- appkey (str): 您的阿里云appkey。 | |||||
- token (str): 阿里云API的认证令牌。 | |||||
返回值: | |||||
- str: 成功时输出音频文件的路径,否则为None。 | |||||
""" | |||||
headers = { | |||||
"Content-Type": "application/json", | |||||
} | |||||
data = { | |||||
"text": text, | |||||
"appkey": appkey, | |||||
"token": token, | |||||
"format": "wav" | |||||
} | |||||
response = requests.post(url, headers=headers, data=json.dumps(data)) | |||||
if response.status_code == 200 and response.headers['Content-Type'] == 'audio/mpeg': | |||||
output_file = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav" | |||||
with open(output_file, 'wb') as file: | |||||
file.write(response.content) | |||||
logger.debug(f"音频文件保存成功,文件名:{output_file}") | |||||
else: | |||||
logger.debug("响应状态码: {}".format(response.status_code)) | |||||
logger.debug("响应内容: {}".format(response.text)) | |||||
output_file = None | |||||
return output_file | |||||
def speech_to_text_aliyun(url, audioContent, appkey, token): | |||||
""" | |||||
使用阿里云的语音识别服务识别音频文件中的语音。 | |||||
参数: | |||||
- url (str): 阿里云语音识别服务的端点URL。 | |||||
- audioContent (byte): pcm音频数据。 | |||||
- appkey (str): 您的阿里云appkey。 | |||||
- token (str): 阿里云API的认证令牌。 | |||||
返回值: | |||||
- str: 成功时输出识别到的文本,否则为None。 | |||||
""" | |||||
format = 'pcm' | |||||
sample_rate = 16000 | |||||
enablePunctuationPrediction = True | |||||
enableInverseTextNormalization = True | |||||
enableVoiceDetection = False | |||||
# 设置RESTful请求参数 | |||||
request = url + '?appkey=' + appkey | |||||
request = request + '&format=' + format | |||||
request = request + '&sample_rate=' + str(sample_rate) | |||||
if enablePunctuationPrediction : | |||||
request = request + '&enable_punctuation_prediction=' + 'true' | |||||
if enableInverseTextNormalization : | |||||
request = request + '&enable_inverse_text_normalization=' + 'true' | |||||
if enableVoiceDetection : | |||||
request = request + '&enable_voice_detection=' + 'true' | |||||
host = 'nls-gateway-cn-shanghai.aliyuncs.com' | |||||
# 设置HTTPS请求头部 | |||||
httpHeaders = { | |||||
'X-NLS-Token': token, | |||||
'Content-type': 'application/octet-stream', | |||||
'Content-Length': len(audioContent) | |||||
} | |||||
conn = http.client.HTTPSConnection(host) | |||||
conn.request(method='POST', url=request, body=audioContent, headers=httpHeaders) | |||||
response = conn.getresponse() | |||||
body = response.read() | |||||
try: | |||||
body = json.loads(body) | |||||
status = body['status'] | |||||
if status == 20000000 : | |||||
result = body['result'] | |||||
if result : | |||||
logger.info(f"阿里云语音识别到了:{result}") | |||||
conn.close() | |||||
return result | |||||
else : | |||||
logger.error(f"语音识别失败,状态码: {status}") | |||||
except ValueError: | |||||
logger.error(f"语音识别失败,收到非JSON格式的数据: {body}") | |||||
conn.close() | |||||
return None | |||||
class AliyunTokenGenerator: | |||||
""" | |||||
用于生成阿里云服务认证令牌的类。 | |||||
属性: | |||||
- access_key_id (str): 您的阿里云访问密钥ID。 | |||||
- access_key_secret (str): 您的阿里云访问密钥秘密。 | |||||
""" | |||||
def __init__(self, access_key_id, access_key_secret): | |||||
self.access_key_id = access_key_id | |||||
self.access_key_secret = access_key_secret | |||||
def sign_request(self, parameters): | |||||
""" | |||||
为阿里云服务签名请求。 | |||||
参数: | |||||
- parameters (dict): 请求的参数字典。 | |||||
返回值: | |||||
- str: 请求的签名签章。 | |||||
""" | |||||
# 将参数按照字典顺序排序 | |||||
sorted_params = sorted(parameters.items()) | |||||
# 构造待签名的查询字符串 | |||||
canonicalized_query_string = '' | |||||
for (k, v) in sorted_params: | |||||
canonicalized_query_string += '&' + self.percent_encode(k) + '=' + self.percent_encode(v) | |||||
# 构造用于签名的字符串 | |||||
string_to_sign = 'GET&%2F&' + self.percent_encode(canonicalized_query_string[1:]) # 使用GET方法 | |||||
# 使用HMAC算法计算签名 | |||||
h = hmac.new((self.access_key_secret + "&").encode('utf-8'), string_to_sign.encode('utf-8'), hashlib.sha1) | |||||
signature = base64.encodebytes(h.digest()).strip() | |||||
return signature | |||||
def percent_encode(self, encode_str): | |||||
""" | |||||
对字符串进行百分比编码。 | |||||
参数: | |||||
- encode_str (str): 要编码的字符串。 | |||||
返回值: | |||||
- str: 编码后的字符串。 | |||||
""" | |||||
encode_str = str(encode_str) | |||||
res = urllib.parse.quote(encode_str, '') | |||||
res = res.replace('+', '%20') | |||||
res = res.replace('*', '%2A') | |||||
res = res.replace('%7E', '~') | |||||
return res | |||||
def get_token(self): | |||||
""" | |||||
获取阿里云服务的令牌。 | |||||
返回值: | |||||
- str: 获取到的令牌。 | |||||
""" | |||||
# 设置请求参数 | |||||
params = { | |||||
'Format': 'JSON', | |||||
'Version': '2019-02-28', | |||||
'AccessKeyId': self.access_key_id, | |||||
'SignatureMethod': 'HMAC-SHA1', | |||||
'Timestamp': datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"), | |||||
'SignatureVersion': '1.0', | |||||
'SignatureNonce': str(uuid.uuid4()), # 使用uuid生成唯一的随机数 | |||||
'Action': 'CreateToken', | |||||
'RegionId': 'cn-shanghai' | |||||
} | |||||
# 计算签名 | |||||
signature = self.sign_request(params) | |||||
params['Signature'] = signature | |||||
# 构造请求URL | |||||
url = 'http://nls-meta.cn-shanghai.aliyuncs.com/?' + urllib.parse.urlencode(params) | |||||
# 发送请求 | |||||
response = requests.get(url) | |||||
return response.text |
@@ -0,0 +1,109 @@ | |||||
# -*- coding: utf-8 -*- | |||||
""" | |||||
Author: chazzjimel | |||||
Email: chazzjimel@gmail.com | |||||
wechat:cheung-z-x | |||||
Description: | |||||
ali voice service | |||||
""" | |||||
import json | |||||
import os | |||||
import re | |||||
import time | |||||
from common.log import logger | |||||
from voice.audio_convert import get_pcm_from_wav | |||||
from voice.voice import Voice | |||||
from voice.ali.ali_api import AliyunTokenGenerator, speech_to_text_aliyun, text_to_speech_aliyun | |||||
from config import conf | |||||
from pypinyin import pinyin, Style | |||||
from common.singleton import singleton | |||||
@singleton | |||||
class AliVoice(Voice): | |||||
def __init__(self): | |||||
""" | |||||
初始化AliVoice类,从配置文件加载必要的配置。 | |||||
""" | |||||
try: | |||||
curdir = os.path.dirname(__file__) | |||||
config_path = os.path.join(curdir, "config.json") | |||||
print(config_path) | |||||
with open(config_path, "r") as fr: | |||||
config = json.load(fr) | |||||
self.token = None | |||||
self.token_expire_time = 0 | |||||
# 默认复用阿里云千问的 access_key 和 access_secret | |||||
self.api_url_voice_to_text = config.get("api_url_voice_to_text") | |||||
self.api_url_text_to_voice = config.get("api_url_text_to_voice") | |||||
self.app_key = config.get("app_key") | |||||
self.access_key_id = conf().get("qwen_access_key_id") or config.get("access_key_id") | |||||
self.access_key_secret = conf().get("qwen_access_key_secret") or config.get("access_key_secret") | |||||
except Exception as e: | |||||
logger.warn("AliVoice init failed: %s, ignore " % e) | |||||
def textToVoice(self, text): | |||||
""" | |||||
将文本转换为语音文件。 | |||||
:param text: 要转换的文本。 | |||||
:return: 返回一个Reply对象,其中包含转换得到的语音文件或错误信息。 | |||||
""" | |||||
# 清除文本中的非中文、非英文和非基本字符 | |||||
text = re.sub(r'[^\u4e00-\u9fa5\u3040-\u30FF\uAC00-\uD7AFa-zA-Z0-9' | |||||
r'äöüÄÖÜáéíóúÁÉÍÓÚàèìòùÀÈÌÒÙâêîôûÂÊÎÔÛçÇñÑ,。!?,.]', '', text) | |||||
# 提取有效的token | |||||
token_id = self.get_valid_token() | |||||
fileName = text_to_speech_aliyun(self.api_url_text_to_voice, text, self.app_key, token_id) | |||||
if fileName: | |||||
logger.info("[Ali] textToVoice text={} voice file name={}".format(text, fileName)) | |||||
return fileName | |||||
# reply = Reply(ReplyType.VOICE, fileName) | |||||
# else: | |||||
# reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | |||||
# return reply | |||||
def voiceToText(self, voice_file): | |||||
""" | |||||
将语音文件转换为文本。 | |||||
:param voice_file: 要转换的语音文件。 | |||||
:return: 返回一个Reply对象,其中包含转换得到的文本或错误信息。 | |||||
""" | |||||
# 提取有效的token | |||||
token_id = self.get_valid_token() | |||||
logger.debug("[Ali] voice file name={}".format(voice_file)) | |||||
pcm = get_pcm_from_wav(voice_file) | |||||
text = speech_to_text_aliyun(self.api_url_voice_to_text, pcm, self.app_key, token_id) | |||||
# print(text) | |||||
if text: | |||||
return text | |||||
# reply = Reply(ReplyType.TEXT, text) | |||||
# else: | |||||
# reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") | |||||
# return reply | |||||
def get_valid_token(self): | |||||
""" | |||||
获取有效的阿里云token。 | |||||
:return: 返回有效的token字符串。 | |||||
""" | |||||
current_time = time.time() | |||||
if self.token is None or current_time >= self.token_expire_time: | |||||
get_token = AliyunTokenGenerator(self.access_key_id, self.access_key_secret) | |||||
token_str = get_token.get_token() | |||||
token_data = json.loads(token_str) | |||||
self.token = token_data["Token"]["Id"] | |||||
# 将过期时间减少一小段时间(例如5分钟),以避免在边界条件下的过期 | |||||
self.token_expire_time = token_data["Token"]["ExpireTime"] - 300 | |||||
logger.debug(f"新获取的阿里云token:{self.token}") | |||||
else: | |||||
logger.debug("使用缓存的token") | |||||
return self.token |
@@ -0,0 +1,7 @@ | |||||
{ | |||||
"api_url_text_to_voice": "https://nls-gateway-cn-shanghai.aliyuncs.com/stream/v1/tts", | |||||
"api_url_voice_to_text": "https://nls-gateway.cn-shanghai.aliyuncs.com/stream/v1/asr", | |||||
"app_key": "F3VB6magxpjpPgKH", | |||||
"access_key_id": "LTAI5tJS8kD1mh2fLzVJ4u3w", | |||||
"access_key_secret": "ahiLuHLiSeqBDMCgtmc9Qe3uvgo6pJ" | |||||
} |
@@ -0,0 +1,7 @@ | |||||
{ | |||||
"api_url_text_to_voice": "https://nls-gateway-cn-shanghai.aliyuncs.com/stream/v1/tts", | |||||
"api_url_voice_to_text": "https://nls-gateway.cn-shanghai.aliyuncs.com/stream/v1/asr", | |||||
"app_key": "", | |||||
"access_key_id": "", | |||||
"access_key_secret": "" | |||||
} |
@@ -0,0 +1,136 @@ | |||||
import shutil | |||||
import wave | |||||
from common.log import logger | |||||
try: | |||||
import pysilk | |||||
except ImportError: | |||||
logger.debug("import pysilk failed, wechaty voice message will not be supported.") | |||||
from pydub import AudioSegment | |||||
sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率 | |||||
def find_closest_sil_supports(sample_rate): | |||||
""" | |||||
找到最接近的支持的采样率 | |||||
""" | |||||
if sample_rate in sil_supports: | |||||
return sample_rate | |||||
closest = 0 | |||||
mindiff = 9999999 | |||||
for rate in sil_supports: | |||||
diff = abs(rate - sample_rate) | |||||
if diff < mindiff: | |||||
closest = rate | |||||
mindiff = diff | |||||
return closest | |||||
def get_pcm_from_wav(wav_path): | |||||
""" | |||||
从 wav 文件中读取 pcm | |||||
:param wav_path: wav 文件路径 | |||||
:returns: pcm 数据 | |||||
""" | |||||
wav = wave.open(wav_path, "rb") | |||||
return wav.readframes(wav.getnframes()) | |||||
def any_to_mp3(any_path, mp3_path): | |||||
""" | |||||
把任意格式转成mp3文件 | |||||
""" | |||||
if any_path.endswith(".mp3"): | |||||
shutil.copy2(any_path, mp3_path) | |||||
return | |||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): | |||||
sil_to_wav(any_path, any_path) | |||||
any_path = mp3_path | |||||
audio = AudioSegment.from_file(any_path) | |||||
audio.export(mp3_path, format="mp3") | |||||
def any_to_wav(any_path, wav_path): | |||||
""" | |||||
把任意格式转成wav文件 | |||||
""" | |||||
if any_path.endswith(".wav"): | |||||
shutil.copy2(any_path, wav_path) | |||||
return | |||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): | |||||
return sil_to_wav(any_path, wav_path) | |||||
audio = AudioSegment.from_file(any_path) | |||||
audio.set_frame_rate(8000) # 百度语音转写支持8000采样率, pcm_s16le, 单通道语音识别 | |||||
audio.set_channels(1) | |||||
audio.export(wav_path, format="wav", codec='pcm_s16le') | |||||
def any_to_sil(any_path, sil_path): | |||||
""" | |||||
把任意格式转成sil文件 | |||||
""" | |||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): | |||||
shutil.copy2(any_path, sil_path) | |||||
return 10000 | |||||
audio = AudioSegment.from_file(any_path) | |||||
rate = find_closest_sil_supports(audio.frame_rate) | |||||
# Convert to PCM_s16 | |||||
pcm_s16 = audio.set_sample_width(2) | |||||
pcm_s16 = pcm_s16.set_frame_rate(rate) | |||||
wav_data = pcm_s16.raw_data | |||||
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate) | |||||
with open(sil_path, "wb") as f: | |||||
f.write(silk_data) | |||||
return audio.duration_seconds * 1000 | |||||
def any_to_amr(any_path, amr_path): | |||||
""" | |||||
把任意格式转成amr文件 | |||||
""" | |||||
if any_path.endswith(".amr"): | |||||
shutil.copy2(any_path, amr_path) | |||||
return | |||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): | |||||
raise NotImplementedError("Not support file type: {}".format(any_path)) | |||||
audio = AudioSegment.from_file(any_path) | |||||
audio = audio.set_frame_rate(8000) # only support 8000 | |||||
audio.export(amr_path, format="amr") | |||||
return audio.duration_seconds * 1000 | |||||
def sil_to_wav(silk_path, wav_path, rate: int = 24000): | |||||
""" | |||||
silk 文件转 wav | |||||
""" | |||||
wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate) | |||||
with open(wav_path, "wb") as f: | |||||
f.write(wav_data) | |||||
def split_audio(file_path, max_segment_length_ms=60000): | |||||
""" | |||||
分割音频文件 | |||||
""" | |||||
audio = AudioSegment.from_file(file_path) | |||||
audio_length_ms = len(audio) | |||||
if audio_length_ms <= max_segment_length_ms: | |||||
return audio_length_ms, [file_path] | |||||
segments = [] | |||||
for start_ms in range(0, audio_length_ms, max_segment_length_ms): | |||||
end_ms = min(audio_length_ms, start_ms + max_segment_length_ms) | |||||
segment = audio[start_ms:end_ms] | |||||
segments.append(segment) | |||||
file_prefix = file_path[: file_path.rindex(".")] | |||||
format = file_path[file_path.rindex(".") + 1 :] | |||||
files = [] | |||||
for i, segment in enumerate(segments): | |||||
path = f"{file_prefix}_{i+1}" + f".{format}" | |||||
segment.export(path, format=format) | |||||
files.append(path) | |||||
return audio_length_ms, files |
@@ -0,0 +1,17 @@ | |||||
""" | |||||
Voice service abstract class | |||||
""" | |||||
class Voice(object): | |||||
def voiceToText(self, voice_file): | |||||
""" | |||||
Send voice to voice service and get text | |||||
""" | |||||
raise NotImplementedError | |||||
def textToVoice(self, text): | |||||
""" | |||||
Send text to voice service and get voice | |||||
""" | |||||
raise NotImplementedError |