Coverage for src/extratools_core/trie.py: 62%

93 statements  

« prev     ^ index     » next       coverage.py v7.8.1, created at 2025-06-11 20:59 -0700

1from __future__ import annotations 

2 

3from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping 

4from typing import Any 

5 

6 

7class TrieDict[VT: Any](MutableMapping[str, VT]): 

8 def __init__( 

9 self, 

10 initial_data: Mapping[str, VT] | Iterable[tuple[str, VT]] | None = None, 

11 ) -> None: 

12 self.root: dict[str, Any] = {} 

13 

14 self.__len: int = 0 

15 

16 if initial_data: 

17 for key, value in ( 

18 initial_data.items() if isinstance(initial_data, Mapping) 

19 else initial_data 

20 ): 

21 self.__setitem__(key, value) 

22 

23 def __len__(self) -> int: 

24 return self.__len 

25 

26 def __find(self, s: str, func: Callable[[dict[str, Any], str], Any]) -> Any: 

27 node: dict[str, Any] = self.root 

28 

29 while True: 

30 c: str = s[0] if s else "" 

31 rest: str = s[1:] if s else "" 

32 

33 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c) 

34 if next_node is None: 

35 raise KeyError 

36 

37 if isinstance(next_node, dict): 

38 node = next_node 

39 s = rest 

40 continue 

41 

42 if rest == next_node[0]: 

43 return func(node, c) 

44 

45 raise KeyError 

46 

47 def __delitem__(self, s: str) -> None: 

48 def delitem(node: dict[str, Any], c: str) -> None: 

49 del node[c] 

50 self.__len -= 1 

51 

52 return self.__find(s, delitem) 

53 

54 def __getitem__(self, s: str) -> VT: 

55 def getitem(node: dict[str, Any], c: str) -> VT: 

56 return node[c][1] 

57 

58 return self.__find(s, getitem) 

59 

60 def __setitem__(self, s: str, v: VT) -> None: 

61 self.__set(s, v, self.root, is_new=True) 

62 

63 def __set(self, s: str, v: VT, node: dict[str, Any], *, is_new: bool) -> None: 

64 if not s: 

65 is_new = is_new and "" not in node 

66 node[""] = ("", v) 

67 if is_new: 

68 self.__len += 1 

69 

70 return 

71 

72 c: str = s[0] 

73 rest: str = s[1:] 

74 

75 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c) 

76 if next_node is None: 

77 node[c] = (rest, v) 

78 if is_new: 

79 self.__len += 1 

80 elif isinstance(next_node, dict): 

81 self.__set(rest, v, next_node, is_new=is_new) 

82 else: 

83 other_rest: str 

84 other_value: VT 

85 other_rest, other_value = next_node 

86 

87 if rest == other_rest: 

88 node[c] = (rest, v) 

89 return 

90 

91 next_node = node[c] = {} 

92 

93 self.__set(other_rest, other_value, next_node, is_new=False) 

94 self.__set(rest, v, next_node, is_new=is_new) 

95 

96 def __iter__(self) -> Iterator[str]: 

97 for _, value in self.__prefixes("", self.root): 

98 yield value 

99 

100 def prefixes(self) -> Iterator[tuple[str, str]]: 

101 yield from self.__prefixes("", self.root) 

102 

103 def __prefixes(self, prefix: str, node: dict[str, Any]) -> Iterator[tuple[str, str]]: 

104 for key, next_node in node.items(): 

105 new_prefix = prefix + key 

106 if isinstance(next_node, dict): 

107 yield from self.__prefixes(new_prefix, next_node) 

108 else: 

109 yield (new_prefix, new_prefix + next_node[0]) 

110 

111 def match(self, prefix: str) -> Iterator[str]: 

112 node: dict[str, Any] = self.root 

113 s: str = prefix 

114 

115 matched: str = "" 

116 

117 while s: 

118 c: str = s[0] 

119 rest: str = s[1:] 

120 matched += c 

121 

122 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c) 

123 if next_node is None: 

124 return 

125 

126 if isinstance(next_node, dict): 

127 node = next_node 

128 s = rest 

129 continue 

130 

131 other_rest: str = next_node[0] 

132 if other_rest.startswith(rest): 

133 yield matched + other_rest 

134 

135 return 

136 

137 for _, value in self.__prefixes(prefix, node): 

138 yield value