maze_dataset.plotting
utilities for plotting mazes and printing tokens
- any
LatticeMaze
orSolvedMaze
comes with aas_pixels()
method that returns a 2D numpy array of pixel values, but this is somewhat limited MazePlot
is a class that can be used to plot mazes and paths in a more customizable wayprint_tokens
contains utilities for printing tokens, colored by their type, position, or some custom weights (i.e. attention weights)
1"""utilities for plotting mazes and printing tokens 2 3- any `LatticeMaze` or `SolvedMaze` comes with a `as_pixels()` method that returns 4 a 2D numpy array of pixel values, but this is somewhat limited 5- `MazePlot` is a class that can be used to plot mazes and paths in a more customizable way 6- `print_tokens` contains utilities for printing tokens, colored by their type, position, or some custom weights (i.e. attention weights) 7""" 8 9from maze_dataset.plotting.plot_dataset import plot_dataset_mazes, print_dataset_mazes 10from maze_dataset.plotting.plot_maze import DEFAULT_FORMATS, MazePlot, PathFormat 11from maze_dataset.plotting.print_tokens import ( 12 color_maze_tokens_AOTP, 13 color_tokens_cmap, 14 color_tokens_rgb, 15) 16 17__all__ = [ 18 # submodules 19 "plot_dataset", 20 "plot_maze", 21 "plot_tokens", 22 "print_tokens", 23 # imports 24 "plot_dataset_mazes", 25 "print_dataset_mazes", 26 "DEFAULT_FORMATS", 27 "MazePlot", 28 "PathFormat", 29 "color_tokens_cmap", 30 "color_maze_tokens_AOTP", 31 "color_tokens_rgb", 32]
12def plot_dataset_mazes( 13 ds: MazeDataset, 14 count: int | None = None, 15 figsize_mult: tuple[float, float] = (1.0, 2.0), 16 title: bool | str = True, 17) -> tuple | None: 18 "plot `count` mazes from the dataset `d` in a single figure using `SolvedMaze.as_pixels()`" 19 count = count or len(ds) 20 if count == 0: 21 print("No mazes to plot for dataset") 22 return None 23 fig, axes = plt.subplots( 24 1, 25 count, 26 figsize=(count * figsize_mult[0], figsize_mult[1]), 27 ) 28 if count == 1: 29 axes = [axes] 30 for i in range(count): 31 axes[i].imshow(ds[i].as_pixels()) 32 # remove ticks 33 axes[i].set_xticks([]) 34 axes[i].set_yticks([]) 35 36 # set title 37 if title: 38 if isinstance(title, str): 39 fig.suptitle(title) 40 else: 41 kwargs: dict = { 42 "grid_n": ds.cfg.grid_n, 43 # "n_mazes": ds.cfg.n_mazes, 44 **ds.cfg.maze_ctor_kwargs, 45 } 46 fig.suptitle( 47 f"{ds.cfg.to_fname()}\n{ds.cfg.maze_ctor.__name__}({', '.join(f'{k}={v}' for k, v in kwargs.items())})", 48 ) 49 50 # tight layout 51 fig.tight_layout() 52 # remove whitespace between title and subplots 53 fig.subplots_adjust(top=1.0) 54 55 return fig, axes
plot count
mazes from the dataset d
in a single figure using SolvedMaze.as_pixels()
58def print_dataset_mazes(ds: MazeDataset, count: int | None = None) -> None: 59 "print ascii representation of `count` mazes from the dataset `d`" 60 count = count or len(ds) 61 if count == 0: 62 print("No mazes to print for dataset") 63 return 64 for i in range(count): 65 print(ds[i].as_ascii(), "\n\n-----\n")
print ascii representation of count
mazes from the dataset d
128class MazePlot: 129 """Class for displaying mazes and paths""" 130 131 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 132 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 133 134 Wall thickness is fixed to 1px 135 A "unit" consists of a single node and the right and lower connection/wall. 136 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 137 """ 138 self.unit_length: int = unit_length 139 self.maze: LatticeMaze = maze 140 self.true_path: StyledPath | None = None 141 self.predicted_paths: list[StyledPath] = [] 142 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 143 self.custom_node_value_flag: bool = False 144 self.node_color_map: str = "Blues" 145 self.target_token_coord: Coord = None 146 self.preceding_tokens_coords: CoordArray = None 147 self.colormap_center: float | None = None 148 self.cbar_ax = None 149 self.marked_coords: list[tuple[Coord, dict]] = list() 150 151 self.marker_kwargs_current: dict = dict( 152 marker="s", 153 color="green", 154 ms=12, 155 ) 156 self.marker_kwargs_next: dict = dict( 157 marker="P", 158 color="green", 159 ms=12, 160 ) 161 162 if isinstance(maze, SolvedMaze): 163 self.add_true_path(maze.solution) 164 else: 165 if isinstance(maze, TargetedLatticeMaze): 166 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) 167 168 @property 169 def solved_maze(self) -> SolvedMaze: 170 "get the underlying `SolvedMaze` object" 171 if self.true_path is None: 172 raise ValueError( 173 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 174 ) 175 return SolvedMaze.from_lattice_maze( 176 lattice_maze=self.maze, 177 solution=self.true_path.path, 178 ) 179 180 def add_true_path( 181 self, 182 path: CoordList | CoordArray | StyledPath, 183 path_fmt: PathFormat | None = None, 184 **kwargs, 185 ) -> MazePlot: 186 "add a true path to the maze with optional formatting" 187 self.true_path = process_path_input( 188 path=path, 189 _default_key="true", 190 path_fmt=path_fmt, 191 **kwargs, 192 ) 193 194 return self 195 196 def add_predicted_path( 197 self, 198 path: CoordList | CoordArray | StyledPath, 199 path_fmt: PathFormat | None = None, 200 **kwargs, 201 ) -> MazePlot: 202 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 203 204 Default formatting depends on nuber of paths already saved in predicted path list. 205 """ 206 styled_path: StyledPath = process_path_input( 207 path=path, 208 _default_key="predicted", 209 path_fmt=path_fmt, 210 **kwargs, 211 ) 212 213 # set default label and color if not specified 214 if styled_path.label is None: 215 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 216 217 if styled_path.color is None: 218 color_num: int = len(self.predicted_paths) % len( 219 DEFAULT_PREDICTED_PATH_COLORS, 220 ) 221 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 222 223 self.predicted_paths.append(styled_path) 224 return self 225 226 def add_multiple_paths( 227 self, 228 path_list: Sequence[CoordList | CoordArray | StyledPath], 229 ) -> MazePlot: 230 """Function for adding multiple paths to MazePlot at once. 231 232 > DOCS: what are the two ways? 233 This can be done in two ways: 234 1. Passing a list of 235 """ 236 for path in path_list: 237 self.add_predicted_path(path) 238 return self 239 240 def add_node_values( 241 self, 242 node_values: Float[np.ndarray, "grid_n grid_n"], 243 color_map: str = "Blues", 244 target_token_coord: Coord | None = None, 245 preceeding_tokens_coords: CoordArray = None, 246 colormap_center: float | None = None, 247 colormap_max: float | None = None, 248 hide_colorbar: bool = False, 249 ) -> MazePlot: 250 """add node values to the maze for visualization as a heatmap 251 252 > DOCS: what are these arguments? 253 254 # Parameters: 255 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 256 - `color_map : str` 257 (defaults to `"Blues"`) 258 - `target_token_coord : Coord | None` 259 (defaults to `None`) 260 - `preceeding_tokens_coords : CoordArray` 261 (defaults to `None`) 262 - `colormap_center : float | None` 263 (defaults to `None`) 264 - `colormap_max : float | None` 265 (defaults to `None`) 266 - `hide_colorbar : bool` 267 (defaults to `False`) 268 269 # Returns: 270 - `MazePlot` 271 """ 272 assert node_values.shape == self.maze.grid_shape, ( 273 "Please pass node values of the same sape as LatticeMaze.grid_shape" 274 ) 275 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 276 277 self.node_values = node_values 278 # Set flag for choosing cmap while plotting maze 279 self.custom_node_value_flag = True 280 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 281 self.node_color_map = color_map 282 self.colormap_center = colormap_center 283 self.colormap_max = colormap_max 284 self.hide_colorbar = hide_colorbar 285 286 if target_token_coord is not None: 287 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 288 if preceeding_tokens_coords is not None: 289 for coord in preceeding_tokens_coords: 290 self.marked_coords.append((coord, self.marker_kwargs_current)) 291 return self 292 293 def plot( 294 self, 295 dpi: int = 100, 296 title: str = "", 297 fig_ax: tuple | None = None, 298 plain: bool = False, 299 ) -> MazePlot: 300 """Plot the maze and paths.""" 301 # set up figure 302 if fig_ax is None: 303 self.fig = plt.figure(dpi=dpi) 304 self.ax = self.fig.add_subplot(1, 1, 1) 305 else: 306 self.fig, self.ax = fig_ax 307 308 # plot maze 309 self._plot_maze() 310 311 # Plot labels 312 if not plain: 313 tick_arr = np.arange(self.maze.grid_shape[0]) 314 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 316 self.ax.set_xlabel("col") 317 self.ax.set_ylabel("row") 318 self.ax.set_title(title) 319 320 # plot paths 321 if self.true_path is not None: 322 self._plot_path(self.true_path) 323 for path in self.predicted_paths: 324 self._plot_path(path) 325 326 # plot markers 327 for coord, kwargs in self.marked_coords: 328 self._place_marked_coords([coord], **kwargs) 329 330 return self 331 332 def _rowcol_to_coord(self, point: Coord) -> np.ndarray: 333 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" 334 point = np.array([point[1], point[0]]) 335 return self.unit_length * (point + 0.5) 336 337 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 338 """Mark coordinates on the maze with a marker. 339 340 default marker is a blue "+": 341 `dict(marker="+", color="blue")` 342 """ 343 kwargs = { 344 **dict(marker="+", color="blue"), 345 **kwargs, 346 } 347 for coord in coords: 348 self.marked_coords.append((coord, kwargs)) 349 350 return self 351 352 def _place_marked_coords( 353 self, 354 coords: CoordArray | list[Coord], 355 **kwargs, 356 ) -> MazePlot: 357 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) 358 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) 359 360 return self 361 362 def _plot_maze(self) -> None: # noqa: C901, PLR0912 363 """Define Colormap and plot maze. 364 365 Colormap: x is -inf: black 366 else: use colormap 367 """ 368 img = self._lattice_maze_to_img() 369 370 # if no node_values have been passed (no colormap) 371 if self.custom_node_value_flag is False: 372 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) 373 374 else: 375 assert self.node_values is not None, "Please pass node values." 376 assert not np.isnan(self.node_values).any(), ( 377 "Please pass node values, they cannot be nan." 378 ) 379 380 vals_min: float = np.nanmin(self.node_values) 381 vals_max: float = np.nanmax(self.node_values) 382 # if both are negative or both are positive, set max/min to 0 383 if vals_max < 0.0: 384 vals_max = 0.0 385 elif vals_min > 0.0: 386 vals_min = 0.0 387 388 # adjust vals_max, in case you need consistent colorbar across multiple plots 389 vals_max = self.colormap_max or vals_max 390 391 # create colormap 392 cmap = mpl.colormaps[self.node_color_map] 393 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black 394 cmap.set_bad(color="black") 395 396 if self.colormap_center is not None: 397 if not (vals_min < self.colormap_center < vals_max): 398 if vals_min == self.colormap_center: 399 vals_min -= 1e-10 400 elif vals_max == self.colormap_center: 401 vals_max += 1e-10 402 else: 403 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}" 404 raise ValueError( 405 err_msg, 406 ) 407 408 norm = mpl.colors.TwoSlopeNorm( 409 vmin=vals_min, 410 vcenter=self.colormap_center, 411 vmax=vals_max, 412 ) 413 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) 414 else: 415 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) 416 417 # Add colorbar based on the condition of self.hide_colorbar 418 if not self.hide_colorbar: 419 ticks = np.linspace(vals_min, vals_max, 5) 420 421 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): 422 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) 423 424 if ( 425 self.colormap_center is not None 426 and self.colormap_center not in ticks 427 and vals_min < self.colormap_center < vals_max 428 ): 429 ticks = np.insert( 430 ticks, 431 np.searchsorted(ticks, self.colormap_center), 432 self.colormap_center, 433 ) 434 435 cbar = plt.colorbar( 436 _plotted, 437 ticks=ticks, 438 ax=self.ax, 439 cax=self.cbar_ax, 440 ) 441 self.cbar_ax = cbar.ax 442 443 # make the boundaries of the image thicker (walls look weird without this) 444 for axis in ["top", "bottom", "left", "right"]: 445 self.ax.spines[axis].set_linewidth(2) 446 447 def _lattice_maze_to_img( 448 self, 449 connection_val_scale: float = 0.93, 450 ) -> Bool[np.ndarray, "row col"]: 451 """Build an image to visualise the maze. 452 453 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. 454 - Nodes have area: (ul-1) * (ul-1) and value 1 by default 455 - take node_value if passed via .add_node_values() 456 - Walls have area: 1 * (ul-1) and value -1 457 - Connections have area: 1 * (ul-1); color and value 0.93 by default 458 - take node_value if passed via .add_node_values() 459 460 Axes definition: 461 (0,0) col 462 ----|-----------> 463 | 464 row | 465 | 466 v 467 468 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. 469 """ 470 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored 471 node_bdry_hack: int 472 connection_list_processed: Float[np.ndarray, "dim row col"] 473 # Set node and connection values 474 if self.node_values is None: 475 scaled_node_values = np.ones(self.maze.grid_shape) 476 connection_values = scaled_node_values * connection_val_scale 477 node_bdry_hack = 0 478 # TODO: hack 479 # invert connection list 480 connection_list_processed = np.logical_not(self.maze.connection_list) 481 else: 482 # TODO: hack 483 scaled_node_values = self.node_values 484 # connection_values = scaled_node_values 485 connection_values = np.full_like(scaled_node_values, np.nan) 486 node_bdry_hack = 1 487 connection_list_processed = self.maze.connection_list 488 489 # Create background image (all pixels set to -1, walls everywhere) 490 img: Float[np.ndarray, "row col"] = -np.ones( 491 ( 492 self.maze.grid_shape[0] * self.unit_length + 1, 493 self.maze.grid_shape[1] * self.unit_length + 1, 494 ), 495 dtype=float, 496 ) 497 498 # Draw nodes and connections by iterating through lattice 499 for row in range(self.maze.grid_shape[0]): 500 for col in range(self.maze.grid_shape[1]): 501 # Draw node 502 img[ 503 row * self.unit_length + 1 : (row + 1) * self.unit_length 504 + node_bdry_hack, 505 col * self.unit_length + 1 : (col + 1) * self.unit_length 506 + node_bdry_hack, 507 ] = scaled_node_values[row, col] 508 509 # Down connection 510 if not connection_list_processed[0, row, col]: 511 img[ 512 (row + 1) * self.unit_length, 513 col * self.unit_length + 1 : (col + 1) * self.unit_length, 514 ] = connection_values[row, col] 515 516 # Right connection 517 if not connection_list_processed[1, row, col]: 518 img[ 519 row * self.unit_length + 1 : (row + 1) * self.unit_length, 520 (col + 1) * self.unit_length, 521 ] = connection_values[row, col] 522 523 return img 524 525 def _plot_path(self, path_format: PathFormat) -> None: 526 if len(path_format.path) == 0: 527 warnings.warn(f"Empty path, skipping plotting\n{path_format = }") 528 return 529 p_transformed = np.array( 530 [self._rowcol_to_coord(coord) for coord in path_format.path], 531 ) 532 if path_format.quiver_kwargs is not None: 533 try: 534 x: np.ndarray = p_transformed[:, 0] 535 y: np.ndarray = p_transformed[:, 1] 536 except Exception as e: 537 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" 538 raise ValueError( 539 err_msg, 540 ) from e 541 542 # Generate colors from the colormap 543 if path_format.cmap is not None: 544 n = len(x) - 1 # Number of arrows 545 cmap = plt.get_cmap(path_format.cmap) 546 colors = [cmap(i / n) for i in range(n)] 547 else: 548 colors = path_format.color 549 550 self.ax.quiver( 551 x[:-1], 552 y[:-1], 553 x[1:] - x[:-1], 554 y[1:] - y[:-1], 555 scale_units="xy", 556 angles="xy", 557 scale=1, 558 color=colors, 559 **path_format.quiver_kwargs, 560 ) 561 else: 562 self.ax.plot( 563 p_transformed[:, 0], 564 p_transformed[:, 1], 565 path_format.fmt, 566 lw=path_format.line_width, 567 color=path_format.color, 568 label=path_format.label, 569 ) 570 # mark endpoints 571 self.ax.plot( 572 [p_transformed[0][0]], 573 [p_transformed[0][1]], 574 "o", 575 color=path_format.color, 576 ms=10, 577 ) 578 self.ax.plot( 579 [p_transformed[-1][0]], 580 [p_transformed[-1][1]], 581 "x", 582 color=path_format.color, 583 ms=10, 584 ) 585 586 def to_ascii( 587 self, 588 show_endpoints: bool = True, 589 show_solution: bool = True, 590 ) -> str: 591 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 592 if self.true_path: 593 return self.solved_maze.as_ascii( 594 show_endpoints=show_endpoints, 595 show_solution=show_solution, 596 ) 597 else: 598 return self.maze.as_ascii(show_endpoints=show_endpoints)
Class for displaying mazes and paths
131 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 132 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 133 134 Wall thickness is fixed to 1px 135 A "unit" consists of a single node and the right and lower connection/wall. 136 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 137 """ 138 self.unit_length: int = unit_length 139 self.maze: LatticeMaze = maze 140 self.true_path: StyledPath | None = None 141 self.predicted_paths: list[StyledPath] = [] 142 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 143 self.custom_node_value_flag: bool = False 144 self.node_color_map: str = "Blues" 145 self.target_token_coord: Coord = None 146 self.preceding_tokens_coords: CoordArray = None 147 self.colormap_center: float | None = None 148 self.cbar_ax = None 149 self.marked_coords: list[tuple[Coord, dict]] = list() 150 151 self.marker_kwargs_current: dict = dict( 152 marker="s", 153 color="green", 154 ms=12, 155 ) 156 self.marker_kwargs_next: dict = dict( 157 marker="P", 158 color="green", 159 ms=12, 160 ) 161 162 if isinstance(maze, SolvedMaze): 163 self.add_true_path(maze.solution) 164 else: 165 if isinstance(maze, TargetedLatticeMaze): 166 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
UNIT_LENGTH: Set ratio between node size and wall thickness in image.
Wall thickness is fixed to 1px A "unit" consists of a single node and the right and lower connection/wall. Example: ul = 14 yields 13:1 ratio between node size and wall thickness
168 @property 169 def solved_maze(self) -> SolvedMaze: 170 "get the underlying `SolvedMaze` object" 171 if self.true_path is None: 172 raise ValueError( 173 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 174 ) 175 return SolvedMaze.from_lattice_maze( 176 lattice_maze=self.maze, 177 solution=self.true_path.path, 178 )
get the underlying SolvedMaze
object
180 def add_true_path( 181 self, 182 path: CoordList | CoordArray | StyledPath, 183 path_fmt: PathFormat | None = None, 184 **kwargs, 185 ) -> MazePlot: 186 "add a true path to the maze with optional formatting" 187 self.true_path = process_path_input( 188 path=path, 189 _default_key="true", 190 path_fmt=path_fmt, 191 **kwargs, 192 ) 193 194 return self
add a true path to the maze with optional formatting
196 def add_predicted_path( 197 self, 198 path: CoordList | CoordArray | StyledPath, 199 path_fmt: PathFormat | None = None, 200 **kwargs, 201 ) -> MazePlot: 202 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 203 204 Default formatting depends on nuber of paths already saved in predicted path list. 205 """ 206 styled_path: StyledPath = process_path_input( 207 path=path, 208 _default_key="predicted", 209 path_fmt=path_fmt, 210 **kwargs, 211 ) 212 213 # set default label and color if not specified 214 if styled_path.label is None: 215 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 216 217 if styled_path.color is None: 218 color_num: int = len(self.predicted_paths) % len( 219 DEFAULT_PREDICTED_PATH_COLORS, 220 ) 221 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 222 223 self.predicted_paths.append(styled_path) 224 return self
Recieve predicted path and formatting preferences from input and save in predicted_path list.
Default formatting depends on nuber of paths already saved in predicted path list.
226 def add_multiple_paths( 227 self, 228 path_list: Sequence[CoordList | CoordArray | StyledPath], 229 ) -> MazePlot: 230 """Function for adding multiple paths to MazePlot at once. 231 232 > DOCS: what are the two ways? 233 This can be done in two ways: 234 1. Passing a list of 235 """ 236 for path in path_list: 237 self.add_predicted_path(path) 238 return self
Function for adding multiple paths to MazePlot at once.
DOCS: what are the two ways? This can be done in two ways:
- Passing a list of
240 def add_node_values( 241 self, 242 node_values: Float[np.ndarray, "grid_n grid_n"], 243 color_map: str = "Blues", 244 target_token_coord: Coord | None = None, 245 preceeding_tokens_coords: CoordArray = None, 246 colormap_center: float | None = None, 247 colormap_max: float | None = None, 248 hide_colorbar: bool = False, 249 ) -> MazePlot: 250 """add node values to the maze for visualization as a heatmap 251 252 > DOCS: what are these arguments? 253 254 # Parameters: 255 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 256 - `color_map : str` 257 (defaults to `"Blues"`) 258 - `target_token_coord : Coord | None` 259 (defaults to `None`) 260 - `preceeding_tokens_coords : CoordArray` 261 (defaults to `None`) 262 - `colormap_center : float | None` 263 (defaults to `None`) 264 - `colormap_max : float | None` 265 (defaults to `None`) 266 - `hide_colorbar : bool` 267 (defaults to `False`) 268 269 # Returns: 270 - `MazePlot` 271 """ 272 assert node_values.shape == self.maze.grid_shape, ( 273 "Please pass node values of the same sape as LatticeMaze.grid_shape" 274 ) 275 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 276 277 self.node_values = node_values 278 # Set flag for choosing cmap while plotting maze 279 self.custom_node_value_flag = True 280 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 281 self.node_color_map = color_map 282 self.colormap_center = colormap_center 283 self.colormap_max = colormap_max 284 self.hide_colorbar = hide_colorbar 285 286 if target_token_coord is not None: 287 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 288 if preceeding_tokens_coords is not None: 289 for coord in preceeding_tokens_coords: 290 self.marked_coords.append((coord, self.marker_kwargs_current)) 291 return self
add node values to the maze for visualization as a heatmap
DOCS: what are these arguments?
Parameters:
node_values : Float[np.ndarray, "grid_n grid_n"]
color_map : str
(defaults to"Blues"
)target_token_coord : Coord | None
(defaults toNone
)preceeding_tokens_coords : CoordArray
(defaults toNone
)colormap_center : float | None
(defaults toNone
)colormap_max : float | None
(defaults toNone
)hide_colorbar : bool
(defaults toFalse
)
Returns:
293 def plot( 294 self, 295 dpi: int = 100, 296 title: str = "", 297 fig_ax: tuple | None = None, 298 plain: bool = False, 299 ) -> MazePlot: 300 """Plot the maze and paths.""" 301 # set up figure 302 if fig_ax is None: 303 self.fig = plt.figure(dpi=dpi) 304 self.ax = self.fig.add_subplot(1, 1, 1) 305 else: 306 self.fig, self.ax = fig_ax 307 308 # plot maze 309 self._plot_maze() 310 311 # Plot labels 312 if not plain: 313 tick_arr = np.arange(self.maze.grid_shape[0]) 314 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 316 self.ax.set_xlabel("col") 317 self.ax.set_ylabel("row") 318 self.ax.set_title(title) 319 320 # plot paths 321 if self.true_path is not None: 322 self._plot_path(self.true_path) 323 for path in self.predicted_paths: 324 self._plot_path(path) 325 326 # plot markers 327 for coord, kwargs in self.marked_coords: 328 self._place_marked_coords([coord], **kwargs) 329 330 return self
Plot the maze and paths.
337 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 338 """Mark coordinates on the maze with a marker. 339 340 default marker is a blue "+": 341 `dict(marker="+", color="blue")` 342 """ 343 kwargs = { 344 **dict(marker="+", color="blue"), 345 **kwargs, 346 } 347 for coord in coords: 348 self.marked_coords.append((coord, kwargs)) 349 350 return self
Mark coordinates on the maze with a marker.
default marker is a blue "+":
dict(marker="+", color="blue")
586 def to_ascii( 587 self, 588 show_endpoints: bool = True, 589 show_solution: bool = True, 590 ) -> str: 591 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 592 if self.true_path: 593 return self.solved_maze.as_ascii( 594 show_endpoints=show_endpoints, 595 show_solution=show_solution, 596 ) 597 else: 598 return self.maze.as_ascii(show_endpoints=show_endpoints)
wrapper for self.solved_maze.as_ascii()
, shows the path if we have self.true_path
26@dataclass(kw_only=True) 27class PathFormat: 28 """formatting options for path plot""" 29 30 label: str | None = None 31 fmt: str = "o" 32 color: str | None = None 33 cmap: str | None = None 34 line_width: float | None = None 35 quiver_kwargs: dict | None = None 36 37 def combine(self, other: PathFormat) -> PathFormat: 38 """combine with other PathFormat object, overwriting attributes with non-None values. 39 40 returns a modified copy of self. 41 """ 42 output: PathFormat = deepcopy(self) 43 for key, value in other.__dict__.items(): 44 if key == "path": 45 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 46 raise ValueError( 47 err_msg, 48 ) 49 if value is not None: 50 setattr(output, key, value) 51 52 return output
formatting options for path plot
37 def combine(self, other: PathFormat) -> PathFormat: 38 """combine with other PathFormat object, overwriting attributes with non-None values. 39 40 returns a modified copy of self. 41 """ 42 output: PathFormat = deepcopy(self) 43 for key, value in other.__dict__.items(): 44 if key == "path": 45 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 46 raise ValueError( 47 err_msg, 48 ) 49 if value is not None: 50 setattr(output, key, value) 51 52 return output
combine with other PathFormat object, overwriting attributes with non-None values.
returns a modified copy of self.
125def color_tokens_cmap( 126 tokens: list[str], 127 weights: Sequence[float], 128 cmap: str | matplotlib.colors.Colormap = "Blues", 129 fmt: FormatType = "html", 130 template: str | None = None, 131 labels: bool = False, 132) -> str: 133 "color tokens given a list of weights and a colormap" 134 n_tok: int = len(tokens) 135 assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'" 136 weights_np: Float[np.ndarray, " n_tok"] = np.array(weights) 137 # normalize weights to [0, 1] 138 weights_norm = matplotlib.colors.Normalize()(weights_np) 139 140 if isinstance(cmap, str): 141 cmap = matplotlib.colormaps.get_cmap(cmap) 142 143 colors: RGBArray = cmap(weights_norm)[:, :3] * 255 144 145 output: str = color_tokens_rgb( 146 tokens=tokens, 147 colors=colors, 148 fmt=fmt, 149 template=template, 150 ) 151 152 if labels: 153 if fmt != "terminal": 154 raise NotImplementedError("labels only supported for terminal") 155 # align labels with the tokens 156 output += "\n" 157 for tok, weight in zip(tokens, weights_np, strict=False): 158 # 2 decimal points, left-aligned and trailing spaces to match token length 159 weight_str: str = f"{weight:.1f}" 160 # omit if longer than token 161 if len(weight_str) > len(tok): 162 weight_str = " " * len(tok) 163 else: 164 weight_str = weight_str.ljust(len(tok)) 165 output += f"{weight_str} " 166 167 return output
color tokens given a list of weights and a colormap
184def color_maze_tokens_AOTP( 185 tokens: list[str], 186 fmt: FormatType = "html", 187 template: str | None = None, 188 **kwargs, 189) -> str: 190 """color tokens assuming AOTP format 191 192 i.e: adjaceny list, origin, target, path 193 194 """ 195 output: list[str] = [ 196 " ".join( 197 tokens_between( 198 tokens, 199 start_tok, 200 end_tok, 201 include_start=True, 202 include_end=True, 203 ), 204 ) 205 for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS 206 ] 207 208 colors: RGBArray = np.array( 209 list(_MAZE_TOKENS_DEFAULT_COLORS.values()), 210 dtype=np.uint8, 211 ) 212 213 return color_tokens_rgb( 214 tokens=output, 215 colors=colors, 216 fmt=fmt, 217 template=template, 218 **kwargs, 219 )
color tokens assuming AOTP format
i.e: adjaceny list, origin, target, path
64def color_tokens_rgb( 65 tokens: list, 66 colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"], 67 fmt: FormatType = "html", 68 template: str | None = None, 69 clr_join: str | None = None, 70 max_length: int | None = None, 71) -> str: 72 """color tokens from a list with an RGB color array 73 74 tokens will not be escaped if `fmt` is None 75 76 # Parameters: 77 - `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length. 78 """ 79 # process format 80 if fmt is None: 81 assert template is not None 82 assert clr_join is not None 83 else: 84 assert template is None 85 assert clr_join is None 86 template = TEMPLATES[fmt] 87 clr_join = _COLOR_JOIN[fmt] 88 89 if max_length is not None: 90 # TODO: why are we using a map here again? 91 # TYPING: this is missing a lot of type hints 92 wrapped: list = list( # noqa: C417 93 map( 94 lambda x: textwrap.wrap( 95 x, 96 width=max_length, 97 break_long_words=False, 98 break_on_hyphens=False, 99 ), 100 tokens, 101 ), 102 ) 103 colors = list( 104 flatten( 105 [[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))], 106 levels_to_flatten=1, 107 ), 108 ) 109 wrapped = list(flatten(wrapped, levels_to_flatten=1)) 110 tokens = wrapped 111 112 # put everything together 113 output = [ 114 template.format( 115 clr=clr_join.join(map(str, map(int, clr))), 116 tok=_escape_tok(tok, fmt), 117 ) 118 for tok, clr in zip(tokens, colors, strict=False) 119 ] 120 121 return " ".join(output)
color tokens from a list with an RGB color array
tokens will not be escaped if fmt
is None
Parameters:
max_length: int | None
: Max number of characters before triggering a line wrap, i.e., making a new colorbox. IfNone
, no limit on max length.