You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 satır
8.4KB

  1. import aioredis
  2. import os
  3. import uuid
  4. import asyncio
  5. import threading
  6. from fastapi import FastAPI, Depends
  7. from fastapi import Request
  8. import time
  9. # 定义全局 redis_helper 作为单例
  10. class RedisService:
  11. _instance = None
  12. def __new__(cls, host='localhost', port=6379, password=None, db=0):
  13. if not cls._instance:
  14. cls._instance = super(RedisService, cls).__new__(cls)
  15. cls._instance.client = None
  16. cls._instance.lock_renewal_thread = None
  17. return cls._instance
  18. async def init(self, host='localhost', port=6379, password=None, db=0):
  19. """初始化 Redis 连接"""
  20. #self.client = await aioredis.Redis(f'redis://{host}:{port}', password=password, db=db)
  21. self.client = await aioredis.Redis(host=host, port=port, password=password, db=db)
  22. # async def set_hash(self, hash_key, data, timeout=None):
  23. # """添加或更新哈希,并设置有效期"""
  24. # await self.client.hmset_dict(hash_key, data)
  25. # if timeout:
  26. # # 设置有效期(单位:秒)
  27. # await self.client.expire(hash_key, timeout)
  28. async def set_hash(self, hash_key, data, timeout=None):
  29. """添加或更新哈希,并设置有效期"""
  30. # 使用 hmset 方法设置哈希表数据
  31. await self.client.hmset(hash_key, data)
  32. if timeout:
  33. # 设置有效期(单位:秒)
  34. await self.client.expire(hash_key, timeout)
  35. async def get_hash(self, hash_key):
  36. """获取整个哈希表数据"""
  37. result = await self.client.hgetall(hash_key)
  38. # 将字节数据解码成字符串格式返回
  39. return {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
  40. async def get_hash_field(self, hash_key, field):
  41. """获取哈希表中的单个字段值"""
  42. result = await self.client.hget(hash_key, field)
  43. return result.decode('utf-8') if result else None
  44. async def delete_hash(self, hash_key):
  45. """删除整个哈希表"""
  46. await self.client.delete(hash_key)
  47. async def delete_hash_field(self, hash_key, field):
  48. """删除哈希表中的某个字段"""
  49. await self.client.hdel(hash_key, field)
  50. async def update_hash_field(self, hash_key, field, value):
  51. """更新哈希表中的某个字段"""
  52. await self.client.hset(hash_key, field, value)
  53. # async def acquire_lock(self, lock_name, timeout=60):
  54. # """
  55. # 尝试获取分布式锁,成功返回 True,失败返回 False
  56. # :param lock_name: 锁的名称
  57. # :param timeout: 锁的超时时间(秒)
  58. # :return: bool
  59. # """
  60. # # identifier = str(time.time()) # 使用时间戳作为唯一标识
  61. # # if await self.client.setnx(lock_name, identifier):
  62. # # await self.client.expire(lock_name, timeout)
  63. # # self.lock_renewal_thread = threading.Thread(target=self.renew_lock, args=(lock_name, identifier, timeout))
  64. # # self.lock_renewal_thread.start()
  65. # # return True
  66. # # return False
  67. # identifier = str(time.time()) # 使用 UUID 作为唯一标识
  68. # if await self.client.setnx(lock_name, identifier):
  69. # await self.client.expire(lock_name, timeout)
  70. # self.lock_renewal_thread = threading.Thread(
  71. # target=self.run_renew_lock,
  72. # args=(lock_name, identifier, timeout)
  73. # )
  74. # self.lock_renewal_thread.start()
  75. # return True
  76. # return False
  77. # def run_renew_lock(self, lock_name, identifier, timeout):
  78. # """
  79. # 在单独的线程中运行异步的 renew_lock 方法
  80. # """
  81. # loop = asyncio.new_event_loop()
  82. # asyncio.set_event_loop(loop)
  83. # loop.run_until_complete(self.renew_lock(lock_name, identifier, timeout))
  84. # loop.close()
  85. async def acquire_lock(self, lock_name, timeout=60):
  86. """
  87. 尝试获取分布式锁,成功返回 True,失败返回 False
  88. :param lock_name: 锁的名称
  89. :param timeout: 锁的超时时间(秒)
  90. :return: bool
  91. """
  92. identifier = str(uuid.uuid4()) # 使用 UUID 作为唯一标识
  93. if await self.client.setnx(lock_name, identifier):
  94. await self.client.expire(lock_name, timeout)
  95. # 启动异步任务来续期锁
  96. self.lock_renewal_task = asyncio.create_task(
  97. self.renew_lock(lock_name, identifier, timeout)
  98. )
  99. return True
  100. return False
  101. async def renew_lock(self, lock_name, identifier, timeout):
  102. """
  103. 锁的自动续期
  104. :param lock_name: 锁的名称
  105. :param identifier: 锁的唯一标识
  106. :param timeout: 锁的超时时间(秒)
  107. """
  108. while True:
  109. await asyncio.sleep(timeout / 2)
  110. if await self.client.get(lock_name) == identifier.encode():
  111. await self.client.expire(lock_name, timeout)
  112. else:
  113. break
  114. async def release_lock(self, lock_name, identifier):
  115. """
  116. 释放分布式锁
  117. :param lock_name: 锁的名称
  118. :param identifier: 锁的唯一标识
  119. """
  120. if await self.client.get(lock_name) == identifier.encode():
  121. await self.client.delete(lock_name)
  122. if self.lock_renewal_thread:
  123. self.lock_renewal_thread.join()
  124. async def enqueue(self, queue_name, item):
  125. """
  126. 将元素添加到队列的尾部(右侧)
  127. :param queue_name: 队列名称
  128. :param item: 要添加到队列的元素
  129. """
  130. await self.client.rpush(queue_name, item)
  131. print(f"Enqueued: {item} to queue {queue_name}")
  132. async def dequeue(self, queue_name):
  133. """
  134. 从队列的头部(左侧)移除并返回元素
  135. :param queue_name: 队列名称
  136. :return: 移除的元素,如果队列为空则返回 None
  137. """
  138. item = await self.client.lpop(queue_name)
  139. if item:
  140. print(f"Dequeued: {item.decode('utf-8')} from queue {queue_name}")
  141. return item.decode('utf-8')
  142. print(f"Queue {queue_name} is empty")
  143. return None
  144. async def get_queue_length(self, queue_name):
  145. """
  146. 获取队列的长度
  147. :param queue_name: 队列名称
  148. :return: 队列的长度
  149. """
  150. length = await self.client.llen(queue_name)
  151. print(f"Queue {queue_name} length: {length}")
  152. return length
  153. async def peek_queue(self, queue_name):
  154. """
  155. 查看队列的头部元素,但不移除
  156. :param queue_name: 队列名称
  157. :return: 队列的头部元素,如果队列为空则返回 None
  158. """
  159. item = await self.client.lrange(queue_name, 0, 0)
  160. if item:
  161. print(f"Peeked: {item[0].decode('utf-8')} from queue {queue_name}")
  162. return item[0].decode('utf-8')
  163. print(f"Queue {queue_name} is empty")
  164. return None
  165. async def clear_queue(self, queue_name):
  166. """
  167. 清空队列
  168. :param queue_name: 队列名称
  169. """
  170. await self.client.delete(queue_name)
  171. print(f"Cleared queue {queue_name}")
  172. async def increment_hash_field(self, hash_key, field, amount=1):
  173. """
  174. 对哈希表中的指定字段进行递增操作(原子性)。
  175. :param hash_key: 哈希表的 key
  176. :param field: 要递增的字段
  177. :param amount: 递增的数值(默认为 1)
  178. :return: 递增后的值
  179. """
  180. return await self.client.hincrby(hash_key, field, amount)
  181. async def expire(self, key, timeout):
  182. """
  183. 设置 Redis 键的过期时间(单位:秒)
  184. :param key: Redis 键
  185. :param timeout: 过期时间(秒)
  186. """
  187. await self.client.expire(key, timeout)
  188. async def expire_field(self, hash_key, field, timeout):
  189. """
  190. 通过辅助键方式设置哈希表中某个字段的过期时间
  191. :param hash_key: 哈希表的 key
  192. :param field: 要设置过期的字段
  193. :param timeout: 过期时间(秒)
  194. """
  195. expire_key = f"{hash_key}:{field}:expire"
  196. await self.client.set(expire_key, "1")
  197. await self.client.expire(expire_key, timeout)
  198. # Dependency injection helper function
  199. async def get_redis_service(request: Request) -> RedisService:
  200. return request.app.state.redis_serive