您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

198 行
7.3KB

  1. import aioredis
  2. import os
  3. import uuid
  4. import asyncio
  5. import threading
  6. from fastapi import FastAPI, Depends
  7. from fastapi import Request
  8. import time
  9. # 定义全局 redis_helper 作为单例
  10. class RedisService:
  11. _instance = None
  12. def __new__(cls, host='localhost', port=6379, password=None, db=0):
  13. if not cls._instance:
  14. cls._instance = super(RedisService, cls).__new__(cls)
  15. cls._instance.client = None
  16. cls._instance.lock_renewal_thread = None
  17. return cls._instance
  18. async def init(self, host='localhost', port=6379, password=None, db=0):
  19. """初始化 Redis 连接"""
  20. #self.client = await aioredis.Redis(f'redis://{host}:{port}', password=password, db=db)
  21. self.client = await aioredis.Redis(host=host, port=port, password=password, db=db)
  22. # async def set_hash(self, hash_key, data, timeout=None):
  23. # """添加或更新哈希,并设置有效期"""
  24. # await self.client.hmset_dict(hash_key, data)
  25. # if timeout:
  26. # # 设置有效期(单位:秒)
  27. # await self.client.expire(hash_key, timeout)
  28. async def set_hash(self, hash_key, data, timeout=None):
  29. """添加或更新哈希,并设置有效期"""
  30. # 使用 hmset 方法设置哈希表数据
  31. await self.client.hmset(hash_key, data)
  32. if timeout:
  33. # 设置有效期(单位:秒)
  34. await self.client.expire(hash_key, timeout)
  35. async def get_hash(self, hash_key):
  36. """获取整个哈希表数据"""
  37. result = await self.client.hgetall(hash_key)
  38. # 将字节数据解码成字符串格式返回
  39. return {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
  40. async def get_hash_field(self, hash_key, field):
  41. """获取哈希表中的单个字段值"""
  42. result = await self.client.hget(hash_key, field)
  43. return result.decode('utf-8') if result else None
  44. async def delete_hash(self, hash_key):
  45. """删除整个哈希表"""
  46. await self.client.delete(hash_key)
  47. async def delete_hash_field(self, hash_key, field):
  48. """删除哈希表中的某个字段"""
  49. await self.client.hdel(hash_key, field)
  50. async def update_hash_field(self, hash_key, field, value):
  51. """更新哈希表中的某个字段"""
  52. await self.client.hset(hash_key, field, value)
  53. # async def acquire_lock(self, lock_name, timeout=60):
  54. # """
  55. # 尝试获取分布式锁,成功返回 True,失败返回 False
  56. # :param lock_name: 锁的名称
  57. # :param timeout: 锁的超时时间(秒)
  58. # :return: bool
  59. # """
  60. # # identifier = str(time.time()) # 使用时间戳作为唯一标识
  61. # # if await self.client.setnx(lock_name, identifier):
  62. # # await self.client.expire(lock_name, timeout)
  63. # # self.lock_renewal_thread = threading.Thread(target=self.renew_lock, args=(lock_name, identifier, timeout))
  64. # # self.lock_renewal_thread.start()
  65. # # return True
  66. # # return False
  67. # identifier = str(time.time()) # 使用 UUID 作为唯一标识
  68. # if await self.client.setnx(lock_name, identifier):
  69. # await self.client.expire(lock_name, timeout)
  70. # self.lock_renewal_thread = threading.Thread(
  71. # target=self.run_renew_lock,
  72. # args=(lock_name, identifier, timeout)
  73. # )
  74. # self.lock_renewal_thread.start()
  75. # return True
  76. # return False
  77. # def run_renew_lock(self, lock_name, identifier, timeout):
  78. # """
  79. # 在单独的线程中运行异步的 renew_lock 方法
  80. # """
  81. # loop = asyncio.new_event_loop()
  82. # asyncio.set_event_loop(loop)
  83. # loop.run_until_complete(self.renew_lock(lock_name, identifier, timeout))
  84. # loop.close()
  85. async def acquire_lock(self, lock_name, timeout=60):
  86. """
  87. 尝试获取分布式锁,成功返回 True,失败返回 False
  88. :param lock_name: 锁的名称
  89. :param timeout: 锁的超时时间(秒)
  90. :return: bool
  91. """
  92. identifier = str(uuid.uuid4()) # 使用 UUID 作为唯一标识
  93. if await self.client.setnx(lock_name, identifier):
  94. await self.client.expire(lock_name, timeout)
  95. # 启动异步任务来续期锁
  96. self.lock_renewal_task = asyncio.create_task(
  97. self.renew_lock(lock_name, identifier, timeout)
  98. )
  99. return True
  100. return False
  101. async def renew_lock(self, lock_name, identifier, timeout):
  102. """
  103. 锁的自动续期
  104. :param lock_name: 锁的名称
  105. :param identifier: 锁的唯一标识
  106. :param timeout: 锁的超时时间(秒)
  107. """
  108. while True:
  109. await asyncio.sleep(timeout / 2)
  110. if await self.client.get(lock_name) == identifier.encode():
  111. await self.client.expire(lock_name, timeout)
  112. else:
  113. break
  114. async def release_lock(self, lock_name, identifier):
  115. """
  116. 释放分布式锁
  117. :param lock_name: 锁的名称
  118. :param identifier: 锁的唯一标识
  119. """
  120. if await self.client.get(lock_name) == identifier.encode():
  121. await self.client.delete(lock_name)
  122. if self.lock_renewal_thread:
  123. self.lock_renewal_thread.join()
  124. async def enqueue(self, queue_name, item):
  125. """
  126. 将元素添加到队列的尾部(右侧)
  127. :param queue_name: 队列名称
  128. :param item: 要添加到队列的元素
  129. """
  130. await self.client.rpush(queue_name, item)
  131. print(f"Enqueued: {item} to queue {queue_name}")
  132. async def dequeue(self, queue_name):
  133. """
  134. 从队列的头部(左侧)移除并返回元素
  135. :param queue_name: 队列名称
  136. :return: 移除的元素,如果队列为空则返回 None
  137. """
  138. item = await self.client.lpop(queue_name)
  139. if item:
  140. print(f"Dequeued: {item.decode('utf-8')} from queue {queue_name}")
  141. return item.decode('utf-8')
  142. print(f"Queue {queue_name} is empty")
  143. return None
  144. async def get_queue_length(self, queue_name):
  145. """
  146. 获取队列的长度
  147. :param queue_name: 队列名称
  148. :return: 队列的长度
  149. """
  150. length = await self.client.llen(queue_name)
  151. print(f"Queue {queue_name} length: {length}")
  152. return length
  153. async def peek_queue(self, queue_name):
  154. """
  155. 查看队列的头部元素,但不移除
  156. :param queue_name: 队列名称
  157. :return: 队列的头部元素,如果队列为空则返回 None
  158. """
  159. item = await self.client.lrange(queue_name, 0, 0)
  160. if item:
  161. print(f"Peeked: {item[0].decode('utf-8')} from queue {queue_name}")
  162. return item[0].decode('utf-8')
  163. print(f"Queue {queue_name} is empty")
  164. return None
  165. async def clear_queue(self, queue_name):
  166. """
  167. 清空队列
  168. :param queue_name: 队列名称
  169. """
  170. await self.client.delete(queue_name)
  171. print(f"Cleared queue {queue_name}")
  172. # Dependency injection helper function
  173. async def get_redis_service(request: Request) -> RedisService:
  174. return request.app.state.redis_serive