maze_dataset.utils
misc utilities for the maze_dataset
package
1"misc utilities for the `maze_dataset` package" 2 3import enum 4import itertools 5import math 6import typing 7from dataclasses import Field # noqa: TC003 8from functools import cache, wraps 9from types import UnionType 10from typing import ( 11 Callable, 12 Generator, 13 Iterable, 14 Literal, 15 TypeVar, 16 get_args, 17 get_origin, 18 overload, 19) 20 21import frozendict 22import numpy as np 23from jaxtyping import Bool, Int, Int8 24from muutils.misc import IsDataclass, flatten, is_abstract 25 26 27def bool_array_from_string( 28 string: str, 29 shape: list[int], 30 true_symbol: str = "T", 31) -> Bool[np.ndarray, "*shape"]: 32 """Transform a string into an ndarray of bools. 33 34 Parameters 35 ---------- 36 string: str 37 The string representation of the array 38 shape: list[int] 39 The shape of the resulting array 40 true_symbol: 41 The character to parse as True. Whitespace will be removed. All other characters will be parsed as False. 42 43 Returns 44 ------- 45 np.ndarray 46 A ndarray with dtype bool of shape `shape` 47 48 Examples 49 -------- 50 >>> bool_array_from_string( 51 ... "TT TF", shape=[2,2] 52 ... ) 53 array([[ True, True], 54 [ True, False]]) 55 56 """ 57 stripped = "".join(string.split()) 58 59 expected_symbol_count = math.prod(shape) 60 symbol_count = len(stripped) 61 if len(stripped) != expected_symbol_count: 62 err_msg: str = ( 63 f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}.", 64 ) 65 raise ValueError(err_msg) 66 67 bools = [(symbol == true_symbol) for symbol in stripped] 68 return np.array(bools).reshape(*shape) 69 70 71def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]: 72 """returns an array of indices, sorted by distance from the corner 73 74 this gives the property that `np.ndindex((n,n))` is equal to 75 the first n^2 elements of `np.ndindex((n+1, n+1))` 76 77 ``` 78 >>> corner_first_ndindex(1) 79 [(0, 0)] 80 >>> corner_first_ndindex(2) 81 [(0, 0), (0, 1), (1, 0), (1, 1)] 82 >>> corner_first_ndindex(3) 83 [(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)] 84 ``` 85 """ 86 unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)]))) 87 return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1])) 88 89 90# alternate numpy version from GPT-4: 91""" 92# Create all index combinations 93indices = np.indices([n]*ndim).reshape(ndim, -1).T 94# Find the max value for each index 95max_indices = np.max(indices, axis=1) 96# Identify the odd max values 97odd_mask = max_indices % 2 != 0 98# Make a copy of indices to avoid changing the original one 99indices_copy = indices.copy() 100# Reverse the order of the coordinates for indices with odd max value 101indices_copy[odd_mask] = indices_copy[odd_mask, ::-1] 102# Sort by max index value, then by coordinates 103sorted_order = np.lexsort((*indices_copy.T, max_indices)) 104return indices[sorted_order] 105""" 106 107 108@overload 109def manhattan_distance( 110 edges: Int[np.ndarray, "edges coord=2 row_col=2"], 111) -> Int8[np.ndarray, " edges"]: ... 112@overload 113def manhattan_distance( 114 edges: Int[np.ndarray, "coord=2 row_col=2"], 115) -> int: ... 116def manhattan_distance( 117 edges: ( 118 Int[np.ndarray, "edges coord=2 row_col=2"] 119 | Int[np.ndarray, "coord=2 row_col=2"] 120 ), 121) -> Int8[np.ndarray, " edges"] | int: 122 """Returns the Manhattan distance between two coords.""" 123 # magic values for dims fine here 124 if len(edges.shape) == 3: # noqa: PLR2004 125 return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype( 126 np.int8, 127 ) 128 elif len(edges.shape) == 2: # noqa: PLR2004 129 return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8)) 130 else: 131 err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints." 132 raise ValueError(err_msg) 133 134 135def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]: 136 """Returns an array with the maximum possible degree for each coord.""" 137 out = np.full((n, n), 2) 138 out[1:-1, :] += 1 139 out[:, 1:-1] += 1 140 return out 141 142 143def lattice_connection_array( 144 n: int, 145) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]: 146 """Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n. 147 148 Thanks Claude. 149 150 # Parameters 151 - `n`: The size of the square lattice. 152 153 # Returns 154 np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. 155 In each pair, the coord with the smaller sum always comes first. 156 """ 157 row_coords, col_coords = np.meshgrid( 158 np.arange(n, dtype=np.int8), 159 np.arange(n, dtype=np.int8), 160 indexing="ij", 161 ) 162 163 # Horizontal edges 164 horiz_edges = np.column_stack( 165 ( 166 row_coords[:, :-1].ravel(), 167 col_coords[:, :-1].ravel(), 168 row_coords[:, 1:].ravel(), 169 col_coords[:, 1:].ravel(), 170 ), 171 ) 172 173 # Vertical edges 174 vert_edges = np.column_stack( 175 ( 176 row_coords[:-1, :].ravel(), 177 col_coords[:-1, :].ravel(), 178 row_coords[1:, :].ravel(), 179 col_coords[1:, :].ravel(), 180 ), 181 ) 182 183 return np.concatenate( 184 (horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)), 185 axis=0, 186 ) 187 188 189def adj_list_to_nested_set(adj_list: list) -> set: 190 """Used for comparison of adj_lists 191 192 Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] 193 We don't care about order of coordinate pairs within 194 the adj_list or coordinates within each coordinate pair. 195 """ 196 return { 197 frozenset([tuple(start_coord), tuple(end_coord)]) 198 for start_coord, end_coord in adj_list 199 } 200 201 202FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum) 203""" 204# `FiniteValued` 205The details of this type are not possible to fully define via the Python 3.10 typing library. 206This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space. 207`FiniteValued` defines the domain of supported types for the `all_instances` function, since that function relies heavily on static typing. 208These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below). 209The leaves of the tree must always be Primitive Types. 210 211# `FiniteValued` Subtypes 212*: Indicates that this subtype is not yet supported by `all_instances` 213 214## Non-`FiniteValued` (Unbounded) Types 215These are NOT valid subtypes, and are listed for illustrative purposes only. 216This list is not comprehensive. 217While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite, 218they are considered unbounded types in this context. 219- No Container subtype may contain any of these unbounded subtypes. 220- `int` 221- `float` 222- `str` 223- `list` 224- `set`: Set types without a `FiniteValued` argument are unbounded 225- `tuple`: Tuple types without a fixed length are unbounded 226 227## Primitive Types 228Primitive types are non-nested types which resolve directly to a concrete range of values 229- `bool`: has 2 possible values 230- *`enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members 231- `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition. 232This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy. 233 234## Container Types 235Container types are types which contain zero or more fields of `FiniteValued` type. 236The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`. 237- `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`. 238- `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`. 239- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed. 240- *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type. 241 242## Superclass Types 243Superclass types don't directly contain data members like container types. 244Their range is the union of the ranges of their subtypes. 245- Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types 246- *`IsDataclass`: Concrete dataclasses which also have their own subclasses. 247- *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types 248- `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3] 249""" 250 251 252def _apply_validation_func( 253 type_: FiniteValued, 254 vals: Generator[FiniteValued, None, None], 255 validation_funcs: ( 256 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] | None 257 ) = None, 258) -> Generator[FiniteValued, None, None]: 259 """Helper function for `all_instances`. 260 261 Filters `vals` according to `validation_funcs`. 262 If `type_` is a regular type, searches in MRO order in `validation_funcs` and applies the first match, if any. 263 Handles generic types supported by `all_instances` with special `if` clauses. 264 265 # Parameters 266 - `type_: FiniteValued`: A type 267 - `vals: Generator[FiniteValued, None, None]`: Instances of `type_` 268 - `validation_funcs: dict`: Collection of types mapped to filtering validation functions 269 """ 270 if validation_funcs is None: 271 return vals 272 if type_ in validation_funcs: # Only possible catch of UnionTypes 273 # TYPING: Incompatible return value type (got "filter[FiniteValued]", expected "Generator[FiniteValued, None, None]") [return-value] 274 return filter(validation_funcs[type_], vals) 275 elif hasattr( 276 type_, 277 "__mro__", 278 ): # Generic types like UnionType, Literal don't have `__mro__` 279 for superclass in type_.__mro__: 280 if superclass not in validation_funcs: 281 continue 282 # TYPING: error: Incompatible types in assignment (expression has type "filter[FiniteValued]", variable has type "Generator[FiniteValued, None, None]") [assignment] 283 vals = filter(validation_funcs[superclass], vals) 284 break # Only the first validation function hit in the mro is applied 285 elif get_origin(type_) == Literal: 286 return flatten( 287 ( 288 _apply_validation_func(type(v), [v], validation_funcs) 289 for v in get_args(type_) 290 ), 291 levels_to_flatten=1, 292 ) 293 return vals 294 295 296# TYPING: some better type hints would be nice here 297def _all_instances_wrapper(f: Callable) -> Callable: 298 """Converts dicts to frozendicts to allow caching and applies `_apply_validation_func`.""" 299 300 @wraps(f) 301 def wrapper(*args, **kwargs): # noqa: ANN202 302 @cache 303 def cached_wrapper( # noqa: ANN202 304 type_: type, 305 all_instances_func: Callable, 306 validation_funcs: ( 307 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] 308 | None 309 ), 310 ): 311 return _apply_validation_func( 312 type_, 313 all_instances_func(type_, validation_funcs), 314 validation_funcs, 315 ) 316 317 validation_funcs: frozendict.frozendict 318 # TODO: what is this magic value here exactly? 319 if len(args) >= 2 and args[1] is not None: # noqa: PLR2004 320 validation_funcs = frozendict.frozendict(args[1]) 321 elif "validation_funcs" in kwargs and kwargs["validation_funcs"] is not None: 322 validation_funcs = frozendict.frozendict(kwargs["validation_funcs"]) 323 else: 324 validation_funcs = None 325 return cached_wrapper(args[0], f, validation_funcs) 326 327 return wrapper 328 329 330class UnsupportedAllInstancesError(TypeError): 331 """Raised when `all_instances` is called on an unsupported type 332 333 either has unbounded possible values or is not supported (Enum is not supported) 334 """ 335 336 def __init__(self, type_: type) -> None: 337 "constructs an error message with the type and mro of the type" 338 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }" 339 super().__init__(msg) 340 341 342@_all_instances_wrapper 343def all_instances( 344 type_: FiniteValued, 345 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None, 346) -> Generator[FiniteValued, None, None]: 347 """Returns all possible values of an instance of `type_` if finite instances exist. 348 349 Uses type hinting to construct the possible values. 350 All nested elements of `type_` must themselves be typed. 351 Do not use with types whose members contain circular references. 352 Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`. 353 354 # Parameters 355 - `type_: FiniteValued` 356 A finite-valued type. See docstring on `FiniteValued` for full details. 357 - `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None` 358 A mapping of types to auxiliary functions to validate instances of that type. 359 This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. 360 See `validation_funcs` Details section below. 361 (default: `None`) 362 363 ## Supported `type_` Values 364 See docstring on `FiniteValued` for full details. 365 `type_` may be: 366 - `FiniteValued` 367 - A finite-valued, fixed-length Generic tuple type. 368 E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK. 369 `tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed. 370 - Nested versions of any of the types in this list 371 - A `UnionType` of any of the types in this list 372 373 ## `validation_funcs` Details 374 - `validation_funcs` is applied after all instances have been generated according to type hints. 375 - If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`. 376 - `validation_funcs` is passed down for all recursive calls of `all_instances`. 377 - This allows for improved performance through maximal pruning of the exponential tree. 378 - `validation_funcs` supports subclass checking. 379 - If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order. 380 - If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned. 381 - If no superclass of `type_` is found, then no filter is applied. 382 383 # Raises: 384 - `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`. 385 """ 386 if type_ == bool: # noqa: E721 387 yield from [True, False] 388 elif hasattr(type_, "__dataclass_fields__"): 389 if is_abstract(type_): 390 # Abstract dataclass: call `all_instances` on each subclass 391 yield from flatten( 392 ( 393 all_instances(sub, validation_funcs) 394 for sub in type_.__subclasses__() 395 ), 396 levels_to_flatten=1, 397 ) 398 else: 399 # Concrete dataclass: construct dataclass instances with all possible combinations of fields 400 fields: list[Field] = type_.__dataclass_fields__ 401 fields_to_types: dict[str, type] = {f: fields[f].type for f in fields} 402 all_arg_sequences: Iterable = itertools.product( 403 *[ 404 all_instances(arg_type, validation_funcs) 405 for arg_type in fields_to_types.values() 406 ], 407 ) 408 yield from ( 409 type_( 410 **dict(zip(fields_to_types.keys(), args, strict=False)), 411 ) 412 for args in all_arg_sequences 413 ) 414 else: 415 type_origin = get_origin(type_) 416 if type_origin == tuple: # noqa: E721 417 # Only matches Generic type tuple since regular tuple is not finite-valued 418 # Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields. 419 yield from ( 420 tuple(combo) 421 for combo in itertools.product( 422 *( 423 all_instances(tup_item, validation_funcs) 424 for tup_item in get_args(type_) 425 ), 426 ) 427 ) 428 elif type_origin in (UnionType, typing.Union): 429 # Union: call `all_instances` for each type in the Union 430 yield from flatten( 431 [all_instances(sub, validation_funcs) for sub in get_args(type_)], 432 levels_to_flatten=1, 433 ) 434 elif type_origin is Literal: 435 # Literal: return all Literal arguments 436 yield from get_args(type_) 437 else: 438 raise UnsupportedAllInstancesError(type_)
28def bool_array_from_string( 29 string: str, 30 shape: list[int], 31 true_symbol: str = "T", 32) -> Bool[np.ndarray, "*shape"]: 33 """Transform a string into an ndarray of bools. 34 35 Parameters 36 ---------- 37 string: str 38 The string representation of the array 39 shape: list[int] 40 The shape of the resulting array 41 true_symbol: 42 The character to parse as True. Whitespace will be removed. All other characters will be parsed as False. 43 44 Returns 45 ------- 46 np.ndarray 47 A ndarray with dtype bool of shape `shape` 48 49 Examples 50 -------- 51 >>> bool_array_from_string( 52 ... "TT TF", shape=[2,2] 53 ... ) 54 array([[ True, True], 55 [ True, False]]) 56 57 """ 58 stripped = "".join(string.split()) 59 60 expected_symbol_count = math.prod(shape) 61 symbol_count = len(stripped) 62 if len(stripped) != expected_symbol_count: 63 err_msg: str = ( 64 f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}.", 65 ) 66 raise ValueError(err_msg) 67 68 bools = [(symbol == true_symbol) for symbol in stripped] 69 return np.array(bools).reshape(*shape)
Transform a string into an ndarray of bools.
Parameters
string: str The string representation of the array shape: list[int] The shape of the resulting array true_symbol: The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.
Returns
np.ndarray
A ndarray with dtype bool of shape shape
Examples
>>> bool_array_from_string(
... "TT TF", shape=[2,2]
... )
array([[ True, True],
[ True, False]])
72def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]: 73 """returns an array of indices, sorted by distance from the corner 74 75 this gives the property that `np.ndindex((n,n))` is equal to 76 the first n^2 elements of `np.ndindex((n+1, n+1))` 77 78 ``` 79 >>> corner_first_ndindex(1) 80 [(0, 0)] 81 >>> corner_first_ndindex(2) 82 [(0, 0), (0, 1), (1, 0), (1, 1)] 83 >>> corner_first_ndindex(3) 84 [(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)] 85 ``` 86 """ 87 unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)]))) 88 return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1]))
returns an array of indices, sorted by distance from the corner
this gives the property that np.ndindex((n,n))
is equal to
the first n^2 elements of np.ndindex((n+1, n+1))
>>> corner_first_ndindex(1)
[(0, 0)]
>>> corner_first_ndindex(2)
[(0, 0), (0, 1), (1, 0), (1, 1)]
>>> corner_first_ndindex(3)
[(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
117def manhattan_distance( 118 edges: ( 119 Int[np.ndarray, "edges coord=2 row_col=2"] 120 | Int[np.ndarray, "coord=2 row_col=2"] 121 ), 122) -> Int8[np.ndarray, " edges"] | int: 123 """Returns the Manhattan distance between two coords.""" 124 # magic values for dims fine here 125 if len(edges.shape) == 3: # noqa: PLR2004 126 return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype( 127 np.int8, 128 ) 129 elif len(edges.shape) == 2: # noqa: PLR2004 130 return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8)) 131 else: 132 err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints." 133 raise ValueError(err_msg)
Returns the Manhattan distance between two coords.
136def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]: 137 """Returns an array with the maximum possible degree for each coord.""" 138 out = np.full((n, n), 2) 139 out[1:-1, :] += 1 140 out[:, 1:-1] += 1 141 return out
Returns an array with the maximum possible degree for each coord.
144def lattice_connection_array( 145 n: int, 146) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]: 147 """Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n. 148 149 Thanks Claude. 150 151 # Parameters 152 - `n`: The size of the square lattice. 153 154 # Returns 155 np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. 156 In each pair, the coord with the smaller sum always comes first. 157 """ 158 row_coords, col_coords = np.meshgrid( 159 np.arange(n, dtype=np.int8), 160 np.arange(n, dtype=np.int8), 161 indexing="ij", 162 ) 163 164 # Horizontal edges 165 horiz_edges = np.column_stack( 166 ( 167 row_coords[:, :-1].ravel(), 168 col_coords[:, :-1].ravel(), 169 row_coords[:, 1:].ravel(), 170 col_coords[:, 1:].ravel(), 171 ), 172 ) 173 174 # Vertical edges 175 vert_edges = np.column_stack( 176 ( 177 row_coords[:-1, :].ravel(), 178 col_coords[:-1, :].ravel(), 179 row_coords[1:, :].ravel(), 180 col_coords[1:, :].ravel(), 181 ), 182 ) 183 184 return np.concatenate( 185 (horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)), 186 axis=0, 187 )
Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.
Thanks Claude.
Parameters
n
: The size of the square lattice.
Returns
np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. In each pair, the coord with the smaller sum always comes first.
190def adj_list_to_nested_set(adj_list: list) -> set: 191 """Used for comparison of adj_lists 192 193 Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] 194 We don't care about order of coordinate pairs within 195 the adj_list or coordinates within each coordinate pair. 196 """ 197 return { 198 frozenset([tuple(start_coord), tuple(end_coord)]) 199 for start_coord, end_coord in adj_list 200 }
Used for comparison of adj_lists
Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] We don't care about order of coordinate pairs within the adj_list or coordinates within each coordinate pair.
FiniteValued
The details of this type are not possible to fully define via the Python 3.10 typing library.
This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space.
FiniteValued
defines the domain of supported types for the all_instances
function, since that function relies heavily on static typing.
These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below).
The leaves of the tree must always be Primitive Types.
FiniteValued
Subtypes
*: Indicates that this subtype is not yet supported by all_instances
Non-FiniteValued
(Unbounded) Types
These are NOT valid subtypes, and are listed for illustrative purposes only. This list is not comprehensive. While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite, they are considered unbounded types in this context.
- No Container subtype may contain any of these unbounded subtypes.
int
float
str
list
set
: Set types without aFiniteValued
argument are unboundedtuple
: Tuple types without a fixed length are unbounded
Primitive Types
Primitive types are non-nested types which resolve directly to a concrete range of values
bool
: has 2 possible values- *
enum.Enum
: The range of a concreteEnum
subclass is its set of enum members typing.Literal
: Every type constructed usingLiteral
has a finite set of possible literal values in its definition. This is the preferred way to include limited ranges of non-FiniteValued
types such asint
orstr
in aFiniteValued
hierarchy.
Container Types
Container types are types which contain zero or more fields of FiniteValued
type.
The range of a container type is the cartesian product of their field types, except for set[FiniteValued]
.
tuple[FiniteValued]
: Tuples of fixed length whose elements are eachFiniteValued
.IsDataclass
: Concrete dataclasses whose fields areFiniteValued
.- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are
FiniteValued
-typed. - *
set[FiniteValued]
: Sets of fixed length of aFiniteValued
type.
Superclass Types
Superclass types don't directly contain data members like container types. Their range is the union of the ranges of their subtypes.
- Abstract dataclasses: Abstract dataclasses whose subclasses are all
FiniteValued
superclass or container types - *
IsDataclass
: Concrete dataclasses which also have their own subclasses. - *Standard abstract classes: Abstract dataclasses whose subclasses are all
FiniteValued
superclass or container types UnionType
: Any union ofFiniteValued
types, e.g., bool | Literal[2, 3]
331class UnsupportedAllInstancesError(TypeError): 332 """Raised when `all_instances` is called on an unsupported type 333 334 either has unbounded possible values or is not supported (Enum is not supported) 335 """ 336 337 def __init__(self, type_: type) -> None: 338 "constructs an error message with the type and mro of the type" 339 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }" 340 super().__init__(msg)
Raised when all_instances
is called on an unsupported type
either has unbounded possible values or is not supported (Enum is not supported)
337 def __init__(self, type_: type) -> None: 338 "constructs an error message with the type and mro of the type" 339 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }" 340 super().__init__(msg)
constructs an error message with the type and mro of the type
Inherited Members
- builtins.BaseException
- with_traceback
- add_note
- args
343@_all_instances_wrapper 344def all_instances( 345 type_: FiniteValued, 346 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None, 347) -> Generator[FiniteValued, None, None]: 348 """Returns all possible values of an instance of `type_` if finite instances exist. 349 350 Uses type hinting to construct the possible values. 351 All nested elements of `type_` must themselves be typed. 352 Do not use with types whose members contain circular references. 353 Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`. 354 355 # Parameters 356 - `type_: FiniteValued` 357 A finite-valued type. See docstring on `FiniteValued` for full details. 358 - `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None` 359 A mapping of types to auxiliary functions to validate instances of that type. 360 This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. 361 See `validation_funcs` Details section below. 362 (default: `None`) 363 364 ## Supported `type_` Values 365 See docstring on `FiniteValued` for full details. 366 `type_` may be: 367 - `FiniteValued` 368 - A finite-valued, fixed-length Generic tuple type. 369 E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK. 370 `tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed. 371 - Nested versions of any of the types in this list 372 - A `UnionType` of any of the types in this list 373 374 ## `validation_funcs` Details 375 - `validation_funcs` is applied after all instances have been generated according to type hints. 376 - If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`. 377 - `validation_funcs` is passed down for all recursive calls of `all_instances`. 378 - This allows for improved performance through maximal pruning of the exponential tree. 379 - `validation_funcs` supports subclass checking. 380 - If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order. 381 - If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned. 382 - If no superclass of `type_` is found, then no filter is applied. 383 384 # Raises: 385 - `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`. 386 """ 387 if type_ == bool: # noqa: E721 388 yield from [True, False] 389 elif hasattr(type_, "__dataclass_fields__"): 390 if is_abstract(type_): 391 # Abstract dataclass: call `all_instances` on each subclass 392 yield from flatten( 393 ( 394 all_instances(sub, validation_funcs) 395 for sub in type_.__subclasses__() 396 ), 397 levels_to_flatten=1, 398 ) 399 else: 400 # Concrete dataclass: construct dataclass instances with all possible combinations of fields 401 fields: list[Field] = type_.__dataclass_fields__ 402 fields_to_types: dict[str, type] = {f: fields[f].type for f in fields} 403 all_arg_sequences: Iterable = itertools.product( 404 *[ 405 all_instances(arg_type, validation_funcs) 406 for arg_type in fields_to_types.values() 407 ], 408 ) 409 yield from ( 410 type_( 411 **dict(zip(fields_to_types.keys(), args, strict=False)), 412 ) 413 for args in all_arg_sequences 414 ) 415 else: 416 type_origin = get_origin(type_) 417 if type_origin == tuple: # noqa: E721 418 # Only matches Generic type tuple since regular tuple is not finite-valued 419 # Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields. 420 yield from ( 421 tuple(combo) 422 for combo in itertools.product( 423 *( 424 all_instances(tup_item, validation_funcs) 425 for tup_item in get_args(type_) 426 ), 427 ) 428 ) 429 elif type_origin in (UnionType, typing.Union): 430 # Union: call `all_instances` for each type in the Union 431 yield from flatten( 432 [all_instances(sub, validation_funcs) for sub in get_args(type_)], 433 levels_to_flatten=1, 434 ) 435 elif type_origin is Literal: 436 # Literal: return all Literal arguments 437 yield from get_args(type_) 438 else: 439 raise UnsupportedAllInstancesError(type_)
Returns all possible values of an instance of type_
if finite instances exist.
Uses type hinting to construct the possible values.
All nested elements of type_
must themselves be typed.
Do not use with types whose members contain circular references.
Function is susceptible to infinite recursion if type_
is a dataclass whose member tree includes another instance of type_
.
Parameters
type_: FiniteValued
A finite-valued type. See docstring onFiniteValued
for full details.validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None
A mapping of types to auxiliary functions to validate instances of that type. This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. Seevalidation_funcs
Details section below. (default:None
)
Supported type_
Values
See docstring on FiniteValued
for full details.
type_
may be:
FiniteValued
- A finite-valued, fixed-length Generic tuple type.
E.g.,
tuple[bool]
,tuple[bool, MyEnum]
are OK.tuple[bool, ...]
is NOT supported, since the length of the tuple is not fixed. - Nested versions of any of the types in this list
- A
UnionType
of any of the types in this list
validation_funcs
Details
validation_funcs
is applied after all instances have been generated according to type hints.- If
type_
is invalidation_funcs
, then the list of instances is filtered byvalidation_funcs[type_](instance)
. validation_funcs
is passed down for all recursive calls ofall_instances
.- This allows for improved performance through maximal pruning of the exponential tree.
validation_funcs
supports subclass checking.- If
type_
is not found invalidation_funcs
, then the search is performed iteratively in mro order. - If a superclass of
type_
is found while searching in mro order, that validation function is applied and the list is returned. - If no superclass of
type_
is found, then no filter is applied.
Raises:
UnsupportedAllInstancesError
: Iftype_
is not supported byall_instances
.