docs for maze-dataset v1.3.0
View Source on GitHub

maze_dataset.benchmark.config_sweep

Benchmarking of how successful maze generation is for various values of percolation


  1"""Benchmarking of how successful maze generation is for various values of percolation"""
  2
  3import functools
  4import json
  5import warnings
  6from pathlib import Path
  7from typing import Any, Callable, Generic, Sequence, TypeVar
  8
  9import matplotlib.pyplot as plt
 10import numpy as np
 11from jaxtyping import Float
 12from muutils.dictmagic import dotlist_to_nested_dict, update_with_nested_dict
 13from muutils.json_serialize import (
 14	JSONitem,
 15	SerializableDataclass,
 16	json_serialize,
 17	serializable_dataclass,
 18	serializable_field,
 19)
 20from muutils.parallel import run_maybe_parallel
 21from zanj import ZANJ
 22
 23from maze_dataset import MazeDataset, MazeDatasetConfig
 24from maze_dataset.generation import LatticeMazeGenerators
 25
 26SweepReturnType = TypeVar("SweepReturnType")
 27ParamType = TypeVar("ParamType")
 28AnalysisFunc = Callable[[MazeDatasetConfig], SweepReturnType]
 29
 30
 31def dataset_success_fraction(cfg: MazeDatasetConfig) -> float:
 32	"""empirical success fraction of maze generation
 33
 34	for use as an `analyze_func` in `sweep()`
 35	"""
 36	dataset: MazeDataset = MazeDataset.from_config(
 37		cfg,
 38		do_download=False,
 39		load_local=False,
 40		save_local=False,
 41		verbose=False,
 42	)
 43
 44	return len(dataset) / cfg.n_mazes
 45
 46
 47ANALYSIS_FUNCS: dict[str, AnalysisFunc] = dict(
 48	dataset_success_fraction=dataset_success_fraction,
 49)
 50
 51
 52def sweep(
 53	cfg_base: MazeDatasetConfig,
 54	param_values: list[ParamType],
 55	param_key: str,
 56	analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
 57) -> list[SweepReturnType]:
 58	"""given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value
 59
 60	# Parameters:
 61	- `cfg_base : MazeDatasetConfig`
 62		base config on which we will modify the value at `param_key` with values from `param_values`
 63	- `param_values : list[ParamType]`
 64		list of values to try
 65	- `param_key : str`
 66		value to modify in `cfg_base`
 67	- `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]`
 68		function which analyzes the resulting config. originally built for `dataset_success_fraction`
 69
 70	# Returns:
 71	- `list[SweepReturnType]`
 72		_description_
 73	"""
 74	outputs: list[SweepReturnType] = []
 75
 76	for p in param_values:
 77		# update the config
 78		cfg_dict: dict = cfg_base.serialize()
 79		update_with_nested_dict(
 80			cfg_dict,
 81			dotlist_to_nested_dict({param_key: p}),
 82		)
 83		cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict)
 84
 85		outputs.append(analyze_func(cfg_test))
 86
 87	return outputs
 88
 89
 90@serializable_dataclass()
 91class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]):
 92	"""result of a parameter sweep"""
 93
 94	configs: list[MazeDatasetConfig] = serializable_field(
 95		serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs],
 96		deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs],
 97	)
 98	param_values: list[ParamType] = serializable_field(
 99		serialization_fn=lambda x: json_serialize(x),
100		deserialize_fn=lambda x: x,
101		assert_type=False,
102	)
103	result_values: dict[str, Sequence[SweepReturnType]] = serializable_field(
104		serialization_fn=lambda x: json_serialize(x),
105		deserialize_fn=lambda x: x,
106		assert_type=False,
107	)
108	param_key: str
109	analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field(
110		serialization_fn=lambda f: f.__name__,
111		deserialize_fn=ANALYSIS_FUNCS.get,
112		assert_type=False,
113	)
114
115	def summary(self) -> JSONitem:
116		"human-readable and json-dumpable short summary of the result"
117		return {
118			"len(configs)": len(self.configs),
119			"len(param_values)": len(self.param_values),
120			"len(result_values)": len(self.result_values),
121			"param_key": self.param_key,
122			"analyze_func": self.analyze_func.__name__,
123		}
124
125	def save(self, path: str | Path, z: ZANJ | None = None) -> None:
126		"save to a file with zanj"
127		if z is None:
128			z = ZANJ()
129
130		z.save(self, path)
131
132	@classmethod
133	def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult":
134		"read from a file with zanj"
135		if z is None:
136			z = ZANJ()
137
138		return z.read(path)
139
140	def configs_by_name(self) -> dict[str, MazeDatasetConfig]:
141		"return configs by name"
142		return {cfg.name: cfg for cfg in self.configs}
143
144	def configs_by_key(self) -> dict[str, MazeDatasetConfig]:
145		"return configs by the key used in `result_values`, which is the filename of the config"
146		return {cfg.to_fname(): cfg for cfg in self.configs}
147
148	def configs_shared(self) -> dict[str, Any]:
149		"return key: value pairs that are shared across all configs"
150		# we know that the configs all have the same keys,
151		# so this way of doing it is fine
152		config_vals: dict[str, set[Any]] = dict()
153		for cfg in self.configs:
154			for k, v in cfg.serialize().items():
155				if k not in config_vals:
156					config_vals[k] = set()
157				config_vals[k].add(json.dumps(v))
158
159		shared_vals: dict[str, Any] = dict()
160
161		cfg_ser: dict = self.configs[0].serialize()
162		for k, v in config_vals.items():
163			if len(v) == 1:
164				shared_vals[k] = cfg_ser[k]
165
166		return shared_vals
167
168	def configs_differing_keys(self) -> set[str]:
169		"return keys that differ across configs"
170		shared_vals: dict[str, Any] = self.configs_shared()
171		differing_keys: set[str] = set()
172
173		for k in MazeDatasetConfig.__dataclass_fields__:
174			if k not in shared_vals:
175				differing_keys.add(k)
176
177		return differing_keys
178
179	def configs_value_set(self, key: str) -> list[Any]:
180		"return a list of the unique values for a given key"
181		d: dict[str, Any] = {
182			json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key)
183			for cfg in self.configs
184		}
185
186		return list(d.values())
187
188	def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult":
189		"get a subset of this `Result` where the configs has `key` satisfying `val_check`"
190		configs_list: list[MazeDatasetConfig] = [
191			cfg for cfg in self.configs if val_check(getattr(cfg, key))
192		]
193		configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list}
194		result_values: dict[str, Sequence[SweepReturnType]] = {
195			k: self.result_values[k] for k in configs_keys
196		}
197
198		return SweepResult(
199			configs=configs_list,
200			param_values=self.param_values,
201			result_values=result_values,
202			param_key=self.param_key,
203			analyze_func=self.analyze_func,
204		)
205
206	@classmethod
207	def analyze(
208		cls,
209		configs: list[MazeDatasetConfig],
210		param_values: list[ParamType],
211		param_key: str,
212		analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
213		parallel: bool | int = False,
214		**kwargs,
215	) -> "SweepResult":
216		"""Analyze success rate of maze generation for different percolation values
217
218		# Parameters:
219		- `configs : list[MazeDatasetConfig]`
220		configs to try
221		- `param_values : np.ndarray`
222		numpy array of values to try
223
224		# Returns:
225		- `SweepResult`
226		"""
227		n_pvals: int = len(param_values)
228
229		result_values_list: list[float] = run_maybe_parallel(
230			# TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]"  [arg-type]
231			func=functools.partial(  # type: ignore[arg-type]
232				sweep,
233				param_values=param_values,
234				param_key=param_key,
235				analyze_func=analyze_func,
236			),
237			iterable=configs,
238			keep_ordered=True,
239			parallel=parallel,
240			pbar_kwargs=dict(total=len(configs)),
241			**kwargs,
242		)
243		result_values: dict[str, Float[np.ndarray, n_pvals]] = {
244			cfg.to_fname(): np.array(res)
245			for cfg, res in zip(configs, result_values_list, strict=False)
246		}
247		return cls(
248			configs=configs,
249			param_values=param_values,
250			# TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]"  [arg-type]
251			result_values=result_values,  # type: ignore[arg-type]
252			param_key=param_key,
253			analyze_func=analyze_func,
254		)
255
256	def plot(
257		self,
258		save_path: str | None = None,
259		cfg_keys: list[str] | None = None,
260		cmap_name: str | None = "viridis",
261		plot_only: bool = False,
262		show: bool = True,
263		ax: plt.Axes | None = None,
264	) -> plt.Axes:
265		"""Plot the results of percolation analysis"""
266		# set up figure
267		if not ax:
268			fig: plt.Figure
269			ax_: plt.Axes
270			fig, ax_ = plt.subplots(1, 1, figsize=(22, 10))
271		else:
272			ax_ = ax
273
274		# plot
275		cmap = plt.get_cmap(cmap_name)
276		n_cfgs: int = len(self.result_values)
277		for i, (ep_cfg_name, result_values) in enumerate(
278			sorted(
279				self.result_values.items(),
280				# HACK: sort by grid size
281				#                 |--< name of config
282				#                 |    |-----------< gets 'g{n}'
283				#                 |    |            |--< gets '{n}'
284				#                 |    |            |
285				key=lambda x: int(x[0].split("-")[0][1:]),
286			),
287		):
288			ax_.plot(
289				# TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
290				self.param_values,  # type: ignore[arg-type]
291				# TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
292				result_values,  # type: ignore[arg-type]
293				".-",
294				label=self.configs_by_key()[ep_cfg_name].name,
295				color=cmap((i + 0.5) / (n_cfgs - 0.5)),
296			)
297
298		# repr of config
299		cfg_shared: dict = self.configs_shared()
300		cfg_repr: str = (
301			str(cfg_shared)
302			if cfg_keys is None
303			else (
304				"MazeDatasetConfig("
305				+ ", ".join(
306					[
307						f"{k}={cfg_shared[k].__name__}"
308						# TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo"  [arg-type]
309						if isinstance(cfg_shared[k], Callable)  # type: ignore[arg-type]
310						else f"{k}={cfg_shared[k]}"
311						for k in cfg_keys
312					],
313				)
314				+ ")"
315			)
316		)
317
318		# add title and stuff
319		if not plot_only:
320			ax_.set_xlabel(self.param_key)
321			ax_.set_ylabel(self.analyze_func.__name__)
322			ax_.set_title(
323				f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}",
324			)
325			ax_.grid(True)
326			ax_.legend(loc="center left")
327
328		# save and show
329		if save_path:
330			plt.savefig(save_path)
331
332		if show:
333			plt.show()
334
335		return ax_
336
337
338DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [
339	(
340		"any",
341		dict(deadend_start=False, deadend_end=False, except_on_no_valid_endpoint=False),
342	),
343	(
344		"deadends",
345		dict(
346			deadend_start=True,
347			deadend_end=True,
348			endpoints_not_equal=False,
349			except_on_no_valid_endpoint=False,
350		),
351	),
352	(
353		"deadends_unique",
354		dict(
355			deadend_start=True,
356			deadend_end=True,
357			endpoints_not_equal=True,
358			except_on_no_valid_endpoint=False,
359		),
360	),
361]
362
363
364def endpoint_kwargs_to_name(ep_kwargs: dict) -> str:
365	"""convert endpoint kwargs options to a human-readable name"""
366	if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False):
367		if ep_kwargs.get("endpoints_not_equal", False):
368			return "deadends_unique"
369		else:
370			return "deadends"
371	else:
372		return "any"
373
374
375def full_percolation_analysis(
376	n_mazes: int,
377	p_val_count: int,
378	grid_sizes: list[int],
379	ep_kwargs: list[tuple[str, dict]] | None = None,
380	generators: Sequence[Callable] = (
381		LatticeMazeGenerators.gen_percolation,
382		LatticeMazeGenerators.gen_dfs_percolation,
383	),
384	save_dir: Path = Path("../docs/benchmarks/percolation_fractions"),
385	parallel: bool | int = False,
386	**analyze_kwargs,
387) -> SweepResult:
388	"run the full analysis of how percolation affects maze generation success"
389	if ep_kwargs is None:
390		ep_kwargs = DEFAULT_ENDPOINT_KWARGS
391
392	# configs
393	configs: list[MazeDatasetConfig] = list()
394
395	# TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx`
396	for ep_kw_name, ep_kw in ep_kwargs:  # noqa: B007
397		for gf_idx, gen_func in enumerate(generators):  # noqa: B007
398			configs.extend(
399				[
400					MazeDatasetConfig(
401						name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}",
402						grid_n=grid_n,
403						n_mazes=n_mazes,
404						maze_ctor=gen_func,
405						maze_ctor_kwargs=dict(p=float("nan")),
406						endpoint_kwargs=ep_kw,
407					)
408					for grid_n in grid_sizes
409				],
410			)
411
412	# get results
413	result: SweepResult = SweepResult.analyze(
414		configs=configs,  # type: ignore[misc]
415		# TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]"  [arg-type]
416		param_values=np.linspace(0.0, 1.0, p_val_count).tolist(),  # type: ignore[arg-type]
417		param_key="maze_ctor_kwargs.p",
418		analyze_func=dataset_success_fraction,
419		parallel=parallel,
420		**analyze_kwargs,
421	)
422
423	# save the result
424	results_path: Path = (
425		save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj"
426	)
427	print(f"Saving results to {results_path.as_posix()}")
428	result.save(results_path)
429
430	return result
431
432
433def _is_eq(a, b) -> bool:  # noqa: ANN001
434	"""check if two objects are equal"""
435	return a == b
436
437
438def plot_grouped(  # noqa: C901
439	results: SweepResult,
440	predict_fn: Callable[[MazeDatasetConfig], float] | None = None,
441	prediction_density: int = 50,
442	save_dir: Path | None = None,
443	show: bool = True,
444	logy: bool = False,
445) -> None:
446	"""Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs
447
448	with separate colormaps for each maze generator function
449
450	# Parameters:
451	- `results : SweepResult`
452		The sweep results to plot
453	- `predict_fn : Callable[[MazeDatasetConfig], float] | None`
454		Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.
455	- `prediction_density : int`
456		Number of points to use for prediction curves (default: 50)
457	- `save_dir : Path | None`
458		Directory to save plots (defaults to `None`, meaning no saving)
459	- `show : bool`
460		Whether to display the plots (defaults to `True`)
461
462	# Usage:
463	```python
464	>>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
465	>>> plot_grouped(result, save_dir=Path("./plots"), show=False)
466	```
467	"""
468	# groups
469	endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs")  # type: ignore[assignment]
470	generator_funcs_names: list[str] = list(
471		{cfg.maze_ctor.__name__ for cfg in results.configs},
472	)
473
474	# if predicting, create denser p values
475	if predict_fn is not None:
476		p_dense = np.linspace(0.0, 1.0, prediction_density)
477
478	# separate plot for each set of endpoint kwargs
479	for ep_kw in endpoint_kwargs_set:
480		results_epkw: SweepResult = results.get_where(
481			"endpoint_kwargs",
482			functools.partial(_is_eq, b=ep_kw),
483			# lambda x: x == ep_kw,
484		)
485		shared_keys: set[str] = set(results_epkw.configs_shared().keys())
486		cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"})
487		fig, ax = plt.subplots(1, 1, figsize=(22, 10))
488		for gf_idx, gen_func in enumerate(generator_funcs_names):
489			results_filtered: SweepResult = results_epkw.get_where(
490				"maze_ctor",
491				# HACK: big hassle to do this without a lambda, is it really that bad?
492				lambda x: x.__name__ == gen_func,  # noqa: B023
493			)
494			if len(results_filtered.configs) < 1:
495				warnings.warn(
496					f"No results for {gen_func} and {ep_kw}. Skipping.",
497				)
498				continue
499
500			cmap_name = "Reds" if gf_idx == 0 else "Blues"
501			cmap = plt.get_cmap(cmap_name)
502
503			# Plot actual results
504			ax = results_filtered.plot(
505				cfg_keys=list(cfg_keys),
506				ax=ax,
507				show=False,
508				cmap_name=cmap_name,
509			)
510			if logy:
511				ax.set_yscale("log")
512
513			# Plot predictions if function provided
514			if predict_fn is not None:
515				for cfg_idx, cfg in enumerate(results_filtered.configs):
516					predictions = []
517					for p in p_dense:
518						cfg_temp = MazeDatasetConfig.load(cfg.serialize())
519						cfg_temp.maze_ctor_kwargs["p"] = p
520						predictions.append(predict_fn(cfg_temp))
521
522					# Get the same color as the actual data
523					n_cfgs: int = len(results_filtered.configs)
524					color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5))
525
526					# Plot prediction as dashed line
527					ax.plot(p_dense, predictions, "--", color=color, alpha=0.8)
528
529		# save and show
530		if save_dir:
531			save_path: Path = save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.svg"
532			print(f"Saving plot to {save_path.as_posix()}")
533			save_path.parent.mkdir(exist_ok=True, parents=True)
534			plt.savefig(save_path)
535
536		if show:
537			plt.show()

AnalysisFunc = typing.Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType]
def dataset_success_fraction(cfg: maze_dataset.MazeDatasetConfig) -> float:
32def dataset_success_fraction(cfg: MazeDatasetConfig) -> float:
33	"""empirical success fraction of maze generation
34
35	for use as an `analyze_func` in `sweep()`
36	"""
37	dataset: MazeDataset = MazeDataset.from_config(
38		cfg,
39		do_download=False,
40		load_local=False,
41		save_local=False,
42		verbose=False,
43	)
44
45	return len(dataset) / cfg.n_mazes

empirical success fraction of maze generation

for use as an analyze_func in sweep()

ANALYSIS_FUNCS: dict[str, typing.Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType]] = {'dataset_success_fraction': <function dataset_success_fraction>}
def sweep( cfg_base: maze_dataset.MazeDatasetConfig, param_values: list[~ParamType], param_key: str, analyze_func: Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType]) -> list[~SweepReturnType]:
53def sweep(
54	cfg_base: MazeDatasetConfig,
55	param_values: list[ParamType],
56	param_key: str,
57	analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
58) -> list[SweepReturnType]:
59	"""given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value
60
61	# Parameters:
62	- `cfg_base : MazeDatasetConfig`
63		base config on which we will modify the value at `param_key` with values from `param_values`
64	- `param_values : list[ParamType]`
65		list of values to try
66	- `param_key : str`
67		value to modify in `cfg_base`
68	- `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]`
69		function which analyzes the resulting config. originally built for `dataset_success_fraction`
70
71	# Returns:
72	- `list[SweepReturnType]`
73		_description_
74	"""
75	outputs: list[SweepReturnType] = []
76
77	for p in param_values:
78		# update the config
79		cfg_dict: dict = cfg_base.serialize()
80		update_with_nested_dict(
81			cfg_dict,
82			dotlist_to_nested_dict({param_key: p}),
83		)
84		cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict)
85
86		outputs.append(analyze_func(cfg_test))
87
88	return outputs

given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value

Parameters:

  • cfg_base : MazeDatasetConfig base config on which we will modify the value at param_key with values from param_values
  • param_values : list[ParamType] list of values to try
  • param_key : str value to modify in cfg_base
  • analyze_func : Callable[[MazeDatasetConfig], SweepReturnType] function which analyzes the resulting config. originally built for dataset_success_fraction

Returns:

  • list[SweepReturnType] _description_
@serializable_dataclass()
class SweepResult(muutils.json_serialize.serializable_dataclass.SerializableDataclass, typing.Generic[~ParamType, ~SweepReturnType]):
 91@serializable_dataclass()
 92class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]):
 93	"""result of a parameter sweep"""
 94
 95	configs: list[MazeDatasetConfig] = serializable_field(
 96		serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs],
 97		deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs],
 98	)
 99	param_values: list[ParamType] = serializable_field(
100		serialization_fn=lambda x: json_serialize(x),
101		deserialize_fn=lambda x: x,
102		assert_type=False,
103	)
104	result_values: dict[str, Sequence[SweepReturnType]] = serializable_field(
105		serialization_fn=lambda x: json_serialize(x),
106		deserialize_fn=lambda x: x,
107		assert_type=False,
108	)
109	param_key: str
110	analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field(
111		serialization_fn=lambda f: f.__name__,
112		deserialize_fn=ANALYSIS_FUNCS.get,
113		assert_type=False,
114	)
115
116	def summary(self) -> JSONitem:
117		"human-readable and json-dumpable short summary of the result"
118		return {
119			"len(configs)": len(self.configs),
120			"len(param_values)": len(self.param_values),
121			"len(result_values)": len(self.result_values),
122			"param_key": self.param_key,
123			"analyze_func": self.analyze_func.__name__,
124		}
125
126	def save(self, path: str | Path, z: ZANJ | None = None) -> None:
127		"save to a file with zanj"
128		if z is None:
129			z = ZANJ()
130
131		z.save(self, path)
132
133	@classmethod
134	def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult":
135		"read from a file with zanj"
136		if z is None:
137			z = ZANJ()
138
139		return z.read(path)
140
141	def configs_by_name(self) -> dict[str, MazeDatasetConfig]:
142		"return configs by name"
143		return {cfg.name: cfg for cfg in self.configs}
144
145	def configs_by_key(self) -> dict[str, MazeDatasetConfig]:
146		"return configs by the key used in `result_values`, which is the filename of the config"
147		return {cfg.to_fname(): cfg for cfg in self.configs}
148
149	def configs_shared(self) -> dict[str, Any]:
150		"return key: value pairs that are shared across all configs"
151		# we know that the configs all have the same keys,
152		# so this way of doing it is fine
153		config_vals: dict[str, set[Any]] = dict()
154		for cfg in self.configs:
155			for k, v in cfg.serialize().items():
156				if k not in config_vals:
157					config_vals[k] = set()
158				config_vals[k].add(json.dumps(v))
159
160		shared_vals: dict[str, Any] = dict()
161
162		cfg_ser: dict = self.configs[0].serialize()
163		for k, v in config_vals.items():
164			if len(v) == 1:
165				shared_vals[k] = cfg_ser[k]
166
167		return shared_vals
168
169	def configs_differing_keys(self) -> set[str]:
170		"return keys that differ across configs"
171		shared_vals: dict[str, Any] = self.configs_shared()
172		differing_keys: set[str] = set()
173
174		for k in MazeDatasetConfig.__dataclass_fields__:
175			if k not in shared_vals:
176				differing_keys.add(k)
177
178		return differing_keys
179
180	def configs_value_set(self, key: str) -> list[Any]:
181		"return a list of the unique values for a given key"
182		d: dict[str, Any] = {
183			json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key)
184			for cfg in self.configs
185		}
186
187		return list(d.values())
188
189	def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult":
190		"get a subset of this `Result` where the configs has `key` satisfying `val_check`"
191		configs_list: list[MazeDatasetConfig] = [
192			cfg for cfg in self.configs if val_check(getattr(cfg, key))
193		]
194		configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list}
195		result_values: dict[str, Sequence[SweepReturnType]] = {
196			k: self.result_values[k] for k in configs_keys
197		}
198
199		return SweepResult(
200			configs=configs_list,
201			param_values=self.param_values,
202			result_values=result_values,
203			param_key=self.param_key,
204			analyze_func=self.analyze_func,
205		)
206
207	@classmethod
208	def analyze(
209		cls,
210		configs: list[MazeDatasetConfig],
211		param_values: list[ParamType],
212		param_key: str,
213		analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
214		parallel: bool | int = False,
215		**kwargs,
216	) -> "SweepResult":
217		"""Analyze success rate of maze generation for different percolation values
218
219		# Parameters:
220		- `configs : list[MazeDatasetConfig]`
221		configs to try
222		- `param_values : np.ndarray`
223		numpy array of values to try
224
225		# Returns:
226		- `SweepResult`
227		"""
228		n_pvals: int = len(param_values)
229
230		result_values_list: list[float] = run_maybe_parallel(
231			# TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]"  [arg-type]
232			func=functools.partial(  # type: ignore[arg-type]
233				sweep,
234				param_values=param_values,
235				param_key=param_key,
236				analyze_func=analyze_func,
237			),
238			iterable=configs,
239			keep_ordered=True,
240			parallel=parallel,
241			pbar_kwargs=dict(total=len(configs)),
242			**kwargs,
243		)
244		result_values: dict[str, Float[np.ndarray, n_pvals]] = {
245			cfg.to_fname(): np.array(res)
246			for cfg, res in zip(configs, result_values_list, strict=False)
247		}
248		return cls(
249			configs=configs,
250			param_values=param_values,
251			# TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]"  [arg-type]
252			result_values=result_values,  # type: ignore[arg-type]
253			param_key=param_key,
254			analyze_func=analyze_func,
255		)
256
257	def plot(
258		self,
259		save_path: str | None = None,
260		cfg_keys: list[str] | None = None,
261		cmap_name: str | None = "viridis",
262		plot_only: bool = False,
263		show: bool = True,
264		ax: plt.Axes | None = None,
265	) -> plt.Axes:
266		"""Plot the results of percolation analysis"""
267		# set up figure
268		if not ax:
269			fig: plt.Figure
270			ax_: plt.Axes
271			fig, ax_ = plt.subplots(1, 1, figsize=(22, 10))
272		else:
273			ax_ = ax
274
275		# plot
276		cmap = plt.get_cmap(cmap_name)
277		n_cfgs: int = len(self.result_values)
278		for i, (ep_cfg_name, result_values) in enumerate(
279			sorted(
280				self.result_values.items(),
281				# HACK: sort by grid size
282				#                 |--< name of config
283				#                 |    |-----------< gets 'g{n}'
284				#                 |    |            |--< gets '{n}'
285				#                 |    |            |
286				key=lambda x: int(x[0].split("-")[0][1:]),
287			),
288		):
289			ax_.plot(
290				# TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
291				self.param_values,  # type: ignore[arg-type]
292				# TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
293				result_values,  # type: ignore[arg-type]
294				".-",
295				label=self.configs_by_key()[ep_cfg_name].name,
296				color=cmap((i + 0.5) / (n_cfgs - 0.5)),
297			)
298
299		# repr of config
300		cfg_shared: dict = self.configs_shared()
301		cfg_repr: str = (
302			str(cfg_shared)
303			if cfg_keys is None
304			else (
305				"MazeDatasetConfig("
306				+ ", ".join(
307					[
308						f"{k}={cfg_shared[k].__name__}"
309						# TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo"  [arg-type]
310						if isinstance(cfg_shared[k], Callable)  # type: ignore[arg-type]
311						else f"{k}={cfg_shared[k]}"
312						for k in cfg_keys
313					],
314				)
315				+ ")"
316			)
317		)
318
319		# add title and stuff
320		if not plot_only:
321			ax_.set_xlabel(self.param_key)
322			ax_.set_ylabel(self.analyze_func.__name__)
323			ax_.set_title(
324				f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}",
325			)
326			ax_.grid(True)
327			ax_.legend(loc="center left")
328
329		# save and show
330		if save_path:
331			plt.savefig(save_path)
332
333		if show:
334			plt.show()
335
336		return ax_

result of a parameter sweep

SweepResult( configs: list[maze_dataset.MazeDatasetConfig], param_values: list[~ParamType], result_values: dict[str, typing.Sequence[~SweepReturnType]], param_key: str, analyze_func: Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType])
param_values: list[~ParamType]
result_values: dict[str, typing.Sequence[~SweepReturnType]]
param_key: str
analyze_func: Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType]
def summary( self) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]:
116	def summary(self) -> JSONitem:
117		"human-readable and json-dumpable short summary of the result"
118		return {
119			"len(configs)": len(self.configs),
120			"len(param_values)": len(self.param_values),
121			"len(result_values)": len(self.result_values),
122			"param_key": self.param_key,
123			"analyze_func": self.analyze_func.__name__,
124		}

human-readable and json-dumpable short summary of the result

def save(self, path: str | pathlib.Path, z: zanj.zanj.ZANJ | None = None) -> None:
126	def save(self, path: str | Path, z: ZANJ | None = None) -> None:
127		"save to a file with zanj"
128		if z is None:
129			z = ZANJ()
130
131		z.save(self, path)

save to a file with zanj

@classmethod
def read( cls, path: str | pathlib.Path, z: zanj.zanj.ZANJ | None = None) -> SweepResult:
133	@classmethod
134	def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult":
135		"read from a file with zanj"
136		if z is None:
137			z = ZANJ()
138
139		return z.read(path)

read from a file with zanj

def configs_by_name( self) -> dict[str, maze_dataset.MazeDatasetConfig]:
141	def configs_by_name(self) -> dict[str, MazeDatasetConfig]:
142		"return configs by name"
143		return {cfg.name: cfg for cfg in self.configs}

return configs by name

def configs_by_key( self) -> dict[str, maze_dataset.MazeDatasetConfig]:
145	def configs_by_key(self) -> dict[str, MazeDatasetConfig]:
146		"return configs by the key used in `result_values`, which is the filename of the config"
147		return {cfg.to_fname(): cfg for cfg in self.configs}

return configs by the key used in result_values, which is the filename of the config

def configs_shared(self) -> dict[str, typing.Any]:
149	def configs_shared(self) -> dict[str, Any]:
150		"return key: value pairs that are shared across all configs"
151		# we know that the configs all have the same keys,
152		# so this way of doing it is fine
153		config_vals: dict[str, set[Any]] = dict()
154		for cfg in self.configs:
155			for k, v in cfg.serialize().items():
156				if k not in config_vals:
157					config_vals[k] = set()
158				config_vals[k].add(json.dumps(v))
159
160		shared_vals: dict[str, Any] = dict()
161
162		cfg_ser: dict = self.configs[0].serialize()
163		for k, v in config_vals.items():
164			if len(v) == 1:
165				shared_vals[k] = cfg_ser[k]
166
167		return shared_vals

return key: value pairs that are shared across all configs

def configs_differing_keys(self) -> set[str]:
169	def configs_differing_keys(self) -> set[str]:
170		"return keys that differ across configs"
171		shared_vals: dict[str, Any] = self.configs_shared()
172		differing_keys: set[str] = set()
173
174		for k in MazeDatasetConfig.__dataclass_fields__:
175			if k not in shared_vals:
176				differing_keys.add(k)
177
178		return differing_keys

return keys that differ across configs

def configs_value_set(self, key: str) -> list[typing.Any]:
180	def configs_value_set(self, key: str) -> list[Any]:
181		"return a list of the unique values for a given key"
182		d: dict[str, Any] = {
183			json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key)
184			for cfg in self.configs
185		}
186
187		return list(d.values())

return a list of the unique values for a given key

def get_where( self, key: str, val_check: Callable[[Any], bool]) -> SweepResult:
189	def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult":
190		"get a subset of this `Result` where the configs has `key` satisfying `val_check`"
191		configs_list: list[MazeDatasetConfig] = [
192			cfg for cfg in self.configs if val_check(getattr(cfg, key))
193		]
194		configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list}
195		result_values: dict[str, Sequence[SweepReturnType]] = {
196			k: self.result_values[k] for k in configs_keys
197		}
198
199		return SweepResult(
200			configs=configs_list,
201			param_values=self.param_values,
202			result_values=result_values,
203			param_key=self.param_key,
204			analyze_func=self.analyze_func,
205		)

get a subset of this Result where the configs has key satisfying val_check

@classmethod
def analyze( cls, configs: list[maze_dataset.MazeDatasetConfig], param_values: list[~ParamType], param_key: str, analyze_func: Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType], parallel: bool | int = False, **kwargs) -> SweepResult:
207	@classmethod
208	def analyze(
209		cls,
210		configs: list[MazeDatasetConfig],
211		param_values: list[ParamType],
212		param_key: str,
213		analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
214		parallel: bool | int = False,
215		**kwargs,
216	) -> "SweepResult":
217		"""Analyze success rate of maze generation for different percolation values
218
219		# Parameters:
220		- `configs : list[MazeDatasetConfig]`
221		configs to try
222		- `param_values : np.ndarray`
223		numpy array of values to try
224
225		# Returns:
226		- `SweepResult`
227		"""
228		n_pvals: int = len(param_values)
229
230		result_values_list: list[float] = run_maybe_parallel(
231			# TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]"  [arg-type]
232			func=functools.partial(  # type: ignore[arg-type]
233				sweep,
234				param_values=param_values,
235				param_key=param_key,
236				analyze_func=analyze_func,
237			),
238			iterable=configs,
239			keep_ordered=True,
240			parallel=parallel,
241			pbar_kwargs=dict(total=len(configs)),
242			**kwargs,
243		)
244		result_values: dict[str, Float[np.ndarray, n_pvals]] = {
245			cfg.to_fname(): np.array(res)
246			for cfg, res in zip(configs, result_values_list, strict=False)
247		}
248		return cls(
249			configs=configs,
250			param_values=param_values,
251			# TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]"  [arg-type]
252			result_values=result_values,  # type: ignore[arg-type]
253			param_key=param_key,
254			analyze_func=analyze_func,
255		)

Analyze success rate of maze generation for different percolation values

Parameters:

  • configs : list[MazeDatasetConfig] configs to try
  • param_values : np.ndarray numpy array of values to try

Returns:

def plot( self, save_path: str | None = None, cfg_keys: list[str] | None = None, cmap_name: str | None = 'viridis', plot_only: bool = False, show: bool = True, ax: matplotlib.axes._axes.Axes | None = None) -> matplotlib.axes._axes.Axes:
257	def plot(
258		self,
259		save_path: str | None = None,
260		cfg_keys: list[str] | None = None,
261		cmap_name: str | None = "viridis",
262		plot_only: bool = False,
263		show: bool = True,
264		ax: plt.Axes | None = None,
265	) -> plt.Axes:
266		"""Plot the results of percolation analysis"""
267		# set up figure
268		if not ax:
269			fig: plt.Figure
270			ax_: plt.Axes
271			fig, ax_ = plt.subplots(1, 1, figsize=(22, 10))
272		else:
273			ax_ = ax
274
275		# plot
276		cmap = plt.get_cmap(cmap_name)
277		n_cfgs: int = len(self.result_values)
278		for i, (ep_cfg_name, result_values) in enumerate(
279			sorted(
280				self.result_values.items(),
281				# HACK: sort by grid size
282				#                 |--< name of config
283				#                 |    |-----------< gets 'g{n}'
284				#                 |    |            |--< gets '{n}'
285				#                 |    |            |
286				key=lambda x: int(x[0].split("-")[0][1:]),
287			),
288		):
289			ax_.plot(
290				# TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
291				self.param_values,  # type: ignore[arg-type]
292				# TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
293				result_values,  # type: ignore[arg-type]
294				".-",
295				label=self.configs_by_key()[ep_cfg_name].name,
296				color=cmap((i + 0.5) / (n_cfgs - 0.5)),
297			)
298
299		# repr of config
300		cfg_shared: dict = self.configs_shared()
301		cfg_repr: str = (
302			str(cfg_shared)
303			if cfg_keys is None
304			else (
305				"MazeDatasetConfig("
306				+ ", ".join(
307					[
308						f"{k}={cfg_shared[k].__name__}"
309						# TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo"  [arg-type]
310						if isinstance(cfg_shared[k], Callable)  # type: ignore[arg-type]
311						else f"{k}={cfg_shared[k]}"
312						for k in cfg_keys
313					],
314				)
315				+ ")"
316			)
317		)
318
319		# add title and stuff
320		if not plot_only:
321			ax_.set_xlabel(self.param_key)
322			ax_.set_ylabel(self.analyze_func.__name__)
323			ax_.set_title(
324				f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}",
325			)
326			ax_.grid(True)
327			ax_.legend(loc="center left")
328
329		# save and show
330		if save_path:
331			plt.savefig(save_path)
332
333		if show:
334			plt.show()
335
336		return ax_

Plot the results of percolation analysis

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
serialize
load
validate_fields_types
validate_field_type
diff
update_from_nested_dict
DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [('any', {'deadend_start': False, 'deadend_end': False, 'except_on_no_valid_endpoint': False}), ('deadends', {'deadend_start': True, 'deadend_end': True, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False}), ('deadends_unique', {'deadend_start': True, 'deadend_end': True, 'endpoints_not_equal': True, 'except_on_no_valid_endpoint': False})]
def endpoint_kwargs_to_name(ep_kwargs: dict) -> str:
365def endpoint_kwargs_to_name(ep_kwargs: dict) -> str:
366	"""convert endpoint kwargs options to a human-readable name"""
367	if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False):
368		if ep_kwargs.get("endpoints_not_equal", False):
369			return "deadends_unique"
370		else:
371			return "deadends"
372	else:
373		return "any"

convert endpoint kwargs options to a human-readable name

def full_percolation_analysis( n_mazes: int, p_val_count: int, grid_sizes: list[int], ep_kwargs: list[tuple[str, dict]] | None = None, generators: Sequence[Callable] = (<function LatticeMazeGenerators.gen_percolation>, <function LatticeMazeGenerators.gen_dfs_percolation>), save_dir: pathlib.Path = PosixPath('../docs/benchmarks/percolation_fractions'), parallel: bool | int = False, **analyze_kwargs) -> SweepResult:
376def full_percolation_analysis(
377	n_mazes: int,
378	p_val_count: int,
379	grid_sizes: list[int],
380	ep_kwargs: list[tuple[str, dict]] | None = None,
381	generators: Sequence[Callable] = (
382		LatticeMazeGenerators.gen_percolation,
383		LatticeMazeGenerators.gen_dfs_percolation,
384	),
385	save_dir: Path = Path("../docs/benchmarks/percolation_fractions"),
386	parallel: bool | int = False,
387	**analyze_kwargs,
388) -> SweepResult:
389	"run the full analysis of how percolation affects maze generation success"
390	if ep_kwargs is None:
391		ep_kwargs = DEFAULT_ENDPOINT_KWARGS
392
393	# configs
394	configs: list[MazeDatasetConfig] = list()
395
396	# TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx`
397	for ep_kw_name, ep_kw in ep_kwargs:  # noqa: B007
398		for gf_idx, gen_func in enumerate(generators):  # noqa: B007
399			configs.extend(
400				[
401					MazeDatasetConfig(
402						name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}",
403						grid_n=grid_n,
404						n_mazes=n_mazes,
405						maze_ctor=gen_func,
406						maze_ctor_kwargs=dict(p=float("nan")),
407						endpoint_kwargs=ep_kw,
408					)
409					for grid_n in grid_sizes
410				],
411			)
412
413	# get results
414	result: SweepResult = SweepResult.analyze(
415		configs=configs,  # type: ignore[misc]
416		# TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]"  [arg-type]
417		param_values=np.linspace(0.0, 1.0, p_val_count).tolist(),  # type: ignore[arg-type]
418		param_key="maze_ctor_kwargs.p",
419		analyze_func=dataset_success_fraction,
420		parallel=parallel,
421		**analyze_kwargs,
422	)
423
424	# save the result
425	results_path: Path = (
426		save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj"
427	)
428	print(f"Saving results to {results_path.as_posix()}")
429	result.save(results_path)
430
431	return result

run the full analysis of how percolation affects maze generation success

def plot_grouped( results: SweepResult, predict_fn: Optional[Callable[[maze_dataset.MazeDatasetConfig], float]] = None, prediction_density: int = 50, save_dir: pathlib.Path | None = None, show: bool = True, logy: bool = False) -> None:
439def plot_grouped(  # noqa: C901
440	results: SweepResult,
441	predict_fn: Callable[[MazeDatasetConfig], float] | None = None,
442	prediction_density: int = 50,
443	save_dir: Path | None = None,
444	show: bool = True,
445	logy: bool = False,
446) -> None:
447	"""Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs
448
449	with separate colormaps for each maze generator function
450
451	# Parameters:
452	- `results : SweepResult`
453		The sweep results to plot
454	- `predict_fn : Callable[[MazeDatasetConfig], float] | None`
455		Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.
456	- `prediction_density : int`
457		Number of points to use for prediction curves (default: 50)
458	- `save_dir : Path | None`
459		Directory to save plots (defaults to `None`, meaning no saving)
460	- `show : bool`
461		Whether to display the plots (defaults to `True`)
462
463	# Usage:
464	```python
465	>>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
466	>>> plot_grouped(result, save_dir=Path("./plots"), show=False)
467	```
468	"""
469	# groups
470	endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs")  # type: ignore[assignment]
471	generator_funcs_names: list[str] = list(
472		{cfg.maze_ctor.__name__ for cfg in results.configs},
473	)
474
475	# if predicting, create denser p values
476	if predict_fn is not None:
477		p_dense = np.linspace(0.0, 1.0, prediction_density)
478
479	# separate plot for each set of endpoint kwargs
480	for ep_kw in endpoint_kwargs_set:
481		results_epkw: SweepResult = results.get_where(
482			"endpoint_kwargs",
483			functools.partial(_is_eq, b=ep_kw),
484			# lambda x: x == ep_kw,
485		)
486		shared_keys: set[str] = set(results_epkw.configs_shared().keys())
487		cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"})
488		fig, ax = plt.subplots(1, 1, figsize=(22, 10))
489		for gf_idx, gen_func in enumerate(generator_funcs_names):
490			results_filtered: SweepResult = results_epkw.get_where(
491				"maze_ctor",
492				# HACK: big hassle to do this without a lambda, is it really that bad?
493				lambda x: x.__name__ == gen_func,  # noqa: B023
494			)
495			if len(results_filtered.configs) < 1:
496				warnings.warn(
497					f"No results for {gen_func} and {ep_kw}. Skipping.",
498				)
499				continue
500
501			cmap_name = "Reds" if gf_idx == 0 else "Blues"
502			cmap = plt.get_cmap(cmap_name)
503
504			# Plot actual results
505			ax = results_filtered.plot(
506				cfg_keys=list(cfg_keys),
507				ax=ax,
508				show=False,
509				cmap_name=cmap_name,
510			)
511			if logy:
512				ax.set_yscale("log")
513
514			# Plot predictions if function provided
515			if predict_fn is not None:
516				for cfg_idx, cfg in enumerate(results_filtered.configs):
517					predictions = []
518					for p in p_dense:
519						cfg_temp = MazeDatasetConfig.load(cfg.serialize())
520						cfg_temp.maze_ctor_kwargs["p"] = p
521						predictions.append(predict_fn(cfg_temp))
522
523					# Get the same color as the actual data
524					n_cfgs: int = len(results_filtered.configs)
525					color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5))
526
527					# Plot prediction as dashed line
528					ax.plot(p_dense, predictions, "--", color=color, alpha=0.8)
529
530		# save and show
531		if save_dir:
532			save_path: Path = save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.svg"
533			print(f"Saving plot to {save_path.as_posix()}")
534			save_path.parent.mkdir(exist_ok=True, parents=True)
535			plt.savefig(save_path)
536
537		if show:
538			plt.show()

Plot grouped sweep percolation value results for each distinct endpoint_kwargs in the configs

with separate colormaps for each maze generator function

Parameters:

  • results : SweepResult The sweep results to plot
  • predict_fn : Callable[[MazeDatasetConfig], float] | None Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.
  • prediction_density : int Number of points to use for prediction curves (default: 50)
  • save_dir : Path | None Directory to save plots (defaults to None, meaning no saving)
  • show : bool Whether to display the plots (defaults to True)

Usage:

>>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
>>> plot_grouped(result, save_dir=Path("./plots"), show=False)