Coverage for maze_dataset/utils.py: 57%
100 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
1"misc utilities for the `maze_dataset` package"
3import enum
4import itertools
5import math
6import typing
7from dataclasses import Field # noqa: TC003
8from functools import cache, wraps
9from types import UnionType
10from typing import (
11 Callable,
12 Generator,
13 Iterable,
14 Literal,
15 TypeVar,
16 get_args,
17 get_origin,
18 overload,
19)
21import frozendict
22import numpy as np
23from jaxtyping import Bool, Int, Int8
24from muutils.misc import IsDataclass, flatten, is_abstract
27def bool_array_from_string(
28 string: str,
29 shape: list[int],
30 true_symbol: str = "T",
31) -> Bool[np.ndarray, "*shape"]:
32 """Transform a string into an ndarray of bools.
34 Parameters
35 ----------
36 string: str
37 The string representation of the array
38 shape: list[int]
39 The shape of the resulting array
40 true_symbol:
41 The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.
43 Returns
44 -------
45 np.ndarray
46 A ndarray with dtype bool of shape `shape`
48 Examples
49 --------
50 >>> bool_array_from_string(
51 ... "TT TF", shape=[2,2]
52 ... )
53 array([[ True, True],
54 [ True, False]])
56 """
57 stripped = "".join(string.split())
59 expected_symbol_count = math.prod(shape)
60 symbol_count = len(stripped)
61 if len(stripped) != expected_symbol_count:
62 err_msg: str = (
63 f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}.",
64 )
65 raise ValueError(err_msg)
67 bools = [(symbol == true_symbol) for symbol in stripped]
68 return np.array(bools).reshape(*shape)
71def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]:
72 """returns an array of indices, sorted by distance from the corner
74 this gives the property that `np.ndindex((n,n))` is equal to
75 the first n^2 elements of `np.ndindex((n+1, n+1))`
77 ```
78 >>> corner_first_ndindex(1)
79 [(0, 0)]
80 >>> corner_first_ndindex(2)
81 [(0, 0), (0, 1), (1, 0), (1, 1)]
82 >>> corner_first_ndindex(3)
83 [(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
84 ```
85 """
86 unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)])))
87 return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1]))
90# alternate numpy version from GPT-4:
91"""
92# Create all index combinations
93indices = np.indices([n]*ndim).reshape(ndim, -1).T
94# Find the max value for each index
95max_indices = np.max(indices, axis=1)
96# Identify the odd max values
97odd_mask = max_indices % 2 != 0
98# Make a copy of indices to avoid changing the original one
99indices_copy = indices.copy()
100# Reverse the order of the coordinates for indices with odd max value
101indices_copy[odd_mask] = indices_copy[odd_mask, ::-1]
102# Sort by max index value, then by coordinates
103sorted_order = np.lexsort((*indices_copy.T, max_indices))
104return indices[sorted_order]
105"""
108@overload
109def manhattan_distance(
110 edges: Int[np.ndarray, "edges coord=2 row_col=2"],
111) -> Int8[np.ndarray, " edges"]: ...
112@overload
113def manhattan_distance(
114 edges: Int[np.ndarray, "coord=2 row_col=2"],
115) -> int: ...
116def manhattan_distance(
117 edges: (
118 Int[np.ndarray, "edges coord=2 row_col=2"]
119 | Int[np.ndarray, "coord=2 row_col=2"]
120 ),
121) -> Int8[np.ndarray, " edges"] | int:
122 """Returns the Manhattan distance between two coords."""
123 # magic values for dims fine here
124 if len(edges.shape) == 3: # noqa: PLR2004
125 return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype(
126 np.int8,
127 )
128 elif len(edges.shape) == 2: # noqa: PLR2004
129 return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8))
130 else:
131 err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints."
132 raise ValueError(err_msg)
135def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]:
136 """Returns an array with the maximum possible degree for each coord."""
137 out = np.full((n, n), 2)
138 out[1:-1, :] += 1
139 out[:, 1:-1] += 1
140 return out
143def lattice_connection_array(
144 n: int,
145) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]:
146 """Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.
148 Thanks Claude.
150 # Parameters
151 - `n`: The size of the square lattice.
153 # Returns
154 np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice.
155 In each pair, the coord with the smaller sum always comes first.
156 """
157 row_coords, col_coords = np.meshgrid(
158 np.arange(n, dtype=np.int8),
159 np.arange(n, dtype=np.int8),
160 indexing="ij",
161 )
163 # Horizontal edges
164 horiz_edges = np.column_stack(
165 (
166 row_coords[:, :-1].ravel(),
167 col_coords[:, :-1].ravel(),
168 row_coords[:, 1:].ravel(),
169 col_coords[:, 1:].ravel(),
170 ),
171 )
173 # Vertical edges
174 vert_edges = np.column_stack(
175 (
176 row_coords[:-1, :].ravel(),
177 col_coords[:-1, :].ravel(),
178 row_coords[1:, :].ravel(),
179 col_coords[1:, :].ravel(),
180 ),
181 )
183 return np.concatenate(
184 (horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)),
185 axis=0,
186 )
189def adj_list_to_nested_set(adj_list: list) -> set:
190 """Used for comparison of adj_lists
192 Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...]
193 We don't care about order of coordinate pairs within
194 the adj_list or coordinates within each coordinate pair.
195 """
196 return {
197 frozenset([tuple(start_coord), tuple(end_coord)])
198 for start_coord, end_coord in adj_list
199 }
202FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum)
203"""
204# `FiniteValued`
205The details of this type are not possible to fully define via the Python 3.10 typing library.
206This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space.
207`FiniteValued` defines the domain of supported types for the `all_instances` function, since that function relies heavily on static typing.
208These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below).
209The leaves of the tree must always be Primitive Types.
211# `FiniteValued` Subtypes
212*: Indicates that this subtype is not yet supported by `all_instances`
214## Non-`FiniteValued` (Unbounded) Types
215These are NOT valid subtypes, and are listed for illustrative purposes only.
216This list is not comprehensive.
217While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite,
218they are considered unbounded types in this context.
219- No Container subtype may contain any of these unbounded subtypes.
220- `int`
221- `float`
222- `str`
223- `list`
224- `set`: Set types without a `FiniteValued` argument are unbounded
225- `tuple`: Tuple types without a fixed length are unbounded
227## Primitive Types
228Primitive types are non-nested types which resolve directly to a concrete range of values
229- `bool`: has 2 possible values
230- *`enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members
231- `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition.
232This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy.
234## Container Types
235Container types are types which contain zero or more fields of `FiniteValued` type.
236The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`.
237- `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`.
238- `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`.
239- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed.
240- *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type.
242## Superclass Types
243Superclass types don't directly contain data members like container types.
244Their range is the union of the ranges of their subtypes.
245- Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types
246- *`IsDataclass`: Concrete dataclasses which also have their own subclasses.
247- *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types
248- `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3]
249"""
252def _apply_validation_func(
253 type_: FiniteValued,
254 vals: Generator[FiniteValued, None, None],
255 validation_funcs: (
256 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] | None
257 ) = None,
258) -> Generator[FiniteValued, None, None]:
259 """Helper function for `all_instances`.
261 Filters `vals` according to `validation_funcs`.
262 If `type_` is a regular type, searches in MRO order in `validation_funcs` and applies the first match, if any.
263 Handles generic types supported by `all_instances` with special `if` clauses.
265 # Parameters
266 - `type_: FiniteValued`: A type
267 - `vals: Generator[FiniteValued, None, None]`: Instances of `type_`
268 - `validation_funcs: dict`: Collection of types mapped to filtering validation functions
269 """
270 if validation_funcs is None:
271 return vals
272 if type_ in validation_funcs: # Only possible catch of UnionTypes
273 # TYPING: Incompatible return value type (got "filter[FiniteValued]", expected "Generator[FiniteValued, None, None]") [return-value]
274 return filter(validation_funcs[type_], vals)
275 elif hasattr(
276 type_,
277 "__mro__",
278 ): # Generic types like UnionType, Literal don't have `__mro__`
279 for superclass in type_.__mro__:
280 if superclass not in validation_funcs:
281 continue
282 # TYPING: error: Incompatible types in assignment (expression has type "filter[FiniteValued]", variable has type "Generator[FiniteValued, None, None]") [assignment]
283 vals = filter(validation_funcs[superclass], vals)
284 break # Only the first validation function hit in the mro is applied
285 elif get_origin(type_) == Literal:
286 return flatten(
287 (
288 _apply_validation_func(type(v), [v], validation_funcs)
289 for v in get_args(type_)
290 ),
291 levels_to_flatten=1,
292 )
293 return vals
296# TYPING: some better type hints would be nice here
297def _all_instances_wrapper(f: Callable) -> Callable:
298 """Converts dicts to frozendicts to allow caching and applies `_apply_validation_func`."""
300 @wraps(f)
301 def wrapper(*args, **kwargs): # noqa: ANN202
302 @cache
303 def cached_wrapper( # noqa: ANN202
304 type_: type,
305 all_instances_func: Callable,
306 validation_funcs: (
307 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]]
308 | None
309 ),
310 ):
311 return _apply_validation_func(
312 type_,
313 all_instances_func(type_, validation_funcs),
314 validation_funcs,
315 )
317 validation_funcs: frozendict.frozendict
318 # TODO: what is this magic value here exactly?
319 if len(args) >= 2 and args[1] is not None: # noqa: PLR2004
320 validation_funcs = frozendict.frozendict(args[1])
321 elif "validation_funcs" in kwargs and kwargs["validation_funcs"] is not None:
322 validation_funcs = frozendict.frozendict(kwargs["validation_funcs"])
323 else:
324 validation_funcs = None
325 return cached_wrapper(args[0], f, validation_funcs)
327 return wrapper
330class UnsupportedAllInstancesError(TypeError):
331 """Raised when `all_instances` is called on an unsupported type
333 either has unbounded possible values or is not supported (Enum is not supported)
334 """
336 def __init__(self, type_: type) -> None:
337 "constructs an error message with the type and mro of the type"
338 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }"
339 super().__init__(msg)
342@_all_instances_wrapper
343def all_instances(
344 type_: FiniteValued,
345 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None,
346) -> Generator[FiniteValued, None, None]:
347 """Returns all possible values of an instance of `type_` if finite instances exist.
349 Uses type hinting to construct the possible values.
350 All nested elements of `type_` must themselves be typed.
351 Do not use with types whose members contain circular references.
352 Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`.
354 # Parameters
355 - `type_: FiniteValued`
356 A finite-valued type. See docstring on `FiniteValued` for full details.
357 - `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None`
358 A mapping of types to auxiliary functions to validate instances of that type.
359 This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide.
360 See `validation_funcs` Details section below.
361 (default: `None`)
363 ## Supported `type_` Values
364 See docstring on `FiniteValued` for full details.
365 `type_` may be:
366 - `FiniteValued`
367 - A finite-valued, fixed-length Generic tuple type.
368 E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK.
369 `tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed.
370 - Nested versions of any of the types in this list
371 - A `UnionType` of any of the types in this list
373 ## `validation_funcs` Details
374 - `validation_funcs` is applied after all instances have been generated according to type hints.
375 - If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`.
376 - `validation_funcs` is passed down for all recursive calls of `all_instances`.
377 - This allows for improved performance through maximal pruning of the exponential tree.
378 - `validation_funcs` supports subclass checking.
379 - If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order.
380 - If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned.
381 - If no superclass of `type_` is found, then no filter is applied.
383 # Raises:
384 - `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`.
385 """
386 if type_ == bool: # noqa: E721
387 yield from [True, False]
388 elif hasattr(type_, "__dataclass_fields__"):
389 if is_abstract(type_):
390 # Abstract dataclass: call `all_instances` on each subclass
391 yield from flatten(
392 (
393 all_instances(sub, validation_funcs)
394 for sub in type_.__subclasses__()
395 ),
396 levels_to_flatten=1,
397 )
398 else:
399 # Concrete dataclass: construct dataclass instances with all possible combinations of fields
400 fields: list[Field] = type_.__dataclass_fields__
401 fields_to_types: dict[str, type] = {f: fields[f].type for f in fields}
402 all_arg_sequences: Iterable = itertools.product(
403 *[
404 all_instances(arg_type, validation_funcs)
405 for arg_type in fields_to_types.values()
406 ],
407 )
408 yield from (
409 type_(
410 **dict(zip(fields_to_types.keys(), args, strict=False)),
411 )
412 for args in all_arg_sequences
413 )
414 else:
415 type_origin = get_origin(type_)
416 if type_origin == tuple: # noqa: E721
417 # Only matches Generic type tuple since regular tuple is not finite-valued
418 # Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields.
419 yield from (
420 tuple(combo)
421 for combo in itertools.product(
422 *(
423 all_instances(tup_item, validation_funcs)
424 for tup_item in get_args(type_)
425 ),
426 )
427 )
428 elif type_origin in (UnionType, typing.Union):
429 # Union: call `all_instances` for each type in the Union
430 yield from flatten(
431 [all_instances(sub, validation_funcs) for sub in get_args(type_)],
432 levels_to_flatten=1,
433 )
434 elif type_origin is Literal:
435 # Literal: return all Literal arguments
436 yield from get_args(type_)
437 else:
438 raise UnsupportedAllInstancesError(type_)