docs for maze-dataset v1.3.0
View Source on GitHub

maze_dataset.generation

generation functions have signature (grid_shape: Coord, **kwargs) -> LatticeMaze and are methods in LatticeMazeGenerators

DEFAULT_GENERATORS is a list of generator name, generator kwargs pairs used in tests and demos


 1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`
 2
 3`DEFAULT_GENERATORS` is a list of generator name, generator kwargs pairs used in tests and demos
 4
 5"""
 6
 7from maze_dataset.generation.generators import (
 8	GENERATORS_MAP,
 9	LatticeMazeGenerators,
10	get_maze_with_solution,
11	numpy_rng,
12)
13
14__all__ = [
15	# submodules
16	"default_generators",
17	"generators",
18	"seed",
19	# imports
20	"LatticeMazeGenerators",
21	"GENERATORS_MAP",
22	"get_maze_with_solution",
23	"numpy_rng",
24]

class LatticeMazeGenerators:
 54class LatticeMazeGenerators:
 55	"""namespace for lattice maze generation algorithms"""
 56
 57	@staticmethod
 58	def gen_dfs(
 59		grid_shape: Coord | CoordTup,
 60		lattice_dim: int = 2,
 61		accessible_cells: float | None = None,
 62		max_tree_depth: float | None = None,
 63		do_forks: bool = True,
 64		randomized_stack: bool = False,
 65		start_coord: Coord | None = None,
 66	) -> LatticeMaze:
 67		"""generate a lattice maze using depth first search, iterative
 68
 69		# Arguments
 70		- `grid_shape: Coord`: the shape of the grid
 71		- `lattice_dim: int`: the dimension of the lattice
 72			(default: `2`)
 73		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 74			(default: `None`)
 75		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 76			(default: `None`)
 77		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 78		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 79
 80		# algorithm
 81		1. Choose the initial cell, mark it as visited and push it to the stack
 82		2. While the stack is not empty
 83			1. Pop a cell from the stack and make it a current cell
 84			2. If the current cell has any neighbours which have not been visited
 85				1. Push the current cell to the stack
 86				2. Choose one of the unvisited neighbours
 87				3. Remove the wall between the current cell and the chosen cell
 88				4. Mark the chosen cell as visited and push it to the stack
 89		"""
 90		# Default values if no constraints have been passed
 91		grid_shape_: Coord = np.array(grid_shape)
 92		n_total_cells: int = int(np.prod(grid_shape_))
 93
 94		n_accessible_cells: int
 95		if accessible_cells is None:
 96			n_accessible_cells = n_total_cells
 97		elif isinstance(accessible_cells, float):
 98			assert accessible_cells <= 1, (
 99				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
100			)
101
102			n_accessible_cells = int(accessible_cells * n_total_cells)
103		else:
104			assert isinstance(accessible_cells, int)
105			n_accessible_cells = accessible_cells
106
107		if max_tree_depth is None:
108			max_tree_depth = (
109				2 * n_total_cells
110			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
111		elif isinstance(max_tree_depth, float):
112			assert max_tree_depth <= 1, (
113				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
114			)
115
116			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
117
118		# choose a random start coord
119		start_coord = _random_start_coord(grid_shape_, start_coord)
120
121		# initialize the maze with no connections
122		connection_list: ConnectionList = np.zeros(
123			(lattice_dim, grid_shape_[0], grid_shape_[1]),
124			dtype=np.bool_,
125		)
126
127		# initialize the stack with the target coord
128		visited_cells: set[tuple[int, int]] = set()
129		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
130		stack: list[Coord] = [start_coord]
131
132		# initialize tree_depth_counter
133		current_tree_depth: int = 1
134
135		# loop until the stack is empty or n_connected_cells is reached
136		while stack and (len(visited_cells) < n_accessible_cells):
137			# get the current coord from the stack
138			current_coord: Coord
139			if randomized_stack:
140				# we dont care about S311 because this isnt security related
141				current_coord = stack.pop(random.randint(0, len(stack) - 1))  # noqa: S311
142			else:
143				current_coord = stack.pop()
144
145			# filter neighbors by being within grid bounds and being unvisited
146			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
147				(neighbor, delta)
148				for neighbor, delta in zip(
149					current_coord + NEIGHBORS_MASK,
150					NEIGHBORS_MASK,
151					strict=False,
152				)
153				if (
154					(tuple(neighbor) not in visited_cells)
155					and (0 <= neighbor[0] < grid_shape_[0])
156					and (0 <= neighbor[1] < grid_shape_[1])
157				)
158			]
159
160			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
161			if unvisited_neighbors_deltas and (
162				current_tree_depth <= max_tree_depth / 2
163			):
164				# if we want a maze without forks, simply don't add the current coord back to the stack
165				if do_forks and (len(unvisited_neighbors_deltas) > 1):
166					stack.append(current_coord)
167
168				# choose one of the unvisited neighbors
169				# we dont care about S311 because this isn't security related
170				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)  # noqa: S311
171
172				# add connection
173				dim: int = int(np.argmax(np.abs(delta)))
174				# if positive, down/right from current coord
175				# if negative, up/left from current coord (down/right from neighbor)
176				clist_node: Coord = (
177					current_coord if (delta.sum() > 0) else chosen_neighbor
178				)
179				connection_list[dim, clist_node[0], clist_node[1]] = True
180
181				# add to visited cells and stack
182				visited_cells.add(tuple(chosen_neighbor))
183				stack.append(chosen_neighbor)
184
185				# Update current tree depth
186				current_tree_depth += 1
187			else:
188				current_tree_depth -= 1
189
190		return LatticeMaze(
191			connection_list=connection_list,
192			generation_meta=dict(
193				func_name="gen_dfs",
194				grid_shape=grid_shape_,
195				start_coord=start_coord,
196				n_accessible_cells=int(n_accessible_cells),
197				max_tree_depth=int(max_tree_depth),
198				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
199				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
200				# treated as fully connected even when it is most certainly not, causing solving the maze to break
201				fully_connected=bool(len(visited_cells) == n_total_cells),
202				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
203			),
204		)
205
206	@staticmethod
207	def gen_prim(
208		grid_shape: Coord | CoordTup,
209		lattice_dim: int = 2,
210		accessible_cells: float | None = None,
211		max_tree_depth: float | None = None,
212		do_forks: bool = True,
213		start_coord: Coord | None = None,
214	) -> LatticeMaze:
215		"(broken!) generate a lattice maze using Prim's algorithm"
216		warnings.warn(
217			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
218		)
219		return LatticeMazeGenerators.gen_dfs(
220			grid_shape=grid_shape,
221			lattice_dim=lattice_dim,
222			accessible_cells=accessible_cells,
223			max_tree_depth=max_tree_depth,
224			do_forks=do_forks,
225			start_coord=start_coord,
226			randomized_stack=True,
227		)
228
229	@staticmethod
230	def gen_wilson(
231		grid_shape: Coord | CoordTup,
232		**kwargs,
233	) -> LatticeMaze:
234		"""Generate a lattice maze using Wilson's algorithm.
235
236		# Algorithm
237		Wilson's algorithm generates an unbiased (random) maze
238		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
239		acyclic and all cells are part of a unique connected space.
240		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
241		"""
242		assert not kwargs, (
243			f"gen_wilson does not take any additional arguments, got {kwargs = }"
244		)
245
246		grid_shape_: Coord = np.array(grid_shape)
247
248		# Initialize grid and visited cells
249		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
250		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
251
252		# Choose a random cell and mark it as visited
253		start_coord: Coord = _random_start_coord(grid_shape_, None)
254		visited[start_coord[0], start_coord[1]] = True
255		del start_coord
256
257		while not visited.all():
258			# Perform loop-erased random walk from another random cell
259
260			# Choose walk_start only from unvisited cells
261			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
262			walk_start: Coord = unvisited_coords[
263				np.random.choice(unvisited_coords.shape[0])
264			]
265
266			# Perform the random walk
267			path: list[Coord] = [walk_start]
268			current: Coord = walk_start
269
270			# exit the loop once the current path hits a visited cell
271			while not visited[current[0], current[1]]:
272				# find a valid neighbor (one always exists on a lattice)
273				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
274				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
275
276				# Check for loop
277				loop_exit: int | None = None
278				for i, p in enumerate(path):
279					if np.array_equal(next_cell, p):
280						loop_exit = i
281						break
282
283				# erase the loop, or continue the walk
284				if loop_exit is not None:
285					# this removes everything after and including the loop start
286					path = path[: loop_exit + 1]
287					# reset current cell to end of path
288					current = path[-1]
289				else:
290					path.append(next_cell)
291					current = next_cell
292
293			# Add the path to the maze
294			for i in range(len(path) - 1):
295				c_1: Coord = path[i]
296				c_2: Coord = path[i + 1]
297
298				# find the dimension of the connection
299				delta: Coord = c_2 - c_1
300				dim: int = int(np.argmax(np.abs(delta)))
301
302				# if positive, down/right from current coord
303				# if negative, up/left from current coord (down/right from neighbor)
304				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
305				connection_list[dim, clist_node[0], clist_node[1]] = True
306				visited[c_1[0], c_1[1]] = True
307				# we dont add c_2 because the last c_2 will have already been visited
308
309		return LatticeMaze(
310			connection_list=connection_list,
311			generation_meta=dict(
312				func_name="gen_wilson",
313				grid_shape=grid_shape_,
314				fully_connected=True,
315			),
316		)
317
318	@staticmethod
319	def gen_percolation(
320		grid_shape: Coord | CoordTup,
321		p: float = 0.4,
322		lattice_dim: int = 2,
323		start_coord: Coord | None = None,
324	) -> LatticeMaze:
325		"""generate a lattice maze using simple percolation
326
327		note that p in the range (0.4, 0.7) gives the most interesting mazes
328
329		# Arguments
330		- `grid_shape: Coord`: the shape of the grid
331		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
332		- `p: float`: the probability of a cell being accessible (default: `0.5`)
333		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
334		"""
335		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
336		grid_shape_: Coord = np.array(grid_shape)
337
338		start_coord = _random_start_coord(grid_shape_, start_coord)
339
340		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
341
342		connection_list = _fill_edges_with_walls(connection_list)
343
344		output: LatticeMaze = LatticeMaze(
345			connection_list=connection_list,
346			generation_meta=dict(
347				func_name="gen_percolation",
348				grid_shape=grid_shape_,
349				percolation_p=p,
350				start_coord=start_coord,
351			),
352		)
353
354		# generation_meta is sometimes None, but not here since we just made it a dict above
355		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
356			start_coord,
357		)
358
359		return output
360
361	@staticmethod
362	def gen_dfs_percolation(
363		grid_shape: Coord | CoordTup,
364		p: float = 0.4,
365		lattice_dim: int = 2,
366		accessible_cells: int | None = None,
367		max_tree_depth: int | None = None,
368		start_coord: Coord | None = None,
369	) -> LatticeMaze:
370		"""dfs and then percolation (adds cycles)"""
371		grid_shape_: Coord = np.array(grid_shape)
372		start_coord = _random_start_coord(grid_shape_, start_coord)
373
374		# generate initial maze via dfs
375		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
376			grid_shape=grid_shape_,
377			lattice_dim=lattice_dim,
378			accessible_cells=accessible_cells,
379			max_tree_depth=max_tree_depth,
380			start_coord=start_coord,
381		)
382
383		# percolate
384		connection_list_perc: np.ndarray = (
385			np.random.rand(*maze.connection_list.shape) < p
386		)
387		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
388
389		maze.__dict__["connection_list"] = np.logical_or(
390			maze.connection_list,
391			connection_list_perc,
392		)
393
394		# generation_meta is sometimes None, but not here since we just made it a dict above
395		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
396		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
397		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
398			start_coord,
399		)
400
401		return maze

namespace for lattice maze generation algorithms

@staticmethod
def gen_dfs( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, randomized_stack: bool = False, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
 57	@staticmethod
 58	def gen_dfs(
 59		grid_shape: Coord | CoordTup,
 60		lattice_dim: int = 2,
 61		accessible_cells: float | None = None,
 62		max_tree_depth: float | None = None,
 63		do_forks: bool = True,
 64		randomized_stack: bool = False,
 65		start_coord: Coord | None = None,
 66	) -> LatticeMaze:
 67		"""generate a lattice maze using depth first search, iterative
 68
 69		# Arguments
 70		- `grid_shape: Coord`: the shape of the grid
 71		- `lattice_dim: int`: the dimension of the lattice
 72			(default: `2`)
 73		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 74			(default: `None`)
 75		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 76			(default: `None`)
 77		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 78		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 79
 80		# algorithm
 81		1. Choose the initial cell, mark it as visited and push it to the stack
 82		2. While the stack is not empty
 83			1. Pop a cell from the stack and make it a current cell
 84			2. If the current cell has any neighbours which have not been visited
 85				1. Push the current cell to the stack
 86				2. Choose one of the unvisited neighbours
 87				3. Remove the wall between the current cell and the chosen cell
 88				4. Mark the chosen cell as visited and push it to the stack
 89		"""
 90		# Default values if no constraints have been passed
 91		grid_shape_: Coord = np.array(grid_shape)
 92		n_total_cells: int = int(np.prod(grid_shape_))
 93
 94		n_accessible_cells: int
 95		if accessible_cells is None:
 96			n_accessible_cells = n_total_cells
 97		elif isinstance(accessible_cells, float):
 98			assert accessible_cells <= 1, (
 99				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
100			)
101
102			n_accessible_cells = int(accessible_cells * n_total_cells)
103		else:
104			assert isinstance(accessible_cells, int)
105			n_accessible_cells = accessible_cells
106
107		if max_tree_depth is None:
108			max_tree_depth = (
109				2 * n_total_cells
110			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
111		elif isinstance(max_tree_depth, float):
112			assert max_tree_depth <= 1, (
113				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
114			)
115
116			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
117
118		# choose a random start coord
119		start_coord = _random_start_coord(grid_shape_, start_coord)
120
121		# initialize the maze with no connections
122		connection_list: ConnectionList = np.zeros(
123			(lattice_dim, grid_shape_[0], grid_shape_[1]),
124			dtype=np.bool_,
125		)
126
127		# initialize the stack with the target coord
128		visited_cells: set[tuple[int, int]] = set()
129		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
130		stack: list[Coord] = [start_coord]
131
132		# initialize tree_depth_counter
133		current_tree_depth: int = 1
134
135		# loop until the stack is empty or n_connected_cells is reached
136		while stack and (len(visited_cells) < n_accessible_cells):
137			# get the current coord from the stack
138			current_coord: Coord
139			if randomized_stack:
140				# we dont care about S311 because this isnt security related
141				current_coord = stack.pop(random.randint(0, len(stack) - 1))  # noqa: S311
142			else:
143				current_coord = stack.pop()
144
145			# filter neighbors by being within grid bounds and being unvisited
146			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
147				(neighbor, delta)
148				for neighbor, delta in zip(
149					current_coord + NEIGHBORS_MASK,
150					NEIGHBORS_MASK,
151					strict=False,
152				)
153				if (
154					(tuple(neighbor) not in visited_cells)
155					and (0 <= neighbor[0] < grid_shape_[0])
156					and (0 <= neighbor[1] < grid_shape_[1])
157				)
158			]
159
160			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
161			if unvisited_neighbors_deltas and (
162				current_tree_depth <= max_tree_depth / 2
163			):
164				# if we want a maze without forks, simply don't add the current coord back to the stack
165				if do_forks and (len(unvisited_neighbors_deltas) > 1):
166					stack.append(current_coord)
167
168				# choose one of the unvisited neighbors
169				# we dont care about S311 because this isn't security related
170				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)  # noqa: S311
171
172				# add connection
173				dim: int = int(np.argmax(np.abs(delta)))
174				# if positive, down/right from current coord
175				# if negative, up/left from current coord (down/right from neighbor)
176				clist_node: Coord = (
177					current_coord if (delta.sum() > 0) else chosen_neighbor
178				)
179				connection_list[dim, clist_node[0], clist_node[1]] = True
180
181				# add to visited cells and stack
182				visited_cells.add(tuple(chosen_neighbor))
183				stack.append(chosen_neighbor)
184
185				# Update current tree depth
186				current_tree_depth += 1
187			else:
188				current_tree_depth -= 1
189
190		return LatticeMaze(
191			connection_list=connection_list,
192			generation_meta=dict(
193				func_name="gen_dfs",
194				grid_shape=grid_shape_,
195				start_coord=start_coord,
196				n_accessible_cells=int(n_accessible_cells),
197				max_tree_depth=int(max_tree_depth),
198				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
199				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
200				# treated as fully connected even when it is most certainly not, causing solving the maze to break
201				fully_connected=bool(len(visited_cells) == n_total_cells),
202				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
203			),
204		)

generate a lattice maze using depth first search, iterative

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • accessible_cells: int | float |None: the number of accessible cells in the maze. If None, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default: None)
  • max_tree_depth: int | float | None: the maximum depth of the tree. If None, defaults to 2 * accessible_cells. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default: None)
  • do_forks: bool: whether to allow forks in the maze. If False, the maze will be have no forks and will be a simple hallway.
  • start_coord: Coord | None: the starting coordinate of the generation algorithm. If None, defaults to a random coordinate.

algorithm

  1. Choose the initial cell, mark it as visited and push it to the stack
  2. While the stack is not empty
    1. Pop a cell from the stack and make it a current cell
    2. If the current cell has any neighbours which have not been visited
      1. Push the current cell to the stack
      2. Choose one of the unvisited neighbours
      3. Remove the wall between the current cell and the chosen cell
      4. Mark the chosen cell as visited and push it to the stack
@staticmethod
def gen_prim( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
206	@staticmethod
207	def gen_prim(
208		grid_shape: Coord | CoordTup,
209		lattice_dim: int = 2,
210		accessible_cells: float | None = None,
211		max_tree_depth: float | None = None,
212		do_forks: bool = True,
213		start_coord: Coord | None = None,
214	) -> LatticeMaze:
215		"(broken!) generate a lattice maze using Prim's algorithm"
216		warnings.warn(
217			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
218		)
219		return LatticeMazeGenerators.gen_dfs(
220			grid_shape=grid_shape,
221			lattice_dim=lattice_dim,
222			accessible_cells=accessible_cells,
223			max_tree_depth=max_tree_depth,
224			do_forks=do_forks,
225			start_coord=start_coord,
226			randomized_stack=True,
227		)

(broken!) generate a lattice maze using Prim's algorithm

@staticmethod
def gen_wilson( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], **kwargs) -> maze_dataset.LatticeMaze:
229	@staticmethod
230	def gen_wilson(
231		grid_shape: Coord | CoordTup,
232		**kwargs,
233	) -> LatticeMaze:
234		"""Generate a lattice maze using Wilson's algorithm.
235
236		# Algorithm
237		Wilson's algorithm generates an unbiased (random) maze
238		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
239		acyclic and all cells are part of a unique connected space.
240		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
241		"""
242		assert not kwargs, (
243			f"gen_wilson does not take any additional arguments, got {kwargs = }"
244		)
245
246		grid_shape_: Coord = np.array(grid_shape)
247
248		# Initialize grid and visited cells
249		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
250		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
251
252		# Choose a random cell and mark it as visited
253		start_coord: Coord = _random_start_coord(grid_shape_, None)
254		visited[start_coord[0], start_coord[1]] = True
255		del start_coord
256
257		while not visited.all():
258			# Perform loop-erased random walk from another random cell
259
260			# Choose walk_start only from unvisited cells
261			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
262			walk_start: Coord = unvisited_coords[
263				np.random.choice(unvisited_coords.shape[0])
264			]
265
266			# Perform the random walk
267			path: list[Coord] = [walk_start]
268			current: Coord = walk_start
269
270			# exit the loop once the current path hits a visited cell
271			while not visited[current[0], current[1]]:
272				# find a valid neighbor (one always exists on a lattice)
273				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
274				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
275
276				# Check for loop
277				loop_exit: int | None = None
278				for i, p in enumerate(path):
279					if np.array_equal(next_cell, p):
280						loop_exit = i
281						break
282
283				# erase the loop, or continue the walk
284				if loop_exit is not None:
285					# this removes everything after and including the loop start
286					path = path[: loop_exit + 1]
287					# reset current cell to end of path
288					current = path[-1]
289				else:
290					path.append(next_cell)
291					current = next_cell
292
293			# Add the path to the maze
294			for i in range(len(path) - 1):
295				c_1: Coord = path[i]
296				c_2: Coord = path[i + 1]
297
298				# find the dimension of the connection
299				delta: Coord = c_2 - c_1
300				dim: int = int(np.argmax(np.abs(delta)))
301
302				# if positive, down/right from current coord
303				# if negative, up/left from current coord (down/right from neighbor)
304				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
305				connection_list[dim, clist_node[0], clist_node[1]] = True
306				visited[c_1[0], c_1[1]] = True
307				# we dont add c_2 because the last c_2 will have already been visited
308
309		return LatticeMaze(
310			connection_list=connection_list,
311			generation_meta=dict(
312				func_name="gen_wilson",
313				grid_shape=grid_shape_,
314				fully_connected=True,
315			),
316		)

Generate a lattice maze using Wilson's algorithm.

Algorithm

Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm

@staticmethod
def gen_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
318	@staticmethod
319	def gen_percolation(
320		grid_shape: Coord | CoordTup,
321		p: float = 0.4,
322		lattice_dim: int = 2,
323		start_coord: Coord | None = None,
324	) -> LatticeMaze:
325		"""generate a lattice maze using simple percolation
326
327		note that p in the range (0.4, 0.7) gives the most interesting mazes
328
329		# Arguments
330		- `grid_shape: Coord`: the shape of the grid
331		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
332		- `p: float`: the probability of a cell being accessible (default: `0.5`)
333		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
334		"""
335		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
336		grid_shape_: Coord = np.array(grid_shape)
337
338		start_coord = _random_start_coord(grid_shape_, start_coord)
339
340		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
341
342		connection_list = _fill_edges_with_walls(connection_list)
343
344		output: LatticeMaze = LatticeMaze(
345			connection_list=connection_list,
346			generation_meta=dict(
347				func_name="gen_percolation",
348				grid_shape=grid_shape_,
349				percolation_p=p,
350				start_coord=start_coord,
351			),
352		)
353
354		# generation_meta is sometimes None, but not here since we just made it a dict above
355		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
356			start_coord,
357		)
358
359		return output

generate a lattice maze using simple percolation

note that p in the range (0.4, 0.7) gives the most interesting mazes

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • p: float: the probability of a cell being accessible (default: 0.5)
  • start_coord: Coord | None: the starting coordinate for the connected component (default: None will give a random start)
@staticmethod
def gen_dfs_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, accessible_cells: int | None = None, max_tree_depth: int | None = None, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
361	@staticmethod
362	def gen_dfs_percolation(
363		grid_shape: Coord | CoordTup,
364		p: float = 0.4,
365		lattice_dim: int = 2,
366		accessible_cells: int | None = None,
367		max_tree_depth: int | None = None,
368		start_coord: Coord | None = None,
369	) -> LatticeMaze:
370		"""dfs and then percolation (adds cycles)"""
371		grid_shape_: Coord = np.array(grid_shape)
372		start_coord = _random_start_coord(grid_shape_, start_coord)
373
374		# generate initial maze via dfs
375		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
376			grid_shape=grid_shape_,
377			lattice_dim=lattice_dim,
378			accessible_cells=accessible_cells,
379			max_tree_depth=max_tree_depth,
380			start_coord=start_coord,
381		)
382
383		# percolate
384		connection_list_perc: np.ndarray = (
385			np.random.rand(*maze.connection_list.shape) < p
386		)
387		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
388
389		maze.__dict__["connection_list"] = np.logical_or(
390			maze.connection_list,
391			connection_list_perc,
392		)
393
394		# generation_meta is sometimes None, but not here since we just made it a dict above
395		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
396		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
397		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
398			start_coord,
399		)
400
401		return maze

dfs and then percolation (adds cycles)

GENERATORS_MAP = {'gen_dfs': <function LatticeMazeGenerators.gen_dfs>, 'gen_wilson': <function LatticeMazeGenerators.gen_wilson>, 'gen_percolation': <function LatticeMazeGenerators.gen_percolation>, 'gen_dfs_percolation': <function LatticeMazeGenerators.gen_dfs_percolation>, 'gen_prim': <function LatticeMazeGenerators.gen_prim>}
def get_maze_with_solution( gen_name: str, grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], maze_ctor_kwargs: dict | None = None) -> maze_dataset.SolvedMaze:
429def get_maze_with_solution(
430	gen_name: str,
431	grid_shape: Coord | CoordTup,
432	maze_ctor_kwargs: dict | None = None,
433) -> SolvedMaze:
434	"helper function to get a maze already with a solution"
435	if maze_ctor_kwargs is None:
436		maze_ctor_kwargs = dict()
437	# TYPING: error: Too few arguments  [call-arg]
438	# not sure why this is happening -- doesnt recognize the kwargs?
439	maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)  # type: ignore[call-arg]
440	solution: CoordArray = np.array(maze.generate_random_path())
441	return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

helper function to get a maze already with a solution

numpy_rng = Generator(PCG64) at 0x7D1A85F4DB60