Coverage for maze_dataset/plotting/plot_maze.py: 83%

217 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more""" 

2 

3from __future__ import annotations # for type hinting self as return value 

4 

5import warnings 

6from copy import deepcopy 

7from dataclasses import dataclass 

8from typing import Sequence 

9 

10import matplotlib as mpl 

11import matplotlib.pyplot as plt 

12import numpy as np 

13from jaxtyping import Bool, Float 

14 

15from maze_dataset.constants import Coord, CoordArray, CoordList 

16from maze_dataset.maze import ( 

17 LatticeMaze, 

18 SolvedMaze, 

19 TargetedLatticeMaze, 

20) 

21 

22LARGE_NEGATIVE_NUMBER: float = -1e10 

23 

24 

25@dataclass(kw_only=True) 

26class PathFormat: 

27 """formatting options for path plot""" 

28 

29 label: str | None = None 

30 fmt: str = "o" 

31 color: str | None = None 

32 cmap: str | None = None 

33 line_width: float | None = None 

34 quiver_kwargs: dict | None = None 

35 

36 def combine(self, other: PathFormat) -> PathFormat: 

37 """combine with other PathFormat object, overwriting attributes with non-None values. 

38 

39 returns a modified copy of self. 

40 """ 

41 output: PathFormat = deepcopy(self) 

42 for key, value in other.__dict__.items(): 

43 if key == "path": 

44 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 

45 raise ValueError( 

46 err_msg, 

47 ) 

48 if value is not None: 

49 setattr(output, key, value) 

50 

51 return output 

52 

53 

54# styled path 

55@dataclass 

56class StyledPath(PathFormat): 

57 "a `StyledPath` is a `PathFormat` with a specific path" 

58 

59 path: CoordArray 

60 

61 

62DEFAULT_FORMATS: dict[str, PathFormat] = { 

63 "true": PathFormat( 

64 label="true path", 

65 fmt="--", 

66 color="red", 

67 line_width=2.5, 

68 quiver_kwargs=None, 

69 ), 

70 "predicted": PathFormat( 

71 label=None, 

72 fmt=":", 

73 color=None, 

74 line_width=2, 

75 quiver_kwargs={"width": 0.015}, 

76 ), 

77} 

78 

79 

80def process_path_input( 

81 path: CoordList | CoordArray | StyledPath, 

82 _default_key: str, 

83 path_fmt: PathFormat | None = None, 

84 **kwargs, 

85) -> StyledPath: 

86 "convert a path, which might be a list or array of coords, into a `StyledPath`" 

87 styled_path: StyledPath 

88 if isinstance(path, StyledPath): 

89 styled_path = path 

90 elif isinstance(path, np.ndarray): 

91 styled_path = StyledPath(path=path) 

92 # add default formatting 

93 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 

94 elif isinstance(path, list): 

95 styled_path = StyledPath(path=np.array(path)) 

96 # add default formatting 

97 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 

98 else: 

99 err_msg: str = ( 

100 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}" 

101 ) 

102 raise TypeError( 

103 err_msg, 

104 ) 

105 

106 # add formatting from path_fmt 

107 if path_fmt is not None: 

108 styled_path = styled_path.combine(path_fmt) 

109 

110 # add formatting from kwargs 

111 for key, value in kwargs.items(): 

112 setattr(styled_path, key, value) 

113 

114 return styled_path 

115 

116 

117DEFAULT_PREDICTED_PATH_COLORS: list[str] = [ 

118 "tab:orange", 

119 "tab:olive", 

120 "sienna", 

121 "mediumseagreen", 

122 "tab:purple", 

123 "slategrey", 

124] 

125 

126 

127class MazePlot: 

128 """Class for displaying mazes and paths""" 

129 

130 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 

131 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 

132 

133 Wall thickness is fixed to 1px 

134 A "unit" consists of a single node and the right and lower connection/wall. 

135 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 

136 """ 

137 self.unit_length: int = unit_length 

138 self.maze: LatticeMaze = maze 

139 self.true_path: StyledPath | None = None 

140 self.predicted_paths: list[StyledPath] = [] 

141 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 

142 self.custom_node_value_flag: bool = False 

143 self.node_color_map: str = "Blues" 

144 self.target_token_coord: Coord = None 

145 self.preceding_tokens_coords: CoordArray = None 

146 self.colormap_center: float | None = None 

147 self.cbar_ax = None 

148 self.marked_coords: list[tuple[Coord, dict]] = list() 

149 

150 self.marker_kwargs_current: dict = dict( 

151 marker="s", 

152 color="green", 

153 ms=12, 

154 ) 

155 self.marker_kwargs_next: dict = dict( 

156 marker="P", 

157 color="green", 

158 ms=12, 

159 ) 

160 

161 if isinstance(maze, SolvedMaze): 

162 self.add_true_path(maze.solution) 

163 else: 

164 if isinstance(maze, TargetedLatticeMaze): 

165 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) 

166 

167 @property 

168 def solved_maze(self) -> SolvedMaze: 

169 "get the underlying `SolvedMaze` object" 

170 if self.true_path is None: 

171 raise ValueError( 

172 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 

173 ) 

174 return SolvedMaze.from_lattice_maze( 

175 lattice_maze=self.maze, 

176 solution=self.true_path.path, 

177 ) 

178 

179 def add_true_path( 

180 self, 

181 path: CoordList | CoordArray | StyledPath, 

182 path_fmt: PathFormat | None = None, 

183 **kwargs, 

184 ) -> MazePlot: 

185 "add a true path to the maze with optional formatting" 

186 self.true_path = process_path_input( 

187 path=path, 

188 _default_key="true", 

189 path_fmt=path_fmt, 

190 **kwargs, 

191 ) 

192 

193 return self 

194 

195 def add_predicted_path( 

196 self, 

197 path: CoordList | CoordArray | StyledPath, 

198 path_fmt: PathFormat | None = None, 

199 **kwargs, 

200 ) -> MazePlot: 

201 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 

202 

203 Default formatting depends on nuber of paths already saved in predicted path list. 

204 """ 

205 styled_path: StyledPath = process_path_input( 

206 path=path, 

207 _default_key="predicted", 

208 path_fmt=path_fmt, 

209 **kwargs, 

210 ) 

211 

212 # set default label and color if not specified 

213 if styled_path.label is None: 

214 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 

215 

216 if styled_path.color is None: 

217 color_num: int = len(self.predicted_paths) % len( 

218 DEFAULT_PREDICTED_PATH_COLORS, 

219 ) 

220 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 

221 

222 self.predicted_paths.append(styled_path) 

223 return self 

224 

225 def add_multiple_paths( 

226 self, 

227 path_list: Sequence[CoordList | CoordArray | StyledPath], 

228 ) -> MazePlot: 

229 """Function for adding multiple paths to MazePlot at once. 

230 

231 > DOCS: what are the two ways? 

232 This can be done in two ways: 

233 1. Passing a list of 

234 """ 

235 for path in path_list: 

236 self.add_predicted_path(path) 

237 return self 

238 

239 def add_node_values( 

240 self, 

241 node_values: Float[np.ndarray, "grid_n grid_n"], 

242 color_map: str = "Blues", 

243 target_token_coord: Coord | None = None, 

244 preceeding_tokens_coords: CoordArray = None, 

245 colormap_center: float | None = None, 

246 colormap_max: float | None = None, 

247 hide_colorbar: bool = False, 

248 ) -> MazePlot: 

249 """add node values to the maze for visualization as a heatmap 

250 

251 > DOCS: what are these arguments? 

252 

253 # Parameters: 

254 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 

255 - `color_map : str` 

256 (defaults to `"Blues"`) 

257 - `target_token_coord : Coord | None` 

258 (defaults to `None`) 

259 - `preceeding_tokens_coords : CoordArray` 

260 (defaults to `None`) 

261 - `colormap_center : float | None` 

262 (defaults to `None`) 

263 - `colormap_max : float | None` 

264 (defaults to `None`) 

265 - `hide_colorbar : bool` 

266 (defaults to `False`) 

267 

268 # Returns: 

269 - `MazePlot` 

270 """ 

271 assert node_values.shape == self.maze.grid_shape, ( 

272 "Please pass node values of the same sape as LatticeMaze.grid_shape" 

273 ) 

274 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 

275 

276 self.node_values = node_values 

277 # Set flag for choosing cmap while plotting maze 

278 self.custom_node_value_flag = True 

279 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 

280 self.node_color_map = color_map 

281 self.colormap_center = colormap_center 

282 self.colormap_max = colormap_max 

283 self.hide_colorbar = hide_colorbar 

284 

285 if target_token_coord is not None: 

286 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 

287 if preceeding_tokens_coords is not None: 

288 for coord in preceeding_tokens_coords: 

289 self.marked_coords.append((coord, self.marker_kwargs_current)) 

290 return self 

291 

292 def plot( 

293 self, 

294 dpi: int = 100, 

295 title: str = "", 

296 fig_ax: tuple | None = None, 

297 plain: bool = False, 

298 ) -> MazePlot: 

299 """Plot the maze and paths.""" 

300 # set up figure 

301 if fig_ax is None: 

302 self.fig = plt.figure(dpi=dpi) 

303 self.ax = self.fig.add_subplot(1, 1, 1) 

304 else: 

305 self.fig, self.ax = fig_ax 

306 

307 # plot maze 

308 self._plot_maze() 

309 

310 # Plot labels 

311 if not plain: 

312 tick_arr = np.arange(self.maze.grid_shape[0]) 

313 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 

314 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 

315 self.ax.set_xlabel("col") 

316 self.ax.set_ylabel("row") 

317 self.ax.set_title(title) 

318 

319 # plot paths 

320 if self.true_path is not None: 

321 self._plot_path(self.true_path) 

322 for path in self.predicted_paths: 

323 self._plot_path(path) 

324 

325 # plot markers 

326 for coord, kwargs in self.marked_coords: 

327 self._place_marked_coords([coord], **kwargs) 

328 

329 return self 

330 

331 def _rowcol_to_coord(self, point: Coord) -> np.ndarray: 

332 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" 

333 point = np.array([point[1], point[0]]) 

334 return self.unit_length * (point + 0.5) 

335 

336 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 

337 """Mark coordinates on the maze with a marker. 

338 

339 default marker is a blue "+": 

340 `dict(marker="+", color="blue")` 

341 """ 

342 kwargs = { 

343 **dict(marker="+", color="blue"), 

344 **kwargs, 

345 } 

346 for coord in coords: 

347 self.marked_coords.append((coord, kwargs)) 

348 

349 return self 

350 

351 def _place_marked_coords( 

352 self, 

353 coords: CoordArray | list[Coord], 

354 **kwargs, 

355 ) -> MazePlot: 

356 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) 

357 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) 

358 

359 return self 

360 

361 def _plot_maze(self) -> None: # noqa: C901, PLR0912 

362 """Define Colormap and plot maze. 

363 

364 Colormap: x is -inf: black 

365 else: use colormap 

366 """ 

367 img = self._lattice_maze_to_img() 

368 

369 # if no node_values have been passed (no colormap) 

370 if self.custom_node_value_flag is False: 

371 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) 

372 

373 else: 

374 assert self.node_values is not None, "Please pass node values." 

375 assert not np.isnan(self.node_values).any(), ( 

376 "Please pass node values, they cannot be nan." 

377 ) 

378 

379 vals_min: float = np.nanmin(self.node_values) 

380 vals_max: float = np.nanmax(self.node_values) 

381 # if both are negative or both are positive, set max/min to 0 

382 if vals_max < 0.0: 

383 vals_max = 0.0 

384 elif vals_min > 0.0: 

385 vals_min = 0.0 

386 

387 # adjust vals_max, in case you need consistent colorbar across multiple plots 

388 vals_max = self.colormap_max or vals_max 

389 

390 # create colormap 

391 cmap = mpl.colormaps[self.node_color_map] 

392 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black 

393 cmap.set_bad(color="black") 

394 

395 if self.colormap_center is not None: 

396 if not (vals_min < self.colormap_center < vals_max): 

397 if vals_min == self.colormap_center: 

398 vals_min -= 1e-10 

399 elif vals_max == self.colormap_center: 

400 vals_max += 1e-10 

401 else: 

402 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}" 

403 raise ValueError( 

404 err_msg, 

405 ) 

406 

407 norm = mpl.colors.TwoSlopeNorm( 

408 vmin=vals_min, 

409 vcenter=self.colormap_center, 

410 vmax=vals_max, 

411 ) 

412 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) 

413 else: 

414 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) 

415 

416 # Add colorbar based on the condition of self.hide_colorbar 

417 if not self.hide_colorbar: 

418 ticks = np.linspace(vals_min, vals_max, 5) 

419 

420 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): 

421 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) 

422 

423 if ( 

424 self.colormap_center is not None 

425 and self.colormap_center not in ticks 

426 and vals_min < self.colormap_center < vals_max 

427 ): 

428 ticks = np.insert( 

429 ticks, 

430 np.searchsorted(ticks, self.colormap_center), 

431 self.colormap_center, 

432 ) 

433 

434 cbar = plt.colorbar( 

435 _plotted, 

436 ticks=ticks, 

437 ax=self.ax, 

438 cax=self.cbar_ax, 

439 ) 

440 self.cbar_ax = cbar.ax 

441 

442 # make the boundaries of the image thicker (walls look weird without this) 

443 for axis in ["top", "bottom", "left", "right"]: 

444 self.ax.spines[axis].set_linewidth(2) 

445 

446 def _lattice_maze_to_img( 

447 self, 

448 connection_val_scale: float = 0.93, 

449 ) -> Bool[np.ndarray, "row col"]: 

450 """Build an image to visualise the maze. 

451 

452 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. 

453 - Nodes have area: (ul-1) * (ul-1) and value 1 by default 

454 - take node_value if passed via .add_node_values() 

455 - Walls have area: 1 * (ul-1) and value -1 

456 - Connections have area: 1 * (ul-1); color and value 0.93 by default 

457 - take node_value if passed via .add_node_values() 

458 

459 Axes definition: 

460 (0,0) col 

461 ----|-----------> 

462 | 

463 row | 

464 | 

465 v 

466 

467 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. 

468 """ 

469 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored 

470 node_bdry_hack: int 

471 connection_list_processed: Float[np.ndarray, "dim row col"] 

472 # Set node and connection values 

473 if self.node_values is None: 

474 scaled_node_values = np.ones(self.maze.grid_shape) 

475 connection_values = scaled_node_values * connection_val_scale 

476 node_bdry_hack = 0 

477 # TODO: hack 

478 # invert connection list 

479 connection_list_processed = np.logical_not(self.maze.connection_list) 

480 else: 

481 # TODO: hack 

482 scaled_node_values = self.node_values 

483 # connection_values = scaled_node_values 

484 connection_values = np.full_like(scaled_node_values, np.nan) 

485 node_bdry_hack = 1 

486 connection_list_processed = self.maze.connection_list 

487 

488 # Create background image (all pixels set to -1, walls everywhere) 

489 img: Float[np.ndarray, "row col"] = -np.ones( 

490 ( 

491 self.maze.grid_shape[0] * self.unit_length + 1, 

492 self.maze.grid_shape[1] * self.unit_length + 1, 

493 ), 

494 dtype=float, 

495 ) 

496 

497 # Draw nodes and connections by iterating through lattice 

498 for row in range(self.maze.grid_shape[0]): 

499 for col in range(self.maze.grid_shape[1]): 

500 # Draw node 

501 img[ 

502 row * self.unit_length + 1 : (row + 1) * self.unit_length 

503 + node_bdry_hack, 

504 col * self.unit_length + 1 : (col + 1) * self.unit_length 

505 + node_bdry_hack, 

506 ] = scaled_node_values[row, col] 

507 

508 # Down connection 

509 if not connection_list_processed[0, row, col]: 

510 img[ 

511 (row + 1) * self.unit_length, 

512 col * self.unit_length + 1 : (col + 1) * self.unit_length, 

513 ] = connection_values[row, col] 

514 

515 # Right connection 

516 if not connection_list_processed[1, row, col]: 

517 img[ 

518 row * self.unit_length + 1 : (row + 1) * self.unit_length, 

519 (col + 1) * self.unit_length, 

520 ] = connection_values[row, col] 

521 

522 return img 

523 

524 def _plot_path(self, path_format: PathFormat) -> None: 

525 if len(path_format.path) == 0: 

526 warnings.warn(f"Empty path, skipping plotting\n{path_format = }") 

527 return 

528 p_transformed = np.array( 

529 [self._rowcol_to_coord(coord) for coord in path_format.path], 

530 ) 

531 if path_format.quiver_kwargs is not None: 

532 try: 

533 x: np.ndarray = p_transformed[:, 0] 

534 y: np.ndarray = p_transformed[:, 1] 

535 except Exception as e: 

536 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" 

537 raise ValueError( 

538 err_msg, 

539 ) from e 

540 

541 # Generate colors from the colormap 

542 if path_format.cmap is not None: 

543 n = len(x) - 1 # Number of arrows 

544 cmap = plt.get_cmap(path_format.cmap) 

545 colors = [cmap(i / n) for i in range(n)] 

546 else: 

547 colors = path_format.color 

548 

549 self.ax.quiver( 

550 x[:-1], 

551 y[:-1], 

552 x[1:] - x[:-1], 

553 y[1:] - y[:-1], 

554 scale_units="xy", 

555 angles="xy", 

556 scale=1, 

557 color=colors, 

558 **path_format.quiver_kwargs, 

559 ) 

560 else: 

561 self.ax.plot( 

562 p_transformed[:, 0], 

563 p_transformed[:, 1], 

564 path_format.fmt, 

565 lw=path_format.line_width, 

566 color=path_format.color, 

567 label=path_format.label, 

568 ) 

569 # mark endpoints 

570 self.ax.plot( 

571 [p_transformed[0][0]], 

572 [p_transformed[0][1]], 

573 "o", 

574 color=path_format.color, 

575 ms=10, 

576 ) 

577 self.ax.plot( 

578 [p_transformed[-1][0]], 

579 [p_transformed[-1][1]], 

580 "x", 

581 color=path_format.color, 

582 ms=10, 

583 ) 

584 

585 def to_ascii( 

586 self, 

587 show_endpoints: bool = True, 

588 show_solution: bool = True, 

589 ) -> str: 

590 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 

591 if self.true_path: 

592 return self.solved_maze.as_ascii( 

593 show_endpoints=show_endpoints, 

594 show_solution=show_solution, 

595 ) 

596 else: 

597 return self.maze.as_ascii(show_endpoints=show_endpoints)