docs for maze-dataset v1.3.0
View Source on GitHub

maze_dataset.generation.generators

generation functions have signature (grid_shape: Coord, **kwargs) -> LatticeMaze and are methods in LatticeMazeGenerators


  1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`"""
  2
  3import random
  4import warnings
  5from typing import Any, Callable
  6
  7import numpy as np
  8from jaxtyping import Bool
  9
 10from maze_dataset.constants import CoordArray, CoordTup
 11from maze_dataset.generation.seed import GLOBAL_SEED
 12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze
 13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls
 14
 15numpy_rng = np.random.default_rng(GLOBAL_SEED)
 16random.seed(GLOBAL_SEED)
 17
 18
 19def _random_start_coord(
 20	grid_shape: Coord,
 21	start_coord: Coord | CoordTup | None,
 22) -> Coord:
 23	"picking a random start coord within the bounds of `grid_shape` if none is provided"
 24	start_coord_: Coord
 25	if start_coord is None:
 26		start_coord_ = np.random.randint(
 27			0,  # lower bound
 28			np.maximum(grid_shape - 1, 1),  # upper bound (at least 1)
 29			size=len(grid_shape),  # dimensionality
 30		)
 31	else:
 32		start_coord_ = np.array(start_coord)
 33
 34	return start_coord_
 35
 36
 37def get_neighbors_in_bounds(
 38	coord: Coord,
 39	grid_shape: Coord,
 40) -> CoordArray:
 41	"get all neighbors of a coordinate that are within the bounds of the grid"
 42	# get all neighbors
 43	neighbors: CoordArray = coord + NEIGHBORS_MASK
 44
 45	# filter neighbors by being within grid bounds
 46	neighbors_in_bounds: CoordArray = neighbors[
 47		(neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
 48	]
 49
 50	return neighbors_in_bounds
 51
 52
 53class LatticeMazeGenerators:
 54	"""namespace for lattice maze generation algorithms"""
 55
 56	@staticmethod
 57	def gen_dfs(
 58		grid_shape: Coord | CoordTup,
 59		lattice_dim: int = 2,
 60		accessible_cells: float | None = None,
 61		max_tree_depth: float | None = None,
 62		do_forks: bool = True,
 63		randomized_stack: bool = False,
 64		start_coord: Coord | None = None,
 65	) -> LatticeMaze:
 66		"""generate a lattice maze using depth first search, iterative
 67
 68		# Arguments
 69		- `grid_shape: Coord`: the shape of the grid
 70		- `lattice_dim: int`: the dimension of the lattice
 71			(default: `2`)
 72		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 73			(default: `None`)
 74		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 75			(default: `None`)
 76		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 77		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 78
 79		# algorithm
 80		1. Choose the initial cell, mark it as visited and push it to the stack
 81		2. While the stack is not empty
 82			1. Pop a cell from the stack and make it a current cell
 83			2. If the current cell has any neighbours which have not been visited
 84				1. Push the current cell to the stack
 85				2. Choose one of the unvisited neighbours
 86				3. Remove the wall between the current cell and the chosen cell
 87				4. Mark the chosen cell as visited and push it to the stack
 88		"""
 89		# Default values if no constraints have been passed
 90		grid_shape_: Coord = np.array(grid_shape)
 91		n_total_cells: int = int(np.prod(grid_shape_))
 92
 93		n_accessible_cells: int
 94		if accessible_cells is None:
 95			n_accessible_cells = n_total_cells
 96		elif isinstance(accessible_cells, float):
 97			assert accessible_cells <= 1, (
 98				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
 99			)
100
101			n_accessible_cells = int(accessible_cells * n_total_cells)
102		else:
103			assert isinstance(accessible_cells, int)
104			n_accessible_cells = accessible_cells
105
106		if max_tree_depth is None:
107			max_tree_depth = (
108				2 * n_total_cells
109			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
110		elif isinstance(max_tree_depth, float):
111			assert max_tree_depth <= 1, (
112				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
113			)
114
115			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
116
117		# choose a random start coord
118		start_coord = _random_start_coord(grid_shape_, start_coord)
119
120		# initialize the maze with no connections
121		connection_list: ConnectionList = np.zeros(
122			(lattice_dim, grid_shape_[0], grid_shape_[1]),
123			dtype=np.bool_,
124		)
125
126		# initialize the stack with the target coord
127		visited_cells: set[tuple[int, int]] = set()
128		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
129		stack: list[Coord] = [start_coord]
130
131		# initialize tree_depth_counter
132		current_tree_depth: int = 1
133
134		# loop until the stack is empty or n_connected_cells is reached
135		while stack and (len(visited_cells) < n_accessible_cells):
136			# get the current coord from the stack
137			current_coord: Coord
138			if randomized_stack:
139				# we dont care about S311 because this isnt security related
140				current_coord = stack.pop(random.randint(0, len(stack) - 1))  # noqa: S311
141			else:
142				current_coord = stack.pop()
143
144			# filter neighbors by being within grid bounds and being unvisited
145			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
146				(neighbor, delta)
147				for neighbor, delta in zip(
148					current_coord + NEIGHBORS_MASK,
149					NEIGHBORS_MASK,
150					strict=False,
151				)
152				if (
153					(tuple(neighbor) not in visited_cells)
154					and (0 <= neighbor[0] < grid_shape_[0])
155					and (0 <= neighbor[1] < grid_shape_[1])
156				)
157			]
158
159			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
160			if unvisited_neighbors_deltas and (
161				current_tree_depth <= max_tree_depth / 2
162			):
163				# if we want a maze without forks, simply don't add the current coord back to the stack
164				if do_forks and (len(unvisited_neighbors_deltas) > 1):
165					stack.append(current_coord)
166
167				# choose one of the unvisited neighbors
168				# we dont care about S311 because this isn't security related
169				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)  # noqa: S311
170
171				# add connection
172				dim: int = int(np.argmax(np.abs(delta)))
173				# if positive, down/right from current coord
174				# if negative, up/left from current coord (down/right from neighbor)
175				clist_node: Coord = (
176					current_coord if (delta.sum() > 0) else chosen_neighbor
177				)
178				connection_list[dim, clist_node[0], clist_node[1]] = True
179
180				# add to visited cells and stack
181				visited_cells.add(tuple(chosen_neighbor))
182				stack.append(chosen_neighbor)
183
184				# Update current tree depth
185				current_tree_depth += 1
186			else:
187				current_tree_depth -= 1
188
189		return LatticeMaze(
190			connection_list=connection_list,
191			generation_meta=dict(
192				func_name="gen_dfs",
193				grid_shape=grid_shape_,
194				start_coord=start_coord,
195				n_accessible_cells=int(n_accessible_cells),
196				max_tree_depth=int(max_tree_depth),
197				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
198				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
199				# treated as fully connected even when it is most certainly not, causing solving the maze to break
200				fully_connected=bool(len(visited_cells) == n_total_cells),
201				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
202			),
203		)
204
205	@staticmethod
206	def gen_prim(
207		grid_shape: Coord | CoordTup,
208		lattice_dim: int = 2,
209		accessible_cells: float | None = None,
210		max_tree_depth: float | None = None,
211		do_forks: bool = True,
212		start_coord: Coord | None = None,
213	) -> LatticeMaze:
214		"(broken!) generate a lattice maze using Prim's algorithm"
215		warnings.warn(
216			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
217		)
218		return LatticeMazeGenerators.gen_dfs(
219			grid_shape=grid_shape,
220			lattice_dim=lattice_dim,
221			accessible_cells=accessible_cells,
222			max_tree_depth=max_tree_depth,
223			do_forks=do_forks,
224			start_coord=start_coord,
225			randomized_stack=True,
226		)
227
228	@staticmethod
229	def gen_wilson(
230		grid_shape: Coord | CoordTup,
231		**kwargs,
232	) -> LatticeMaze:
233		"""Generate a lattice maze using Wilson's algorithm.
234
235		# Algorithm
236		Wilson's algorithm generates an unbiased (random) maze
237		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
238		acyclic and all cells are part of a unique connected space.
239		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
240		"""
241		assert not kwargs, (
242			f"gen_wilson does not take any additional arguments, got {kwargs = }"
243		)
244
245		grid_shape_: Coord = np.array(grid_shape)
246
247		# Initialize grid and visited cells
248		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
249		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
250
251		# Choose a random cell and mark it as visited
252		start_coord: Coord = _random_start_coord(grid_shape_, None)
253		visited[start_coord[0], start_coord[1]] = True
254		del start_coord
255
256		while not visited.all():
257			# Perform loop-erased random walk from another random cell
258
259			# Choose walk_start only from unvisited cells
260			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
261			walk_start: Coord = unvisited_coords[
262				np.random.choice(unvisited_coords.shape[0])
263			]
264
265			# Perform the random walk
266			path: list[Coord] = [walk_start]
267			current: Coord = walk_start
268
269			# exit the loop once the current path hits a visited cell
270			while not visited[current[0], current[1]]:
271				# find a valid neighbor (one always exists on a lattice)
272				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
273				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
274
275				# Check for loop
276				loop_exit: int | None = None
277				for i, p in enumerate(path):
278					if np.array_equal(next_cell, p):
279						loop_exit = i
280						break
281
282				# erase the loop, or continue the walk
283				if loop_exit is not None:
284					# this removes everything after and including the loop start
285					path = path[: loop_exit + 1]
286					# reset current cell to end of path
287					current = path[-1]
288				else:
289					path.append(next_cell)
290					current = next_cell
291
292			# Add the path to the maze
293			for i in range(len(path) - 1):
294				c_1: Coord = path[i]
295				c_2: Coord = path[i + 1]
296
297				# find the dimension of the connection
298				delta: Coord = c_2 - c_1
299				dim: int = int(np.argmax(np.abs(delta)))
300
301				# if positive, down/right from current coord
302				# if negative, up/left from current coord (down/right from neighbor)
303				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
304				connection_list[dim, clist_node[0], clist_node[1]] = True
305				visited[c_1[0], c_1[1]] = True
306				# we dont add c_2 because the last c_2 will have already been visited
307
308		return LatticeMaze(
309			connection_list=connection_list,
310			generation_meta=dict(
311				func_name="gen_wilson",
312				grid_shape=grid_shape_,
313				fully_connected=True,
314			),
315		)
316
317	@staticmethod
318	def gen_percolation(
319		grid_shape: Coord | CoordTup,
320		p: float = 0.4,
321		lattice_dim: int = 2,
322		start_coord: Coord | None = None,
323	) -> LatticeMaze:
324		"""generate a lattice maze using simple percolation
325
326		note that p in the range (0.4, 0.7) gives the most interesting mazes
327
328		# Arguments
329		- `grid_shape: Coord`: the shape of the grid
330		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
331		- `p: float`: the probability of a cell being accessible (default: `0.5`)
332		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
333		"""
334		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
335		grid_shape_: Coord = np.array(grid_shape)
336
337		start_coord = _random_start_coord(grid_shape_, start_coord)
338
339		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
340
341		connection_list = _fill_edges_with_walls(connection_list)
342
343		output: LatticeMaze = LatticeMaze(
344			connection_list=connection_list,
345			generation_meta=dict(
346				func_name="gen_percolation",
347				grid_shape=grid_shape_,
348				percolation_p=p,
349				start_coord=start_coord,
350			),
351		)
352
353		# generation_meta is sometimes None, but not here since we just made it a dict above
354		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
355			start_coord,
356		)
357
358		return output
359
360	@staticmethod
361	def gen_dfs_percolation(
362		grid_shape: Coord | CoordTup,
363		p: float = 0.4,
364		lattice_dim: int = 2,
365		accessible_cells: int | None = None,
366		max_tree_depth: int | None = None,
367		start_coord: Coord | None = None,
368	) -> LatticeMaze:
369		"""dfs and then percolation (adds cycles)"""
370		grid_shape_: Coord = np.array(grid_shape)
371		start_coord = _random_start_coord(grid_shape_, start_coord)
372
373		# generate initial maze via dfs
374		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
375			grid_shape=grid_shape_,
376			lattice_dim=lattice_dim,
377			accessible_cells=accessible_cells,
378			max_tree_depth=max_tree_depth,
379			start_coord=start_coord,
380		)
381
382		# percolate
383		connection_list_perc: np.ndarray = (
384			np.random.rand(*maze.connection_list.shape) < p
385		)
386		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
387
388		maze.__dict__["connection_list"] = np.logical_or(
389			maze.connection_list,
390			connection_list_perc,
391		)
392
393		# generation_meta is sometimes None, but not here since we just made it a dict above
394		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
395		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
396		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
397			start_coord,
398		)
399
400		return maze
401
402
403# cant automatically populate this because it messes with pickling :(
404GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = {
405	"gen_dfs": LatticeMazeGenerators.gen_dfs,
406	# TYPING: error: Dict entry 1 has incompatible type
407	# "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]";
408	# expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]"  [dict-item]
409	# gen_wilson takes no kwargs and we check that the kwargs are empty
410	# but mypy doesnt like this, `Any` != `KwArg(Any)`
411	"gen_wilson": LatticeMazeGenerators.gen_wilson,  # type: ignore[dict-item]
412	"gen_percolation": LatticeMazeGenerators.gen_percolation,
413	"gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation,
414	"gen_prim": LatticeMazeGenerators.gen_prim,
415}
416"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`"
417
418_GENERATORS_PERCOLATED: list[str] = [
419	"gen_percolation",
420	"gen_dfs_percolation",
421]
422"""list of generator names that generate percolated mazes
423we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail
424this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array`
425"""
426
427
428def get_maze_with_solution(
429	gen_name: str,
430	grid_shape: Coord | CoordTup,
431	maze_ctor_kwargs: dict | None = None,
432) -> SolvedMaze:
433	"helper function to get a maze already with a solution"
434	if maze_ctor_kwargs is None:
435		maze_ctor_kwargs = dict()
436	# TYPING: error: Too few arguments  [call-arg]
437	# not sure why this is happening -- doesnt recognize the kwargs?
438	maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)  # type: ignore[call-arg]
439	solution: CoordArray = np.array(maze.generate_random_path())
440	return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

numpy_rng = Generator(PCG64) at 0x7D1A85F4DB60
def get_neighbors_in_bounds( coord: jaxtyping.Int8[ndarray, 'row_col=2'], grid_shape: jaxtyping.Int8[ndarray, 'row_col=2']) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
38def get_neighbors_in_bounds(
39	coord: Coord,
40	grid_shape: Coord,
41) -> CoordArray:
42	"get all neighbors of a coordinate that are within the bounds of the grid"
43	# get all neighbors
44	neighbors: CoordArray = coord + NEIGHBORS_MASK
45
46	# filter neighbors by being within grid bounds
47	neighbors_in_bounds: CoordArray = neighbors[
48		(neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
49	]
50
51	return neighbors_in_bounds

get all neighbors of a coordinate that are within the bounds of the grid

class LatticeMazeGenerators:
 54class LatticeMazeGenerators:
 55	"""namespace for lattice maze generation algorithms"""
 56
 57	@staticmethod
 58	def gen_dfs(
 59		grid_shape: Coord | CoordTup,
 60		lattice_dim: int = 2,
 61		accessible_cells: float | None = None,
 62		max_tree_depth: float | None = None,
 63		do_forks: bool = True,
 64		randomized_stack: bool = False,
 65		start_coord: Coord | None = None,
 66	) -> LatticeMaze:
 67		"""generate a lattice maze using depth first search, iterative
 68
 69		# Arguments
 70		- `grid_shape: Coord`: the shape of the grid
 71		- `lattice_dim: int`: the dimension of the lattice
 72			(default: `2`)
 73		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 74			(default: `None`)
 75		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 76			(default: `None`)
 77		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 78		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 79
 80		# algorithm
 81		1. Choose the initial cell, mark it as visited and push it to the stack
 82		2. While the stack is not empty
 83			1. Pop a cell from the stack and make it a current cell
 84			2. If the current cell has any neighbours which have not been visited
 85				1. Push the current cell to the stack
 86				2. Choose one of the unvisited neighbours
 87				3. Remove the wall between the current cell and the chosen cell
 88				4. Mark the chosen cell as visited and push it to the stack
 89		"""
 90		# Default values if no constraints have been passed
 91		grid_shape_: Coord = np.array(grid_shape)
 92		n_total_cells: int = int(np.prod(grid_shape_))
 93
 94		n_accessible_cells: int
 95		if accessible_cells is None:
 96			n_accessible_cells = n_total_cells
 97		elif isinstance(accessible_cells, float):
 98			assert accessible_cells <= 1, (
 99				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
100			)
101
102			n_accessible_cells = int(accessible_cells * n_total_cells)
103		else:
104			assert isinstance(accessible_cells, int)
105			n_accessible_cells = accessible_cells
106
107		if max_tree_depth is None:
108			max_tree_depth = (
109				2 * n_total_cells
110			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
111		elif isinstance(max_tree_depth, float):
112			assert max_tree_depth <= 1, (
113				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
114			)
115
116			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
117
118		# choose a random start coord
119		start_coord = _random_start_coord(grid_shape_, start_coord)
120
121		# initialize the maze with no connections
122		connection_list: ConnectionList = np.zeros(
123			(lattice_dim, grid_shape_[0], grid_shape_[1]),
124			dtype=np.bool_,
125		)
126
127		# initialize the stack with the target coord
128		visited_cells: set[tuple[int, int]] = set()
129		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
130		stack: list[Coord] = [start_coord]
131
132		# initialize tree_depth_counter
133		current_tree_depth: int = 1
134
135		# loop until the stack is empty or n_connected_cells is reached
136		while stack and (len(visited_cells) < n_accessible_cells):
137			# get the current coord from the stack
138			current_coord: Coord
139			if randomized_stack:
140				# we dont care about S311 because this isnt security related
141				current_coord = stack.pop(random.randint(0, len(stack) - 1))  # noqa: S311
142			else:
143				current_coord = stack.pop()
144
145			# filter neighbors by being within grid bounds and being unvisited
146			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
147				(neighbor, delta)
148				for neighbor, delta in zip(
149					current_coord + NEIGHBORS_MASK,
150					NEIGHBORS_MASK,
151					strict=False,
152				)
153				if (
154					(tuple(neighbor) not in visited_cells)
155					and (0 <= neighbor[0] < grid_shape_[0])
156					and (0 <= neighbor[1] < grid_shape_[1])
157				)
158			]
159
160			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
161			if unvisited_neighbors_deltas and (
162				current_tree_depth <= max_tree_depth / 2
163			):
164				# if we want a maze without forks, simply don't add the current coord back to the stack
165				if do_forks and (len(unvisited_neighbors_deltas) > 1):
166					stack.append(current_coord)
167
168				# choose one of the unvisited neighbors
169				# we dont care about S311 because this isn't security related
170				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)  # noqa: S311
171
172				# add connection
173				dim: int = int(np.argmax(np.abs(delta)))
174				# if positive, down/right from current coord
175				# if negative, up/left from current coord (down/right from neighbor)
176				clist_node: Coord = (
177					current_coord if (delta.sum() > 0) else chosen_neighbor
178				)
179				connection_list[dim, clist_node[0], clist_node[1]] = True
180
181				# add to visited cells and stack
182				visited_cells.add(tuple(chosen_neighbor))
183				stack.append(chosen_neighbor)
184
185				# Update current tree depth
186				current_tree_depth += 1
187			else:
188				current_tree_depth -= 1
189
190		return LatticeMaze(
191			connection_list=connection_list,
192			generation_meta=dict(
193				func_name="gen_dfs",
194				grid_shape=grid_shape_,
195				start_coord=start_coord,
196				n_accessible_cells=int(n_accessible_cells),
197				max_tree_depth=int(max_tree_depth),
198				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
199				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
200				# treated as fully connected even when it is most certainly not, causing solving the maze to break
201				fully_connected=bool(len(visited_cells) == n_total_cells),
202				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
203			),
204		)
205
206	@staticmethod
207	def gen_prim(
208		grid_shape: Coord | CoordTup,
209		lattice_dim: int = 2,
210		accessible_cells: float | None = None,
211		max_tree_depth: float | None = None,
212		do_forks: bool = True,
213		start_coord: Coord | None = None,
214	) -> LatticeMaze:
215		"(broken!) generate a lattice maze using Prim's algorithm"
216		warnings.warn(
217			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
218		)
219		return LatticeMazeGenerators.gen_dfs(
220			grid_shape=grid_shape,
221			lattice_dim=lattice_dim,
222			accessible_cells=accessible_cells,
223			max_tree_depth=max_tree_depth,
224			do_forks=do_forks,
225			start_coord=start_coord,
226			randomized_stack=True,
227		)
228
229	@staticmethod
230	def gen_wilson(
231		grid_shape: Coord | CoordTup,
232		**kwargs,
233	) -> LatticeMaze:
234		"""Generate a lattice maze using Wilson's algorithm.
235
236		# Algorithm
237		Wilson's algorithm generates an unbiased (random) maze
238		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
239		acyclic and all cells are part of a unique connected space.
240		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
241		"""
242		assert not kwargs, (
243			f"gen_wilson does not take any additional arguments, got {kwargs = }"
244		)
245
246		grid_shape_: Coord = np.array(grid_shape)
247
248		# Initialize grid and visited cells
249		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
250		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
251
252		# Choose a random cell and mark it as visited
253		start_coord: Coord = _random_start_coord(grid_shape_, None)
254		visited[start_coord[0], start_coord[1]] = True
255		del start_coord
256
257		while not visited.all():
258			# Perform loop-erased random walk from another random cell
259
260			# Choose walk_start only from unvisited cells
261			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
262			walk_start: Coord = unvisited_coords[
263				np.random.choice(unvisited_coords.shape[0])
264			]
265
266			# Perform the random walk
267			path: list[Coord] = [walk_start]
268			current: Coord = walk_start
269
270			# exit the loop once the current path hits a visited cell
271			while not visited[current[0], current[1]]:
272				# find a valid neighbor (one always exists on a lattice)
273				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
274				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
275
276				# Check for loop
277				loop_exit: int | None = None
278				for i, p in enumerate(path):
279					if np.array_equal(next_cell, p):
280						loop_exit = i
281						break
282
283				# erase the loop, or continue the walk
284				if loop_exit is not None:
285					# this removes everything after and including the loop start
286					path = path[: loop_exit + 1]
287					# reset current cell to end of path
288					current = path[-1]
289				else:
290					path.append(next_cell)
291					current = next_cell
292
293			# Add the path to the maze
294			for i in range(len(path) - 1):
295				c_1: Coord = path[i]
296				c_2: Coord = path[i + 1]
297
298				# find the dimension of the connection
299				delta: Coord = c_2 - c_1
300				dim: int = int(np.argmax(np.abs(delta)))
301
302				# if positive, down/right from current coord
303				# if negative, up/left from current coord (down/right from neighbor)
304				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
305				connection_list[dim, clist_node[0], clist_node[1]] = True
306				visited[c_1[0], c_1[1]] = True
307				# we dont add c_2 because the last c_2 will have already been visited
308
309		return LatticeMaze(
310			connection_list=connection_list,
311			generation_meta=dict(
312				func_name="gen_wilson",
313				grid_shape=grid_shape_,
314				fully_connected=True,
315			),
316		)
317
318	@staticmethod
319	def gen_percolation(
320		grid_shape: Coord | CoordTup,
321		p: float = 0.4,
322		lattice_dim: int = 2,
323		start_coord: Coord | None = None,
324	) -> LatticeMaze:
325		"""generate a lattice maze using simple percolation
326
327		note that p in the range (0.4, 0.7) gives the most interesting mazes
328
329		# Arguments
330		- `grid_shape: Coord`: the shape of the grid
331		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
332		- `p: float`: the probability of a cell being accessible (default: `0.5`)
333		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
334		"""
335		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
336		grid_shape_: Coord = np.array(grid_shape)
337
338		start_coord = _random_start_coord(grid_shape_, start_coord)
339
340		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
341
342		connection_list = _fill_edges_with_walls(connection_list)
343
344		output: LatticeMaze = LatticeMaze(
345			connection_list=connection_list,
346			generation_meta=dict(
347				func_name="gen_percolation",
348				grid_shape=grid_shape_,
349				percolation_p=p,
350				start_coord=start_coord,
351			),
352		)
353
354		# generation_meta is sometimes None, but not here since we just made it a dict above
355		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
356			start_coord,
357		)
358
359		return output
360
361	@staticmethod
362	def gen_dfs_percolation(
363		grid_shape: Coord | CoordTup,
364		p: float = 0.4,
365		lattice_dim: int = 2,
366		accessible_cells: int | None = None,
367		max_tree_depth: int | None = None,
368		start_coord: Coord | None = None,
369	) -> LatticeMaze:
370		"""dfs and then percolation (adds cycles)"""
371		grid_shape_: Coord = np.array(grid_shape)
372		start_coord = _random_start_coord(grid_shape_, start_coord)
373
374		# generate initial maze via dfs
375		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
376			grid_shape=grid_shape_,
377			lattice_dim=lattice_dim,
378			accessible_cells=accessible_cells,
379			max_tree_depth=max_tree_depth,
380			start_coord=start_coord,
381		)
382
383		# percolate
384		connection_list_perc: np.ndarray = (
385			np.random.rand(*maze.connection_list.shape) < p
386		)
387		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
388
389		maze.__dict__["connection_list"] = np.logical_or(
390			maze.connection_list,
391			connection_list_perc,
392		)
393
394		# generation_meta is sometimes None, but not here since we just made it a dict above
395		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
396		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
397		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
398			start_coord,
399		)
400
401		return maze

namespace for lattice maze generation algorithms

@staticmethod
def gen_dfs( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, randomized_stack: bool = False, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
 57	@staticmethod
 58	def gen_dfs(
 59		grid_shape: Coord | CoordTup,
 60		lattice_dim: int = 2,
 61		accessible_cells: float | None = None,
 62		max_tree_depth: float | None = None,
 63		do_forks: bool = True,
 64		randomized_stack: bool = False,
 65		start_coord: Coord | None = None,
 66	) -> LatticeMaze:
 67		"""generate a lattice maze using depth first search, iterative
 68
 69		# Arguments
 70		- `grid_shape: Coord`: the shape of the grid
 71		- `lattice_dim: int`: the dimension of the lattice
 72			(default: `2`)
 73		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 74			(default: `None`)
 75		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 76			(default: `None`)
 77		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 78		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 79
 80		# algorithm
 81		1. Choose the initial cell, mark it as visited and push it to the stack
 82		2. While the stack is not empty
 83			1. Pop a cell from the stack and make it a current cell
 84			2. If the current cell has any neighbours which have not been visited
 85				1. Push the current cell to the stack
 86				2. Choose one of the unvisited neighbours
 87				3. Remove the wall between the current cell and the chosen cell
 88				4. Mark the chosen cell as visited and push it to the stack
 89		"""
 90		# Default values if no constraints have been passed
 91		grid_shape_: Coord = np.array(grid_shape)
 92		n_total_cells: int = int(np.prod(grid_shape_))
 93
 94		n_accessible_cells: int
 95		if accessible_cells is None:
 96			n_accessible_cells = n_total_cells
 97		elif isinstance(accessible_cells, float):
 98			assert accessible_cells <= 1, (
 99				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
100			)
101
102			n_accessible_cells = int(accessible_cells * n_total_cells)
103		else:
104			assert isinstance(accessible_cells, int)
105			n_accessible_cells = accessible_cells
106
107		if max_tree_depth is None:
108			max_tree_depth = (
109				2 * n_total_cells
110			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
111		elif isinstance(max_tree_depth, float):
112			assert max_tree_depth <= 1, (
113				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
114			)
115
116			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
117
118		# choose a random start coord
119		start_coord = _random_start_coord(grid_shape_, start_coord)
120
121		# initialize the maze with no connections
122		connection_list: ConnectionList = np.zeros(
123			(lattice_dim, grid_shape_[0], grid_shape_[1]),
124			dtype=np.bool_,
125		)
126
127		# initialize the stack with the target coord
128		visited_cells: set[tuple[int, int]] = set()
129		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
130		stack: list[Coord] = [start_coord]
131
132		# initialize tree_depth_counter
133		current_tree_depth: int = 1
134
135		# loop until the stack is empty or n_connected_cells is reached
136		while stack and (len(visited_cells) < n_accessible_cells):
137			# get the current coord from the stack
138			current_coord: Coord
139			if randomized_stack:
140				# we dont care about S311 because this isnt security related
141				current_coord = stack.pop(random.randint(0, len(stack) - 1))  # noqa: S311
142			else:
143				current_coord = stack.pop()
144
145			# filter neighbors by being within grid bounds and being unvisited
146			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
147				(neighbor, delta)
148				for neighbor, delta in zip(
149					current_coord + NEIGHBORS_MASK,
150					NEIGHBORS_MASK,
151					strict=False,
152				)
153				if (
154					(tuple(neighbor) not in visited_cells)
155					and (0 <= neighbor[0] < grid_shape_[0])
156					and (0 <= neighbor[1] < grid_shape_[1])
157				)
158			]
159
160			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
161			if unvisited_neighbors_deltas and (
162				current_tree_depth <= max_tree_depth / 2
163			):
164				# if we want a maze without forks, simply don't add the current coord back to the stack
165				if do_forks and (len(unvisited_neighbors_deltas) > 1):
166					stack.append(current_coord)
167
168				# choose one of the unvisited neighbors
169				# we dont care about S311 because this isn't security related
170				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)  # noqa: S311
171
172				# add connection
173				dim: int = int(np.argmax(np.abs(delta)))
174				# if positive, down/right from current coord
175				# if negative, up/left from current coord (down/right from neighbor)
176				clist_node: Coord = (
177					current_coord if (delta.sum() > 0) else chosen_neighbor
178				)
179				connection_list[dim, clist_node[0], clist_node[1]] = True
180
181				# add to visited cells and stack
182				visited_cells.add(tuple(chosen_neighbor))
183				stack.append(chosen_neighbor)
184
185				# Update current tree depth
186				current_tree_depth += 1
187			else:
188				current_tree_depth -= 1
189
190		return LatticeMaze(
191			connection_list=connection_list,
192			generation_meta=dict(
193				func_name="gen_dfs",
194				grid_shape=grid_shape_,
195				start_coord=start_coord,
196				n_accessible_cells=int(n_accessible_cells),
197				max_tree_depth=int(max_tree_depth),
198				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
199				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
200				# treated as fully connected even when it is most certainly not, causing solving the maze to break
201				fully_connected=bool(len(visited_cells) == n_total_cells),
202				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
203			),
204		)

generate a lattice maze using depth first search, iterative

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • accessible_cells: int | float |None: the number of accessible cells in the maze. If None, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default: None)
  • max_tree_depth: int | float | None: the maximum depth of the tree. If None, defaults to 2 * accessible_cells. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default: None)
  • do_forks: bool: whether to allow forks in the maze. If False, the maze will be have no forks and will be a simple hallway.
  • start_coord: Coord | None: the starting coordinate of the generation algorithm. If None, defaults to a random coordinate.

algorithm

  1. Choose the initial cell, mark it as visited and push it to the stack
  2. While the stack is not empty
    1. Pop a cell from the stack and make it a current cell
    2. If the current cell has any neighbours which have not been visited
      1. Push the current cell to the stack
      2. Choose one of the unvisited neighbours
      3. Remove the wall between the current cell and the chosen cell
      4. Mark the chosen cell as visited and push it to the stack
@staticmethod
def gen_prim( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
206	@staticmethod
207	def gen_prim(
208		grid_shape: Coord | CoordTup,
209		lattice_dim: int = 2,
210		accessible_cells: float | None = None,
211		max_tree_depth: float | None = None,
212		do_forks: bool = True,
213		start_coord: Coord | None = None,
214	) -> LatticeMaze:
215		"(broken!) generate a lattice maze using Prim's algorithm"
216		warnings.warn(
217			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
218		)
219		return LatticeMazeGenerators.gen_dfs(
220			grid_shape=grid_shape,
221			lattice_dim=lattice_dim,
222			accessible_cells=accessible_cells,
223			max_tree_depth=max_tree_depth,
224			do_forks=do_forks,
225			start_coord=start_coord,
226			randomized_stack=True,
227		)

(broken!) generate a lattice maze using Prim's algorithm

@staticmethod
def gen_wilson( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], **kwargs) -> maze_dataset.LatticeMaze:
229	@staticmethod
230	def gen_wilson(
231		grid_shape: Coord | CoordTup,
232		**kwargs,
233	) -> LatticeMaze:
234		"""Generate a lattice maze using Wilson's algorithm.
235
236		# Algorithm
237		Wilson's algorithm generates an unbiased (random) maze
238		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
239		acyclic and all cells are part of a unique connected space.
240		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
241		"""
242		assert not kwargs, (
243			f"gen_wilson does not take any additional arguments, got {kwargs = }"
244		)
245
246		grid_shape_: Coord = np.array(grid_shape)
247
248		# Initialize grid and visited cells
249		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
250		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
251
252		# Choose a random cell and mark it as visited
253		start_coord: Coord = _random_start_coord(grid_shape_, None)
254		visited[start_coord[0], start_coord[1]] = True
255		del start_coord
256
257		while not visited.all():
258			# Perform loop-erased random walk from another random cell
259
260			# Choose walk_start only from unvisited cells
261			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
262			walk_start: Coord = unvisited_coords[
263				np.random.choice(unvisited_coords.shape[0])
264			]
265
266			# Perform the random walk
267			path: list[Coord] = [walk_start]
268			current: Coord = walk_start
269
270			# exit the loop once the current path hits a visited cell
271			while not visited[current[0], current[1]]:
272				# find a valid neighbor (one always exists on a lattice)
273				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
274				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
275
276				# Check for loop
277				loop_exit: int | None = None
278				for i, p in enumerate(path):
279					if np.array_equal(next_cell, p):
280						loop_exit = i
281						break
282
283				# erase the loop, or continue the walk
284				if loop_exit is not None:
285					# this removes everything after and including the loop start
286					path = path[: loop_exit + 1]
287					# reset current cell to end of path
288					current = path[-1]
289				else:
290					path.append(next_cell)
291					current = next_cell
292
293			# Add the path to the maze
294			for i in range(len(path) - 1):
295				c_1: Coord = path[i]
296				c_2: Coord = path[i + 1]
297
298				# find the dimension of the connection
299				delta: Coord = c_2 - c_1
300				dim: int = int(np.argmax(np.abs(delta)))
301
302				# if positive, down/right from current coord
303				# if negative, up/left from current coord (down/right from neighbor)
304				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
305				connection_list[dim, clist_node[0], clist_node[1]] = True
306				visited[c_1[0], c_1[1]] = True
307				# we dont add c_2 because the last c_2 will have already been visited
308
309		return LatticeMaze(
310			connection_list=connection_list,
311			generation_meta=dict(
312				func_name="gen_wilson",
313				grid_shape=grid_shape_,
314				fully_connected=True,
315			),
316		)

Generate a lattice maze using Wilson's algorithm.

Algorithm

Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm

@staticmethod
def gen_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
318	@staticmethod
319	def gen_percolation(
320		grid_shape: Coord | CoordTup,
321		p: float = 0.4,
322		lattice_dim: int = 2,
323		start_coord: Coord | None = None,
324	) -> LatticeMaze:
325		"""generate a lattice maze using simple percolation
326
327		note that p in the range (0.4, 0.7) gives the most interesting mazes
328
329		# Arguments
330		- `grid_shape: Coord`: the shape of the grid
331		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
332		- `p: float`: the probability of a cell being accessible (default: `0.5`)
333		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
334		"""
335		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
336		grid_shape_: Coord = np.array(grid_shape)
337
338		start_coord = _random_start_coord(grid_shape_, start_coord)
339
340		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
341
342		connection_list = _fill_edges_with_walls(connection_list)
343
344		output: LatticeMaze = LatticeMaze(
345			connection_list=connection_list,
346			generation_meta=dict(
347				func_name="gen_percolation",
348				grid_shape=grid_shape_,
349				percolation_p=p,
350				start_coord=start_coord,
351			),
352		)
353
354		# generation_meta is sometimes None, but not here since we just made it a dict above
355		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
356			start_coord,
357		)
358
359		return output

generate a lattice maze using simple percolation

note that p in the range (0.4, 0.7) gives the most interesting mazes

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • p: float: the probability of a cell being accessible (default: 0.5)
  • start_coord: Coord | None: the starting coordinate for the connected component (default: None will give a random start)
@staticmethod
def gen_dfs_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, accessible_cells: int | None = None, max_tree_depth: int | None = None, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
361	@staticmethod
362	def gen_dfs_percolation(
363		grid_shape: Coord | CoordTup,
364		p: float = 0.4,
365		lattice_dim: int = 2,
366		accessible_cells: int | None = None,
367		max_tree_depth: int | None = None,
368		start_coord: Coord | None = None,
369	) -> LatticeMaze:
370		"""dfs and then percolation (adds cycles)"""
371		grid_shape_: Coord = np.array(grid_shape)
372		start_coord = _random_start_coord(grid_shape_, start_coord)
373
374		# generate initial maze via dfs
375		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
376			grid_shape=grid_shape_,
377			lattice_dim=lattice_dim,
378			accessible_cells=accessible_cells,
379			max_tree_depth=max_tree_depth,
380			start_coord=start_coord,
381		)
382
383		# percolate
384		connection_list_perc: np.ndarray = (
385			np.random.rand(*maze.connection_list.shape) < p
386		)
387		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
388
389		maze.__dict__["connection_list"] = np.logical_or(
390			maze.connection_list,
391			connection_list_perc,
392		)
393
394		# generation_meta is sometimes None, but not here since we just made it a dict above
395		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
396		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
397		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
398			start_coord,
399		)
400
401		return maze

dfs and then percolation (adds cycles)

GENERATORS_MAP: dict[str, typing.Callable[[jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], typing.Any], maze_dataset.LatticeMaze]] = {'gen_dfs': <function LatticeMazeGenerators.gen_dfs>, 'gen_wilson': <function LatticeMazeGenerators.gen_wilson>, 'gen_percolation': <function LatticeMazeGenerators.gen_percolation>, 'gen_dfs_percolation': <function LatticeMazeGenerators.gen_dfs_percolation>, 'gen_prim': <function LatticeMazeGenerators.gen_prim>}

mapping of generator names to generator functions, useful for loading MazeDatasetConfig

def get_maze_with_solution( gen_name: str, grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], maze_ctor_kwargs: dict | None = None) -> maze_dataset.SolvedMaze:
429def get_maze_with_solution(
430	gen_name: str,
431	grid_shape: Coord | CoordTup,
432	maze_ctor_kwargs: dict | None = None,
433) -> SolvedMaze:
434	"helper function to get a maze already with a solution"
435	if maze_ctor_kwargs is None:
436		maze_ctor_kwargs = dict()
437	# TYPING: error: Too few arguments  [call-arg]
438	# not sure why this is happening -- doesnt recognize the kwargs?
439	maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)  # type: ignore[call-arg]
440	solution: CoordArray = np.array(maze.generate_random_path())
441	return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

helper function to get a maze already with a solution