Coverage for maze_dataset/generation/generators.py: 77%

132 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 17:51 -0600

1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`""" 

2 

3import random 

4import warnings 

5from typing import Any, Callable 

6 

7import numpy as np 

8from jaxtyping import Bool 

9 

10from maze_dataset.constants import CoordArray, CoordTup 

11from maze_dataset.generation.seed import GLOBAL_SEED 

12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze 

13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls 

14 

15numpy_rng = np.random.default_rng(GLOBAL_SEED) 

16random.seed(GLOBAL_SEED) 

17 

18 

19def _random_start_coord( 

20 grid_shape: Coord, 

21 start_coord: Coord | CoordTup | None, 

22) -> Coord: 

23 "picking a random start coord within the bounds of `grid_shape` if none is provided" 

24 start_coord_: Coord 

25 if start_coord is None: 

26 start_coord_ = np.random.randint( 

27 0, # lower bound 

28 np.maximum(grid_shape - 1, 1), # upper bound (at least 1) 

29 size=len(grid_shape), # dimensionality 

30 ) 

31 else: 

32 start_coord_ = np.array(start_coord) 

33 

34 return start_coord_ 

35 

36 

37def get_neighbors_in_bounds( 

38 coord: Coord, 

39 grid_shape: Coord, 

40) -> CoordArray: 

41 "get all neighbors of a coordinate that are within the bounds of the grid" 

42 # get all neighbors 

43 neighbors: CoordArray = coord + NEIGHBORS_MASK 

44 

45 # filter neighbors by being within grid bounds 

46 neighbors_in_bounds: CoordArray = neighbors[ 

47 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1) 

48 ] 

49 

50 return neighbors_in_bounds 

51 

52 

53class LatticeMazeGenerators: 

54 """namespace for lattice maze generation algorithms""" 

55 

56 @staticmethod 

57 def gen_dfs( 

58 grid_shape: Coord | CoordTup, 

59 lattice_dim: int = 2, 

60 accessible_cells: float | None = None, 

61 max_tree_depth: float | None = None, 

62 do_forks: bool = True, 

63 randomized_stack: bool = False, 

64 start_coord: Coord | None = None, 

65 ) -> LatticeMaze: 

66 """generate a lattice maze using depth first search, iterative 

67 

68 # Arguments 

69 - `grid_shape: Coord`: the shape of the grid 

70 - `lattice_dim: int`: the dimension of the lattice 

71 (default: `2`) 

72 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 

73 (default: `None`) 

74 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 

75 (default: `None`) 

76 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 

77 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 

78 

79 # algorithm 

80 1. Choose the initial cell, mark it as visited and push it to the stack 

81 2. While the stack is not empty 

82 1. Pop a cell from the stack and make it a current cell 

83 2. If the current cell has any neighbours which have not been visited 

84 1. Push the current cell to the stack 

85 2. Choose one of the unvisited neighbours 

86 3. Remove the wall between the current cell and the chosen cell 

87 4. Mark the chosen cell as visited and push it to the stack 

88 """ 

89 # Default values if no constraints have been passed 

90 grid_shape_: Coord = np.array(grid_shape) 

91 n_total_cells: int = int(np.prod(grid_shape_)) 

92 

93 n_accessible_cells: int 

94 if accessible_cells is None: 

95 n_accessible_cells = n_total_cells 

96 elif isinstance(accessible_cells, float): 

97 assert accessible_cells <= 1, ( 

98 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 

99 ) 

100 

101 n_accessible_cells = int(accessible_cells * n_total_cells) 

102 else: 

103 assert isinstance(accessible_cells, int) 

104 n_accessible_cells = accessible_cells 

105 

106 if max_tree_depth is None: 

107 max_tree_depth = ( 

108 2 * n_total_cells 

109 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 

110 elif isinstance(max_tree_depth, float): 

111 assert max_tree_depth <= 1, ( 

112 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 

113 ) 

114 

115 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 

116 

117 # choose a random start coord 

118 start_coord = _random_start_coord(grid_shape_, start_coord) 

119 

120 # initialize the maze with no connections 

121 connection_list: ConnectionList = np.zeros( 

122 (lattice_dim, grid_shape_[0], grid_shape_[1]), 

123 dtype=np.bool_, 

124 ) 

125 

126 # initialize the stack with the target coord 

127 visited_cells: set[tuple[int, int]] = set() 

128 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 

129 stack: list[Coord] = [start_coord] 

130 

131 # initialize tree_depth_counter 

132 current_tree_depth: int = 1 

133 

134 # loop until the stack is empty or n_connected_cells is reached 

135 while stack and (len(visited_cells) < n_accessible_cells): 

136 # get the current coord from the stack 

137 current_coord: Coord 

138 if randomized_stack: 

139 # we dont care about S311 because this isnt security related 

140 current_coord = stack.pop(random.randint(0, len(stack) - 1)) # noqa: S311 

141 else: 

142 current_coord = stack.pop() 

143 

144 # filter neighbors by being within grid bounds and being unvisited 

145 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 

146 (neighbor, delta) 

147 for neighbor, delta in zip( 

148 current_coord + NEIGHBORS_MASK, 

149 NEIGHBORS_MASK, 

150 strict=False, 

151 ) 

152 if ( 

153 (tuple(neighbor) not in visited_cells) 

154 and (0 <= neighbor[0] < grid_shape_[0]) 

155 and (0 <= neighbor[1] < grid_shape_[1]) 

156 ) 

157 ] 

158 

159 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 

160 if unvisited_neighbors_deltas and ( 

161 current_tree_depth <= max_tree_depth / 2 

162 ): 

163 # if we want a maze without forks, simply don't add the current coord back to the stack 

164 if do_forks and (len(unvisited_neighbors_deltas) > 1): 

165 stack.append(current_coord) 

166 

167 # choose one of the unvisited neighbors 

168 # we dont care about S311 because this isn't security related 

169 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) # noqa: S311 

170 

171 # add connection 

172 dim: int = int(np.argmax(np.abs(delta))) 

173 # if positive, down/right from current coord 

174 # if negative, up/left from current coord (down/right from neighbor) 

175 clist_node: Coord = ( 

176 current_coord if (delta.sum() > 0) else chosen_neighbor 

177 ) 

178 connection_list[dim, clist_node[0], clist_node[1]] = True 

179 

180 # add to visited cells and stack 

181 visited_cells.add(tuple(chosen_neighbor)) 

182 stack.append(chosen_neighbor) 

183 

184 # Update current tree depth 

185 current_tree_depth += 1 

186 else: 

187 current_tree_depth -= 1 

188 

189 return LatticeMaze( 

190 connection_list=connection_list, 

191 generation_meta=dict( 

192 func_name="gen_dfs", 

193 grid_shape=grid_shape_, 

194 start_coord=start_coord, 

195 n_accessible_cells=int(n_accessible_cells), 

196 max_tree_depth=int(max_tree_depth), 

197 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 

198 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 

199 # treated as fully connected even when it is most certainly not, causing solving the maze to break 

200 fully_connected=bool(len(visited_cells) == n_total_cells), 

201 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 

202 ), 

203 ) 

204 

205 @staticmethod 

206 def gen_prim( 

207 grid_shape: Coord | CoordTup, 

208 lattice_dim: int = 2, 

209 accessible_cells: float | None = None, 

210 max_tree_depth: float | None = None, 

211 do_forks: bool = True, 

212 start_coord: Coord | None = None, 

213 ) -> LatticeMaze: 

214 "(broken!) generate a lattice maze using Prim's algorithm" 

215 warnings.warn( 

216 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 

217 ) 

218 return LatticeMazeGenerators.gen_dfs( 

219 grid_shape=grid_shape, 

220 lattice_dim=lattice_dim, 

221 accessible_cells=accessible_cells, 

222 max_tree_depth=max_tree_depth, 

223 do_forks=do_forks, 

224 start_coord=start_coord, 

225 randomized_stack=True, 

226 ) 

227 

228 @staticmethod 

229 def gen_wilson( 

230 grid_shape: Coord | CoordTup, 

231 **kwargs, 

232 ) -> LatticeMaze: 

233 """Generate a lattice maze using Wilson's algorithm. 

234 

235 # Algorithm 

236 Wilson's algorithm generates an unbiased (random) maze 

237 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 

238 acyclic and all cells are part of a unique connected space. 

239 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 

240 """ 

241 assert not kwargs, ( 

242 f"gen_wilson does not take any additional arguments, got {kwargs = }" 

243 ) 

244 

245 grid_shape_: Coord = np.array(grid_shape) 

246 

247 # Initialize grid and visited cells 

248 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 

249 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 

250 

251 # Choose a random cell and mark it as visited 

252 start_coord: Coord = _random_start_coord(grid_shape_, None) 

253 visited[start_coord[0], start_coord[1]] = True 

254 del start_coord 

255 

256 while not visited.all(): 

257 # Perform loop-erased random walk from another random cell 

258 

259 # Choose walk_start only from unvisited cells 

260 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 

261 walk_start: Coord = unvisited_coords[ 

262 np.random.choice(unvisited_coords.shape[0]) 

263 ] 

264 

265 # Perform the random walk 

266 path: list[Coord] = [walk_start] 

267 current: Coord = walk_start 

268 

269 # exit the loop once the current path hits a visited cell 

270 while not visited[current[0], current[1]]: 

271 # find a valid neighbor (one always exists on a lattice) 

272 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 

273 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 

274 

275 # Check for loop 

276 loop_exit: int | None = None 

277 for i, p in enumerate(path): 

278 if np.array_equal(next_cell, p): 

279 loop_exit = i 

280 break 

281 

282 # erase the loop, or continue the walk 

283 if loop_exit is not None: 

284 # this removes everything after and including the loop start 

285 path = path[: loop_exit + 1] 

286 # reset current cell to end of path 

287 current = path[-1] 

288 else: 

289 path.append(next_cell) 

290 current = next_cell 

291 

292 # Add the path to the maze 

293 for i in range(len(path) - 1): 

294 c_1: Coord = path[i] 

295 c_2: Coord = path[i + 1] 

296 

297 # find the dimension of the connection 

298 delta: Coord = c_2 - c_1 

299 dim: int = int(np.argmax(np.abs(delta))) 

300 

301 # if positive, down/right from current coord 

302 # if negative, up/left from current coord (down/right from neighbor) 

303 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 

304 connection_list[dim, clist_node[0], clist_node[1]] = True 

305 visited[c_1[0], c_1[1]] = True 

306 # we dont add c_2 because the last c_2 will have already been visited 

307 

308 return LatticeMaze( 

309 connection_list=connection_list, 

310 generation_meta=dict( 

311 func_name="gen_wilson", 

312 grid_shape=grid_shape_, 

313 fully_connected=True, 

314 ), 

315 ) 

316 

317 @staticmethod 

318 def gen_percolation( 

319 grid_shape: Coord | CoordTup, 

320 p: float = 0.4, 

321 lattice_dim: int = 2, 

322 start_coord: Coord | None = None, 

323 ) -> LatticeMaze: 

324 """generate a lattice maze using simple percolation 

325 

326 note that p in the range (0.4, 0.7) gives the most interesting mazes 

327 

328 # Arguments 

329 - `grid_shape: Coord`: the shape of the grid 

330 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 

331 - `p: float`: the probability of a cell being accessible (default: `0.5`) 

332 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 

333 """ 

334 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 

335 grid_shape_: Coord = np.array(grid_shape) 

336 

337 start_coord = _random_start_coord(grid_shape_, start_coord) 

338 

339 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 

340 

341 connection_list = _fill_edges_with_walls(connection_list) 

342 

343 output: LatticeMaze = LatticeMaze( 

344 connection_list=connection_list, 

345 generation_meta=dict( 

346 func_name="gen_percolation", 

347 grid_shape=grid_shape_, 

348 percolation_p=p, 

349 start_coord=start_coord, 

350 ), 

351 ) 

352 

353 # generation_meta is sometimes None, but not here since we just made it a dict above 

354 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 

355 start_coord, 

356 ) 

357 

358 return output 

359 

360 @staticmethod 

361 def gen_dfs_percolation( 

362 grid_shape: Coord | CoordTup, 

363 p: float = 0.4, 

364 lattice_dim: int = 2, 

365 accessible_cells: int | None = None, 

366 max_tree_depth: int | None = None, 

367 start_coord: Coord | None = None, 

368 ) -> LatticeMaze: 

369 """dfs and then percolation (adds cycles)""" 

370 grid_shape_: Coord = np.array(grid_shape) 

371 start_coord = _random_start_coord(grid_shape_, start_coord) 

372 

373 # generate initial maze via dfs 

374 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 

375 grid_shape=grid_shape_, 

376 lattice_dim=lattice_dim, 

377 accessible_cells=accessible_cells, 

378 max_tree_depth=max_tree_depth, 

379 start_coord=start_coord, 

380 ) 

381 

382 # percolate 

383 connection_list_perc: np.ndarray = ( 

384 np.random.rand(*maze.connection_list.shape) < p 

385 ) 

386 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 

387 

388 maze.__dict__["connection_list"] = np.logical_or( 

389 maze.connection_list, 

390 connection_list_perc, 

391 ) 

392 

393 # generation_meta is sometimes None, but not here since we just made it a dict above 

394 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 

395 maze.generation_meta["percolation_p"] = p # type: ignore[index] 

396 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 

397 start_coord, 

398 ) 

399 

400 return maze 

401 

402 

403# cant automatically populate this because it messes with pickling :( 

404GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = { 

405 "gen_dfs": LatticeMazeGenerators.gen_dfs, 

406 # TYPING: error: Dict entry 1 has incompatible type 

407 # "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]"; 

408 # expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]" [dict-item] 

409 # gen_wilson takes no kwargs and we check that the kwargs are empty 

410 # but mypy doesnt like this, `Any` != `KwArg(Any)` 

411 "gen_wilson": LatticeMazeGenerators.gen_wilson, # type: ignore[dict-item] 

412 "gen_percolation": LatticeMazeGenerators.gen_percolation, 

413 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation, 

414 "gen_prim": LatticeMazeGenerators.gen_prim, 

415} 

416"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`" 

417 

418_GENERATORS_PERCOLATED: list[str] = [ 

419 "gen_percolation", 

420 "gen_dfs_percolation", 

421] 

422"""list of generator names that generate percolated mazes 

423we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail 

424this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array` 

425""" 

426 

427 

428def get_maze_with_solution( 

429 gen_name: str, 

430 grid_shape: Coord | CoordTup, 

431 maze_ctor_kwargs: dict | None = None, 

432) -> SolvedMaze: 

433 "helper function to get a maze already with a solution" 

434 if maze_ctor_kwargs is None: 

435 maze_ctor_kwargs = dict() 

436 # TYPING: error: Too few arguments [call-arg] 

437 # not sure why this is happening -- doesnt recognize the kwargs? 

438 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg] 

439 solution: CoordArray = np.array(maze.generate_random_path()) 

440 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)