import aioredis
import os
import uuid
import asyncio
import threading
from fastapi import FastAPI, Depends
from fastapi import Request
import time

# 定义全局 redis_helper 作为单例
class RedisService:
    _instance = None
    def __new__(cls, host='localhost', port=6379, password=None, db=0):
        if not cls._instance:
            cls._instance = super(RedisService, cls).__new__(cls)
            cls._instance.client = None
            cls._instance.lock_renewal_thread = None
        return cls._instance

    async def init(self, host='localhost', port=6379, password=None, db=0):
        """初始化 Redis 连接"""
        #self.client = await aioredis.Redis(f'redis://{host}:{port}', password=password, db=db)
        self.client = await aioredis.Redis(host=host, port=port, password=password, db=db)
    

    # async def set_hash(self, hash_key, data, timeout=None):
    #     """添加或更新哈希,并设置有效期"""
    #     await self.client.hmset_dict(hash_key, data)
    #     if timeout:
    #         # 设置有效期(单位:秒)
    #         await self.client.expire(hash_key, timeout)

    async def set_hash(self, hash_key, data, timeout=None):
        """添加或更新哈希,并设置有效期"""
        # 使用 hmset 方法设置哈希表数据
        await self.client.hmset(hash_key, data)
        if timeout:
            # 设置有效期(单位:秒)
            await self.client.expire(hash_key, timeout)

    async def get_hash(self, hash_key):
        """获取整个哈希表数据"""
        result = await self.client.hgetall(hash_key)
        # 将字节数据解码成字符串格式返回
        return {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}

    async def get_hash_field(self, hash_key, field):
        """获取哈希表中的单个字段值"""
        result = await self.client.hget(hash_key, field)
        return result.decode('utf-8') if result else None

    async def delete_hash(self, hash_key):
        """删除整个哈希表"""
        await self.client.delete(hash_key)

    async def delete_hash_field(self, hash_key, field):
        """删除哈希表中的某个字段"""
        await self.client.hdel(hash_key, field)

    async def update_hash_field(self, hash_key, field, value):
        """更新哈希表中的某个字段"""
        await self.client.hset(hash_key, field, value)

    # async def acquire_lock(self, lock_name, timeout=60):
    #     """
    #     尝试获取分布式锁,成功返回 True,失败返回 False
    #     :param lock_name: 锁的名称
    #     :param timeout: 锁的超时时间(秒)
    #     :return: bool
    #     """
    #     # identifier = str(time.time())  # 使用时间戳作为唯一标识
    #     # if await self.client.setnx(lock_name, identifier):
    #     #     await self.client.expire(lock_name, timeout)
    #     #     self.lock_renewal_thread = threading.Thread(target=self.renew_lock, args=(lock_name, identifier, timeout))
    #     #     self.lock_renewal_thread.start()
    #     #     return True
    #     # return False

    #     identifier = str(time.time())  # 使用 UUID 作为唯一标识
    #     if await self.client.setnx(lock_name, identifier):
    #         await self.client.expire(lock_name, timeout)
    #         self.lock_renewal_thread = threading.Thread(
    #             target=self.run_renew_lock, 
    #             args=(lock_name, identifier, timeout)
    #         )
    #         self.lock_renewal_thread.start()
    #         return True
    #     return False

    
    # def run_renew_lock(self, lock_name, identifier, timeout):
    #     """
    #     在单独的线程中运行异步的 renew_lock 方法
    #     """
    #     loop = asyncio.new_event_loop()
    #     asyncio.set_event_loop(loop)
    #     loop.run_until_complete(self.renew_lock(lock_name, identifier, timeout))
    #     loop.close()    


    async def acquire_lock(self, lock_name, timeout=60):
        """
        尝试获取分布式锁,成功返回 True,失败返回 False
        :param lock_name: 锁的名称
        :param timeout: 锁的超时时间(秒)
        :return: bool
        """
        identifier = str(uuid.uuid4())  # 使用 UUID 作为唯一标识
        if await self.client.setnx(lock_name, identifier):
            await self.client.expire(lock_name, timeout)
            # 启动异步任务来续期锁
            self.lock_renewal_task = asyncio.create_task(
                self.renew_lock(lock_name, identifier, timeout)
            )
            return True
        return False

    async def renew_lock(self, lock_name, identifier, timeout):
        """
        锁的自动续期
        :param lock_name: 锁的名称
        :param identifier: 锁的唯一标识
        :param timeout: 锁的超时时间(秒)
        """
        while True:
            await asyncio.sleep(timeout / 2)
            if await self.client.get(lock_name) == identifier.encode():
                await self.client.expire(lock_name, timeout)
            else:
                break

    async def release_lock(self, lock_name, identifier):
        """
        释放分布式锁
        :param lock_name: 锁的名称
        :param identifier: 锁的唯一标识
        """
        if await self.client.get(lock_name) == identifier.encode():
            await self.client.delete(lock_name)
            if self.lock_renewal_thread:
                self.lock_renewal_thread.join()

    async def enqueue(self, queue_name, item):
        """
        将元素添加到队列的尾部(右侧)
        :param queue_name: 队列名称
        :param item: 要添加到队列的元素
        """
        await self.client.rpush(queue_name, item)
        print(f"Enqueued: {item} to queue {queue_name}")

    async def dequeue(self, queue_name):
        """
        从队列的头部(左侧)移除并返回元素
        :param queue_name: 队列名称
        :return: 移除的元素,如果队列为空则返回 None
        """
        item = await self.client.lpop(queue_name)
        if item:
            print(f"Dequeued: {item.decode('utf-8')} from queue {queue_name}")
            return item.decode('utf-8')
        print(f"Queue {queue_name} is empty")
        return None

    async def get_queue_length(self, queue_name):
        """
        获取队列的长度
        :param queue_name: 队列名称
        :return: 队列的长度
        """
        length = await self.client.llen(queue_name)
        print(f"Queue {queue_name} length: {length}")
        return length

    async def peek_queue(self, queue_name):
        """
        查看队列的头部元素,但不移除
        :param queue_name: 队列名称
        :return: 队列的头部元素,如果队列为空则返回 None
        """
        item = await self.client.lrange(queue_name, 0, 0)
        if item:
            print(f"Peeked: {item[0].decode('utf-8')} from queue {queue_name}")
            return item[0].decode('utf-8')
        print(f"Queue {queue_name} is empty")
        return None

    async def clear_queue(self, queue_name):
        """
        清空队列
        :param queue_name: 队列名称
        """
        await self.client.delete(queue_name)
        print(f"Cleared queue {queue_name}")

    async def increment_hash_field(self, hash_key, field, amount=1):
        """
        对哈希表中的指定字段进行递增操作(原子性)。
        
        :param hash_key: 哈希表的 key
        :param field: 要递增的字段
        :param amount: 递增的数值(默认为 1)
        :return: 递增后的值
        """
        return await self.client.hincrby(hash_key, field, amount)
    
    async def expire(self, key, timeout):
        """
        设置 Redis 键的过期时间(单位:秒)
        
        :param key: Redis 键
        :param timeout: 过期时间(秒)
        """
        await self.client.expire(key, timeout)

    async def expire_field(self, hash_key, field, timeout):
        """
        通过辅助键方式设置哈希表中某个字段的过期时间
        
        :param hash_key: 哈希表的 key
        :param field: 要设置过期的字段
        :param timeout: 过期时间(秒)
        """
        expire_key = f"{hash_key}:{field}:expire"
        await self.client.set(expire_key, "1")
        await self.client.expire(expire_key, timeout)

# Dependency injection helper function
async def get_redis_service(request: Request) -> RedisService:
    return request.app.state.redis_serive