Coverage for maze_dataset/generation/generators.py: 77%
132 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`"""
3import random
4import warnings
5from typing import Any, Callable
7import numpy as np
8from jaxtyping import Bool
10from maze_dataset.constants import CoordArray, CoordTup
11from maze_dataset.generation.seed import GLOBAL_SEED
12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze
13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls
15numpy_rng = np.random.default_rng(GLOBAL_SEED)
16random.seed(GLOBAL_SEED)
19def _random_start_coord(
20 grid_shape: Coord,
21 start_coord: Coord | CoordTup | None,
22) -> Coord:
23 "picking a random start coord within the bounds of `grid_shape` if none is provided"
24 start_coord_: Coord
25 if start_coord is None:
26 start_coord_ = np.random.randint(
27 0, # lower bound
28 np.maximum(grid_shape - 1, 1), # upper bound (at least 1)
29 size=len(grid_shape), # dimensionality
30 )
31 else:
32 start_coord_ = np.array(start_coord)
34 return start_coord_
37def get_neighbors_in_bounds(
38 coord: Coord,
39 grid_shape: Coord,
40) -> CoordArray:
41 "get all neighbors of a coordinate that are within the bounds of the grid"
42 # get all neighbors
43 neighbors: CoordArray = coord + NEIGHBORS_MASK
45 # filter neighbors by being within grid bounds
46 neighbors_in_bounds: CoordArray = neighbors[
47 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
48 ]
50 return neighbors_in_bounds
53class LatticeMazeGenerators:
54 """namespace for lattice maze generation algorithms"""
56 @staticmethod
57 def gen_dfs(
58 grid_shape: Coord | CoordTup,
59 lattice_dim: int = 2,
60 accessible_cells: float | None = None,
61 max_tree_depth: float | None = None,
62 do_forks: bool = True,
63 randomized_stack: bool = False,
64 start_coord: Coord | None = None,
65 ) -> LatticeMaze:
66 """generate a lattice maze using depth first search, iterative
68 # Arguments
69 - `grid_shape: Coord`: the shape of the grid
70 - `lattice_dim: int`: the dimension of the lattice
71 (default: `2`)
72 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
73 (default: `None`)
74 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
75 (default: `None`)
76 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
77 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
79 # algorithm
80 1. Choose the initial cell, mark it as visited and push it to the stack
81 2. While the stack is not empty
82 1. Pop a cell from the stack and make it a current cell
83 2. If the current cell has any neighbours which have not been visited
84 1. Push the current cell to the stack
85 2. Choose one of the unvisited neighbours
86 3. Remove the wall between the current cell and the chosen cell
87 4. Mark the chosen cell as visited and push it to the stack
88 """
89 # Default values if no constraints have been passed
90 grid_shape_: Coord = np.array(grid_shape)
91 n_total_cells: int = int(np.prod(grid_shape_))
93 n_accessible_cells: int
94 if accessible_cells is None:
95 n_accessible_cells = n_total_cells
96 elif isinstance(accessible_cells, float):
97 assert accessible_cells <= 1, (
98 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
99 )
101 n_accessible_cells = int(accessible_cells * n_total_cells)
102 else:
103 assert isinstance(accessible_cells, int)
104 n_accessible_cells = accessible_cells
106 if max_tree_depth is None:
107 max_tree_depth = (
108 2 * n_total_cells
109 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
110 elif isinstance(max_tree_depth, float):
111 assert max_tree_depth <= 1, (
112 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
113 )
115 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
117 # choose a random start coord
118 start_coord = _random_start_coord(grid_shape_, start_coord)
120 # initialize the maze with no connections
121 connection_list: ConnectionList = np.zeros(
122 (lattice_dim, grid_shape_[0], grid_shape_[1]),
123 dtype=np.bool_,
124 )
126 # initialize the stack with the target coord
127 visited_cells: set[tuple[int, int]] = set()
128 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol
129 stack: list[Coord] = [start_coord]
131 # initialize tree_depth_counter
132 current_tree_depth: int = 1
134 # loop until the stack is empty or n_connected_cells is reached
135 while stack and (len(visited_cells) < n_accessible_cells):
136 # get the current coord from the stack
137 current_coord: Coord
138 if randomized_stack:
139 # we dont care about S311 because this isnt security related
140 current_coord = stack.pop(random.randint(0, len(stack) - 1)) # noqa: S311
141 else:
142 current_coord = stack.pop()
144 # filter neighbors by being within grid bounds and being unvisited
145 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
146 (neighbor, delta)
147 for neighbor, delta in zip(
148 current_coord + NEIGHBORS_MASK,
149 NEIGHBORS_MASK,
150 strict=False,
151 )
152 if (
153 (tuple(neighbor) not in visited_cells)
154 and (0 <= neighbor[0] < grid_shape_[0])
155 and (0 <= neighbor[1] < grid_shape_[1])
156 )
157 ]
159 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
160 if unvisited_neighbors_deltas and (
161 current_tree_depth <= max_tree_depth / 2
162 ):
163 # if we want a maze without forks, simply don't add the current coord back to the stack
164 if do_forks and (len(unvisited_neighbors_deltas) > 1):
165 stack.append(current_coord)
167 # choose one of the unvisited neighbors
168 # we dont care about S311 because this isn't security related
169 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) # noqa: S311
171 # add connection
172 dim: int = int(np.argmax(np.abs(delta)))
173 # if positive, down/right from current coord
174 # if negative, up/left from current coord (down/right from neighbor)
175 clist_node: Coord = (
176 current_coord if (delta.sum() > 0) else chosen_neighbor
177 )
178 connection_list[dim, clist_node[0], clist_node[1]] = True
180 # add to visited cells and stack
181 visited_cells.add(tuple(chosen_neighbor))
182 stack.append(chosen_neighbor)
184 # Update current tree depth
185 current_tree_depth += 1
186 else:
187 current_tree_depth -= 1
189 return LatticeMaze(
190 connection_list=connection_list,
191 generation_meta=dict(
192 func_name="gen_dfs",
193 grid_shape=grid_shape_,
194 start_coord=start_coord,
195 n_accessible_cells=int(n_accessible_cells),
196 max_tree_depth=int(max_tree_depth),
197 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
198 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
199 # treated as fully connected even when it is most certainly not, causing solving the maze to break
200 fully_connected=bool(len(visited_cells) == n_total_cells),
201 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
202 ),
203 )
205 @staticmethod
206 def gen_prim(
207 grid_shape: Coord | CoordTup,
208 lattice_dim: int = 2,
209 accessible_cells: float | None = None,
210 max_tree_depth: float | None = None,
211 do_forks: bool = True,
212 start_coord: Coord | None = None,
213 ) -> LatticeMaze:
214 "(broken!) generate a lattice maze using Prim's algorithm"
215 warnings.warn(
216 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
217 )
218 return LatticeMazeGenerators.gen_dfs(
219 grid_shape=grid_shape,
220 lattice_dim=lattice_dim,
221 accessible_cells=accessible_cells,
222 max_tree_depth=max_tree_depth,
223 do_forks=do_forks,
224 start_coord=start_coord,
225 randomized_stack=True,
226 )
228 @staticmethod
229 def gen_wilson(
230 grid_shape: Coord | CoordTup,
231 **kwargs,
232 ) -> LatticeMaze:
233 """Generate a lattice maze using Wilson's algorithm.
235 # Algorithm
236 Wilson's algorithm generates an unbiased (random) maze
237 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
238 acyclic and all cells are part of a unique connected space.
239 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
240 """
241 assert not kwargs, (
242 f"gen_wilson does not take any additional arguments, got {kwargs = }"
243 )
245 grid_shape_: Coord = np.array(grid_shape)
247 # Initialize grid and visited cells
248 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
249 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
251 # Choose a random cell and mark it as visited
252 start_coord: Coord = _random_start_coord(grid_shape_, None)
253 visited[start_coord[0], start_coord[1]] = True
254 del start_coord
256 while not visited.all():
257 # Perform loop-erased random walk from another random cell
259 # Choose walk_start only from unvisited cells
260 unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
261 walk_start: Coord = unvisited_coords[
262 np.random.choice(unvisited_coords.shape[0])
263 ]
265 # Perform the random walk
266 path: list[Coord] = [walk_start]
267 current: Coord = walk_start
269 # exit the loop once the current path hits a visited cell
270 while not visited[current[0], current[1]]:
271 # find a valid neighbor (one always exists on a lattice)
272 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
273 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
275 # Check for loop
276 loop_exit: int | None = None
277 for i, p in enumerate(path):
278 if np.array_equal(next_cell, p):
279 loop_exit = i
280 break
282 # erase the loop, or continue the walk
283 if loop_exit is not None:
284 # this removes everything after and including the loop start
285 path = path[: loop_exit + 1]
286 # reset current cell to end of path
287 current = path[-1]
288 else:
289 path.append(next_cell)
290 current = next_cell
292 # Add the path to the maze
293 for i in range(len(path) - 1):
294 c_1: Coord = path[i]
295 c_2: Coord = path[i + 1]
297 # find the dimension of the connection
298 delta: Coord = c_2 - c_1
299 dim: int = int(np.argmax(np.abs(delta)))
301 # if positive, down/right from current coord
302 # if negative, up/left from current coord (down/right from neighbor)
303 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
304 connection_list[dim, clist_node[0], clist_node[1]] = True
305 visited[c_1[0], c_1[1]] = True
306 # we dont add c_2 because the last c_2 will have already been visited
308 return LatticeMaze(
309 connection_list=connection_list,
310 generation_meta=dict(
311 func_name="gen_wilson",
312 grid_shape=grid_shape_,
313 fully_connected=True,
314 ),
315 )
317 @staticmethod
318 def gen_percolation(
319 grid_shape: Coord | CoordTup,
320 p: float = 0.4,
321 lattice_dim: int = 2,
322 start_coord: Coord | None = None,
323 ) -> LatticeMaze:
324 """generate a lattice maze using simple percolation
326 note that p in the range (0.4, 0.7) gives the most interesting mazes
328 # Arguments
329 - `grid_shape: Coord`: the shape of the grid
330 - `lattice_dim: int`: the dimension of the lattice (default: `2`)
331 - `p: float`: the probability of a cell being accessible (default: `0.5`)
332 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
333 """
334 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018
335 grid_shape_: Coord = np.array(grid_shape)
337 start_coord = _random_start_coord(grid_shape_, start_coord)
339 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
341 connection_list = _fill_edges_with_walls(connection_list)
343 output: LatticeMaze = LatticeMaze(
344 connection_list=connection_list,
345 generation_meta=dict(
346 func_name="gen_percolation",
347 grid_shape=grid_shape_,
348 percolation_p=p,
349 start_coord=start_coord,
350 ),
351 )
353 # generation_meta is sometimes None, but not here since we just made it a dict above
354 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index]
355 start_coord,
356 )
358 return output
360 @staticmethod
361 def gen_dfs_percolation(
362 grid_shape: Coord | CoordTup,
363 p: float = 0.4,
364 lattice_dim: int = 2,
365 accessible_cells: int | None = None,
366 max_tree_depth: int | None = None,
367 start_coord: Coord | None = None,
368 ) -> LatticeMaze:
369 """dfs and then percolation (adds cycles)"""
370 grid_shape_: Coord = np.array(grid_shape)
371 start_coord = _random_start_coord(grid_shape_, start_coord)
373 # generate initial maze via dfs
374 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
375 grid_shape=grid_shape_,
376 lattice_dim=lattice_dim,
377 accessible_cells=accessible_cells,
378 max_tree_depth=max_tree_depth,
379 start_coord=start_coord,
380 )
382 # percolate
383 connection_list_perc: np.ndarray = (
384 np.random.rand(*maze.connection_list.shape) < p
385 )
386 connection_list_perc = _fill_edges_with_walls(connection_list_perc)
388 maze.__dict__["connection_list"] = np.logical_or(
389 maze.connection_list,
390 connection_list_perc,
391 )
393 # generation_meta is sometimes None, but not here since we just made it a dict above
394 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index]
395 maze.generation_meta["percolation_p"] = p # type: ignore[index]
396 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index]
397 start_coord,
398 )
400 return maze
403# cant automatically populate this because it messes with pickling :(
404GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = {
405 "gen_dfs": LatticeMazeGenerators.gen_dfs,
406 # TYPING: error: Dict entry 1 has incompatible type
407 # "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]";
408 # expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]" [dict-item]
409 # gen_wilson takes no kwargs and we check that the kwargs are empty
410 # but mypy doesnt like this, `Any` != `KwArg(Any)`
411 "gen_wilson": LatticeMazeGenerators.gen_wilson, # type: ignore[dict-item]
412 "gen_percolation": LatticeMazeGenerators.gen_percolation,
413 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation,
414 "gen_prim": LatticeMazeGenerators.gen_prim,
415}
416"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`"
418_GENERATORS_PERCOLATED: list[str] = [
419 "gen_percolation",
420 "gen_dfs_percolation",
421]
422"""list of generator names that generate percolated mazes
423we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail
424this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array`
425"""
428def get_maze_with_solution(
429 gen_name: str,
430 grid_shape: Coord | CoordTup,
431 maze_ctor_kwargs: dict | None = None,
432) -> SolvedMaze:
433 "helper function to get a maze already with a solution"
434 if maze_ctor_kwargs is None:
435 maze_ctor_kwargs = dict()
436 # TYPING: error: Too few arguments [call-arg]
437 # not sure why this is happening -- doesnt recognize the kwargs?
438 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg]
439 solution: CoordArray = np.array(maze.generate_random_path())
440 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)