import aioredis import os import uuid import asyncio import threading from fastapi import FastAPI, Depends from fastapi import Request import time # 定义全局 redis_helper 作为单例 class RedisService: _instance = None def __new__(cls, host='localhost', port=6379, password=None, db=0): if not cls._instance: cls._instance = super(RedisService, cls).__new__(cls) cls._instance.client = None cls._instance.lock_renewal_thread = None return cls._instance async def init(self, host='localhost', port=6379, password=None, db=0): """初始化 Redis 连接""" #self.client = await aioredis.Redis(f'redis://{host}:{port}', password=password, db=db) self.client = await aioredis.Redis(host=host, port=port, password=password, db=db) async def set_hash(self, hash_key, data, timeout=None): """添加或更新哈希,并设置有效期""" await self.client.hmset_dict(hash_key, data) if timeout: # 设置有效期(单位:秒) await self.client.expire(hash_key, timeout) # async def set_hash(self, hash_key, data, timeout=None): # """添加或更新哈希,并设置有效期""" # # 使用 hmset 方法设置哈希表数据 # await self.client.hmset(hash_key, data) # if timeout: # # 设置有效期(单位:秒) # await self.client.expire(hash_key, timeout) # async def set_hash(self, hash_key, data, timeout=None): # """添加或更新哈希,并设置有效期""" # # 使用 hset 方法设置哈希表数据 # await self.client.hset(hash_key, mapping=data) # if timeout: # # 设置有效期(单位:秒) # await self.client.expire(hash_key, timeout) # def flatten_dict(self,d, parent_key="", sep="."): # """ # 将嵌套字典扁平化 # :param d: 嵌套字典 # :param parent_key: 父键(用于递归) # :param sep: 分隔符 # :return: 扁平化字典 # """ # items = [] # for k, v in d.items(): # new_key = f"{parent_key}{sep}{k}" if parent_key else k # if isinstance(v, dict): # items.extend(self.flatten_dict(v, new_key, sep=sep).items()) # else: # items.append((new_key, v)) # return dict(items) # async def set_hash(self, hash_key, data, timeout=None): # """添加或更新哈希,并设置有效期""" # # 扁平化嵌套字典 # flat_data = self.flatten_dict(data) # # 使用 hset 方法设置哈希表数据 # await self.client.hset(hash_key, mapping=flat_data) # if timeout: # # 设置有效期(单位:秒) # await self.client.expire(hash_key, timeout) async def get_hash(self, hash_key): """获取整个哈希表数据""" result = await self.client.hgetall(hash_key) # 将字节数据解码成字符串格式返回 return {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} async def get_hash_field(self, hash_key, field): """获取哈希表中的单个字段值""" result = await self.client.hget(hash_key, field) return result.decode('utf-8') if result else None async def delete_hash(self, hash_key): """删除整个哈希表""" await self.client.delete(hash_key) async def delete_hash_field(self, hash_key, field): """删除哈希表中的某个字段""" await self.client.hdel(hash_key, field) async def update_hash_field(self, hash_key, field, value): """更新哈希表中的某个字段""" await self.client.hset(hash_key, field, value) # async def acquire_lock(self, lock_name, timeout=60): # """ # 尝试获取分布式锁,成功返回 True,失败返回 False # :param lock_name: 锁的名称 # :param timeout: 锁的超时时间(秒) # :return: bool # """ # # identifier = str(time.time()) # 使用时间戳作为唯一标识 # # if await self.client.setnx(lock_name, identifier): # # await self.client.expire(lock_name, timeout) # # self.lock_renewal_thread = threading.Thread(target=self.renew_lock, args=(lock_name, identifier, timeout)) # # self.lock_renewal_thread.start() # # return True # # return False # identifier = str(time.time()) # 使用 UUID 作为唯一标识 # if await self.client.setnx(lock_name, identifier): # await self.client.expire(lock_name, timeout) # self.lock_renewal_thread = threading.Thread( # target=self.run_renew_lock, # args=(lock_name, identifier, timeout) # ) # self.lock_renewal_thread.start() # return True # return False # def run_renew_lock(self, lock_name, identifier, timeout): # """ # 在单独的线程中运行异步的 renew_lock 方法 # """ # loop = asyncio.new_event_loop() # asyncio.set_event_loop(loop) # loop.run_until_complete(self.renew_lock(lock_name, identifier, timeout)) # loop.close() async def acquire_lock(self, lock_name, timeout=60): """ 尝试获取分布式锁,成功返回 True,失败返回 False :param lock_name: 锁的名称 :param timeout: 锁的超时时间(秒) :return: bool """ identifier = str(uuid.uuid4()) # 使用 UUID 作为唯一标识 if await self.client.setnx(lock_name, identifier): await self.client.expire(lock_name, timeout) # 启动异步任务来续期锁 self.lock_renewal_task = asyncio.create_task( self.renew_lock(lock_name, identifier, timeout) ) return True return False async def renew_lock(self, lock_name, identifier, timeout): """ 锁的自动续期 :param lock_name: 锁的名称 :param identifier: 锁的唯一标识 :param timeout: 锁的超时时间(秒) """ while True: await asyncio.sleep(timeout / 2) if await self.client.get(lock_name) == identifier.encode(): await self.client.expire(lock_name, timeout) else: break async def release_lock(self, lock_name, identifier): """ 释放分布式锁 :param lock_name: 锁的名称 :param identifier: 锁的唯一标识 """ if await self.client.get(lock_name) == identifier.encode(): await self.client.delete(lock_name) if self.lock_renewal_thread: self.lock_renewal_thread.join() async def enqueue(self, queue_name, item): """ 将元素添加到队列的尾部(右侧) :param queue_name: 队列名称 :param item: 要添加到队列的元素 """ await self.client.rpush(queue_name, item) print(f"Enqueued: {item} to queue {queue_name}") async def dequeue(self, queue_name): """ 从队列的头部(左侧)移除并返回元素 :param queue_name: 队列名称 :return: 移除的元素,如果队列为空则返回 None """ item = await self.client.lpop(queue_name) if item: print(f"Dequeued: {item.decode('utf-8')} from queue {queue_name}") return item.decode('utf-8') print(f"Queue {queue_name} is empty") return None async def get_queue_length(self, queue_name): """ 获取队列的长度 :param queue_name: 队列名称 :return: 队列的长度 """ length = await self.client.llen(queue_name) print(f"Queue {queue_name} length: {length}") return length async def peek_queue(self, queue_name): """ 查看队列的头部元素,但不移除 :param queue_name: 队列名称 :return: 队列的头部元素,如果队列为空则返回 None """ item = await self.client.lrange(queue_name, 0, 0) if item: print(f"Peeked: {item[0].decode('utf-8')} from queue {queue_name}") return item[0].decode('utf-8') print(f"Queue {queue_name} is empty") return None async def clear_queue(self, queue_name): """ 清空队列 :param queue_name: 队列名称 """ await self.client.delete(queue_name) print(f"Cleared queue {queue_name}") async def increment_hash_field(self, hash_key, field, amount=1): """ 对哈希表中的指定字段进行递增操作(原子性)。 :param hash_key: 哈希表的 key :param field: 要递增的字段 :param amount: 递增的数值(默认为 1) :return: 递增后的值 """ return await self.client.hincrby(hash_key, field, amount) async def expire(self, key, timeout): """ 设置 Redis 键的过期时间(单位:秒) :param key: Redis 键 :param timeout: 过期时间(秒) """ await self.client.expire(key, timeout) async def expire_field(self, hash_key, field, timeout): """ 通过辅助键方式设置哈希表中某个字段的过期时间 :param hash_key: 哈希表的 key :param field: 要设置过期的字段 :param timeout: 过期时间(秒) """ expire_key = f"{hash_key}:{field}:expire" await self.client.set(expire_key, "1") await self.client.expire(expire_key, timeout) # Dependency injection helper function async def get_redis_service(request: Request) -> RedisService: return request.app.state.redis_serive