Coverage for maze_dataset/utils.py: 57%

100 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 17:51 -0600

1"misc utilities for the `maze_dataset` package" 

2 

3import enum 

4import itertools 

5import math 

6import typing 

7from dataclasses import Field # noqa: TC003 

8from functools import cache, wraps 

9from types import UnionType 

10from typing import ( 

11 Callable, 

12 Generator, 

13 Iterable, 

14 Literal, 

15 TypeVar, 

16 get_args, 

17 get_origin, 

18 overload, 

19) 

20 

21import frozendict 

22import numpy as np 

23from jaxtyping import Bool, Int, Int8 

24from muutils.misc import IsDataclass, flatten, is_abstract 

25 

26 

27def bool_array_from_string( 

28 string: str, 

29 shape: list[int], 

30 true_symbol: str = "T", 

31) -> Bool[np.ndarray, "*shape"]: 

32 """Transform a string into an ndarray of bools. 

33 

34 Parameters 

35 ---------- 

36 string: str 

37 The string representation of the array 

38 shape: list[int] 

39 The shape of the resulting array 

40 true_symbol: 

41 The character to parse as True. Whitespace will be removed. All other characters will be parsed as False. 

42 

43 Returns 

44 ------- 

45 np.ndarray 

46 A ndarray with dtype bool of shape `shape` 

47 

48 Examples 

49 -------- 

50 >>> bool_array_from_string( 

51 ... "TT TF", shape=[2,2] 

52 ... ) 

53 array([[ True, True], 

54 [ True, False]]) 

55 

56 """ 

57 stripped = "".join(string.split()) 

58 

59 expected_symbol_count = math.prod(shape) 

60 symbol_count = len(stripped) 

61 if len(stripped) != expected_symbol_count: 

62 err_msg: str = ( 

63 f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}.", 

64 ) 

65 raise ValueError(err_msg) 

66 

67 bools = [(symbol == true_symbol) for symbol in stripped] 

68 return np.array(bools).reshape(*shape) 

69 

70 

71def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]: 

72 """returns an array of indices, sorted by distance from the corner 

73 

74 this gives the property that `np.ndindex((n,n))` is equal to 

75 the first n^2 elements of `np.ndindex((n+1, n+1))` 

76 

77 ``` 

78 >>> corner_first_ndindex(1) 

79 [(0, 0)] 

80 >>> corner_first_ndindex(2) 

81 [(0, 0), (0, 1), (1, 0), (1, 1)] 

82 >>> corner_first_ndindex(3) 

83 [(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)] 

84 ``` 

85 """ 

86 unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)]))) 

87 return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1])) 

88 

89 

90# alternate numpy version from GPT-4: 

91""" 

92# Create all index combinations 

93indices = np.indices([n]*ndim).reshape(ndim, -1).T 

94# Find the max value for each index 

95max_indices = np.max(indices, axis=1) 

96# Identify the odd max values 

97odd_mask = max_indices % 2 != 0 

98# Make a copy of indices to avoid changing the original one 

99indices_copy = indices.copy() 

100# Reverse the order of the coordinates for indices with odd max value 

101indices_copy[odd_mask] = indices_copy[odd_mask, ::-1] 

102# Sort by max index value, then by coordinates 

103sorted_order = np.lexsort((*indices_copy.T, max_indices)) 

104return indices[sorted_order] 

105""" 

106 

107 

108@overload 

109def manhattan_distance( 

110 edges: Int[np.ndarray, "edges coord=2 row_col=2"], 

111) -> Int8[np.ndarray, " edges"]: ... 

112@overload 

113def manhattan_distance( 

114 edges: Int[np.ndarray, "coord=2 row_col=2"], 

115) -> int: ... 

116def manhattan_distance( 

117 edges: ( 

118 Int[np.ndarray, "edges coord=2 row_col=2"] 

119 | Int[np.ndarray, "coord=2 row_col=2"] 

120 ), 

121) -> Int8[np.ndarray, " edges"] | int: 

122 """Returns the Manhattan distance between two coords.""" 

123 # magic values for dims fine here 

124 if len(edges.shape) == 3: # noqa: PLR2004 

125 return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype( 

126 np.int8, 

127 ) 

128 elif len(edges.shape) == 2: # noqa: PLR2004 

129 return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8)) 

130 else: 

131 err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints." 

132 raise ValueError(err_msg) 

133 

134 

135def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]: 

136 """Returns an array with the maximum possible degree for each coord.""" 

137 out = np.full((n, n), 2) 

138 out[1:-1, :] += 1 

139 out[:, 1:-1] += 1 

140 return out 

141 

142 

143def lattice_connection_array( 

144 n: int, 

145) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]: 

146 """Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n. 

147 

148 Thanks Claude. 

149 

150 # Parameters 

151 - `n`: The size of the square lattice. 

152 

153 # Returns 

154 np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. 

155 In each pair, the coord with the smaller sum always comes first. 

156 """ 

157 row_coords, col_coords = np.meshgrid( 

158 np.arange(n, dtype=np.int8), 

159 np.arange(n, dtype=np.int8), 

160 indexing="ij", 

161 ) 

162 

163 # Horizontal edges 

164 horiz_edges = np.column_stack( 

165 ( 

166 row_coords[:, :-1].ravel(), 

167 col_coords[:, :-1].ravel(), 

168 row_coords[:, 1:].ravel(), 

169 col_coords[:, 1:].ravel(), 

170 ), 

171 ) 

172 

173 # Vertical edges 

174 vert_edges = np.column_stack( 

175 ( 

176 row_coords[:-1, :].ravel(), 

177 col_coords[:-1, :].ravel(), 

178 row_coords[1:, :].ravel(), 

179 col_coords[1:, :].ravel(), 

180 ), 

181 ) 

182 

183 return np.concatenate( 

184 (horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)), 

185 axis=0, 

186 ) 

187 

188 

189def adj_list_to_nested_set(adj_list: list) -> set: 

190 """Used for comparison of adj_lists 

191 

192 Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] 

193 We don't care about order of coordinate pairs within 

194 the adj_list or coordinates within each coordinate pair. 

195 """ 

196 return { 

197 frozenset([tuple(start_coord), tuple(end_coord)]) 

198 for start_coord, end_coord in adj_list 

199 } 

200 

201 

202FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum) 

203""" 

204# `FiniteValued` 

205The details of this type are not possible to fully define via the Python 3.10 typing library. 

206This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space. 

207`FiniteValued` defines the domain of supported types for the `all_instances` function, since that function relies heavily on static typing. 

208These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below). 

209The leaves of the tree must always be Primitive Types. 

210 

211# `FiniteValued` Subtypes 

212*: Indicates that this subtype is not yet supported by `all_instances` 

213 

214## Non-`FiniteValued` (Unbounded) Types 

215These are NOT valid subtypes, and are listed for illustrative purposes only. 

216This list is not comprehensive. 

217While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite, 

218they are considered unbounded types in this context. 

219- No Container subtype may contain any of these unbounded subtypes. 

220- `int` 

221- `float` 

222- `str` 

223- `list` 

224- `set`: Set types without a `FiniteValued` argument are unbounded 

225- `tuple`: Tuple types without a fixed length are unbounded 

226 

227## Primitive Types 

228Primitive types are non-nested types which resolve directly to a concrete range of values 

229- `bool`: has 2 possible values 

230- *`enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members 

231- `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition. 

232This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy. 

233 

234## Container Types 

235Container types are types which contain zero or more fields of `FiniteValued` type. 

236The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`. 

237- `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`. 

238- `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`. 

239- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed. 

240- *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type. 

241 

242## Superclass Types 

243Superclass types don't directly contain data members like container types. 

244Their range is the union of the ranges of their subtypes. 

245- Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types 

246- *`IsDataclass`: Concrete dataclasses which also have their own subclasses. 

247- *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types 

248- `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3] 

249""" 

250 

251 

252def _apply_validation_func( 

253 type_: FiniteValued, 

254 vals: Generator[FiniteValued, None, None], 

255 validation_funcs: ( 

256 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] | None 

257 ) = None, 

258) -> Generator[FiniteValued, None, None]: 

259 """Helper function for `all_instances`. 

260 

261 Filters `vals` according to `validation_funcs`. 

262 If `type_` is a regular type, searches in MRO order in `validation_funcs` and applies the first match, if any. 

263 Handles generic types supported by `all_instances` with special `if` clauses. 

264 

265 # Parameters 

266 - `type_: FiniteValued`: A type 

267 - `vals: Generator[FiniteValued, None, None]`: Instances of `type_` 

268 - `validation_funcs: dict`: Collection of types mapped to filtering validation functions 

269 """ 

270 if validation_funcs is None: 

271 return vals 

272 if type_ in validation_funcs: # Only possible catch of UnionTypes 

273 # TYPING: Incompatible return value type (got "filter[FiniteValued]", expected "Generator[FiniteValued, None, None]") [return-value] 

274 return filter(validation_funcs[type_], vals) 

275 elif hasattr( 

276 type_, 

277 "__mro__", 

278 ): # Generic types like UnionType, Literal don't have `__mro__` 

279 for superclass in type_.__mro__: 

280 if superclass not in validation_funcs: 

281 continue 

282 # TYPING: error: Incompatible types in assignment (expression has type "filter[FiniteValued]", variable has type "Generator[FiniteValued, None, None]") [assignment] 

283 vals = filter(validation_funcs[superclass], vals) 

284 break # Only the first validation function hit in the mro is applied 

285 elif get_origin(type_) == Literal: 

286 return flatten( 

287 ( 

288 _apply_validation_func(type(v), [v], validation_funcs) 

289 for v in get_args(type_) 

290 ), 

291 levels_to_flatten=1, 

292 ) 

293 return vals 

294 

295 

296# TYPING: some better type hints would be nice here 

297def _all_instances_wrapper(f: Callable) -> Callable: 

298 """Converts dicts to frozendicts to allow caching and applies `_apply_validation_func`.""" 

299 

300 @wraps(f) 

301 def wrapper(*args, **kwargs): # noqa: ANN202 

302 @cache 

303 def cached_wrapper( # noqa: ANN202 

304 type_: type, 

305 all_instances_func: Callable, 

306 validation_funcs: ( 

307 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] 

308 | None 

309 ), 

310 ): 

311 return _apply_validation_func( 

312 type_, 

313 all_instances_func(type_, validation_funcs), 

314 validation_funcs, 

315 ) 

316 

317 validation_funcs: frozendict.frozendict 

318 # TODO: what is this magic value here exactly? 

319 if len(args) >= 2 and args[1] is not None: # noqa: PLR2004 

320 validation_funcs = frozendict.frozendict(args[1]) 

321 elif "validation_funcs" in kwargs and kwargs["validation_funcs"] is not None: 

322 validation_funcs = frozendict.frozendict(kwargs["validation_funcs"]) 

323 else: 

324 validation_funcs = None 

325 return cached_wrapper(args[0], f, validation_funcs) 

326 

327 return wrapper 

328 

329 

330class UnsupportedAllInstancesError(TypeError): 

331 """Raised when `all_instances` is called on an unsupported type 

332 

333 either has unbounded possible values or is not supported (Enum is not supported) 

334 """ 

335 

336 def __init__(self, type_: type) -> None: 

337 "constructs an error message with the type and mro of the type" 

338 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }" 

339 super().__init__(msg) 

340 

341 

342@_all_instances_wrapper 

343def all_instances( 

344 type_: FiniteValued, 

345 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None, 

346) -> Generator[FiniteValued, None, None]: 

347 """Returns all possible values of an instance of `type_` if finite instances exist. 

348 

349 Uses type hinting to construct the possible values. 

350 All nested elements of `type_` must themselves be typed. 

351 Do not use with types whose members contain circular references. 

352 Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`. 

353 

354 # Parameters 

355 - `type_: FiniteValued` 

356 A finite-valued type. See docstring on `FiniteValued` for full details. 

357 - `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None` 

358 A mapping of types to auxiliary functions to validate instances of that type. 

359 This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. 

360 See `validation_funcs` Details section below. 

361 (default: `None`) 

362 

363 ## Supported `type_` Values 

364 See docstring on `FiniteValued` for full details. 

365 `type_` may be: 

366 - `FiniteValued` 

367 - A finite-valued, fixed-length Generic tuple type. 

368 E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK. 

369 `tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed. 

370 - Nested versions of any of the types in this list 

371 - A `UnionType` of any of the types in this list 

372 

373 ## `validation_funcs` Details 

374 - `validation_funcs` is applied after all instances have been generated according to type hints. 

375 - If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`. 

376 - `validation_funcs` is passed down for all recursive calls of `all_instances`. 

377 - This allows for improved performance through maximal pruning of the exponential tree. 

378 - `validation_funcs` supports subclass checking. 

379 - If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order. 

380 - If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned. 

381 - If no superclass of `type_` is found, then no filter is applied. 

382 

383 # Raises: 

384 - `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`. 

385 """ 

386 if type_ == bool: # noqa: E721 

387 yield from [True, False] 

388 elif hasattr(type_, "__dataclass_fields__"): 

389 if is_abstract(type_): 

390 # Abstract dataclass: call `all_instances` on each subclass 

391 yield from flatten( 

392 ( 

393 all_instances(sub, validation_funcs) 

394 for sub in type_.__subclasses__() 

395 ), 

396 levels_to_flatten=1, 

397 ) 

398 else: 

399 # Concrete dataclass: construct dataclass instances with all possible combinations of fields 

400 fields: list[Field] = type_.__dataclass_fields__ 

401 fields_to_types: dict[str, type] = {f: fields[f].type for f in fields} 

402 all_arg_sequences: Iterable = itertools.product( 

403 *[ 

404 all_instances(arg_type, validation_funcs) 

405 for arg_type in fields_to_types.values() 

406 ], 

407 ) 

408 yield from ( 

409 type_( 

410 **dict(zip(fields_to_types.keys(), args, strict=False)), 

411 ) 

412 for args in all_arg_sequences 

413 ) 

414 else: 

415 type_origin = get_origin(type_) 

416 if type_origin == tuple: # noqa: E721 

417 # Only matches Generic type tuple since regular tuple is not finite-valued 

418 # Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields. 

419 yield from ( 

420 tuple(combo) 

421 for combo in itertools.product( 

422 *( 

423 all_instances(tup_item, validation_funcs) 

424 for tup_item in get_args(type_) 

425 ), 

426 ) 

427 ) 

428 elif type_origin in (UnionType, typing.Union): 

429 # Union: call `all_instances` for each type in the Union 

430 yield from flatten( 

431 [all_instances(sub, validation_funcs) for sub in get_args(type_)], 

432 levels_to_flatten=1, 

433 ) 

434 elif type_origin is Literal: 

435 # Literal: return all Literal arguments 

436 yield from get_args(type_) 

437 else: 

438 raise UnsupportedAllInstancesError(type_)