maze_dataset.generation.generators
generation functions have signature (grid_shape: Coord, **kwargs) -> LatticeMaze
and are methods in LatticeMazeGenerators
1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`""" 2 3import random 4import warnings 5from typing import Any, Callable 6 7import numpy as np 8from jaxtyping import Bool 9 10from maze_dataset.constants import CoordArray, CoordTup 11from maze_dataset.generation.seed import GLOBAL_SEED 12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze 13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls 14 15numpy_rng = np.random.default_rng(GLOBAL_SEED) 16random.seed(GLOBAL_SEED) 17 18 19def _random_start_coord( 20 grid_shape: Coord, 21 start_coord: Coord | CoordTup | None, 22) -> Coord: 23 "picking a random start coord within the bounds of `grid_shape` if none is provided" 24 start_coord_: Coord 25 if start_coord is None: 26 start_coord_ = np.random.randint( 27 0, # lower bound 28 np.maximum(grid_shape - 1, 1), # upper bound (at least 1) 29 size=len(grid_shape), # dimensionality 30 ) 31 else: 32 start_coord_ = np.array(start_coord) 33 34 return start_coord_ 35 36 37def get_neighbors_in_bounds( 38 coord: Coord, 39 grid_shape: Coord, 40) -> CoordArray: 41 "get all neighbors of a coordinate that are within the bounds of the grid" 42 # get all neighbors 43 neighbors: CoordArray = coord + NEIGHBORS_MASK 44 45 # filter neighbors by being within grid bounds 46 neighbors_in_bounds: CoordArray = neighbors[ 47 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1) 48 ] 49 50 return neighbors_in_bounds 51 52 53class LatticeMazeGenerators: 54 """namespace for lattice maze generation algorithms""" 55 56 @staticmethod 57 def gen_dfs( 58 grid_shape: Coord | CoordTup, 59 lattice_dim: int = 2, 60 accessible_cells: float | None = None, 61 max_tree_depth: float | None = None, 62 do_forks: bool = True, 63 randomized_stack: bool = False, 64 start_coord: Coord | None = None, 65 ) -> LatticeMaze: 66 """generate a lattice maze using depth first search, iterative 67 68 # Arguments 69 - `grid_shape: Coord`: the shape of the grid 70 - `lattice_dim: int`: the dimension of the lattice 71 (default: `2`) 72 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 73 (default: `None`) 74 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 75 (default: `None`) 76 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 77 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 78 79 # algorithm 80 1. Choose the initial cell, mark it as visited and push it to the stack 81 2. While the stack is not empty 82 1. Pop a cell from the stack and make it a current cell 83 2. If the current cell has any neighbours which have not been visited 84 1. Push the current cell to the stack 85 2. Choose one of the unvisited neighbours 86 3. Remove the wall between the current cell and the chosen cell 87 4. Mark the chosen cell as visited and push it to the stack 88 """ 89 # Default values if no constraints have been passed 90 grid_shape_: Coord = np.array(grid_shape) 91 n_total_cells: int = int(np.prod(grid_shape_)) 92 93 n_accessible_cells: int 94 if accessible_cells is None: 95 n_accessible_cells = n_total_cells 96 elif isinstance(accessible_cells, float): 97 assert accessible_cells <= 1, ( 98 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 99 ) 100 101 n_accessible_cells = int(accessible_cells * n_total_cells) 102 else: 103 assert isinstance(accessible_cells, int) 104 n_accessible_cells = accessible_cells 105 106 if max_tree_depth is None: 107 max_tree_depth = ( 108 2 * n_total_cells 109 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 110 elif isinstance(max_tree_depth, float): 111 assert max_tree_depth <= 1, ( 112 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 113 ) 114 115 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 116 117 # choose a random start coord 118 start_coord = _random_start_coord(grid_shape_, start_coord) 119 120 # initialize the maze with no connections 121 connection_list: ConnectionList = np.zeros( 122 (lattice_dim, grid_shape_[0], grid_shape_[1]), 123 dtype=np.bool_, 124 ) 125 126 # initialize the stack with the target coord 127 visited_cells: set[tuple[int, int]] = set() 128 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 129 stack: list[Coord] = [start_coord] 130 131 # initialize tree_depth_counter 132 current_tree_depth: int = 1 133 134 # loop until the stack is empty or n_connected_cells is reached 135 while stack and (len(visited_cells) < n_accessible_cells): 136 # get the current coord from the stack 137 current_coord: Coord 138 if randomized_stack: 139 # we dont care about S311 because this isnt security related 140 current_coord = stack.pop(random.randint(0, len(stack) - 1)) # noqa: S311 141 else: 142 current_coord = stack.pop() 143 144 # filter neighbors by being within grid bounds and being unvisited 145 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 146 (neighbor, delta) 147 for neighbor, delta in zip( 148 current_coord + NEIGHBORS_MASK, 149 NEIGHBORS_MASK, 150 strict=False, 151 ) 152 if ( 153 (tuple(neighbor) not in visited_cells) 154 and (0 <= neighbor[0] < grid_shape_[0]) 155 and (0 <= neighbor[1] < grid_shape_[1]) 156 ) 157 ] 158 159 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 160 if unvisited_neighbors_deltas and ( 161 current_tree_depth <= max_tree_depth / 2 162 ): 163 # if we want a maze without forks, simply don't add the current coord back to the stack 164 if do_forks and (len(unvisited_neighbors_deltas) > 1): 165 stack.append(current_coord) 166 167 # choose one of the unvisited neighbors 168 # we dont care about S311 because this isn't security related 169 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) # noqa: S311 170 171 # add connection 172 dim: int = int(np.argmax(np.abs(delta))) 173 # if positive, down/right from current coord 174 # if negative, up/left from current coord (down/right from neighbor) 175 clist_node: Coord = ( 176 current_coord if (delta.sum() > 0) else chosen_neighbor 177 ) 178 connection_list[dim, clist_node[0], clist_node[1]] = True 179 180 # add to visited cells and stack 181 visited_cells.add(tuple(chosen_neighbor)) 182 stack.append(chosen_neighbor) 183 184 # Update current tree depth 185 current_tree_depth += 1 186 else: 187 current_tree_depth -= 1 188 189 return LatticeMaze( 190 connection_list=connection_list, 191 generation_meta=dict( 192 func_name="gen_dfs", 193 grid_shape=grid_shape_, 194 start_coord=start_coord, 195 n_accessible_cells=int(n_accessible_cells), 196 max_tree_depth=int(max_tree_depth), 197 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 198 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 199 # treated as fully connected even when it is most certainly not, causing solving the maze to break 200 fully_connected=bool(len(visited_cells) == n_total_cells), 201 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 202 ), 203 ) 204 205 @staticmethod 206 def gen_prim( 207 grid_shape: Coord | CoordTup, 208 lattice_dim: int = 2, 209 accessible_cells: float | None = None, 210 max_tree_depth: float | None = None, 211 do_forks: bool = True, 212 start_coord: Coord | None = None, 213 ) -> LatticeMaze: 214 "(broken!) generate a lattice maze using Prim's algorithm" 215 warnings.warn( 216 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 217 ) 218 return LatticeMazeGenerators.gen_dfs( 219 grid_shape=grid_shape, 220 lattice_dim=lattice_dim, 221 accessible_cells=accessible_cells, 222 max_tree_depth=max_tree_depth, 223 do_forks=do_forks, 224 start_coord=start_coord, 225 randomized_stack=True, 226 ) 227 228 @staticmethod 229 def gen_wilson( 230 grid_shape: Coord | CoordTup, 231 **kwargs, 232 ) -> LatticeMaze: 233 """Generate a lattice maze using Wilson's algorithm. 234 235 # Algorithm 236 Wilson's algorithm generates an unbiased (random) maze 237 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 238 acyclic and all cells are part of a unique connected space. 239 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 240 """ 241 assert not kwargs, ( 242 f"gen_wilson does not take any additional arguments, got {kwargs = }" 243 ) 244 245 grid_shape_: Coord = np.array(grid_shape) 246 247 # Initialize grid and visited cells 248 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 249 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 250 251 # Choose a random cell and mark it as visited 252 start_coord: Coord = _random_start_coord(grid_shape_, None) 253 visited[start_coord[0], start_coord[1]] = True 254 del start_coord 255 256 while not visited.all(): 257 # Perform loop-erased random walk from another random cell 258 259 # Choose walk_start only from unvisited cells 260 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 261 walk_start: Coord = unvisited_coords[ 262 np.random.choice(unvisited_coords.shape[0]) 263 ] 264 265 # Perform the random walk 266 path: list[Coord] = [walk_start] 267 current: Coord = walk_start 268 269 # exit the loop once the current path hits a visited cell 270 while not visited[current[0], current[1]]: 271 # find a valid neighbor (one always exists on a lattice) 272 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 273 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 274 275 # Check for loop 276 loop_exit: int | None = None 277 for i, p in enumerate(path): 278 if np.array_equal(next_cell, p): 279 loop_exit = i 280 break 281 282 # erase the loop, or continue the walk 283 if loop_exit is not None: 284 # this removes everything after and including the loop start 285 path = path[: loop_exit + 1] 286 # reset current cell to end of path 287 current = path[-1] 288 else: 289 path.append(next_cell) 290 current = next_cell 291 292 # Add the path to the maze 293 for i in range(len(path) - 1): 294 c_1: Coord = path[i] 295 c_2: Coord = path[i + 1] 296 297 # find the dimension of the connection 298 delta: Coord = c_2 - c_1 299 dim: int = int(np.argmax(np.abs(delta))) 300 301 # if positive, down/right from current coord 302 # if negative, up/left from current coord (down/right from neighbor) 303 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 304 connection_list[dim, clist_node[0], clist_node[1]] = True 305 visited[c_1[0], c_1[1]] = True 306 # we dont add c_2 because the last c_2 will have already been visited 307 308 return LatticeMaze( 309 connection_list=connection_list, 310 generation_meta=dict( 311 func_name="gen_wilson", 312 grid_shape=grid_shape_, 313 fully_connected=True, 314 ), 315 ) 316 317 @staticmethod 318 def gen_percolation( 319 grid_shape: Coord | CoordTup, 320 p: float = 0.4, 321 lattice_dim: int = 2, 322 start_coord: Coord | None = None, 323 ) -> LatticeMaze: 324 """generate a lattice maze using simple percolation 325 326 note that p in the range (0.4, 0.7) gives the most interesting mazes 327 328 # Arguments 329 - `grid_shape: Coord`: the shape of the grid 330 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 331 - `p: float`: the probability of a cell being accessible (default: `0.5`) 332 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 333 """ 334 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 335 grid_shape_: Coord = np.array(grid_shape) 336 337 start_coord = _random_start_coord(grid_shape_, start_coord) 338 339 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 340 341 connection_list = _fill_edges_with_walls(connection_list) 342 343 output: LatticeMaze = LatticeMaze( 344 connection_list=connection_list, 345 generation_meta=dict( 346 func_name="gen_percolation", 347 grid_shape=grid_shape_, 348 percolation_p=p, 349 start_coord=start_coord, 350 ), 351 ) 352 353 # generation_meta is sometimes None, but not here since we just made it a dict above 354 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 355 start_coord, 356 ) 357 358 return output 359 360 @staticmethod 361 def gen_dfs_percolation( 362 grid_shape: Coord | CoordTup, 363 p: float = 0.4, 364 lattice_dim: int = 2, 365 accessible_cells: int | None = None, 366 max_tree_depth: int | None = None, 367 start_coord: Coord | None = None, 368 ) -> LatticeMaze: 369 """dfs and then percolation (adds cycles)""" 370 grid_shape_: Coord = np.array(grid_shape) 371 start_coord = _random_start_coord(grid_shape_, start_coord) 372 373 # generate initial maze via dfs 374 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 375 grid_shape=grid_shape_, 376 lattice_dim=lattice_dim, 377 accessible_cells=accessible_cells, 378 max_tree_depth=max_tree_depth, 379 start_coord=start_coord, 380 ) 381 382 # percolate 383 connection_list_perc: np.ndarray = ( 384 np.random.rand(*maze.connection_list.shape) < p 385 ) 386 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 387 388 maze.__dict__["connection_list"] = np.logical_or( 389 maze.connection_list, 390 connection_list_perc, 391 ) 392 393 # generation_meta is sometimes None, but not here since we just made it a dict above 394 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 395 maze.generation_meta["percolation_p"] = p # type: ignore[index] 396 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 397 start_coord, 398 ) 399 400 return maze 401 402 403# cant automatically populate this because it messes with pickling :( 404GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = { 405 "gen_dfs": LatticeMazeGenerators.gen_dfs, 406 # TYPING: error: Dict entry 1 has incompatible type 407 # "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]"; 408 # expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]" [dict-item] 409 # gen_wilson takes no kwargs and we check that the kwargs are empty 410 # but mypy doesnt like this, `Any` != `KwArg(Any)` 411 "gen_wilson": LatticeMazeGenerators.gen_wilson, # type: ignore[dict-item] 412 "gen_percolation": LatticeMazeGenerators.gen_percolation, 413 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation, 414 "gen_prim": LatticeMazeGenerators.gen_prim, 415} 416"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`" 417 418_GENERATORS_PERCOLATED: list[str] = [ 419 "gen_percolation", 420 "gen_dfs_percolation", 421] 422"""list of generator names that generate percolated mazes 423we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail 424this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array` 425""" 426 427 428def get_maze_with_solution( 429 gen_name: str, 430 grid_shape: Coord | CoordTup, 431 maze_ctor_kwargs: dict | None = None, 432) -> SolvedMaze: 433 "helper function to get a maze already with a solution" 434 if maze_ctor_kwargs is None: 435 maze_ctor_kwargs = dict() 436 # TYPING: error: Too few arguments [call-arg] 437 # not sure why this is happening -- doesnt recognize the kwargs? 438 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg] 439 solution: CoordArray = np.array(maze.generate_random_path()) 440 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)
38def get_neighbors_in_bounds( 39 coord: Coord, 40 grid_shape: Coord, 41) -> CoordArray: 42 "get all neighbors of a coordinate that are within the bounds of the grid" 43 # get all neighbors 44 neighbors: CoordArray = coord + NEIGHBORS_MASK 45 46 # filter neighbors by being within grid bounds 47 neighbors_in_bounds: CoordArray = neighbors[ 48 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1) 49 ] 50 51 return neighbors_in_bounds
get all neighbors of a coordinate that are within the bounds of the grid
54class LatticeMazeGenerators: 55 """namespace for lattice maze generation algorithms""" 56 57 @staticmethod 58 def gen_dfs( 59 grid_shape: Coord | CoordTup, 60 lattice_dim: int = 2, 61 accessible_cells: float | None = None, 62 max_tree_depth: float | None = None, 63 do_forks: bool = True, 64 randomized_stack: bool = False, 65 start_coord: Coord | None = None, 66 ) -> LatticeMaze: 67 """generate a lattice maze using depth first search, iterative 68 69 # Arguments 70 - `grid_shape: Coord`: the shape of the grid 71 - `lattice_dim: int`: the dimension of the lattice 72 (default: `2`) 73 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 74 (default: `None`) 75 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 76 (default: `None`) 77 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 78 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 79 80 # algorithm 81 1. Choose the initial cell, mark it as visited and push it to the stack 82 2. While the stack is not empty 83 1. Pop a cell from the stack and make it a current cell 84 2. If the current cell has any neighbours which have not been visited 85 1. Push the current cell to the stack 86 2. Choose one of the unvisited neighbours 87 3. Remove the wall between the current cell and the chosen cell 88 4. Mark the chosen cell as visited and push it to the stack 89 """ 90 # Default values if no constraints have been passed 91 grid_shape_: Coord = np.array(grid_shape) 92 n_total_cells: int = int(np.prod(grid_shape_)) 93 94 n_accessible_cells: int 95 if accessible_cells is None: 96 n_accessible_cells = n_total_cells 97 elif isinstance(accessible_cells, float): 98 assert accessible_cells <= 1, ( 99 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 100 ) 101 102 n_accessible_cells = int(accessible_cells * n_total_cells) 103 else: 104 assert isinstance(accessible_cells, int) 105 n_accessible_cells = accessible_cells 106 107 if max_tree_depth is None: 108 max_tree_depth = ( 109 2 * n_total_cells 110 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 111 elif isinstance(max_tree_depth, float): 112 assert max_tree_depth <= 1, ( 113 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 114 ) 115 116 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 117 118 # choose a random start coord 119 start_coord = _random_start_coord(grid_shape_, start_coord) 120 121 # initialize the maze with no connections 122 connection_list: ConnectionList = np.zeros( 123 (lattice_dim, grid_shape_[0], grid_shape_[1]), 124 dtype=np.bool_, 125 ) 126 127 # initialize the stack with the target coord 128 visited_cells: set[tuple[int, int]] = set() 129 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 130 stack: list[Coord] = [start_coord] 131 132 # initialize tree_depth_counter 133 current_tree_depth: int = 1 134 135 # loop until the stack is empty or n_connected_cells is reached 136 while stack and (len(visited_cells) < n_accessible_cells): 137 # get the current coord from the stack 138 current_coord: Coord 139 if randomized_stack: 140 # we dont care about S311 because this isnt security related 141 current_coord = stack.pop(random.randint(0, len(stack) - 1)) # noqa: S311 142 else: 143 current_coord = stack.pop() 144 145 # filter neighbors by being within grid bounds and being unvisited 146 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 147 (neighbor, delta) 148 for neighbor, delta in zip( 149 current_coord + NEIGHBORS_MASK, 150 NEIGHBORS_MASK, 151 strict=False, 152 ) 153 if ( 154 (tuple(neighbor) not in visited_cells) 155 and (0 <= neighbor[0] < grid_shape_[0]) 156 and (0 <= neighbor[1] < grid_shape_[1]) 157 ) 158 ] 159 160 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 161 if unvisited_neighbors_deltas and ( 162 current_tree_depth <= max_tree_depth / 2 163 ): 164 # if we want a maze without forks, simply don't add the current coord back to the stack 165 if do_forks and (len(unvisited_neighbors_deltas) > 1): 166 stack.append(current_coord) 167 168 # choose one of the unvisited neighbors 169 # we dont care about S311 because this isn't security related 170 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) # noqa: S311 171 172 # add connection 173 dim: int = int(np.argmax(np.abs(delta))) 174 # if positive, down/right from current coord 175 # if negative, up/left from current coord (down/right from neighbor) 176 clist_node: Coord = ( 177 current_coord if (delta.sum() > 0) else chosen_neighbor 178 ) 179 connection_list[dim, clist_node[0], clist_node[1]] = True 180 181 # add to visited cells and stack 182 visited_cells.add(tuple(chosen_neighbor)) 183 stack.append(chosen_neighbor) 184 185 # Update current tree depth 186 current_tree_depth += 1 187 else: 188 current_tree_depth -= 1 189 190 return LatticeMaze( 191 connection_list=connection_list, 192 generation_meta=dict( 193 func_name="gen_dfs", 194 grid_shape=grid_shape_, 195 start_coord=start_coord, 196 n_accessible_cells=int(n_accessible_cells), 197 max_tree_depth=int(max_tree_depth), 198 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 199 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 200 # treated as fully connected even when it is most certainly not, causing solving the maze to break 201 fully_connected=bool(len(visited_cells) == n_total_cells), 202 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 203 ), 204 ) 205 206 @staticmethod 207 def gen_prim( 208 grid_shape: Coord | CoordTup, 209 lattice_dim: int = 2, 210 accessible_cells: float | None = None, 211 max_tree_depth: float | None = None, 212 do_forks: bool = True, 213 start_coord: Coord | None = None, 214 ) -> LatticeMaze: 215 "(broken!) generate a lattice maze using Prim's algorithm" 216 warnings.warn( 217 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 218 ) 219 return LatticeMazeGenerators.gen_dfs( 220 grid_shape=grid_shape, 221 lattice_dim=lattice_dim, 222 accessible_cells=accessible_cells, 223 max_tree_depth=max_tree_depth, 224 do_forks=do_forks, 225 start_coord=start_coord, 226 randomized_stack=True, 227 ) 228 229 @staticmethod 230 def gen_wilson( 231 grid_shape: Coord | CoordTup, 232 **kwargs, 233 ) -> LatticeMaze: 234 """Generate a lattice maze using Wilson's algorithm. 235 236 # Algorithm 237 Wilson's algorithm generates an unbiased (random) maze 238 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 239 acyclic and all cells are part of a unique connected space. 240 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 241 """ 242 assert not kwargs, ( 243 f"gen_wilson does not take any additional arguments, got {kwargs = }" 244 ) 245 246 grid_shape_: Coord = np.array(grid_shape) 247 248 # Initialize grid and visited cells 249 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 250 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 251 252 # Choose a random cell and mark it as visited 253 start_coord: Coord = _random_start_coord(grid_shape_, None) 254 visited[start_coord[0], start_coord[1]] = True 255 del start_coord 256 257 while not visited.all(): 258 # Perform loop-erased random walk from another random cell 259 260 # Choose walk_start only from unvisited cells 261 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 262 walk_start: Coord = unvisited_coords[ 263 np.random.choice(unvisited_coords.shape[0]) 264 ] 265 266 # Perform the random walk 267 path: list[Coord] = [walk_start] 268 current: Coord = walk_start 269 270 # exit the loop once the current path hits a visited cell 271 while not visited[current[0], current[1]]: 272 # find a valid neighbor (one always exists on a lattice) 273 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 274 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 275 276 # Check for loop 277 loop_exit: int | None = None 278 for i, p in enumerate(path): 279 if np.array_equal(next_cell, p): 280 loop_exit = i 281 break 282 283 # erase the loop, or continue the walk 284 if loop_exit is not None: 285 # this removes everything after and including the loop start 286 path = path[: loop_exit + 1] 287 # reset current cell to end of path 288 current = path[-1] 289 else: 290 path.append(next_cell) 291 current = next_cell 292 293 # Add the path to the maze 294 for i in range(len(path) - 1): 295 c_1: Coord = path[i] 296 c_2: Coord = path[i + 1] 297 298 # find the dimension of the connection 299 delta: Coord = c_2 - c_1 300 dim: int = int(np.argmax(np.abs(delta))) 301 302 # if positive, down/right from current coord 303 # if negative, up/left from current coord (down/right from neighbor) 304 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 305 connection_list[dim, clist_node[0], clist_node[1]] = True 306 visited[c_1[0], c_1[1]] = True 307 # we dont add c_2 because the last c_2 will have already been visited 308 309 return LatticeMaze( 310 connection_list=connection_list, 311 generation_meta=dict( 312 func_name="gen_wilson", 313 grid_shape=grid_shape_, 314 fully_connected=True, 315 ), 316 ) 317 318 @staticmethod 319 def gen_percolation( 320 grid_shape: Coord | CoordTup, 321 p: float = 0.4, 322 lattice_dim: int = 2, 323 start_coord: Coord | None = None, 324 ) -> LatticeMaze: 325 """generate a lattice maze using simple percolation 326 327 note that p in the range (0.4, 0.7) gives the most interesting mazes 328 329 # Arguments 330 - `grid_shape: Coord`: the shape of the grid 331 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 332 - `p: float`: the probability of a cell being accessible (default: `0.5`) 333 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 334 """ 335 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 336 grid_shape_: Coord = np.array(grid_shape) 337 338 start_coord = _random_start_coord(grid_shape_, start_coord) 339 340 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 341 342 connection_list = _fill_edges_with_walls(connection_list) 343 344 output: LatticeMaze = LatticeMaze( 345 connection_list=connection_list, 346 generation_meta=dict( 347 func_name="gen_percolation", 348 grid_shape=grid_shape_, 349 percolation_p=p, 350 start_coord=start_coord, 351 ), 352 ) 353 354 # generation_meta is sometimes None, but not here since we just made it a dict above 355 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 356 start_coord, 357 ) 358 359 return output 360 361 @staticmethod 362 def gen_dfs_percolation( 363 grid_shape: Coord | CoordTup, 364 p: float = 0.4, 365 lattice_dim: int = 2, 366 accessible_cells: int | None = None, 367 max_tree_depth: int | None = None, 368 start_coord: Coord | None = None, 369 ) -> LatticeMaze: 370 """dfs and then percolation (adds cycles)""" 371 grid_shape_: Coord = np.array(grid_shape) 372 start_coord = _random_start_coord(grid_shape_, start_coord) 373 374 # generate initial maze via dfs 375 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 376 grid_shape=grid_shape_, 377 lattice_dim=lattice_dim, 378 accessible_cells=accessible_cells, 379 max_tree_depth=max_tree_depth, 380 start_coord=start_coord, 381 ) 382 383 # percolate 384 connection_list_perc: np.ndarray = ( 385 np.random.rand(*maze.connection_list.shape) < p 386 ) 387 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 388 389 maze.__dict__["connection_list"] = np.logical_or( 390 maze.connection_list, 391 connection_list_perc, 392 ) 393 394 # generation_meta is sometimes None, but not here since we just made it a dict above 395 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 396 maze.generation_meta["percolation_p"] = p # type: ignore[index] 397 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 398 start_coord, 399 ) 400 401 return maze
namespace for lattice maze generation algorithms
57 @staticmethod 58 def gen_dfs( 59 grid_shape: Coord | CoordTup, 60 lattice_dim: int = 2, 61 accessible_cells: float | None = None, 62 max_tree_depth: float | None = None, 63 do_forks: bool = True, 64 randomized_stack: bool = False, 65 start_coord: Coord | None = None, 66 ) -> LatticeMaze: 67 """generate a lattice maze using depth first search, iterative 68 69 # Arguments 70 - `grid_shape: Coord`: the shape of the grid 71 - `lattice_dim: int`: the dimension of the lattice 72 (default: `2`) 73 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 74 (default: `None`) 75 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 76 (default: `None`) 77 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 78 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 79 80 # algorithm 81 1. Choose the initial cell, mark it as visited and push it to the stack 82 2. While the stack is not empty 83 1. Pop a cell from the stack and make it a current cell 84 2. If the current cell has any neighbours which have not been visited 85 1. Push the current cell to the stack 86 2. Choose one of the unvisited neighbours 87 3. Remove the wall between the current cell and the chosen cell 88 4. Mark the chosen cell as visited and push it to the stack 89 """ 90 # Default values if no constraints have been passed 91 grid_shape_: Coord = np.array(grid_shape) 92 n_total_cells: int = int(np.prod(grid_shape_)) 93 94 n_accessible_cells: int 95 if accessible_cells is None: 96 n_accessible_cells = n_total_cells 97 elif isinstance(accessible_cells, float): 98 assert accessible_cells <= 1, ( 99 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 100 ) 101 102 n_accessible_cells = int(accessible_cells * n_total_cells) 103 else: 104 assert isinstance(accessible_cells, int) 105 n_accessible_cells = accessible_cells 106 107 if max_tree_depth is None: 108 max_tree_depth = ( 109 2 * n_total_cells 110 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 111 elif isinstance(max_tree_depth, float): 112 assert max_tree_depth <= 1, ( 113 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 114 ) 115 116 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 117 118 # choose a random start coord 119 start_coord = _random_start_coord(grid_shape_, start_coord) 120 121 # initialize the maze with no connections 122 connection_list: ConnectionList = np.zeros( 123 (lattice_dim, grid_shape_[0], grid_shape_[1]), 124 dtype=np.bool_, 125 ) 126 127 # initialize the stack with the target coord 128 visited_cells: set[tuple[int, int]] = set() 129 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 130 stack: list[Coord] = [start_coord] 131 132 # initialize tree_depth_counter 133 current_tree_depth: int = 1 134 135 # loop until the stack is empty or n_connected_cells is reached 136 while stack and (len(visited_cells) < n_accessible_cells): 137 # get the current coord from the stack 138 current_coord: Coord 139 if randomized_stack: 140 # we dont care about S311 because this isnt security related 141 current_coord = stack.pop(random.randint(0, len(stack) - 1)) # noqa: S311 142 else: 143 current_coord = stack.pop() 144 145 # filter neighbors by being within grid bounds and being unvisited 146 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 147 (neighbor, delta) 148 for neighbor, delta in zip( 149 current_coord + NEIGHBORS_MASK, 150 NEIGHBORS_MASK, 151 strict=False, 152 ) 153 if ( 154 (tuple(neighbor) not in visited_cells) 155 and (0 <= neighbor[0] < grid_shape_[0]) 156 and (0 <= neighbor[1] < grid_shape_[1]) 157 ) 158 ] 159 160 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 161 if unvisited_neighbors_deltas and ( 162 current_tree_depth <= max_tree_depth / 2 163 ): 164 # if we want a maze without forks, simply don't add the current coord back to the stack 165 if do_forks and (len(unvisited_neighbors_deltas) > 1): 166 stack.append(current_coord) 167 168 # choose one of the unvisited neighbors 169 # we dont care about S311 because this isn't security related 170 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) # noqa: S311 171 172 # add connection 173 dim: int = int(np.argmax(np.abs(delta))) 174 # if positive, down/right from current coord 175 # if negative, up/left from current coord (down/right from neighbor) 176 clist_node: Coord = ( 177 current_coord if (delta.sum() > 0) else chosen_neighbor 178 ) 179 connection_list[dim, clist_node[0], clist_node[1]] = True 180 181 # add to visited cells and stack 182 visited_cells.add(tuple(chosen_neighbor)) 183 stack.append(chosen_neighbor) 184 185 # Update current tree depth 186 current_tree_depth += 1 187 else: 188 current_tree_depth -= 1 189 190 return LatticeMaze( 191 connection_list=connection_list, 192 generation_meta=dict( 193 func_name="gen_dfs", 194 grid_shape=grid_shape_, 195 start_coord=start_coord, 196 n_accessible_cells=int(n_accessible_cells), 197 max_tree_depth=int(max_tree_depth), 198 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 199 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 200 # treated as fully connected even when it is most certainly not, causing solving the maze to break 201 fully_connected=bool(len(visited_cells) == n_total_cells), 202 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 203 ), 204 )
generate a lattice maze using depth first search, iterative
Arguments
grid_shape: Coord
: the shape of the gridlattice_dim: int
: the dimension of the lattice (default:2
)accessible_cells: int | float |None
: the number of accessible cells in the maze. IfNone
, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default:None
)max_tree_depth: int | float | None
: the maximum depth of the tree. IfNone
, defaults to2 * accessible_cells
. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default:None
)do_forks: bool
: whether to allow forks in the maze. IfFalse
, the maze will be have no forks and will be a simple hallway.start_coord: Coord | None
: the starting coordinate of the generation algorithm. IfNone
, defaults to a random coordinate.
algorithm
- Choose the initial cell, mark it as visited and push it to the stack
- While the stack is not empty
- Pop a cell from the stack and make it a current cell
- If the current cell has any neighbours which have not been visited
- Push the current cell to the stack
- Choose one of the unvisited neighbours
- Remove the wall between the current cell and the chosen cell
- Mark the chosen cell as visited and push it to the stack
206 @staticmethod 207 def gen_prim( 208 grid_shape: Coord | CoordTup, 209 lattice_dim: int = 2, 210 accessible_cells: float | None = None, 211 max_tree_depth: float | None = None, 212 do_forks: bool = True, 213 start_coord: Coord | None = None, 214 ) -> LatticeMaze: 215 "(broken!) generate a lattice maze using Prim's algorithm" 216 warnings.warn( 217 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 218 ) 219 return LatticeMazeGenerators.gen_dfs( 220 grid_shape=grid_shape, 221 lattice_dim=lattice_dim, 222 accessible_cells=accessible_cells, 223 max_tree_depth=max_tree_depth, 224 do_forks=do_forks, 225 start_coord=start_coord, 226 randomized_stack=True, 227 )
(broken!) generate a lattice maze using Prim's algorithm
229 @staticmethod 230 def gen_wilson( 231 grid_shape: Coord | CoordTup, 232 **kwargs, 233 ) -> LatticeMaze: 234 """Generate a lattice maze using Wilson's algorithm. 235 236 # Algorithm 237 Wilson's algorithm generates an unbiased (random) maze 238 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 239 acyclic and all cells are part of a unique connected space. 240 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 241 """ 242 assert not kwargs, ( 243 f"gen_wilson does not take any additional arguments, got {kwargs = }" 244 ) 245 246 grid_shape_: Coord = np.array(grid_shape) 247 248 # Initialize grid and visited cells 249 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 250 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 251 252 # Choose a random cell and mark it as visited 253 start_coord: Coord = _random_start_coord(grid_shape_, None) 254 visited[start_coord[0], start_coord[1]] = True 255 del start_coord 256 257 while not visited.all(): 258 # Perform loop-erased random walk from another random cell 259 260 # Choose walk_start only from unvisited cells 261 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 262 walk_start: Coord = unvisited_coords[ 263 np.random.choice(unvisited_coords.shape[0]) 264 ] 265 266 # Perform the random walk 267 path: list[Coord] = [walk_start] 268 current: Coord = walk_start 269 270 # exit the loop once the current path hits a visited cell 271 while not visited[current[0], current[1]]: 272 # find a valid neighbor (one always exists on a lattice) 273 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 274 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 275 276 # Check for loop 277 loop_exit: int | None = None 278 for i, p in enumerate(path): 279 if np.array_equal(next_cell, p): 280 loop_exit = i 281 break 282 283 # erase the loop, or continue the walk 284 if loop_exit is not None: 285 # this removes everything after and including the loop start 286 path = path[: loop_exit + 1] 287 # reset current cell to end of path 288 current = path[-1] 289 else: 290 path.append(next_cell) 291 current = next_cell 292 293 # Add the path to the maze 294 for i in range(len(path) - 1): 295 c_1: Coord = path[i] 296 c_2: Coord = path[i + 1] 297 298 # find the dimension of the connection 299 delta: Coord = c_2 - c_1 300 dim: int = int(np.argmax(np.abs(delta))) 301 302 # if positive, down/right from current coord 303 # if negative, up/left from current coord (down/right from neighbor) 304 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 305 connection_list[dim, clist_node[0], clist_node[1]] = True 306 visited[c_1[0], c_1[1]] = True 307 # we dont add c_2 because the last c_2 will have already been visited 308 309 return LatticeMaze( 310 connection_list=connection_list, 311 generation_meta=dict( 312 func_name="gen_wilson", 313 grid_shape=grid_shape_, 314 fully_connected=True, 315 ), 316 )
Generate a lattice maze using Wilson's algorithm.
Algorithm
Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
318 @staticmethod 319 def gen_percolation( 320 grid_shape: Coord | CoordTup, 321 p: float = 0.4, 322 lattice_dim: int = 2, 323 start_coord: Coord | None = None, 324 ) -> LatticeMaze: 325 """generate a lattice maze using simple percolation 326 327 note that p in the range (0.4, 0.7) gives the most interesting mazes 328 329 # Arguments 330 - `grid_shape: Coord`: the shape of the grid 331 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 332 - `p: float`: the probability of a cell being accessible (default: `0.5`) 333 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 334 """ 335 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 336 grid_shape_: Coord = np.array(grid_shape) 337 338 start_coord = _random_start_coord(grid_shape_, start_coord) 339 340 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 341 342 connection_list = _fill_edges_with_walls(connection_list) 343 344 output: LatticeMaze = LatticeMaze( 345 connection_list=connection_list, 346 generation_meta=dict( 347 func_name="gen_percolation", 348 grid_shape=grid_shape_, 349 percolation_p=p, 350 start_coord=start_coord, 351 ), 352 ) 353 354 # generation_meta is sometimes None, but not here since we just made it a dict above 355 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 356 start_coord, 357 ) 358 359 return output
generate a lattice maze using simple percolation
note that p in the range (0.4, 0.7) gives the most interesting mazes
Arguments
grid_shape: Coord
: the shape of the gridlattice_dim: int
: the dimension of the lattice (default:2
)p: float
: the probability of a cell being accessible (default:0.5
)start_coord: Coord | None
: the starting coordinate for the connected component (default:None
will give a random start)
361 @staticmethod 362 def gen_dfs_percolation( 363 grid_shape: Coord | CoordTup, 364 p: float = 0.4, 365 lattice_dim: int = 2, 366 accessible_cells: int | None = None, 367 max_tree_depth: int | None = None, 368 start_coord: Coord | None = None, 369 ) -> LatticeMaze: 370 """dfs and then percolation (adds cycles)""" 371 grid_shape_: Coord = np.array(grid_shape) 372 start_coord = _random_start_coord(grid_shape_, start_coord) 373 374 # generate initial maze via dfs 375 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 376 grid_shape=grid_shape_, 377 lattice_dim=lattice_dim, 378 accessible_cells=accessible_cells, 379 max_tree_depth=max_tree_depth, 380 start_coord=start_coord, 381 ) 382 383 # percolate 384 connection_list_perc: np.ndarray = ( 385 np.random.rand(*maze.connection_list.shape) < p 386 ) 387 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 388 389 maze.__dict__["connection_list"] = np.logical_or( 390 maze.connection_list, 391 connection_list_perc, 392 ) 393 394 # generation_meta is sometimes None, but not here since we just made it a dict above 395 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 396 maze.generation_meta["percolation_p"] = p # type: ignore[index] 397 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 398 start_coord, 399 ) 400 401 return maze
dfs and then percolation (adds cycles)
mapping of generator names to generator functions, useful for loading MazeDatasetConfig
429def get_maze_with_solution( 430 gen_name: str, 431 grid_shape: Coord | CoordTup, 432 maze_ctor_kwargs: dict | None = None, 433) -> SolvedMaze: 434 "helper function to get a maze already with a solution" 435 if maze_ctor_kwargs is None: 436 maze_ctor_kwargs = dict() 437 # TYPING: error: Too few arguments [call-arg] 438 # not sure why this is happening -- doesnt recognize the kwargs? 439 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg] 440 solution: CoordArray = np.array(maze.generate_random_path()) 441 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)
helper function to get a maze already with a solution