Coverage for maze_dataset/plotting/plot_maze.py: 83%
217 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more"""
3from __future__ import annotations # for type hinting self as return value
5import warnings
6from copy import deepcopy
7from dataclasses import dataclass
8from typing import Sequence
10import matplotlib as mpl
11import matplotlib.pyplot as plt
12import numpy as np
13from jaxtyping import Bool, Float
15from maze_dataset.constants import Coord, CoordArray, CoordList
16from maze_dataset.maze import (
17 LatticeMaze,
18 SolvedMaze,
19 TargetedLatticeMaze,
20)
22LARGE_NEGATIVE_NUMBER: float = -1e10
25@dataclass(kw_only=True)
26class PathFormat:
27 """formatting options for path plot"""
29 label: str | None = None
30 fmt: str = "o"
31 color: str | None = None
32 cmap: str | None = None
33 line_width: float | None = None
34 quiver_kwargs: dict | None = None
36 def combine(self, other: PathFormat) -> PathFormat:
37 """combine with other PathFormat object, overwriting attributes with non-None values.
39 returns a modified copy of self.
40 """
41 output: PathFormat = deepcopy(self)
42 for key, value in other.__dict__.items():
43 if key == "path":
44 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
45 raise ValueError(
46 err_msg,
47 )
48 if value is not None:
49 setattr(output, key, value)
51 return output
54# styled path
55@dataclass
56class StyledPath(PathFormat):
57 "a `StyledPath` is a `PathFormat` with a specific path"
59 path: CoordArray
62DEFAULT_FORMATS: dict[str, PathFormat] = {
63 "true": PathFormat(
64 label="true path",
65 fmt="--",
66 color="red",
67 line_width=2.5,
68 quiver_kwargs=None,
69 ),
70 "predicted": PathFormat(
71 label=None,
72 fmt=":",
73 color=None,
74 line_width=2,
75 quiver_kwargs={"width": 0.015},
76 ),
77}
80def process_path_input(
81 path: CoordList | CoordArray | StyledPath,
82 _default_key: str,
83 path_fmt: PathFormat | None = None,
84 **kwargs,
85) -> StyledPath:
86 "convert a path, which might be a list or array of coords, into a `StyledPath`"
87 styled_path: StyledPath
88 if isinstance(path, StyledPath):
89 styled_path = path
90 elif isinstance(path, np.ndarray):
91 styled_path = StyledPath(path=path)
92 # add default formatting
93 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
94 elif isinstance(path, list):
95 styled_path = StyledPath(path=np.array(path))
96 # add default formatting
97 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
98 else:
99 err_msg: str = (
100 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}"
101 )
102 raise TypeError(
103 err_msg,
104 )
106 # add formatting from path_fmt
107 if path_fmt is not None:
108 styled_path = styled_path.combine(path_fmt)
110 # add formatting from kwargs
111 for key, value in kwargs.items():
112 setattr(styled_path, key, value)
114 return styled_path
117DEFAULT_PREDICTED_PATH_COLORS: list[str] = [
118 "tab:orange",
119 "tab:olive",
120 "sienna",
121 "mediumseagreen",
122 "tab:purple",
123 "slategrey",
124]
127class MazePlot:
128 """Class for displaying mazes and paths"""
130 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
131 """UNIT_LENGTH: Set ratio between node size and wall thickness in image.
133 Wall thickness is fixed to 1px
134 A "unit" consists of a single node and the right and lower connection/wall.
135 Example: ul = 14 yields 13:1 ratio between node size and wall thickness
136 """
137 self.unit_length: int = unit_length
138 self.maze: LatticeMaze = maze
139 self.true_path: StyledPath | None = None
140 self.predicted_paths: list[StyledPath] = []
141 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
142 self.custom_node_value_flag: bool = False
143 self.node_color_map: str = "Blues"
144 self.target_token_coord: Coord = None
145 self.preceding_tokens_coords: CoordArray = None
146 self.colormap_center: float | None = None
147 self.cbar_ax = None
148 self.marked_coords: list[tuple[Coord, dict]] = list()
150 self.marker_kwargs_current: dict = dict(
151 marker="s",
152 color="green",
153 ms=12,
154 )
155 self.marker_kwargs_next: dict = dict(
156 marker="P",
157 color="green",
158 ms=12,
159 )
161 if isinstance(maze, SolvedMaze):
162 self.add_true_path(maze.solution)
163 else:
164 if isinstance(maze, TargetedLatticeMaze):
165 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
167 @property
168 def solved_maze(self) -> SolvedMaze:
169 "get the underlying `SolvedMaze` object"
170 if self.true_path is None:
171 raise ValueError(
172 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
173 )
174 return SolvedMaze.from_lattice_maze(
175 lattice_maze=self.maze,
176 solution=self.true_path.path,
177 )
179 def add_true_path(
180 self,
181 path: CoordList | CoordArray | StyledPath,
182 path_fmt: PathFormat | None = None,
183 **kwargs,
184 ) -> MazePlot:
185 "add a true path to the maze with optional formatting"
186 self.true_path = process_path_input(
187 path=path,
188 _default_key="true",
189 path_fmt=path_fmt,
190 **kwargs,
191 )
193 return self
195 def add_predicted_path(
196 self,
197 path: CoordList | CoordArray | StyledPath,
198 path_fmt: PathFormat | None = None,
199 **kwargs,
200 ) -> MazePlot:
201 """Recieve predicted path and formatting preferences from input and save in predicted_path list.
203 Default formatting depends on nuber of paths already saved in predicted path list.
204 """
205 styled_path: StyledPath = process_path_input(
206 path=path,
207 _default_key="predicted",
208 path_fmt=path_fmt,
209 **kwargs,
210 )
212 # set default label and color if not specified
213 if styled_path.label is None:
214 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
216 if styled_path.color is None:
217 color_num: int = len(self.predicted_paths) % len(
218 DEFAULT_PREDICTED_PATH_COLORS,
219 )
220 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
222 self.predicted_paths.append(styled_path)
223 return self
225 def add_multiple_paths(
226 self,
227 path_list: Sequence[CoordList | CoordArray | StyledPath],
228 ) -> MazePlot:
229 """Function for adding multiple paths to MazePlot at once.
231 > DOCS: what are the two ways?
232 This can be done in two ways:
233 1. Passing a list of
234 """
235 for path in path_list:
236 self.add_predicted_path(path)
237 return self
239 def add_node_values(
240 self,
241 node_values: Float[np.ndarray, "grid_n grid_n"],
242 color_map: str = "Blues",
243 target_token_coord: Coord | None = None,
244 preceeding_tokens_coords: CoordArray = None,
245 colormap_center: float | None = None,
246 colormap_max: float | None = None,
247 hide_colorbar: bool = False,
248 ) -> MazePlot:
249 """add node values to the maze for visualization as a heatmap
251 > DOCS: what are these arguments?
253 # Parameters:
254 - `node_values : Float[np.ndarray, "grid_n grid_n"]`
255 - `color_map : str`
256 (defaults to `"Blues"`)
257 - `target_token_coord : Coord | None`
258 (defaults to `None`)
259 - `preceeding_tokens_coords : CoordArray`
260 (defaults to `None`)
261 - `colormap_center : float | None`
262 (defaults to `None`)
263 - `colormap_max : float | None`
264 (defaults to `None`)
265 - `hide_colorbar : bool`
266 (defaults to `False`)
268 # Returns:
269 - `MazePlot`
270 """
271 assert node_values.shape == self.maze.grid_shape, (
272 "Please pass node values of the same sape as LatticeMaze.grid_shape"
273 )
274 # assert np.min(node_values) >= 0, "Please pass non-negative node values only."
276 self.node_values = node_values
277 # Set flag for choosing cmap while plotting maze
278 self.custom_node_value_flag = True
279 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero
280 self.node_color_map = color_map
281 self.colormap_center = colormap_center
282 self.colormap_max = colormap_max
283 self.hide_colorbar = hide_colorbar
285 if target_token_coord is not None:
286 self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
287 if preceeding_tokens_coords is not None:
288 for coord in preceeding_tokens_coords:
289 self.marked_coords.append((coord, self.marker_kwargs_current))
290 return self
292 def plot(
293 self,
294 dpi: int = 100,
295 title: str = "",
296 fig_ax: tuple | None = None,
297 plain: bool = False,
298 ) -> MazePlot:
299 """Plot the maze and paths."""
300 # set up figure
301 if fig_ax is None:
302 self.fig = plt.figure(dpi=dpi)
303 self.ax = self.fig.add_subplot(1, 1, 1)
304 else:
305 self.fig, self.ax = fig_ax
307 # plot maze
308 self._plot_maze()
310 # Plot labels
311 if not plain:
312 tick_arr = np.arange(self.maze.grid_shape[0])
313 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
314 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315 self.ax.set_xlabel("col")
316 self.ax.set_ylabel("row")
317 self.ax.set_title(title)
319 # plot paths
320 if self.true_path is not None:
321 self._plot_path(self.true_path)
322 for path in self.predicted_paths:
323 self._plot_path(path)
325 # plot markers
326 for coord, kwargs in self.marked_coords:
327 self._place_marked_coords([coord], **kwargs)
329 return self
331 def _rowcol_to_coord(self, point: Coord) -> np.ndarray:
332 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
333 point = np.array([point[1], point[0]])
334 return self.unit_length * (point + 0.5)
336 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
337 """Mark coordinates on the maze with a marker.
339 default marker is a blue "+":
340 `dict(marker="+", color="blue")`
341 """
342 kwargs = {
343 **dict(marker="+", color="blue"),
344 **kwargs,
345 }
346 for coord in coords:
347 self.marked_coords.append((coord, kwargs))
349 return self
351 def _place_marked_coords(
352 self,
353 coords: CoordArray | list[Coord],
354 **kwargs,
355 ) -> MazePlot:
356 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords])
357 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs)
359 return self
361 def _plot_maze(self) -> None: # noqa: C901, PLR0912
362 """Define Colormap and plot maze.
364 Colormap: x is -inf: black
365 else: use colormap
366 """
367 img = self._lattice_maze_to_img()
369 # if no node_values have been passed (no colormap)
370 if self.custom_node_value_flag is False:
371 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1)
373 else:
374 assert self.node_values is not None, "Please pass node values."
375 assert not np.isnan(self.node_values).any(), (
376 "Please pass node values, they cannot be nan."
377 )
379 vals_min: float = np.nanmin(self.node_values)
380 vals_max: float = np.nanmax(self.node_values)
381 # if both are negative or both are positive, set max/min to 0
382 if vals_max < 0.0:
383 vals_max = 0.0
384 elif vals_min > 0.0:
385 vals_min = 0.0
387 # adjust vals_max, in case you need consistent colorbar across multiple plots
388 vals_max = self.colormap_max or vals_max
390 # create colormap
391 cmap = mpl.colormaps[self.node_color_map]
392 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black
393 cmap.set_bad(color="black")
395 if self.colormap_center is not None:
396 if not (vals_min < self.colormap_center < vals_max):
397 if vals_min == self.colormap_center:
398 vals_min -= 1e-10
399 elif vals_max == self.colormap_center:
400 vals_max += 1e-10
401 else:
402 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}"
403 raise ValueError(
404 err_msg,
405 )
407 norm = mpl.colors.TwoSlopeNorm(
408 vmin=vals_min,
409 vcenter=self.colormap_center,
410 vmax=vals_max,
411 )
412 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm)
413 else:
414 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max)
416 # Add colorbar based on the condition of self.hide_colorbar
417 if not self.hide_colorbar:
418 ticks = np.linspace(vals_min, vals_max, 5)
420 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks):
421 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0)
423 if (
424 self.colormap_center is not None
425 and self.colormap_center not in ticks
426 and vals_min < self.colormap_center < vals_max
427 ):
428 ticks = np.insert(
429 ticks,
430 np.searchsorted(ticks, self.colormap_center),
431 self.colormap_center,
432 )
434 cbar = plt.colorbar(
435 _plotted,
436 ticks=ticks,
437 ax=self.ax,
438 cax=self.cbar_ax,
439 )
440 self.cbar_ax = cbar.ax
442 # make the boundaries of the image thicker (walls look weird without this)
443 for axis in ["top", "bottom", "left", "right"]:
444 self.ax.spines[axis].set_linewidth(2)
446 def _lattice_maze_to_img(
447 self,
448 connection_val_scale: float = 0.93,
449 ) -> Bool[np.ndarray, "row col"]:
450 """Build an image to visualise the maze.
452 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
453 - Nodes have area: (ul-1) * (ul-1) and value 1 by default
454 - take node_value if passed via .add_node_values()
455 - Walls have area: 1 * (ul-1) and value -1
456 - Connections have area: 1 * (ul-1); color and value 0.93 by default
457 - take node_value if passed via .add_node_values()
459 Axes definition:
460 (0,0) col
461 ----|----------->
462 |
463 row |
464 |
465 v
467 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes.
468 """
469 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored
470 node_bdry_hack: int
471 connection_list_processed: Float[np.ndarray, "dim row col"]
472 # Set node and connection values
473 if self.node_values is None:
474 scaled_node_values = np.ones(self.maze.grid_shape)
475 connection_values = scaled_node_values * connection_val_scale
476 node_bdry_hack = 0
477 # TODO: hack
478 # invert connection list
479 connection_list_processed = np.logical_not(self.maze.connection_list)
480 else:
481 # TODO: hack
482 scaled_node_values = self.node_values
483 # connection_values = scaled_node_values
484 connection_values = np.full_like(scaled_node_values, np.nan)
485 node_bdry_hack = 1
486 connection_list_processed = self.maze.connection_list
488 # Create background image (all pixels set to -1, walls everywhere)
489 img: Float[np.ndarray, "row col"] = -np.ones(
490 (
491 self.maze.grid_shape[0] * self.unit_length + 1,
492 self.maze.grid_shape[1] * self.unit_length + 1,
493 ),
494 dtype=float,
495 )
497 # Draw nodes and connections by iterating through lattice
498 for row in range(self.maze.grid_shape[0]):
499 for col in range(self.maze.grid_shape[1]):
500 # Draw node
501 img[
502 row * self.unit_length + 1 : (row + 1) * self.unit_length
503 + node_bdry_hack,
504 col * self.unit_length + 1 : (col + 1) * self.unit_length
505 + node_bdry_hack,
506 ] = scaled_node_values[row, col]
508 # Down connection
509 if not connection_list_processed[0, row, col]:
510 img[
511 (row + 1) * self.unit_length,
512 col * self.unit_length + 1 : (col + 1) * self.unit_length,
513 ] = connection_values[row, col]
515 # Right connection
516 if not connection_list_processed[1, row, col]:
517 img[
518 row * self.unit_length + 1 : (row + 1) * self.unit_length,
519 (col + 1) * self.unit_length,
520 ] = connection_values[row, col]
522 return img
524 def _plot_path(self, path_format: PathFormat) -> None:
525 if len(path_format.path) == 0:
526 warnings.warn(f"Empty path, skipping plotting\n{path_format = }")
527 return
528 p_transformed = np.array(
529 [self._rowcol_to_coord(coord) for coord in path_format.path],
530 )
531 if path_format.quiver_kwargs is not None:
532 try:
533 x: np.ndarray = p_transformed[:, 0]
534 y: np.ndarray = p_transformed[:, 1]
535 except Exception as e:
536 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}"
537 raise ValueError(
538 err_msg,
539 ) from e
541 # Generate colors from the colormap
542 if path_format.cmap is not None:
543 n = len(x) - 1 # Number of arrows
544 cmap = plt.get_cmap(path_format.cmap)
545 colors = [cmap(i / n) for i in range(n)]
546 else:
547 colors = path_format.color
549 self.ax.quiver(
550 x[:-1],
551 y[:-1],
552 x[1:] - x[:-1],
553 y[1:] - y[:-1],
554 scale_units="xy",
555 angles="xy",
556 scale=1,
557 color=colors,
558 **path_format.quiver_kwargs,
559 )
560 else:
561 self.ax.plot(
562 p_transformed[:, 0],
563 p_transformed[:, 1],
564 path_format.fmt,
565 lw=path_format.line_width,
566 color=path_format.color,
567 label=path_format.label,
568 )
569 # mark endpoints
570 self.ax.plot(
571 [p_transformed[0][0]],
572 [p_transformed[0][1]],
573 "o",
574 color=path_format.color,
575 ms=10,
576 )
577 self.ax.plot(
578 [p_transformed[-1][0]],
579 [p_transformed[-1][1]],
580 "x",
581 color=path_format.color,
582 ms=10,
583 )
585 def to_ascii(
586 self,
587 show_endpoints: bool = True,
588 show_solution: bool = True,
589 ) -> str:
590 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
591 if self.true_path:
592 return self.solved_maze.as_ascii(
593 show_endpoints=show_endpoints,
594 show_solution=show_solution,
595 )
596 else:
597 return self.maze.as_ascii(show_endpoints=show_endpoints)