您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

redis_service.py 9.5KB

1 个月前
1周前
1 个月前
1 个月前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import aioredis
  2. import os
  3. import uuid
  4. import asyncio
  5. import threading
  6. from fastapi import FastAPI, Depends
  7. from fastapi import Request
  8. import time
  9. # 定义全局 redis_helper 作为单例
  10. class RedisService:
  11. _instance = None
  12. def __new__(cls, host='localhost', port=6379, password=None, db=0):
  13. if not cls._instance:
  14. cls._instance = super(RedisService, cls).__new__(cls)
  15. cls._instance.client = None
  16. cls._instance.lock_renewal_thread = None
  17. return cls._instance
  18. async def init(self, host='localhost', port=6379, password=None, db=0):
  19. """初始化 Redis 连接"""
  20. #self.client = await aioredis.Redis(f'redis://{host}:{port}', password=password, db=db)
  21. self.client = await aioredis.Redis(host=host, port=port, password=password, db=db)
  22. async def set_hash(self, hash_key, data, timeout=None):
  23. """添加或更新哈希,并设置有效期"""
  24. # 使用 hmset 方法设置哈希表数据
  25. await self.client.hmset(hash_key, data)
  26. if timeout:
  27. # 设置有效期(单位:秒)
  28. await self.client.expire(hash_key, timeout)
  29. # async def set_hash(self, hash_key, data, timeout=None):
  30. # """添加或更新哈希,并设置有效期"""
  31. # # 使用 hset 方法设置哈希表数据
  32. # await self.client.hset(hash_key, mapping=data)
  33. # if timeout:
  34. # # 设置有效期(单位:秒)
  35. # await self.client.expire(hash_key, timeout)
  36. # def flatten_dict(self,d, parent_key="", sep="."):
  37. # """
  38. # 将嵌套字典扁平化
  39. # :param d: 嵌套字典
  40. # :param parent_key: 父键(用于递归)
  41. # :param sep: 分隔符
  42. # :return: 扁平化字典
  43. # """
  44. # items = []
  45. # for k, v in d.items():
  46. # new_key = f"{parent_key}{sep}{k}" if parent_key else k
  47. # if isinstance(v, dict):
  48. # items.extend(self.flatten_dict(v, new_key, sep=sep).items())
  49. # else:
  50. # items.append((new_key, v))
  51. # return dict(items)
  52. # async def set_hash(self, hash_key, data, timeout=None):
  53. # """添加或更新哈希,并设置有效期"""
  54. # # 扁平化嵌套字典
  55. # flat_data = self.flatten_dict(data)
  56. # # 使用 hset 方法设置哈希表数据
  57. # await self.client.hset(hash_key, mapping=flat_data)
  58. # if timeout:
  59. # # 设置有效期(单位:秒)
  60. # await self.client.expire(hash_key, timeout)
  61. async def get_hash(self, hash_key):
  62. """获取整个哈希表数据"""
  63. result = await self.client.hgetall(hash_key)
  64. # 将字节数据解码成字符串格式返回
  65. return {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
  66. async def get_hash_field(self, hash_key, field):
  67. """获取哈希表中的单个字段值"""
  68. result = await self.client.hget(hash_key, field)
  69. return result.decode('utf-8') if result else None
  70. async def delete_hash(self, hash_key):
  71. """删除整个哈希表"""
  72. await self.client.delete(hash_key)
  73. async def delete_hash_field(self, hash_key, field):
  74. """删除哈希表中的某个字段"""
  75. await self.client.hdel(hash_key, field)
  76. async def update_hash_field(self, hash_key, field, value):
  77. """更新哈希表中的某个字段"""
  78. await self.client.hset(hash_key, field, value)
  79. # async def acquire_lock(self, lock_name, timeout=60):
  80. # """
  81. # 尝试获取分布式锁,成功返回 True,失败返回 False
  82. # :param lock_name: 锁的名称
  83. # :param timeout: 锁的超时时间(秒)
  84. # :return: bool
  85. # """
  86. # # identifier = str(time.time()) # 使用时间戳作为唯一标识
  87. # # if await self.client.setnx(lock_name, identifier):
  88. # # await self.client.expire(lock_name, timeout)
  89. # # self.lock_renewal_thread = threading.Thread(target=self.renew_lock, args=(lock_name, identifier, timeout))
  90. # # self.lock_renewal_thread.start()
  91. # # return True
  92. # # return False
  93. # identifier = str(time.time()) # 使用 UUID 作为唯一标识
  94. # if await self.client.setnx(lock_name, identifier):
  95. # await self.client.expire(lock_name, timeout)
  96. # self.lock_renewal_thread = threading.Thread(
  97. # target=self.run_renew_lock,
  98. # args=(lock_name, identifier, timeout)
  99. # )
  100. # self.lock_renewal_thread.start()
  101. # return True
  102. # return False
  103. # def run_renew_lock(self, lock_name, identifier, timeout):
  104. # """
  105. # 在单独的线程中运行异步的 renew_lock 方法
  106. # """
  107. # loop = asyncio.new_event_loop()
  108. # asyncio.set_event_loop(loop)
  109. # loop.run_until_complete(self.renew_lock(lock_name, identifier, timeout))
  110. # loop.close()
  111. async def acquire_lock(self, lock_name, timeout=60):
  112. """
  113. 尝试获取分布式锁,成功返回 True,失败返回 False
  114. :param lock_name: 锁的名称
  115. :param timeout: 锁的超时时间(秒)
  116. :return: bool
  117. """
  118. identifier = str(uuid.uuid4()) # 使用 UUID 作为唯一标识
  119. if await self.client.setnx(lock_name, identifier):
  120. await self.client.expire(lock_name, timeout)
  121. # 启动异步任务来续期锁
  122. self.lock_renewal_task = asyncio.create_task(
  123. self.renew_lock(lock_name, identifier, timeout)
  124. )
  125. return True
  126. return False
  127. async def renew_lock(self, lock_name, identifier, timeout):
  128. """
  129. 锁的自动续期
  130. :param lock_name: 锁的名称
  131. :param identifier: 锁的唯一标识
  132. :param timeout: 锁的超时时间(秒)
  133. """
  134. while True:
  135. await asyncio.sleep(timeout / 2)
  136. if await self.client.get(lock_name) == identifier.encode():
  137. await self.client.expire(lock_name, timeout)
  138. else:
  139. break
  140. async def release_lock(self, lock_name, identifier):
  141. """
  142. 释放分布式锁
  143. :param lock_name: 锁的名称
  144. :param identifier: 锁的唯一标识
  145. """
  146. if await self.client.get(lock_name) == identifier.encode():
  147. await self.client.delete(lock_name)
  148. if self.lock_renewal_thread:
  149. self.lock_renewal_thread.join()
  150. async def enqueue(self, queue_name, item):
  151. """
  152. 将元素添加到队列的尾部(右侧)
  153. :param queue_name: 队列名称
  154. :param item: 要添加到队列的元素
  155. """
  156. await self.client.rpush(queue_name, item)
  157. print(f"Enqueued: {item} to queue {queue_name}")
  158. async def dequeue(self, queue_name):
  159. """
  160. 从队列的头部(左侧)移除并返回元素
  161. :param queue_name: 队列名称
  162. :return: 移除的元素,如果队列为空则返回 None
  163. """
  164. item = await self.client.lpop(queue_name)
  165. if item:
  166. print(f"Dequeued: {item.decode('utf-8')} from queue {queue_name}")
  167. return item.decode('utf-8')
  168. print(f"Queue {queue_name} is empty")
  169. return None
  170. async def get_queue_length(self, queue_name):
  171. """
  172. 获取队列的长度
  173. :param queue_name: 队列名称
  174. :return: 队列的长度
  175. """
  176. length = await self.client.llen(queue_name)
  177. print(f"Queue {queue_name} length: {length}")
  178. return length
  179. async def peek_queue(self, queue_name):
  180. """
  181. 查看队列的头部元素,但不移除
  182. :param queue_name: 队列名称
  183. :return: 队列的头部元素,如果队列为空则返回 None
  184. """
  185. item = await self.client.lrange(queue_name, 0, 0)
  186. if item:
  187. print(f"Peeked: {item[0].decode('utf-8')} from queue {queue_name}")
  188. return item[0].decode('utf-8')
  189. print(f"Queue {queue_name} is empty")
  190. return None
  191. async def clear_queue(self, queue_name):
  192. """
  193. 清空队列
  194. :param queue_name: 队列名称
  195. """
  196. await self.client.delete(queue_name)
  197. print(f"Cleared queue {queue_name}")
  198. async def increment_hash_field(self, hash_key, field, amount=1):
  199. """
  200. 对哈希表中的指定字段进行递增操作(原子性)。
  201. :param hash_key: 哈希表的 key
  202. :param field: 要递增的字段
  203. :param amount: 递增的数值(默认为 1)
  204. :return: 递增后的值
  205. """
  206. return await self.client.hincrby(hash_key, field, amount)
  207. async def expire(self, key, timeout):
  208. """
  209. 设置 Redis 键的过期时间(单位:秒)
  210. :param key: Redis 键
  211. :param timeout: 过期时间(秒)
  212. """
  213. await self.client.expire(key, timeout)
  214. async def expire_field(self, hash_key, field, timeout):
  215. """
  216. 通过辅助键方式设置哈希表中某个字段的过期时间
  217. :param hash_key: 哈希表的 key
  218. :param field: 要设置过期的字段
  219. :param timeout: 过期时间(秒)
  220. """
  221. expire_key = f"{hash_key}:{field}:expire"
  222. await self.client.set(expire_key, "1")
  223. await self.client.expire(expire_key, timeout)
  224. # Dependency injection helper function
  225. async def get_redis_service(request: Request) -> RedisService:
  226. return request.app.state.redis_serive