maze_dataset.plotting.plot_maze
provides MazePlot
, which has many tools for plotting mazes with multiple paths, colored nodes, and more
1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more""" 2 3from __future__ import annotations # for type hinting self as return value 4 5import warnings 6from copy import deepcopy 7from dataclasses import dataclass 8from typing import Sequence 9 10import matplotlib as mpl 11import matplotlib.pyplot as plt 12import numpy as np 13from jaxtyping import Bool, Float 14 15from maze_dataset.constants import Coord, CoordArray, CoordList 16from maze_dataset.maze import ( 17 LatticeMaze, 18 SolvedMaze, 19 TargetedLatticeMaze, 20) 21 22LARGE_NEGATIVE_NUMBER: float = -1e10 23 24 25@dataclass(kw_only=True) 26class PathFormat: 27 """formatting options for path plot""" 28 29 label: str | None = None 30 fmt: str = "o" 31 color: str | None = None 32 cmap: str | None = None 33 line_width: float | None = None 34 quiver_kwargs: dict | None = None 35 36 def combine(self, other: PathFormat) -> PathFormat: 37 """combine with other PathFormat object, overwriting attributes with non-None values. 38 39 returns a modified copy of self. 40 """ 41 output: PathFormat = deepcopy(self) 42 for key, value in other.__dict__.items(): 43 if key == "path": 44 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 45 raise ValueError( 46 err_msg, 47 ) 48 if value is not None: 49 setattr(output, key, value) 50 51 return output 52 53 54# styled path 55@dataclass 56class StyledPath(PathFormat): 57 "a `StyledPath` is a `PathFormat` with a specific path" 58 59 path: CoordArray 60 61 62DEFAULT_FORMATS: dict[str, PathFormat] = { 63 "true": PathFormat( 64 label="true path", 65 fmt="--", 66 color="red", 67 line_width=2.5, 68 quiver_kwargs=None, 69 ), 70 "predicted": PathFormat( 71 label=None, 72 fmt=":", 73 color=None, 74 line_width=2, 75 quiver_kwargs={"width": 0.015}, 76 ), 77} 78 79 80def process_path_input( 81 path: CoordList | CoordArray | StyledPath, 82 _default_key: str, 83 path_fmt: PathFormat | None = None, 84 **kwargs, 85) -> StyledPath: 86 "convert a path, which might be a list or array of coords, into a `StyledPath`" 87 styled_path: StyledPath 88 if isinstance(path, StyledPath): 89 styled_path = path 90 elif isinstance(path, np.ndarray): 91 styled_path = StyledPath(path=path) 92 # add default formatting 93 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 94 elif isinstance(path, list): 95 styled_path = StyledPath(path=np.array(path)) 96 # add default formatting 97 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 98 else: 99 err_msg: str = ( 100 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}" 101 ) 102 raise TypeError( 103 err_msg, 104 ) 105 106 # add formatting from path_fmt 107 if path_fmt is not None: 108 styled_path = styled_path.combine(path_fmt) 109 110 # add formatting from kwargs 111 for key, value in kwargs.items(): 112 setattr(styled_path, key, value) 113 114 return styled_path 115 116 117DEFAULT_PREDICTED_PATH_COLORS: list[str] = [ 118 "tab:orange", 119 "tab:olive", 120 "sienna", 121 "mediumseagreen", 122 "tab:purple", 123 "slategrey", 124] 125 126 127class MazePlot: 128 """Class for displaying mazes and paths""" 129 130 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 131 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 132 133 Wall thickness is fixed to 1px 134 A "unit" consists of a single node and the right and lower connection/wall. 135 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 136 """ 137 self.unit_length: int = unit_length 138 self.maze: LatticeMaze = maze 139 self.true_path: StyledPath | None = None 140 self.predicted_paths: list[StyledPath] = [] 141 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 142 self.custom_node_value_flag: bool = False 143 self.node_color_map: str = "Blues" 144 self.target_token_coord: Coord = None 145 self.preceding_tokens_coords: CoordArray = None 146 self.colormap_center: float | None = None 147 self.cbar_ax = None 148 self.marked_coords: list[tuple[Coord, dict]] = list() 149 150 self.marker_kwargs_current: dict = dict( 151 marker="s", 152 color="green", 153 ms=12, 154 ) 155 self.marker_kwargs_next: dict = dict( 156 marker="P", 157 color="green", 158 ms=12, 159 ) 160 161 if isinstance(maze, SolvedMaze): 162 self.add_true_path(maze.solution) 163 else: 164 if isinstance(maze, TargetedLatticeMaze): 165 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) 166 167 @property 168 def solved_maze(self) -> SolvedMaze: 169 "get the underlying `SolvedMaze` object" 170 if self.true_path is None: 171 raise ValueError( 172 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 173 ) 174 return SolvedMaze.from_lattice_maze( 175 lattice_maze=self.maze, 176 solution=self.true_path.path, 177 ) 178 179 def add_true_path( 180 self, 181 path: CoordList | CoordArray | StyledPath, 182 path_fmt: PathFormat | None = None, 183 **kwargs, 184 ) -> MazePlot: 185 "add a true path to the maze with optional formatting" 186 self.true_path = process_path_input( 187 path=path, 188 _default_key="true", 189 path_fmt=path_fmt, 190 **kwargs, 191 ) 192 193 return self 194 195 def add_predicted_path( 196 self, 197 path: CoordList | CoordArray | StyledPath, 198 path_fmt: PathFormat | None = None, 199 **kwargs, 200 ) -> MazePlot: 201 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 202 203 Default formatting depends on nuber of paths already saved in predicted path list. 204 """ 205 styled_path: StyledPath = process_path_input( 206 path=path, 207 _default_key="predicted", 208 path_fmt=path_fmt, 209 **kwargs, 210 ) 211 212 # set default label and color if not specified 213 if styled_path.label is None: 214 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 215 216 if styled_path.color is None: 217 color_num: int = len(self.predicted_paths) % len( 218 DEFAULT_PREDICTED_PATH_COLORS, 219 ) 220 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 221 222 self.predicted_paths.append(styled_path) 223 return self 224 225 def add_multiple_paths( 226 self, 227 path_list: Sequence[CoordList | CoordArray | StyledPath], 228 ) -> MazePlot: 229 """Function for adding multiple paths to MazePlot at once. 230 231 > DOCS: what are the two ways? 232 This can be done in two ways: 233 1. Passing a list of 234 """ 235 for path in path_list: 236 self.add_predicted_path(path) 237 return self 238 239 def add_node_values( 240 self, 241 node_values: Float[np.ndarray, "grid_n grid_n"], 242 color_map: str = "Blues", 243 target_token_coord: Coord | None = None, 244 preceeding_tokens_coords: CoordArray = None, 245 colormap_center: float | None = None, 246 colormap_max: float | None = None, 247 hide_colorbar: bool = False, 248 ) -> MazePlot: 249 """add node values to the maze for visualization as a heatmap 250 251 > DOCS: what are these arguments? 252 253 # Parameters: 254 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 255 - `color_map : str` 256 (defaults to `"Blues"`) 257 - `target_token_coord : Coord | None` 258 (defaults to `None`) 259 - `preceeding_tokens_coords : CoordArray` 260 (defaults to `None`) 261 - `colormap_center : float | None` 262 (defaults to `None`) 263 - `colormap_max : float | None` 264 (defaults to `None`) 265 - `hide_colorbar : bool` 266 (defaults to `False`) 267 268 # Returns: 269 - `MazePlot` 270 """ 271 assert node_values.shape == self.maze.grid_shape, ( 272 "Please pass node values of the same sape as LatticeMaze.grid_shape" 273 ) 274 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 275 276 self.node_values = node_values 277 # Set flag for choosing cmap while plotting maze 278 self.custom_node_value_flag = True 279 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 280 self.node_color_map = color_map 281 self.colormap_center = colormap_center 282 self.colormap_max = colormap_max 283 self.hide_colorbar = hide_colorbar 284 285 if target_token_coord is not None: 286 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 287 if preceeding_tokens_coords is not None: 288 for coord in preceeding_tokens_coords: 289 self.marked_coords.append((coord, self.marker_kwargs_current)) 290 return self 291 292 def plot( 293 self, 294 dpi: int = 100, 295 title: str = "", 296 fig_ax: tuple | None = None, 297 plain: bool = False, 298 ) -> MazePlot: 299 """Plot the maze and paths.""" 300 # set up figure 301 if fig_ax is None: 302 self.fig = plt.figure(dpi=dpi) 303 self.ax = self.fig.add_subplot(1, 1, 1) 304 else: 305 self.fig, self.ax = fig_ax 306 307 # plot maze 308 self._plot_maze() 309 310 # Plot labels 311 if not plain: 312 tick_arr = np.arange(self.maze.grid_shape[0]) 313 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 314 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_xlabel("col") 316 self.ax.set_ylabel("row") 317 self.ax.set_title(title) 318 319 # plot paths 320 if self.true_path is not None: 321 self._plot_path(self.true_path) 322 for path in self.predicted_paths: 323 self._plot_path(path) 324 325 # plot markers 326 for coord, kwargs in self.marked_coords: 327 self._place_marked_coords([coord], **kwargs) 328 329 return self 330 331 def _rowcol_to_coord(self, point: Coord) -> np.ndarray: 332 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" 333 point = np.array([point[1], point[0]]) 334 return self.unit_length * (point + 0.5) 335 336 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 337 """Mark coordinates on the maze with a marker. 338 339 default marker is a blue "+": 340 `dict(marker="+", color="blue")` 341 """ 342 kwargs = { 343 **dict(marker="+", color="blue"), 344 **kwargs, 345 } 346 for coord in coords: 347 self.marked_coords.append((coord, kwargs)) 348 349 return self 350 351 def _place_marked_coords( 352 self, 353 coords: CoordArray | list[Coord], 354 **kwargs, 355 ) -> MazePlot: 356 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) 357 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) 358 359 return self 360 361 def _plot_maze(self) -> None: # noqa: C901, PLR0912 362 """Define Colormap and plot maze. 363 364 Colormap: x is -inf: black 365 else: use colormap 366 """ 367 img = self._lattice_maze_to_img() 368 369 # if no node_values have been passed (no colormap) 370 if self.custom_node_value_flag is False: 371 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) 372 373 else: 374 assert self.node_values is not None, "Please pass node values." 375 assert not np.isnan(self.node_values).any(), ( 376 "Please pass node values, they cannot be nan." 377 ) 378 379 vals_min: float = np.nanmin(self.node_values) 380 vals_max: float = np.nanmax(self.node_values) 381 # if both are negative or both are positive, set max/min to 0 382 if vals_max < 0.0: 383 vals_max = 0.0 384 elif vals_min > 0.0: 385 vals_min = 0.0 386 387 # adjust vals_max, in case you need consistent colorbar across multiple plots 388 vals_max = self.colormap_max or vals_max 389 390 # create colormap 391 cmap = mpl.colormaps[self.node_color_map] 392 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black 393 cmap.set_bad(color="black") 394 395 if self.colormap_center is not None: 396 if not (vals_min < self.colormap_center < vals_max): 397 if vals_min == self.colormap_center: 398 vals_min -= 1e-10 399 elif vals_max == self.colormap_center: 400 vals_max += 1e-10 401 else: 402 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}" 403 raise ValueError( 404 err_msg, 405 ) 406 407 norm = mpl.colors.TwoSlopeNorm( 408 vmin=vals_min, 409 vcenter=self.colormap_center, 410 vmax=vals_max, 411 ) 412 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) 413 else: 414 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) 415 416 # Add colorbar based on the condition of self.hide_colorbar 417 if not self.hide_colorbar: 418 ticks = np.linspace(vals_min, vals_max, 5) 419 420 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): 421 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) 422 423 if ( 424 self.colormap_center is not None 425 and self.colormap_center not in ticks 426 and vals_min < self.colormap_center < vals_max 427 ): 428 ticks = np.insert( 429 ticks, 430 np.searchsorted(ticks, self.colormap_center), 431 self.colormap_center, 432 ) 433 434 cbar = plt.colorbar( 435 _plotted, 436 ticks=ticks, 437 ax=self.ax, 438 cax=self.cbar_ax, 439 ) 440 self.cbar_ax = cbar.ax 441 442 # make the boundaries of the image thicker (walls look weird without this) 443 for axis in ["top", "bottom", "left", "right"]: 444 self.ax.spines[axis].set_linewidth(2) 445 446 def _lattice_maze_to_img( 447 self, 448 connection_val_scale: float = 0.93, 449 ) -> Bool[np.ndarray, "row col"]: 450 """Build an image to visualise the maze. 451 452 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. 453 - Nodes have area: (ul-1) * (ul-1) and value 1 by default 454 - take node_value if passed via .add_node_values() 455 - Walls have area: 1 * (ul-1) and value -1 456 - Connections have area: 1 * (ul-1); color and value 0.93 by default 457 - take node_value if passed via .add_node_values() 458 459 Axes definition: 460 (0,0) col 461 ----|-----------> 462 | 463 row | 464 | 465 v 466 467 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. 468 """ 469 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored 470 node_bdry_hack: int 471 connection_list_processed: Float[np.ndarray, "dim row col"] 472 # Set node and connection values 473 if self.node_values is None: 474 scaled_node_values = np.ones(self.maze.grid_shape) 475 connection_values = scaled_node_values * connection_val_scale 476 node_bdry_hack = 0 477 # TODO: hack 478 # invert connection list 479 connection_list_processed = np.logical_not(self.maze.connection_list) 480 else: 481 # TODO: hack 482 scaled_node_values = self.node_values 483 # connection_values = scaled_node_values 484 connection_values = np.full_like(scaled_node_values, np.nan) 485 node_bdry_hack = 1 486 connection_list_processed = self.maze.connection_list 487 488 # Create background image (all pixels set to -1, walls everywhere) 489 img: Float[np.ndarray, "row col"] = -np.ones( 490 ( 491 self.maze.grid_shape[0] * self.unit_length + 1, 492 self.maze.grid_shape[1] * self.unit_length + 1, 493 ), 494 dtype=float, 495 ) 496 497 # Draw nodes and connections by iterating through lattice 498 for row in range(self.maze.grid_shape[0]): 499 for col in range(self.maze.grid_shape[1]): 500 # Draw node 501 img[ 502 row * self.unit_length + 1 : (row + 1) * self.unit_length 503 + node_bdry_hack, 504 col * self.unit_length + 1 : (col + 1) * self.unit_length 505 + node_bdry_hack, 506 ] = scaled_node_values[row, col] 507 508 # Down connection 509 if not connection_list_processed[0, row, col]: 510 img[ 511 (row + 1) * self.unit_length, 512 col * self.unit_length + 1 : (col + 1) * self.unit_length, 513 ] = connection_values[row, col] 514 515 # Right connection 516 if not connection_list_processed[1, row, col]: 517 img[ 518 row * self.unit_length + 1 : (row + 1) * self.unit_length, 519 (col + 1) * self.unit_length, 520 ] = connection_values[row, col] 521 522 return img 523 524 def _plot_path(self, path_format: PathFormat) -> None: 525 if len(path_format.path) == 0: 526 warnings.warn(f"Empty path, skipping plotting\n{path_format = }") 527 return 528 p_transformed = np.array( 529 [self._rowcol_to_coord(coord) for coord in path_format.path], 530 ) 531 if path_format.quiver_kwargs is not None: 532 try: 533 x: np.ndarray = p_transformed[:, 0] 534 y: np.ndarray = p_transformed[:, 1] 535 except Exception as e: 536 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" 537 raise ValueError( 538 err_msg, 539 ) from e 540 541 # Generate colors from the colormap 542 if path_format.cmap is not None: 543 n = len(x) - 1 # Number of arrows 544 cmap = plt.get_cmap(path_format.cmap) 545 colors = [cmap(i / n) for i in range(n)] 546 else: 547 colors = path_format.color 548 549 self.ax.quiver( 550 x[:-1], 551 y[:-1], 552 x[1:] - x[:-1], 553 y[1:] - y[:-1], 554 scale_units="xy", 555 angles="xy", 556 scale=1, 557 color=colors, 558 **path_format.quiver_kwargs, 559 ) 560 else: 561 self.ax.plot( 562 p_transformed[:, 0], 563 p_transformed[:, 1], 564 path_format.fmt, 565 lw=path_format.line_width, 566 color=path_format.color, 567 label=path_format.label, 568 ) 569 # mark endpoints 570 self.ax.plot( 571 [p_transformed[0][0]], 572 [p_transformed[0][1]], 573 "o", 574 color=path_format.color, 575 ms=10, 576 ) 577 self.ax.plot( 578 [p_transformed[-1][0]], 579 [p_transformed[-1][1]], 580 "x", 581 color=path_format.color, 582 ms=10, 583 ) 584 585 def to_ascii( 586 self, 587 show_endpoints: bool = True, 588 show_solution: bool = True, 589 ) -> str: 590 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 591 if self.true_path: 592 return self.solved_maze.as_ascii( 593 show_endpoints=show_endpoints, 594 show_solution=show_solution, 595 ) 596 else: 597 return self.maze.as_ascii(show_endpoints=show_endpoints)
26@dataclass(kw_only=True) 27class PathFormat: 28 """formatting options for path plot""" 29 30 label: str | None = None 31 fmt: str = "o" 32 color: str | None = None 33 cmap: str | None = None 34 line_width: float | None = None 35 quiver_kwargs: dict | None = None 36 37 def combine(self, other: PathFormat) -> PathFormat: 38 """combine with other PathFormat object, overwriting attributes with non-None values. 39 40 returns a modified copy of self. 41 """ 42 output: PathFormat = deepcopy(self) 43 for key, value in other.__dict__.items(): 44 if key == "path": 45 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 46 raise ValueError( 47 err_msg, 48 ) 49 if value is not None: 50 setattr(output, key, value) 51 52 return output
formatting options for path plot
37 def combine(self, other: PathFormat) -> PathFormat: 38 """combine with other PathFormat object, overwriting attributes with non-None values. 39 40 returns a modified copy of self. 41 """ 42 output: PathFormat = deepcopy(self) 43 for key, value in other.__dict__.items(): 44 if key == "path": 45 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 46 raise ValueError( 47 err_msg, 48 ) 49 if value is not None: 50 setattr(output, key, value) 51 52 return output
combine with other PathFormat object, overwriting attributes with non-None values.
returns a modified copy of self.
56@dataclass 57class StyledPath(PathFormat): 58 "a `StyledPath` is a `PathFormat` with a specific path" 59 60 path: CoordArray
a StyledPath
is a PathFormat
with a specific path
Inherited Members
81def process_path_input( 82 path: CoordList | CoordArray | StyledPath, 83 _default_key: str, 84 path_fmt: PathFormat | None = None, 85 **kwargs, 86) -> StyledPath: 87 "convert a path, which might be a list or array of coords, into a `StyledPath`" 88 styled_path: StyledPath 89 if isinstance(path, StyledPath): 90 styled_path = path 91 elif isinstance(path, np.ndarray): 92 styled_path = StyledPath(path=path) 93 # add default formatting 94 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 95 elif isinstance(path, list): 96 styled_path = StyledPath(path=np.array(path)) 97 # add default formatting 98 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 99 else: 100 err_msg: str = ( 101 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}" 102 ) 103 raise TypeError( 104 err_msg, 105 ) 106 107 # add formatting from path_fmt 108 if path_fmt is not None: 109 styled_path = styled_path.combine(path_fmt) 110 111 # add formatting from kwargs 112 for key, value in kwargs.items(): 113 setattr(styled_path, key, value) 114 115 return styled_path
convert a path, which might be a list or array of coords, into a StyledPath
128class MazePlot: 129 """Class for displaying mazes and paths""" 130 131 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 132 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 133 134 Wall thickness is fixed to 1px 135 A "unit" consists of a single node and the right and lower connection/wall. 136 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 137 """ 138 self.unit_length: int = unit_length 139 self.maze: LatticeMaze = maze 140 self.true_path: StyledPath | None = None 141 self.predicted_paths: list[StyledPath] = [] 142 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 143 self.custom_node_value_flag: bool = False 144 self.node_color_map: str = "Blues" 145 self.target_token_coord: Coord = None 146 self.preceding_tokens_coords: CoordArray = None 147 self.colormap_center: float | None = None 148 self.cbar_ax = None 149 self.marked_coords: list[tuple[Coord, dict]] = list() 150 151 self.marker_kwargs_current: dict = dict( 152 marker="s", 153 color="green", 154 ms=12, 155 ) 156 self.marker_kwargs_next: dict = dict( 157 marker="P", 158 color="green", 159 ms=12, 160 ) 161 162 if isinstance(maze, SolvedMaze): 163 self.add_true_path(maze.solution) 164 else: 165 if isinstance(maze, TargetedLatticeMaze): 166 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) 167 168 @property 169 def solved_maze(self) -> SolvedMaze: 170 "get the underlying `SolvedMaze` object" 171 if self.true_path is None: 172 raise ValueError( 173 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 174 ) 175 return SolvedMaze.from_lattice_maze( 176 lattice_maze=self.maze, 177 solution=self.true_path.path, 178 ) 179 180 def add_true_path( 181 self, 182 path: CoordList | CoordArray | StyledPath, 183 path_fmt: PathFormat | None = None, 184 **kwargs, 185 ) -> MazePlot: 186 "add a true path to the maze with optional formatting" 187 self.true_path = process_path_input( 188 path=path, 189 _default_key="true", 190 path_fmt=path_fmt, 191 **kwargs, 192 ) 193 194 return self 195 196 def add_predicted_path( 197 self, 198 path: CoordList | CoordArray | StyledPath, 199 path_fmt: PathFormat | None = None, 200 **kwargs, 201 ) -> MazePlot: 202 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 203 204 Default formatting depends on nuber of paths already saved in predicted path list. 205 """ 206 styled_path: StyledPath = process_path_input( 207 path=path, 208 _default_key="predicted", 209 path_fmt=path_fmt, 210 **kwargs, 211 ) 212 213 # set default label and color if not specified 214 if styled_path.label is None: 215 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 216 217 if styled_path.color is None: 218 color_num: int = len(self.predicted_paths) % len( 219 DEFAULT_PREDICTED_PATH_COLORS, 220 ) 221 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 222 223 self.predicted_paths.append(styled_path) 224 return self 225 226 def add_multiple_paths( 227 self, 228 path_list: Sequence[CoordList | CoordArray | StyledPath], 229 ) -> MazePlot: 230 """Function for adding multiple paths to MazePlot at once. 231 232 > DOCS: what are the two ways? 233 This can be done in two ways: 234 1. Passing a list of 235 """ 236 for path in path_list: 237 self.add_predicted_path(path) 238 return self 239 240 def add_node_values( 241 self, 242 node_values: Float[np.ndarray, "grid_n grid_n"], 243 color_map: str = "Blues", 244 target_token_coord: Coord | None = None, 245 preceeding_tokens_coords: CoordArray = None, 246 colormap_center: float | None = None, 247 colormap_max: float | None = None, 248 hide_colorbar: bool = False, 249 ) -> MazePlot: 250 """add node values to the maze for visualization as a heatmap 251 252 > DOCS: what are these arguments? 253 254 # Parameters: 255 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 256 - `color_map : str` 257 (defaults to `"Blues"`) 258 - `target_token_coord : Coord | None` 259 (defaults to `None`) 260 - `preceeding_tokens_coords : CoordArray` 261 (defaults to `None`) 262 - `colormap_center : float | None` 263 (defaults to `None`) 264 - `colormap_max : float | None` 265 (defaults to `None`) 266 - `hide_colorbar : bool` 267 (defaults to `False`) 268 269 # Returns: 270 - `MazePlot` 271 """ 272 assert node_values.shape == self.maze.grid_shape, ( 273 "Please pass node values of the same sape as LatticeMaze.grid_shape" 274 ) 275 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 276 277 self.node_values = node_values 278 # Set flag for choosing cmap while plotting maze 279 self.custom_node_value_flag = True 280 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 281 self.node_color_map = color_map 282 self.colormap_center = colormap_center 283 self.colormap_max = colormap_max 284 self.hide_colorbar = hide_colorbar 285 286 if target_token_coord is not None: 287 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 288 if preceeding_tokens_coords is not None: 289 for coord in preceeding_tokens_coords: 290 self.marked_coords.append((coord, self.marker_kwargs_current)) 291 return self 292 293 def plot( 294 self, 295 dpi: int = 100, 296 title: str = "", 297 fig_ax: tuple | None = None, 298 plain: bool = False, 299 ) -> MazePlot: 300 """Plot the maze and paths.""" 301 # set up figure 302 if fig_ax is None: 303 self.fig = plt.figure(dpi=dpi) 304 self.ax = self.fig.add_subplot(1, 1, 1) 305 else: 306 self.fig, self.ax = fig_ax 307 308 # plot maze 309 self._plot_maze() 310 311 # Plot labels 312 if not plain: 313 tick_arr = np.arange(self.maze.grid_shape[0]) 314 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 316 self.ax.set_xlabel("col") 317 self.ax.set_ylabel("row") 318 self.ax.set_title(title) 319 320 # plot paths 321 if self.true_path is not None: 322 self._plot_path(self.true_path) 323 for path in self.predicted_paths: 324 self._plot_path(path) 325 326 # plot markers 327 for coord, kwargs in self.marked_coords: 328 self._place_marked_coords([coord], **kwargs) 329 330 return self 331 332 def _rowcol_to_coord(self, point: Coord) -> np.ndarray: 333 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" 334 point = np.array([point[1], point[0]]) 335 return self.unit_length * (point + 0.5) 336 337 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 338 """Mark coordinates on the maze with a marker. 339 340 default marker is a blue "+": 341 `dict(marker="+", color="blue")` 342 """ 343 kwargs = { 344 **dict(marker="+", color="blue"), 345 **kwargs, 346 } 347 for coord in coords: 348 self.marked_coords.append((coord, kwargs)) 349 350 return self 351 352 def _place_marked_coords( 353 self, 354 coords: CoordArray | list[Coord], 355 **kwargs, 356 ) -> MazePlot: 357 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) 358 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) 359 360 return self 361 362 def _plot_maze(self) -> None: # noqa: C901, PLR0912 363 """Define Colormap and plot maze. 364 365 Colormap: x is -inf: black 366 else: use colormap 367 """ 368 img = self._lattice_maze_to_img() 369 370 # if no node_values have been passed (no colormap) 371 if self.custom_node_value_flag is False: 372 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) 373 374 else: 375 assert self.node_values is not None, "Please pass node values." 376 assert not np.isnan(self.node_values).any(), ( 377 "Please pass node values, they cannot be nan." 378 ) 379 380 vals_min: float = np.nanmin(self.node_values) 381 vals_max: float = np.nanmax(self.node_values) 382 # if both are negative or both are positive, set max/min to 0 383 if vals_max < 0.0: 384 vals_max = 0.0 385 elif vals_min > 0.0: 386 vals_min = 0.0 387 388 # adjust vals_max, in case you need consistent colorbar across multiple plots 389 vals_max = self.colormap_max or vals_max 390 391 # create colormap 392 cmap = mpl.colormaps[self.node_color_map] 393 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black 394 cmap.set_bad(color="black") 395 396 if self.colormap_center is not None: 397 if not (vals_min < self.colormap_center < vals_max): 398 if vals_min == self.colormap_center: 399 vals_min -= 1e-10 400 elif vals_max == self.colormap_center: 401 vals_max += 1e-10 402 else: 403 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}" 404 raise ValueError( 405 err_msg, 406 ) 407 408 norm = mpl.colors.TwoSlopeNorm( 409 vmin=vals_min, 410 vcenter=self.colormap_center, 411 vmax=vals_max, 412 ) 413 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) 414 else: 415 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) 416 417 # Add colorbar based on the condition of self.hide_colorbar 418 if not self.hide_colorbar: 419 ticks = np.linspace(vals_min, vals_max, 5) 420 421 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): 422 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) 423 424 if ( 425 self.colormap_center is not None 426 and self.colormap_center not in ticks 427 and vals_min < self.colormap_center < vals_max 428 ): 429 ticks = np.insert( 430 ticks, 431 np.searchsorted(ticks, self.colormap_center), 432 self.colormap_center, 433 ) 434 435 cbar = plt.colorbar( 436 _plotted, 437 ticks=ticks, 438 ax=self.ax, 439 cax=self.cbar_ax, 440 ) 441 self.cbar_ax = cbar.ax 442 443 # make the boundaries of the image thicker (walls look weird without this) 444 for axis in ["top", "bottom", "left", "right"]: 445 self.ax.spines[axis].set_linewidth(2) 446 447 def _lattice_maze_to_img( 448 self, 449 connection_val_scale: float = 0.93, 450 ) -> Bool[np.ndarray, "row col"]: 451 """Build an image to visualise the maze. 452 453 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. 454 - Nodes have area: (ul-1) * (ul-1) and value 1 by default 455 - take node_value if passed via .add_node_values() 456 - Walls have area: 1 * (ul-1) and value -1 457 - Connections have area: 1 * (ul-1); color and value 0.93 by default 458 - take node_value if passed via .add_node_values() 459 460 Axes definition: 461 (0,0) col 462 ----|-----------> 463 | 464 row | 465 | 466 v 467 468 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. 469 """ 470 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored 471 node_bdry_hack: int 472 connection_list_processed: Float[np.ndarray, "dim row col"] 473 # Set node and connection values 474 if self.node_values is None: 475 scaled_node_values = np.ones(self.maze.grid_shape) 476 connection_values = scaled_node_values * connection_val_scale 477 node_bdry_hack = 0 478 # TODO: hack 479 # invert connection list 480 connection_list_processed = np.logical_not(self.maze.connection_list) 481 else: 482 # TODO: hack 483 scaled_node_values = self.node_values 484 # connection_values = scaled_node_values 485 connection_values = np.full_like(scaled_node_values, np.nan) 486 node_bdry_hack = 1 487 connection_list_processed = self.maze.connection_list 488 489 # Create background image (all pixels set to -1, walls everywhere) 490 img: Float[np.ndarray, "row col"] = -np.ones( 491 ( 492 self.maze.grid_shape[0] * self.unit_length + 1, 493 self.maze.grid_shape[1] * self.unit_length + 1, 494 ), 495 dtype=float, 496 ) 497 498 # Draw nodes and connections by iterating through lattice 499 for row in range(self.maze.grid_shape[0]): 500 for col in range(self.maze.grid_shape[1]): 501 # Draw node 502 img[ 503 row * self.unit_length + 1 : (row + 1) * self.unit_length 504 + node_bdry_hack, 505 col * self.unit_length + 1 : (col + 1) * self.unit_length 506 + node_bdry_hack, 507 ] = scaled_node_values[row, col] 508 509 # Down connection 510 if not connection_list_processed[0, row, col]: 511 img[ 512 (row + 1) * self.unit_length, 513 col * self.unit_length + 1 : (col + 1) * self.unit_length, 514 ] = connection_values[row, col] 515 516 # Right connection 517 if not connection_list_processed[1, row, col]: 518 img[ 519 row * self.unit_length + 1 : (row + 1) * self.unit_length, 520 (col + 1) * self.unit_length, 521 ] = connection_values[row, col] 522 523 return img 524 525 def _plot_path(self, path_format: PathFormat) -> None: 526 if len(path_format.path) == 0: 527 warnings.warn(f"Empty path, skipping plotting\n{path_format = }") 528 return 529 p_transformed = np.array( 530 [self._rowcol_to_coord(coord) for coord in path_format.path], 531 ) 532 if path_format.quiver_kwargs is not None: 533 try: 534 x: np.ndarray = p_transformed[:, 0] 535 y: np.ndarray = p_transformed[:, 1] 536 except Exception as e: 537 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" 538 raise ValueError( 539 err_msg, 540 ) from e 541 542 # Generate colors from the colormap 543 if path_format.cmap is not None: 544 n = len(x) - 1 # Number of arrows 545 cmap = plt.get_cmap(path_format.cmap) 546 colors = [cmap(i / n) for i in range(n)] 547 else: 548 colors = path_format.color 549 550 self.ax.quiver( 551 x[:-1], 552 y[:-1], 553 x[1:] - x[:-1], 554 y[1:] - y[:-1], 555 scale_units="xy", 556 angles="xy", 557 scale=1, 558 color=colors, 559 **path_format.quiver_kwargs, 560 ) 561 else: 562 self.ax.plot( 563 p_transformed[:, 0], 564 p_transformed[:, 1], 565 path_format.fmt, 566 lw=path_format.line_width, 567 color=path_format.color, 568 label=path_format.label, 569 ) 570 # mark endpoints 571 self.ax.plot( 572 [p_transformed[0][0]], 573 [p_transformed[0][1]], 574 "o", 575 color=path_format.color, 576 ms=10, 577 ) 578 self.ax.plot( 579 [p_transformed[-1][0]], 580 [p_transformed[-1][1]], 581 "x", 582 color=path_format.color, 583 ms=10, 584 ) 585 586 def to_ascii( 587 self, 588 show_endpoints: bool = True, 589 show_solution: bool = True, 590 ) -> str: 591 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 592 if self.true_path: 593 return self.solved_maze.as_ascii( 594 show_endpoints=show_endpoints, 595 show_solution=show_solution, 596 ) 597 else: 598 return self.maze.as_ascii(show_endpoints=show_endpoints)
Class for displaying mazes and paths
131 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 132 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 133 134 Wall thickness is fixed to 1px 135 A "unit" consists of a single node and the right and lower connection/wall. 136 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 137 """ 138 self.unit_length: int = unit_length 139 self.maze: LatticeMaze = maze 140 self.true_path: StyledPath | None = None 141 self.predicted_paths: list[StyledPath] = [] 142 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 143 self.custom_node_value_flag: bool = False 144 self.node_color_map: str = "Blues" 145 self.target_token_coord: Coord = None 146 self.preceding_tokens_coords: CoordArray = None 147 self.colormap_center: float | None = None 148 self.cbar_ax = None 149 self.marked_coords: list[tuple[Coord, dict]] = list() 150 151 self.marker_kwargs_current: dict = dict( 152 marker="s", 153 color="green", 154 ms=12, 155 ) 156 self.marker_kwargs_next: dict = dict( 157 marker="P", 158 color="green", 159 ms=12, 160 ) 161 162 if isinstance(maze, SolvedMaze): 163 self.add_true_path(maze.solution) 164 else: 165 if isinstance(maze, TargetedLatticeMaze): 166 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
UNIT_LENGTH: Set ratio between node size and wall thickness in image.
Wall thickness is fixed to 1px A "unit" consists of a single node and the right and lower connection/wall. Example: ul = 14 yields 13:1 ratio between node size and wall thickness
168 @property 169 def solved_maze(self) -> SolvedMaze: 170 "get the underlying `SolvedMaze` object" 171 if self.true_path is None: 172 raise ValueError( 173 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 174 ) 175 return SolvedMaze.from_lattice_maze( 176 lattice_maze=self.maze, 177 solution=self.true_path.path, 178 )
get the underlying SolvedMaze
object
180 def add_true_path( 181 self, 182 path: CoordList | CoordArray | StyledPath, 183 path_fmt: PathFormat | None = None, 184 **kwargs, 185 ) -> MazePlot: 186 "add a true path to the maze with optional formatting" 187 self.true_path = process_path_input( 188 path=path, 189 _default_key="true", 190 path_fmt=path_fmt, 191 **kwargs, 192 ) 193 194 return self
add a true path to the maze with optional formatting
196 def add_predicted_path( 197 self, 198 path: CoordList | CoordArray | StyledPath, 199 path_fmt: PathFormat | None = None, 200 **kwargs, 201 ) -> MazePlot: 202 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 203 204 Default formatting depends on nuber of paths already saved in predicted path list. 205 """ 206 styled_path: StyledPath = process_path_input( 207 path=path, 208 _default_key="predicted", 209 path_fmt=path_fmt, 210 **kwargs, 211 ) 212 213 # set default label and color if not specified 214 if styled_path.label is None: 215 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 216 217 if styled_path.color is None: 218 color_num: int = len(self.predicted_paths) % len( 219 DEFAULT_PREDICTED_PATH_COLORS, 220 ) 221 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 222 223 self.predicted_paths.append(styled_path) 224 return self
Recieve predicted path and formatting preferences from input and save in predicted_path list.
Default formatting depends on nuber of paths already saved in predicted path list.
226 def add_multiple_paths( 227 self, 228 path_list: Sequence[CoordList | CoordArray | StyledPath], 229 ) -> MazePlot: 230 """Function for adding multiple paths to MazePlot at once. 231 232 > DOCS: what are the two ways? 233 This can be done in two ways: 234 1. Passing a list of 235 """ 236 for path in path_list: 237 self.add_predicted_path(path) 238 return self
Function for adding multiple paths to MazePlot at once.
DOCS: what are the two ways? This can be done in two ways:
- Passing a list of
240 def add_node_values( 241 self, 242 node_values: Float[np.ndarray, "grid_n grid_n"], 243 color_map: str = "Blues", 244 target_token_coord: Coord | None = None, 245 preceeding_tokens_coords: CoordArray = None, 246 colormap_center: float | None = None, 247 colormap_max: float | None = None, 248 hide_colorbar: bool = False, 249 ) -> MazePlot: 250 """add node values to the maze for visualization as a heatmap 251 252 > DOCS: what are these arguments? 253 254 # Parameters: 255 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 256 - `color_map : str` 257 (defaults to `"Blues"`) 258 - `target_token_coord : Coord | None` 259 (defaults to `None`) 260 - `preceeding_tokens_coords : CoordArray` 261 (defaults to `None`) 262 - `colormap_center : float | None` 263 (defaults to `None`) 264 - `colormap_max : float | None` 265 (defaults to `None`) 266 - `hide_colorbar : bool` 267 (defaults to `False`) 268 269 # Returns: 270 - `MazePlot` 271 """ 272 assert node_values.shape == self.maze.grid_shape, ( 273 "Please pass node values of the same sape as LatticeMaze.grid_shape" 274 ) 275 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 276 277 self.node_values = node_values 278 # Set flag for choosing cmap while plotting maze 279 self.custom_node_value_flag = True 280 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 281 self.node_color_map = color_map 282 self.colormap_center = colormap_center 283 self.colormap_max = colormap_max 284 self.hide_colorbar = hide_colorbar 285 286 if target_token_coord is not None: 287 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 288 if preceeding_tokens_coords is not None: 289 for coord in preceeding_tokens_coords: 290 self.marked_coords.append((coord, self.marker_kwargs_current)) 291 return self
add node values to the maze for visualization as a heatmap
DOCS: what are these arguments?
Parameters:
node_values : Float[np.ndarray, "grid_n grid_n"]
color_map : str
(defaults to"Blues"
)target_token_coord : Coord | None
(defaults toNone
)preceeding_tokens_coords : CoordArray
(defaults toNone
)colormap_center : float | None
(defaults toNone
)colormap_max : float | None
(defaults toNone
)hide_colorbar : bool
(defaults toFalse
)
Returns:
293 def plot( 294 self, 295 dpi: int = 100, 296 title: str = "", 297 fig_ax: tuple | None = None, 298 plain: bool = False, 299 ) -> MazePlot: 300 """Plot the maze and paths.""" 301 # set up figure 302 if fig_ax is None: 303 self.fig = plt.figure(dpi=dpi) 304 self.ax = self.fig.add_subplot(1, 1, 1) 305 else: 306 self.fig, self.ax = fig_ax 307 308 # plot maze 309 self._plot_maze() 310 311 # Plot labels 312 if not plain: 313 tick_arr = np.arange(self.maze.grid_shape[0]) 314 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 316 self.ax.set_xlabel("col") 317 self.ax.set_ylabel("row") 318 self.ax.set_title(title) 319 320 # plot paths 321 if self.true_path is not None: 322 self._plot_path(self.true_path) 323 for path in self.predicted_paths: 324 self._plot_path(path) 325 326 # plot markers 327 for coord, kwargs in self.marked_coords: 328 self._place_marked_coords([coord], **kwargs) 329 330 return self
Plot the maze and paths.
337 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 338 """Mark coordinates on the maze with a marker. 339 340 default marker is a blue "+": 341 `dict(marker="+", color="blue")` 342 """ 343 kwargs = { 344 **dict(marker="+", color="blue"), 345 **kwargs, 346 } 347 for coord in coords: 348 self.marked_coords.append((coord, kwargs)) 349 350 return self
Mark coordinates on the maze with a marker.
default marker is a blue "+":
dict(marker="+", color="blue")
586 def to_ascii( 587 self, 588 show_endpoints: bool = True, 589 show_solution: bool = True, 590 ) -> str: 591 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 592 if self.true_path: 593 return self.solved_maze.as_ascii( 594 show_endpoints=show_endpoints, 595 show_solution=show_solution, 596 ) 597 else: 598 return self.maze.as_ascii(show_endpoints=show_endpoints)
wrapper for self.solved_maze.as_ascii()
, shows the path if we have self.true_path