|
- import aioredis
- import os
- import uuid
- import asyncio
- import threading
- from fastapi import FastAPI, Depends
- from fastapi import Request
- import time
-
- # 定义全局 redis_helper 作为单例
- class RedisService:
- _instance = None
- def __new__(cls, host='localhost', port=6379, password=None, db=0):
- if not cls._instance:
- cls._instance = super(RedisService, cls).__new__(cls)
- cls._instance.client = None
- cls._instance.lock_renewal_thread = None
- return cls._instance
-
- async def init(self, host='localhost', port=6379, password=None, db=0):
- """初始化 Redis 连接"""
- #self.client = await aioredis.Redis(f'redis://{host}:{port}', password=password, db=db)
- self.client = await aioredis.Redis(host=host, port=port, password=password, db=db)
-
-
- async def set_hash(self, hash_key, data, timeout=None):
- """添加或更新哈希,并设置有效期"""
- # 使用 hmset 方法设置哈希表数据
- await self.client.hmset(hash_key, data)
- if timeout:
- # 设置有效期(单位:秒)
- await self.client.expire(hash_key, timeout)
-
-
- # async def set_hash(self, hash_key, data, timeout=None):
- # """添加或更新哈希,并设置有效期"""
- # # 使用 hset 方法设置哈希表数据
- # await self.client.hset(hash_key, mapping=data)
-
- # if timeout:
- # # 设置有效期(单位:秒)
- # await self.client.expire(hash_key, timeout)
-
- # def flatten_dict(self,d, parent_key="", sep="."):
- # """
- # 将嵌套字典扁平化
- # :param d: 嵌套字典
- # :param parent_key: 父键(用于递归)
- # :param sep: 分隔符
- # :return: 扁平化字典
- # """
- # items = []
- # for k, v in d.items():
- # new_key = f"{parent_key}{sep}{k}" if parent_key else k
- # if isinstance(v, dict):
- # items.extend(self.flatten_dict(v, new_key, sep=sep).items())
- # else:
- # items.append((new_key, v))
- # return dict(items)
-
- # async def set_hash(self, hash_key, data, timeout=None):
- # """添加或更新哈希,并设置有效期"""
- # # 扁平化嵌套字典
- # flat_data = self.flatten_dict(data)
-
- # # 使用 hset 方法设置哈希表数据
- # await self.client.hset(hash_key, mapping=flat_data)
-
- # if timeout:
- # # 设置有效期(单位:秒)
- # await self.client.expire(hash_key, timeout)
-
- async def get_hash(self, hash_key):
- """获取整个哈希表数据"""
- result = await self.client.hgetall(hash_key)
- # 将字节数据解码成字符串格式返回
- return {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
-
- async def get_hash_field(self, hash_key, field):
- """获取哈希表中的单个字段值"""
- result = await self.client.hget(hash_key, field)
- return result.decode('utf-8') if result else None
-
- async def delete_hash(self, hash_key):
- """删除整个哈希表"""
- await self.client.delete(hash_key)
-
- async def delete_hash_field(self, hash_key, field):
- """删除哈希表中的某个字段"""
- await self.client.hdel(hash_key, field)
-
- async def update_hash_field(self, hash_key, field, value):
- """更新哈希表中的某个字段"""
- await self.client.hset(hash_key, field, value)
-
- # async def acquire_lock(self, lock_name, timeout=60):
- # """
- # 尝试获取分布式锁,成功返回 True,失败返回 False
- # :param lock_name: 锁的名称
- # :param timeout: 锁的超时时间(秒)
- # :return: bool
- # """
- # # identifier = str(time.time()) # 使用时间戳作为唯一标识
- # # if await self.client.setnx(lock_name, identifier):
- # # await self.client.expire(lock_name, timeout)
- # # self.lock_renewal_thread = threading.Thread(target=self.renew_lock, args=(lock_name, identifier, timeout))
- # # self.lock_renewal_thread.start()
- # # return True
- # # return False
-
- # identifier = str(time.time()) # 使用 UUID 作为唯一标识
- # if await self.client.setnx(lock_name, identifier):
- # await self.client.expire(lock_name, timeout)
- # self.lock_renewal_thread = threading.Thread(
- # target=self.run_renew_lock,
- # args=(lock_name, identifier, timeout)
- # )
- # self.lock_renewal_thread.start()
- # return True
- # return False
-
-
- # def run_renew_lock(self, lock_name, identifier, timeout):
- # """
- # 在单独的线程中运行异步的 renew_lock 方法
- # """
- # loop = asyncio.new_event_loop()
- # asyncio.set_event_loop(loop)
- # loop.run_until_complete(self.renew_lock(lock_name, identifier, timeout))
- # loop.close()
-
-
- async def acquire_lock(self, lock_name, timeout=60):
- """
- 尝试获取分布式锁,成功返回 True,失败返回 False
- :param lock_name: 锁的名称
- :param timeout: 锁的超时时间(秒)
- :return: bool
- """
- identifier = str(uuid.uuid4()) # 使用 UUID 作为唯一标识
- if await self.client.setnx(lock_name, identifier):
- await self.client.expire(lock_name, timeout)
- # 启动异步任务来续期锁
- self.lock_renewal_task = asyncio.create_task(
- self.renew_lock(lock_name, identifier, timeout)
- )
- return True
- return False
-
- async def renew_lock(self, lock_name, identifier, timeout):
- """
- 锁的自动续期
- :param lock_name: 锁的名称
- :param identifier: 锁的唯一标识
- :param timeout: 锁的超时时间(秒)
- """
- while True:
- await asyncio.sleep(timeout / 2)
- if await self.client.get(lock_name) == identifier.encode():
- await self.client.expire(lock_name, timeout)
- else:
- break
-
- async def release_lock(self, lock_name, identifier):
- """
- 释放分布式锁
- :param lock_name: 锁的名称
- :param identifier: 锁的唯一标识
- """
- if await self.client.get(lock_name) == identifier.encode():
- await self.client.delete(lock_name)
- if self.lock_renewal_thread:
- self.lock_renewal_thread.join()
-
- async def enqueue(self, queue_name, item):
- """
- 将元素添加到队列的尾部(右侧)
- :param queue_name: 队列名称
- :param item: 要添加到队列的元素
- """
- await self.client.rpush(queue_name, item)
- print(f"Enqueued: {item} to queue {queue_name}")
-
- async def dequeue(self, queue_name):
- """
- 从队列的头部(左侧)移除并返回元素
- :param queue_name: 队列名称
- :return: 移除的元素,如果队列为空则返回 None
- """
- item = await self.client.lpop(queue_name)
- if item:
- print(f"Dequeued: {item.decode('utf-8')} from queue {queue_name}")
- return item.decode('utf-8')
- print(f"Queue {queue_name} is empty")
- return None
-
- async def get_queue_length(self, queue_name):
- """
- 获取队列的长度
- :param queue_name: 队列名称
- :return: 队列的长度
- """
- length = await self.client.llen(queue_name)
- print(f"Queue {queue_name} length: {length}")
- return length
-
- async def peek_queue(self, queue_name):
- """
- 查看队列的头部元素,但不移除
- :param queue_name: 队列名称
- :return: 队列的头部元素,如果队列为空则返回 None
- """
- item = await self.client.lrange(queue_name, 0, 0)
- if item:
- print(f"Peeked: {item[0].decode('utf-8')} from queue {queue_name}")
- return item[0].decode('utf-8')
- print(f"Queue {queue_name} is empty")
- return None
-
- async def clear_queue(self, queue_name):
- """
- 清空队列
- :param queue_name: 队列名称
- """
- await self.client.delete(queue_name)
- print(f"Cleared queue {queue_name}")
-
- async def increment_hash_field(self, hash_key, field, amount=1):
- """
- 对哈希表中的指定字段进行递增操作(原子性)。
-
- :param hash_key: 哈希表的 key
- :param field: 要递增的字段
- :param amount: 递增的数值(默认为 1)
- :return: 递增后的值
- """
- return await self.client.hincrby(hash_key, field, amount)
-
- async def expire(self, key, timeout):
- """
- 设置 Redis 键的过期时间(单位:秒)
-
- :param key: Redis 键
- :param timeout: 过期时间(秒)
- """
- await self.client.expire(key, timeout)
-
- async def expire_field(self, hash_key, field, timeout):
- """
- 通过辅助键方式设置哈希表中某个字段的过期时间
-
- :param hash_key: 哈希表的 key
- :param field: 要设置过期的字段
- :param timeout: 过期时间(秒)
- """
- expire_key = f"{hash_key}:{field}:expire"
- await self.client.set(expire_key, "1")
- await self.client.expire(expire_key, timeout)
-
- # Dependency injection helper function
- async def get_redis_service(request: Request) -> RedisService:
- return request.app.state.redis_serive
|