docs for maze-dataset v1.3.0
View Source on GitHub

maze_dataset.utils

misc utilities for the maze_dataset package


  1"misc utilities for the `maze_dataset` package"
  2
  3import enum
  4import itertools
  5import math
  6import typing
  7from dataclasses import Field  # noqa: TC003
  8from functools import cache, wraps
  9from types import UnionType
 10from typing import (
 11	Callable,
 12	Generator,
 13	Iterable,
 14	Literal,
 15	TypeVar,
 16	get_args,
 17	get_origin,
 18	overload,
 19)
 20
 21import frozendict
 22import numpy as np
 23from jaxtyping import Bool, Int, Int8
 24from muutils.misc import IsDataclass, flatten, is_abstract
 25
 26
 27def bool_array_from_string(
 28	string: str,
 29	shape: list[int],
 30	true_symbol: str = "T",
 31) -> Bool[np.ndarray, "*shape"]:
 32	"""Transform a string into an ndarray of bools.
 33
 34	Parameters
 35	----------
 36	string: str
 37		The string representation of the array
 38	shape: list[int]
 39		The shape of the resulting array
 40	true_symbol:
 41		The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.
 42
 43	Returns
 44	-------
 45	np.ndarray
 46		A ndarray with dtype bool of shape `shape`
 47
 48	Examples
 49	--------
 50	>>> bool_array_from_string(
 51	...	 "TT TF", shape=[2,2]
 52	... )
 53	array([[ True,  True],
 54		[ True, False]])
 55
 56	"""
 57	stripped = "".join(string.split())
 58
 59	expected_symbol_count = math.prod(shape)
 60	symbol_count = len(stripped)
 61	if len(stripped) != expected_symbol_count:
 62		err_msg: str = (
 63			f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}.",
 64		)
 65		raise ValueError(err_msg)
 66
 67	bools = [(symbol == true_symbol) for symbol in stripped]
 68	return np.array(bools).reshape(*shape)
 69
 70
 71def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]:
 72	"""returns an array of indices, sorted by distance from the corner
 73
 74	this gives the property that `np.ndindex((n,n))` is equal to
 75	the first n^2 elements of `np.ndindex((n+1, n+1))`
 76
 77	```
 78	>>> corner_first_ndindex(1)
 79	[(0, 0)]
 80	>>> corner_first_ndindex(2)
 81	[(0, 0), (0, 1), (1, 0), (1, 1)]
 82	>>> corner_first_ndindex(3)
 83	[(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
 84	```
 85	"""
 86	unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)])))
 87	return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1]))
 88
 89
 90# alternate numpy version from GPT-4:
 91"""
 92# Create all index combinations
 93indices = np.indices([n]*ndim).reshape(ndim, -1).T
 94# Find the max value for each index
 95max_indices = np.max(indices, axis=1)
 96# Identify the odd max values
 97odd_mask = max_indices % 2 != 0
 98# Make a copy of indices to avoid changing the original one
 99indices_copy = indices.copy()
100# Reverse the order of the coordinates for indices with odd max value
101indices_copy[odd_mask] = indices_copy[odd_mask, ::-1]
102# Sort by max index value, then by coordinates
103sorted_order = np.lexsort((*indices_copy.T, max_indices))
104return indices[sorted_order]
105"""
106
107
108@overload
109def manhattan_distance(
110	edges: Int[np.ndarray, "edges coord=2 row_col=2"],
111) -> Int8[np.ndarray, " edges"]: ...
112@overload
113def manhattan_distance(
114	edges: Int[np.ndarray, "coord=2 row_col=2"],
115) -> int: ...
116def manhattan_distance(
117	edges: (
118		Int[np.ndarray, "edges coord=2 row_col=2"]
119		| Int[np.ndarray, "coord=2 row_col=2"]
120	),
121) -> Int8[np.ndarray, " edges"] | int:
122	"""Returns the Manhattan distance between two coords."""
123	# magic values for dims fine here
124	if len(edges.shape) == 3:  # noqa: PLR2004
125		return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype(
126			np.int8,
127		)
128	elif len(edges.shape) == 2:  # noqa: PLR2004
129		return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8))
130	else:
131		err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints."
132		raise ValueError(err_msg)
133
134
135def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]:
136	"""Returns an array with the maximum possible degree for each coord."""
137	out = np.full((n, n), 2)
138	out[1:-1, :] += 1
139	out[:, 1:-1] += 1
140	return out
141
142
143def lattice_connection_array(
144	n: int,
145) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]:
146	"""Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.
147
148	Thanks Claude.
149
150	# Parameters
151	- `n`: The size of the square lattice.
152
153	# Returns
154	np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice.
155	In each pair, the coord with the smaller sum always comes first.
156	"""
157	row_coords, col_coords = np.meshgrid(
158		np.arange(n, dtype=np.int8),
159		np.arange(n, dtype=np.int8),
160		indexing="ij",
161	)
162
163	# Horizontal edges
164	horiz_edges = np.column_stack(
165		(
166			row_coords[:, :-1].ravel(),
167			col_coords[:, :-1].ravel(),
168			row_coords[:, 1:].ravel(),
169			col_coords[:, 1:].ravel(),
170		),
171	)
172
173	# Vertical edges
174	vert_edges = np.column_stack(
175		(
176			row_coords[:-1, :].ravel(),
177			col_coords[:-1, :].ravel(),
178			row_coords[1:, :].ravel(),
179			col_coords[1:, :].ravel(),
180		),
181	)
182
183	return np.concatenate(
184		(horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)),
185		axis=0,
186	)
187
188
189def adj_list_to_nested_set(adj_list: list) -> set:
190	"""Used for comparison of adj_lists
191
192	Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...]
193	We don't care about order of coordinate pairs within
194	the adj_list or coordinates within each coordinate pair.
195	"""
196	return {
197		frozenset([tuple(start_coord), tuple(end_coord)])
198		for start_coord, end_coord in adj_list
199	}
200
201
202FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum)
203"""
204# `FiniteValued`
205The details of this type are not possible to fully define via the Python 3.10 typing library.
206This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space.
207`FiniteValued` defines the domain of supported types for the `all_instances` function, since that function relies heavily on static typing.
208These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below).
209The leaves of the tree must always be Primitive Types.
210
211# `FiniteValued` Subtypes
212*: Indicates that this subtype is not yet supported by `all_instances`
213
214## Non-`FiniteValued` (Unbounded) Types
215These are NOT valid subtypes, and are listed for illustrative purposes only.
216This list is not comprehensive.
217While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite,
218they are considered unbounded types in this context.
219- No Container subtype may contain any of these unbounded subtypes.
220- `int`
221- `float`
222- `str`
223- `list`
224- `set`: Set types without a `FiniteValued` argument are unbounded
225- `tuple`: Tuple types without a fixed length are unbounded
226
227## Primitive Types
228Primitive types are non-nested types which resolve directly to a concrete range of values
229- `bool`: has 2 possible values
230- *`enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members
231- `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition.
232This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy.
233
234## Container Types
235Container types are types which contain zero or more fields of `FiniteValued` type.
236The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`.
237- `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`.
238- `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`.
239- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed.
240- *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type.
241
242## Superclass Types
243Superclass types don't directly contain data members like container types.
244Their range is the union of the ranges of their subtypes.
245- Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types
246- *`IsDataclass`: Concrete dataclasses which also have their own subclasses.
247- *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types
248- `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3]
249"""
250
251
252def _apply_validation_func(
253	type_: FiniteValued,
254	vals: Generator[FiniteValued, None, None],
255	validation_funcs: (
256		frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] | None
257	) = None,
258) -> Generator[FiniteValued, None, None]:
259	"""Helper function for `all_instances`.
260
261	Filters `vals` according to `validation_funcs`.
262	If `type_` is a regular type, searches in MRO order in `validation_funcs` and applies the first match, if any.
263	Handles generic types supported by `all_instances` with special `if` clauses.
264
265	# Parameters
266	- `type_: FiniteValued`: A type
267	- `vals: Generator[FiniteValued, None, None]`: Instances of `type_`
268	- `validation_funcs: dict`: Collection of types mapped to filtering validation functions
269	"""
270	if validation_funcs is None:
271		return vals
272	if type_ in validation_funcs:  # Only possible catch of UnionTypes
273		# TYPING: Incompatible return value type (got "filter[FiniteValued]", expected "Generator[FiniteValued, None, None]")  [return-value]
274		return filter(validation_funcs[type_], vals)
275	elif hasattr(
276		type_,
277		"__mro__",
278	):  # Generic types like UnionType, Literal don't have `__mro__`
279		for superclass in type_.__mro__:
280			if superclass not in validation_funcs:
281				continue
282			# TYPING: error: Incompatible types in assignment (expression has type "filter[FiniteValued]", variable has type "Generator[FiniteValued, None, None]")  [assignment]
283			vals = filter(validation_funcs[superclass], vals)
284			break  # Only the first validation function hit in the mro is applied
285	elif get_origin(type_) == Literal:
286		return flatten(
287			(
288				_apply_validation_func(type(v), [v], validation_funcs)
289				for v in get_args(type_)
290			),
291			levels_to_flatten=1,
292		)
293	return vals
294
295
296# TYPING: some better type hints would be nice here
297def _all_instances_wrapper(f: Callable) -> Callable:
298	"""Converts dicts to frozendicts to allow caching and applies `_apply_validation_func`."""
299
300	@wraps(f)
301	def wrapper(*args, **kwargs):  # noqa: ANN202
302		@cache
303		def cached_wrapper(  # noqa: ANN202
304			type_: type,
305			all_instances_func: Callable,
306			validation_funcs: (
307				frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]]
308				| None
309			),
310		):
311			return _apply_validation_func(
312				type_,
313				all_instances_func(type_, validation_funcs),
314				validation_funcs,
315			)
316
317		validation_funcs: frozendict.frozendict
318		# TODO: what is this magic value here exactly?
319		if len(args) >= 2 and args[1] is not None:  # noqa: PLR2004
320			validation_funcs = frozendict.frozendict(args[1])
321		elif "validation_funcs" in kwargs and kwargs["validation_funcs"] is not None:
322			validation_funcs = frozendict.frozendict(kwargs["validation_funcs"])
323		else:
324			validation_funcs = None
325		return cached_wrapper(args[0], f, validation_funcs)
326
327	return wrapper
328
329
330class UnsupportedAllInstancesError(TypeError):
331	"""Raised when `all_instances` is called on an unsupported type
332
333	either has unbounded possible values or is not supported (Enum is not supported)
334	"""
335
336	def __init__(self, type_: type) -> None:
337		"constructs an error message with the type and mro of the type"
338		msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }"
339		super().__init__(msg)
340
341
342@_all_instances_wrapper
343def all_instances(
344	type_: FiniteValued,
345	validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None,
346) -> Generator[FiniteValued, None, None]:
347	"""Returns all possible values of an instance of `type_` if finite instances exist.
348
349	Uses type hinting to construct the possible values.
350	All nested elements of `type_` must themselves be typed.
351	Do not use with types whose members contain circular references.
352	Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`.
353
354	# Parameters
355	- `type_: FiniteValued`
356		A finite-valued type. See docstring on `FiniteValued` for full details.
357	- `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None`
358		A mapping of types to auxiliary functions to validate instances of that type.
359		This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide.
360		See `validation_funcs` Details section below.
361		(default: `None`)
362
363	## Supported `type_` Values
364	See docstring on `FiniteValued` for full details.
365	`type_` may be:
366	- `FiniteValued`
367	- A finite-valued, fixed-length Generic tuple type.
368	E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK.
369	`tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed.
370	- Nested versions of any of the types in this list
371	- A `UnionType` of any of the types in this list
372
373	## `validation_funcs` Details
374	- `validation_funcs` is applied after all instances have been generated according to type hints.
375	- If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`.
376	- `validation_funcs` is passed down for all recursive calls of `all_instances`.
377	- This allows for improved performance through maximal pruning of the exponential tree.
378	- `validation_funcs` supports subclass checking.
379	- If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order.
380	- If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned.
381	- If no superclass of `type_` is found, then no filter is applied.
382
383	# Raises:
384	- `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`.
385	"""
386	if type_ == bool:  # noqa: E721
387		yield from [True, False]
388	elif hasattr(type_, "__dataclass_fields__"):
389		if is_abstract(type_):
390			# Abstract dataclass: call `all_instances` on each subclass
391			yield from flatten(
392				(
393					all_instances(sub, validation_funcs)
394					for sub in type_.__subclasses__()
395				),
396				levels_to_flatten=1,
397			)
398		else:
399			# Concrete dataclass: construct dataclass instances with all possible combinations of fields
400			fields: list[Field] = type_.__dataclass_fields__
401			fields_to_types: dict[str, type] = {f: fields[f].type for f in fields}
402			all_arg_sequences: Iterable = itertools.product(
403				*[
404					all_instances(arg_type, validation_funcs)
405					for arg_type in fields_to_types.values()
406				],
407			)
408			yield from (
409				type_(
410					**dict(zip(fields_to_types.keys(), args, strict=False)),
411				)
412				for args in all_arg_sequences
413			)
414	else:
415		type_origin = get_origin(type_)
416		if type_origin == tuple:  # noqa: E721
417			# Only matches Generic type tuple since regular tuple is not finite-valued
418			# Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields.
419			yield from (
420				tuple(combo)
421				for combo in itertools.product(
422					*(
423						all_instances(tup_item, validation_funcs)
424						for tup_item in get_args(type_)
425					),
426				)
427			)
428		elif type_origin in (UnionType, typing.Union):
429			# Union: call `all_instances` for each type in the Union
430			yield from flatten(
431				[all_instances(sub, validation_funcs) for sub in get_args(type_)],
432				levels_to_flatten=1,
433			)
434		elif type_origin is Literal:
435			# Literal: return all Literal arguments
436			yield from get_args(type_)
437		else:
438			raise UnsupportedAllInstancesError(type_)

def bool_array_from_string( string: str, shape: list[int], true_symbol: str = 'T') -> jaxtyping.Bool[ndarray, '*shape']:
28def bool_array_from_string(
29	string: str,
30	shape: list[int],
31	true_symbol: str = "T",
32) -> Bool[np.ndarray, "*shape"]:
33	"""Transform a string into an ndarray of bools.
34
35	Parameters
36	----------
37	string: str
38		The string representation of the array
39	shape: list[int]
40		The shape of the resulting array
41	true_symbol:
42		The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.
43
44	Returns
45	-------
46	np.ndarray
47		A ndarray with dtype bool of shape `shape`
48
49	Examples
50	--------
51	>>> bool_array_from_string(
52	...	 "TT TF", shape=[2,2]
53	... )
54	array([[ True,  True],
55		[ True, False]])
56
57	"""
58	stripped = "".join(string.split())
59
60	expected_symbol_count = math.prod(shape)
61	symbol_count = len(stripped)
62	if len(stripped) != expected_symbol_count:
63		err_msg: str = (
64			f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}.",
65		)
66		raise ValueError(err_msg)
67
68	bools = [(symbol == true_symbol) for symbol in stripped]
69	return np.array(bools).reshape(*shape)

Transform a string into an ndarray of bools.

Parameters

string: str The string representation of the array shape: list[int] The shape of the resulting array true_symbol: The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.

Returns

np.ndarray A ndarray with dtype bool of shape shape

Examples

>>> bool_array_from_string(
...      "TT TF", shape=[2,2]
... )
array([[ True,  True],
        [ True, False]])
def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]:
72def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]:
73	"""returns an array of indices, sorted by distance from the corner
74
75	this gives the property that `np.ndindex((n,n))` is equal to
76	the first n^2 elements of `np.ndindex((n+1, n+1))`
77
78	```
79	>>> corner_first_ndindex(1)
80	[(0, 0)]
81	>>> corner_first_ndindex(2)
82	[(0, 0), (0, 1), (1, 0), (1, 1)]
83	>>> corner_first_ndindex(3)
84	[(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
85	```
86	"""
87	unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)])))
88	return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1]))

returns an array of indices, sorted by distance from the corner

this gives the property that np.ndindex((n,n)) is equal to the first n^2 elements of np.ndindex((n+1, n+1))

>>> corner_first_ndindex(1)
[(0, 0)]
>>> corner_first_ndindex(2)
[(0, 0), (0, 1), (1, 0), (1, 1)]
>>> corner_first_ndindex(3)
[(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
def manhattan_distance( edges: jaxtyping.Int[ndarray, 'edges coord=2 row_col=2'] | jaxtyping.Int[ndarray, 'coord=2 row_col=2']) -> jaxtyping.Int8[ndarray, 'edges'] | int:
117def manhattan_distance(
118	edges: (
119		Int[np.ndarray, "edges coord=2 row_col=2"]
120		| Int[np.ndarray, "coord=2 row_col=2"]
121	),
122) -> Int8[np.ndarray, " edges"] | int:
123	"""Returns the Manhattan distance between two coords."""
124	# magic values for dims fine here
125	if len(edges.shape) == 3:  # noqa: PLR2004
126		return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype(
127			np.int8,
128		)
129	elif len(edges.shape) == 2:  # noqa: PLR2004
130		return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8))
131	else:
132		err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints."
133		raise ValueError(err_msg)

Returns the Manhattan distance between two coords.

def lattice_max_degrees(n: int) -> jaxtyping.Int8[ndarray, 'row col']:
136def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]:
137	"""Returns an array with the maximum possible degree for each coord."""
138	out = np.full((n, n), 2)
139	out[1:-1, :] += 1
140	out[:, 1:-1] += 1
141	return out

Returns an array with the maximum possible degree for each coord.

def lattice_connection_array( n: int) -> jaxtyping.Int8[ndarray, 'edges=2*n*(n-1) leading_trailing_coord=2 row_col=2']:
144def lattice_connection_array(
145	n: int,
146) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]:
147	"""Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.
148
149	Thanks Claude.
150
151	# Parameters
152	- `n`: The size of the square lattice.
153
154	# Returns
155	np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice.
156	In each pair, the coord with the smaller sum always comes first.
157	"""
158	row_coords, col_coords = np.meshgrid(
159		np.arange(n, dtype=np.int8),
160		np.arange(n, dtype=np.int8),
161		indexing="ij",
162	)
163
164	# Horizontal edges
165	horiz_edges = np.column_stack(
166		(
167			row_coords[:, :-1].ravel(),
168			col_coords[:, :-1].ravel(),
169			row_coords[:, 1:].ravel(),
170			col_coords[:, 1:].ravel(),
171		),
172	)
173
174	# Vertical edges
175	vert_edges = np.column_stack(
176		(
177			row_coords[:-1, :].ravel(),
178			col_coords[:-1, :].ravel(),
179			row_coords[1:, :].ravel(),
180			col_coords[1:, :].ravel(),
181		),
182	)
183
184	return np.concatenate(
185		(horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)),
186		axis=0,
187	)

Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.

Thanks Claude.

Parameters

  • n: The size of the square lattice.

Returns

np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. In each pair, the coord with the smaller sum always comes first.

def adj_list_to_nested_set(adj_list: list) -> set:
190def adj_list_to_nested_set(adj_list: list) -> set:
191	"""Used for comparison of adj_lists
192
193	Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...]
194	We don't care about order of coordinate pairs within
195	the adj_list or coordinates within each coordinate pair.
196	"""
197	return {
198		frozenset([tuple(start_coord), tuple(end_coord)])
199		for start_coord, end_coord in adj_list
200	}

Used for comparison of adj_lists

Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] We don't care about order of coordinate pairs within the adj_list or coordinates within each coordinate pair.

FiniteValued = ~FiniteValued

FiniteValued

The details of this type are not possible to fully define via the Python 3.10 typing library. This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space. FiniteValued defines the domain of supported types for the all_instances function, since that function relies heavily on static typing. These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below). The leaves of the tree must always be Primitive Types.

FiniteValued Subtypes

*: Indicates that this subtype is not yet supported by all_instances

Non-FiniteValued (Unbounded) Types

These are NOT valid subtypes, and are listed for illustrative purposes only. This list is not comprehensive. While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite, they are considered unbounded types in this context.

  • No Container subtype may contain any of these unbounded subtypes.
  • int
  • float
  • str
  • list
  • set: Set types without a FiniteValued argument are unbounded
  • tuple: Tuple types without a fixed length are unbounded

Primitive Types

Primitive types are non-nested types which resolve directly to a concrete range of values

  • bool: has 2 possible values
  • *enum.Enum: The range of a concrete Enum subclass is its set of enum members
  • typing.Literal: Every type constructed using Literal has a finite set of possible literal values in its definition. This is the preferred way to include limited ranges of non-FiniteValued types such as int or str in a FiniteValued hierarchy.

Container Types

Container types are types which contain zero or more fields of FiniteValued type. The range of a container type is the cartesian product of their field types, except for set[FiniteValued].

  • tuple[FiniteValued]: Tuples of fixed length whose elements are each FiniteValued.
  • IsDataclass: Concrete dataclasses whose fields are FiniteValued.
  • *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are FiniteValued-typed.
  • *set[FiniteValued]: Sets of fixed length of a FiniteValued type.

Superclass Types

Superclass types don't directly contain data members like container types. Their range is the union of the ranges of their subtypes.

  • Abstract dataclasses: Abstract dataclasses whose subclasses are all FiniteValued superclass or container types
  • *IsDataclass: Concrete dataclasses which also have their own subclasses.
  • *Standard abstract classes: Abstract dataclasses whose subclasses are all FiniteValued superclass or container types
  • UnionType: Any union of FiniteValued types, e.g., bool | Literal[2, 3]
class UnsupportedAllInstancesError(builtins.TypeError):
331class UnsupportedAllInstancesError(TypeError):
332	"""Raised when `all_instances` is called on an unsupported type
333
334	either has unbounded possible values or is not supported (Enum is not supported)
335	"""
336
337	def __init__(self, type_: type) -> None:
338		"constructs an error message with the type and mro of the type"
339		msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }"
340		super().__init__(msg)

Raised when all_instances is called on an unsupported type

either has unbounded possible values or is not supported (Enum is not supported)

UnsupportedAllInstancesError(type_: type)
337	def __init__(self, type_: type) -> None:
338		"constructs an error message with the type and mro of the type"
339		msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }"
340		super().__init__(msg)

constructs an error message with the type and mro of the type

Inherited Members
builtins.BaseException
with_traceback
add_note
args
def all_instances( type_: ~FiniteValued, validation_funcs: dict[~FiniteValued, typing.Callable[[~FiniteValued], bool]] | None = None) -> Generator[~FiniteValued, NoneType, NoneType]:
343@_all_instances_wrapper
344def all_instances(
345	type_: FiniteValued,
346	validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None,
347) -> Generator[FiniteValued, None, None]:
348	"""Returns all possible values of an instance of `type_` if finite instances exist.
349
350	Uses type hinting to construct the possible values.
351	All nested elements of `type_` must themselves be typed.
352	Do not use with types whose members contain circular references.
353	Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`.
354
355	# Parameters
356	- `type_: FiniteValued`
357		A finite-valued type. See docstring on `FiniteValued` for full details.
358	- `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None`
359		A mapping of types to auxiliary functions to validate instances of that type.
360		This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide.
361		See `validation_funcs` Details section below.
362		(default: `None`)
363
364	## Supported `type_` Values
365	See docstring on `FiniteValued` for full details.
366	`type_` may be:
367	- `FiniteValued`
368	- A finite-valued, fixed-length Generic tuple type.
369	E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK.
370	`tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed.
371	- Nested versions of any of the types in this list
372	- A `UnionType` of any of the types in this list
373
374	## `validation_funcs` Details
375	- `validation_funcs` is applied after all instances have been generated according to type hints.
376	- If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`.
377	- `validation_funcs` is passed down for all recursive calls of `all_instances`.
378	- This allows for improved performance through maximal pruning of the exponential tree.
379	- `validation_funcs` supports subclass checking.
380	- If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order.
381	- If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned.
382	- If no superclass of `type_` is found, then no filter is applied.
383
384	# Raises:
385	- `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`.
386	"""
387	if type_ == bool:  # noqa: E721
388		yield from [True, False]
389	elif hasattr(type_, "__dataclass_fields__"):
390		if is_abstract(type_):
391			# Abstract dataclass: call `all_instances` on each subclass
392			yield from flatten(
393				(
394					all_instances(sub, validation_funcs)
395					for sub in type_.__subclasses__()
396				),
397				levels_to_flatten=1,
398			)
399		else:
400			# Concrete dataclass: construct dataclass instances with all possible combinations of fields
401			fields: list[Field] = type_.__dataclass_fields__
402			fields_to_types: dict[str, type] = {f: fields[f].type for f in fields}
403			all_arg_sequences: Iterable = itertools.product(
404				*[
405					all_instances(arg_type, validation_funcs)
406					for arg_type in fields_to_types.values()
407				],
408			)
409			yield from (
410				type_(
411					**dict(zip(fields_to_types.keys(), args, strict=False)),
412				)
413				for args in all_arg_sequences
414			)
415	else:
416		type_origin = get_origin(type_)
417		if type_origin == tuple:  # noqa: E721
418			# Only matches Generic type tuple since regular tuple is not finite-valued
419			# Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields.
420			yield from (
421				tuple(combo)
422				for combo in itertools.product(
423					*(
424						all_instances(tup_item, validation_funcs)
425						for tup_item in get_args(type_)
426					),
427				)
428			)
429		elif type_origin in (UnionType, typing.Union):
430			# Union: call `all_instances` for each type in the Union
431			yield from flatten(
432				[all_instances(sub, validation_funcs) for sub in get_args(type_)],
433				levels_to_flatten=1,
434			)
435		elif type_origin is Literal:
436			# Literal: return all Literal arguments
437			yield from get_args(type_)
438		else:
439			raise UnsupportedAllInstancesError(type_)

Returns all possible values of an instance of type_ if finite instances exist.

Uses type hinting to construct the possible values. All nested elements of type_ must themselves be typed. Do not use with types whose members contain circular references. Function is susceptible to infinite recursion if type_ is a dataclass whose member tree includes another instance of type_.

Parameters

  • type_: FiniteValued A finite-valued type. See docstring on FiniteValued for full details.
  • validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None A mapping of types to auxiliary functions to validate instances of that type. This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. See validation_funcs Details section below. (default: None)

Supported type_ Values

See docstring on FiniteValued for full details. type_ may be:

  • FiniteValued
  • A finite-valued, fixed-length Generic tuple type. E.g., tuple[bool], tuple[bool, MyEnum] are OK. tuple[bool, ...] is NOT supported, since the length of the tuple is not fixed.
  • Nested versions of any of the types in this list
  • A UnionType of any of the types in this list

validation_funcs Details

  • validation_funcs is applied after all instances have been generated according to type hints.
  • If type_ is in validation_funcs, then the list of instances is filtered by validation_funcs[type_](instance).
  • validation_funcs is passed down for all recursive calls of all_instances.
  • This allows for improved performance through maximal pruning of the exponential tree.
  • validation_funcs supports subclass checking.
  • If type_ is not found in validation_funcs, then the search is performed iteratively in mro order.
  • If a superclass of type_ is found while searching in mro order, that validation function is applied and the list is returned.
  • If no superclass of type_ is found, then no filter is applied.

Raises: