You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.2KB

  1. import heapq
  2. class SortedDict(dict):
  3. def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False):
  4. if init_dict is None:
  5. init_dict = []
  6. if isinstance(init_dict, dict):
  7. init_dict = init_dict.items()
  8. self.sort_func = sort_func
  9. self.sorted_keys = None
  10. self.reverse = reverse
  11. self.heap = []
  12. for k, v in init_dict:
  13. self[k] = v
  14. def __setitem__(self, key, value):
  15. if key in self:
  16. super().__setitem__(key, value)
  17. for i, (priority, k) in enumerate(self.heap):
  18. if k == key:
  19. self.heap[i] = (self.sort_func(key, value), key)
  20. heapq.heapify(self.heap)
  21. break
  22. self.sorted_keys = None
  23. else:
  24. super().__setitem__(key, value)
  25. heapq.heappush(self.heap, (self.sort_func(key, value), key))
  26. self.sorted_keys = None
  27. def __delitem__(self, key):
  28. super().__delitem__(key)
  29. for i, (priority, k) in enumerate(self.heap):
  30. if k == key:
  31. del self.heap[i]
  32. heapq.heapify(self.heap)
  33. break
  34. self.sorted_keys = None
  35. def keys(self):
  36. if self.sorted_keys is None:
  37. self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
  38. return self.sorted_keys
  39. def items(self):
  40. if self.sorted_keys is None:
  41. self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
  42. sorted_items = [(k, self[k]) for k in self.sorted_keys]
  43. return sorted_items
  44. def _update_heap(self, key):
  45. for i, (priority, k) in enumerate(self.heap):
  46. if k == key:
  47. new_priority = self.sort_func(key, self[key])
  48. if new_priority != priority:
  49. self.heap[i] = (new_priority, key)
  50. heapq.heapify(self.heap)
  51. self.sorted_keys = None
  52. break
  53. def __iter__(self):
  54. return iter(self.keys())
  55. def __repr__(self):
  56. return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})"