docs for maze-dataset v1.3.0
View Source on GitHub

maze_dataset.plotting.plot_maze

provides MazePlot, which has many tools for plotting mazes with multiple paths, colored nodes, and more


  1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more"""
  2
  3from __future__ import annotations  # for type hinting self as return value
  4
  5import warnings
  6from copy import deepcopy
  7from dataclasses import dataclass
  8from typing import Sequence
  9
 10import matplotlib as mpl
 11import matplotlib.pyplot as plt
 12import numpy as np
 13from jaxtyping import Bool, Float
 14
 15from maze_dataset.constants import Coord, CoordArray, CoordList
 16from maze_dataset.maze import (
 17	LatticeMaze,
 18	SolvedMaze,
 19	TargetedLatticeMaze,
 20)
 21
 22LARGE_NEGATIVE_NUMBER: float = -1e10
 23
 24
 25@dataclass(kw_only=True)
 26class PathFormat:
 27	"""formatting options for path plot"""
 28
 29	label: str | None = None
 30	fmt: str = "o"
 31	color: str | None = None
 32	cmap: str | None = None
 33	line_width: float | None = None
 34	quiver_kwargs: dict | None = None
 35
 36	def combine(self, other: PathFormat) -> PathFormat:
 37		"""combine with other PathFormat object, overwriting attributes with non-None values.
 38
 39		returns a modified copy of self.
 40		"""
 41		output: PathFormat = deepcopy(self)
 42		for key, value in other.__dict__.items():
 43			if key == "path":
 44				err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
 45				raise ValueError(
 46					err_msg,
 47				)
 48			if value is not None:
 49				setattr(output, key, value)
 50
 51		return output
 52
 53
 54# styled path
 55@dataclass
 56class StyledPath(PathFormat):
 57	"a `StyledPath` is a `PathFormat` with a specific path"
 58
 59	path: CoordArray
 60
 61
 62DEFAULT_FORMATS: dict[str, PathFormat] = {
 63	"true": PathFormat(
 64		label="true path",
 65		fmt="--",
 66		color="red",
 67		line_width=2.5,
 68		quiver_kwargs=None,
 69	),
 70	"predicted": PathFormat(
 71		label=None,
 72		fmt=":",
 73		color=None,
 74		line_width=2,
 75		quiver_kwargs={"width": 0.015},
 76	),
 77}
 78
 79
 80def process_path_input(
 81	path: CoordList | CoordArray | StyledPath,
 82	_default_key: str,
 83	path_fmt: PathFormat | None = None,
 84	**kwargs,
 85) -> StyledPath:
 86	"convert a path, which might be a list or array of coords, into a `StyledPath`"
 87	styled_path: StyledPath
 88	if isinstance(path, StyledPath):
 89		styled_path = path
 90	elif isinstance(path, np.ndarray):
 91		styled_path = StyledPath(path=path)
 92		# add default formatting
 93		styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
 94	elif isinstance(path, list):
 95		styled_path = StyledPath(path=np.array(path))
 96		# add default formatting
 97		styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
 98	else:
 99		err_msg: str = (
100			f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}"
101		)
102		raise TypeError(
103			err_msg,
104		)
105
106	# add formatting from path_fmt
107	if path_fmt is not None:
108		styled_path = styled_path.combine(path_fmt)
109
110	# add formatting from kwargs
111	for key, value in kwargs.items():
112		setattr(styled_path, key, value)
113
114	return styled_path
115
116
117DEFAULT_PREDICTED_PATH_COLORS: list[str] = [
118	"tab:orange",
119	"tab:olive",
120	"sienna",
121	"mediumseagreen",
122	"tab:purple",
123	"slategrey",
124]
125
126
127class MazePlot:
128	"""Class for displaying mazes and paths"""
129
130	def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
131		"""UNIT_LENGTH: Set ratio between node size and wall thickness in image.
132
133		Wall thickness is fixed to 1px
134		A "unit" consists of a single node and the right and lower connection/wall.
135		Example: ul = 14 yields 13:1 ratio between node size and wall thickness
136		"""
137		self.unit_length: int = unit_length
138		self.maze: LatticeMaze = maze
139		self.true_path: StyledPath | None = None
140		self.predicted_paths: list[StyledPath] = []
141		self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
142		self.custom_node_value_flag: bool = False
143		self.node_color_map: str = "Blues"
144		self.target_token_coord: Coord = None
145		self.preceding_tokens_coords: CoordArray = None
146		self.colormap_center: float | None = None
147		self.cbar_ax = None
148		self.marked_coords: list[tuple[Coord, dict]] = list()
149
150		self.marker_kwargs_current: dict = dict(
151			marker="s",
152			color="green",
153			ms=12,
154		)
155		self.marker_kwargs_next: dict = dict(
156			marker="P",
157			color="green",
158			ms=12,
159		)
160
161		if isinstance(maze, SolvedMaze):
162			self.add_true_path(maze.solution)
163		else:
164			if isinstance(maze, TargetedLatticeMaze):
165				self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
166
167	@property
168	def solved_maze(self) -> SolvedMaze:
169		"get the underlying `SolvedMaze` object"
170		if self.true_path is None:
171			raise ValueError(
172				"Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
173			)
174		return SolvedMaze.from_lattice_maze(
175			lattice_maze=self.maze,
176			solution=self.true_path.path,
177		)
178
179	def add_true_path(
180		self,
181		path: CoordList | CoordArray | StyledPath,
182		path_fmt: PathFormat | None = None,
183		**kwargs,
184	) -> MazePlot:
185		"add a true path to the maze with optional formatting"
186		self.true_path = process_path_input(
187			path=path,
188			_default_key="true",
189			path_fmt=path_fmt,
190			**kwargs,
191		)
192
193		return self
194
195	def add_predicted_path(
196		self,
197		path: CoordList | CoordArray | StyledPath,
198		path_fmt: PathFormat | None = None,
199		**kwargs,
200	) -> MazePlot:
201		"""Recieve predicted path and formatting preferences from input and save in predicted_path list.
202
203		Default formatting depends on nuber of paths already saved in predicted path list.
204		"""
205		styled_path: StyledPath = process_path_input(
206			path=path,
207			_default_key="predicted",
208			path_fmt=path_fmt,
209			**kwargs,
210		)
211
212		# set default label and color if not specified
213		if styled_path.label is None:
214			styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
215
216		if styled_path.color is None:
217			color_num: int = len(self.predicted_paths) % len(
218				DEFAULT_PREDICTED_PATH_COLORS,
219			)
220			styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
221
222		self.predicted_paths.append(styled_path)
223		return self
224
225	def add_multiple_paths(
226		self,
227		path_list: Sequence[CoordList | CoordArray | StyledPath],
228	) -> MazePlot:
229		"""Function for adding multiple paths to MazePlot at once.
230
231		> DOCS: what are the two ways?
232		This can be done in two ways:
233		1. Passing a list of
234		"""
235		for path in path_list:
236			self.add_predicted_path(path)
237		return self
238
239	def add_node_values(
240		self,
241		node_values: Float[np.ndarray, "grid_n grid_n"],
242		color_map: str = "Blues",
243		target_token_coord: Coord | None = None,
244		preceeding_tokens_coords: CoordArray = None,
245		colormap_center: float | None = None,
246		colormap_max: float | None = None,
247		hide_colorbar: bool = False,
248	) -> MazePlot:
249		"""add node values to the maze for visualization as a heatmap
250
251		> DOCS: what are these arguments?
252
253		# Parameters:
254		- `node_values : Float[np.ndarray, "grid_n grid_n"]`
255		- `color_map : str`
256			(defaults to `"Blues"`)
257		- `target_token_coord : Coord | None`
258			(defaults to `None`)
259		- `preceeding_tokens_coords : CoordArray`
260			(defaults to `None`)
261		- `colormap_center : float | None`
262			(defaults to `None`)
263		- `colormap_max : float | None`
264			(defaults to `None`)
265		- `hide_colorbar : bool`
266			(defaults to `False`)
267
268		# Returns:
269		- `MazePlot`
270		"""
271		assert node_values.shape == self.maze.grid_shape, (
272			"Please pass node values of the same sape as LatticeMaze.grid_shape"
273		)
274		# assert np.min(node_values) >= 0, "Please pass non-negative node values only."
275
276		self.node_values = node_values
277		# Set flag for choosing cmap while plotting maze
278		self.custom_node_value_flag = True
279		# Retrieve Max node value for plotting, +1e-10 to avoid division by zero
280		self.node_color_map = color_map
281		self.colormap_center = colormap_center
282		self.colormap_max = colormap_max
283		self.hide_colorbar = hide_colorbar
284
285		if target_token_coord is not None:
286			self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
287		if preceeding_tokens_coords is not None:
288			for coord in preceeding_tokens_coords:
289				self.marked_coords.append((coord, self.marker_kwargs_current))
290		return self
291
292	def plot(
293		self,
294		dpi: int = 100,
295		title: str = "",
296		fig_ax: tuple | None = None,
297		plain: bool = False,
298	) -> MazePlot:
299		"""Plot the maze and paths."""
300		# set up figure
301		if fig_ax is None:
302			self.fig = plt.figure(dpi=dpi)
303			self.ax = self.fig.add_subplot(1, 1, 1)
304		else:
305			self.fig, self.ax = fig_ax
306
307		# plot maze
308		self._plot_maze()
309
310		# Plot labels
311		if not plain:
312			tick_arr = np.arange(self.maze.grid_shape[0])
313			self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
314			self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315			self.ax.set_xlabel("col")
316			self.ax.set_ylabel("row")
317			self.ax.set_title(title)
318
319		# plot paths
320		if self.true_path is not None:
321			self._plot_path(self.true_path)
322		for path in self.predicted_paths:
323			self._plot_path(path)
324
325		# plot markers
326		for coord, kwargs in self.marked_coords:
327			self._place_marked_coords([coord], **kwargs)
328
329		return self
330
331	def _rowcol_to_coord(self, point: Coord) -> np.ndarray:
332		"""Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
333		point = np.array([point[1], point[0]])
334		return self.unit_length * (point + 0.5)
335
336	def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
337		"""Mark coordinates on the maze with a marker.
338
339		default marker is a blue "+":
340		`dict(marker="+", color="blue")`
341		"""
342		kwargs = {
343			**dict(marker="+", color="blue"),
344			**kwargs,
345		}
346		for coord in coords:
347			self.marked_coords.append((coord, kwargs))
348
349		return self
350
351	def _place_marked_coords(
352		self,
353		coords: CoordArray | list[Coord],
354		**kwargs,
355	) -> MazePlot:
356		coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords])
357		self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs)
358
359		return self
360
361	def _plot_maze(self) -> None:  # noqa: C901, PLR0912
362		"""Define Colormap and plot maze.
363
364		Colormap: x is -inf: black
365		else: use colormap
366		"""
367		img = self._lattice_maze_to_img()
368
369		# if no node_values have been passed (no colormap)
370		if self.custom_node_value_flag is False:
371			self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1)
372
373		else:
374			assert self.node_values is not None, "Please pass node values."
375			assert not np.isnan(self.node_values).any(), (
376				"Please pass node values, they cannot be nan."
377			)
378
379			vals_min: float = np.nanmin(self.node_values)
380			vals_max: float = np.nanmax(self.node_values)
381			# if both are negative or both are positive, set max/min to 0
382			if vals_max < 0.0:
383				vals_max = 0.0
384			elif vals_min > 0.0:
385				vals_min = 0.0
386
387			# adjust vals_max, in case you need consistent colorbar across multiple plots
388			vals_max = self.colormap_max or vals_max
389
390			# create colormap
391			cmap = mpl.colormaps[self.node_color_map]
392			# TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black
393			cmap.set_bad(color="black")
394
395			if self.colormap_center is not None:
396				if not (vals_min < self.colormap_center < vals_max):
397					if vals_min == self.colormap_center:
398						vals_min -= 1e-10
399					elif vals_max == self.colormap_center:
400						vals_max += 1e-10
401					else:
402						err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}"
403						raise ValueError(
404							err_msg,
405						)
406
407				norm = mpl.colors.TwoSlopeNorm(
408					vmin=vals_min,
409					vcenter=self.colormap_center,
410					vmax=vals_max,
411				)
412				_plotted = self.ax.imshow(img, cmap=cmap, norm=norm)
413			else:
414				_plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max)
415
416			# Add colorbar based on the condition of self.hide_colorbar
417			if not self.hide_colorbar:
418				ticks = np.linspace(vals_min, vals_max, 5)
419
420				if (vals_min < 0.0 < vals_max) and (0.0 not in ticks):
421					ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0)
422
423				if (
424					self.colormap_center is not None
425					and self.colormap_center not in ticks
426					and vals_min < self.colormap_center < vals_max
427				):
428					ticks = np.insert(
429						ticks,
430						np.searchsorted(ticks, self.colormap_center),
431						self.colormap_center,
432					)
433
434				cbar = plt.colorbar(
435					_plotted,
436					ticks=ticks,
437					ax=self.ax,
438					cax=self.cbar_ax,
439				)
440				self.cbar_ax = cbar.ax
441
442		# make the boundaries of the image thicker (walls look weird without this)
443		for axis in ["top", "bottom", "left", "right"]:
444			self.ax.spines[axis].set_linewidth(2)
445
446	def _lattice_maze_to_img(
447		self,
448		connection_val_scale: float = 0.93,
449	) -> Bool[np.ndarray, "row col"]:
450		"""Build an image to visualise the maze.
451
452		Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
453		- Nodes have area: (ul-1) * (ul-1) and value 1 by default
454			- take node_value if passed via .add_node_values()
455		- Walls have area: 1 * (ul-1) and value -1
456		- Connections have area: 1 * (ul-1); color and value 0.93 by default
457			- take node_value if passed via .add_node_values()
458
459		Axes definition:
460		(0,0)     col
461		----|----------->
462			|
463		row |
464			|
465			v
466
467		Returns a matrix of side length (ul) * n + 1 where n is the number of nodes.
468		"""
469		# TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored
470		node_bdry_hack: int
471		connection_list_processed: Float[np.ndarray, "dim row col"]
472		# Set node and connection values
473		if self.node_values is None:
474			scaled_node_values = np.ones(self.maze.grid_shape)
475			connection_values = scaled_node_values * connection_val_scale
476			node_bdry_hack = 0
477			# TODO: hack
478			# invert connection list
479			connection_list_processed = np.logical_not(self.maze.connection_list)
480		else:
481			# TODO: hack
482			scaled_node_values = self.node_values
483			# connection_values = scaled_node_values
484			connection_values = np.full_like(scaled_node_values, np.nan)
485			node_bdry_hack = 1
486			connection_list_processed = self.maze.connection_list
487
488		# Create background image (all pixels set to -1, walls everywhere)
489		img: Float[np.ndarray, "row col"] = -np.ones(
490			(
491				self.maze.grid_shape[0] * self.unit_length + 1,
492				self.maze.grid_shape[1] * self.unit_length + 1,
493			),
494			dtype=float,
495		)
496
497		# Draw nodes and connections by iterating through lattice
498		for row in range(self.maze.grid_shape[0]):
499			for col in range(self.maze.grid_shape[1]):
500				# Draw node
501				img[
502					row * self.unit_length + 1 : (row + 1) * self.unit_length
503					+ node_bdry_hack,
504					col * self.unit_length + 1 : (col + 1) * self.unit_length
505					+ node_bdry_hack,
506				] = scaled_node_values[row, col]
507
508				# Down connection
509				if not connection_list_processed[0, row, col]:
510					img[
511						(row + 1) * self.unit_length,
512						col * self.unit_length + 1 : (col + 1) * self.unit_length,
513					] = connection_values[row, col]
514
515				# Right connection
516				if not connection_list_processed[1, row, col]:
517					img[
518						row * self.unit_length + 1 : (row + 1) * self.unit_length,
519						(col + 1) * self.unit_length,
520					] = connection_values[row, col]
521
522		return img
523
524	def _plot_path(self, path_format: PathFormat) -> None:
525		if len(path_format.path) == 0:
526			warnings.warn(f"Empty path, skipping plotting\n{path_format = }")
527			return
528		p_transformed = np.array(
529			[self._rowcol_to_coord(coord) for coord in path_format.path],
530		)
531		if path_format.quiver_kwargs is not None:
532			try:
533				x: np.ndarray = p_transformed[:, 0]
534				y: np.ndarray = p_transformed[:, 1]
535			except Exception as e:
536				err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}"
537				raise ValueError(
538					err_msg,
539				) from e
540
541			# Generate colors from the colormap
542			if path_format.cmap is not None:
543				n = len(x) - 1  # Number of arrows
544				cmap = plt.get_cmap(path_format.cmap)
545				colors = [cmap(i / n) for i in range(n)]
546			else:
547				colors = path_format.color
548
549			self.ax.quiver(
550				x[:-1],
551				y[:-1],
552				x[1:] - x[:-1],
553				y[1:] - y[:-1],
554				scale_units="xy",
555				angles="xy",
556				scale=1,
557				color=colors,
558				**path_format.quiver_kwargs,
559			)
560		else:
561			self.ax.plot(
562				p_transformed[:, 0],
563				p_transformed[:, 1],
564				path_format.fmt,
565				lw=path_format.line_width,
566				color=path_format.color,
567				label=path_format.label,
568			)
569		# mark endpoints
570		self.ax.plot(
571			[p_transformed[0][0]],
572			[p_transformed[0][1]],
573			"o",
574			color=path_format.color,
575			ms=10,
576		)
577		self.ax.plot(
578			[p_transformed[-1][0]],
579			[p_transformed[-1][1]],
580			"x",
581			color=path_format.color,
582			ms=10,
583		)
584
585	def to_ascii(
586		self,
587		show_endpoints: bool = True,
588		show_solution: bool = True,
589	) -> str:
590		"wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
591		if self.true_path:
592			return self.solved_maze.as_ascii(
593				show_endpoints=show_endpoints,
594				show_solution=show_solution,
595			)
596		else:
597			return self.maze.as_ascii(show_endpoints=show_endpoints)

LARGE_NEGATIVE_NUMBER: float = -10000000000.0
@dataclass(kw_only=True)
class PathFormat:
26@dataclass(kw_only=True)
27class PathFormat:
28	"""formatting options for path plot"""
29
30	label: str | None = None
31	fmt: str = "o"
32	color: str | None = None
33	cmap: str | None = None
34	line_width: float | None = None
35	quiver_kwargs: dict | None = None
36
37	def combine(self, other: PathFormat) -> PathFormat:
38		"""combine with other PathFormat object, overwriting attributes with non-None values.
39
40		returns a modified copy of self.
41		"""
42		output: PathFormat = deepcopy(self)
43		for key, value in other.__dict__.items():
44			if key == "path":
45				err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
46				raise ValueError(
47					err_msg,
48				)
49			if value is not None:
50				setattr(output, key, value)
51
52		return output

formatting options for path plot

PathFormat( *, label: str | None = None, fmt: str = 'o', color: str | None = None, cmap: str | None = None, line_width: float | None = None, quiver_kwargs: dict | None = None)
label: str | None = None
fmt: str = 'o'
color: str | None = None
cmap: str | None = None
line_width: float | None = None
quiver_kwargs: dict | None = None
def combine( self, other: PathFormat) -> PathFormat:
37	def combine(self, other: PathFormat) -> PathFormat:
38		"""combine with other PathFormat object, overwriting attributes with non-None values.
39
40		returns a modified copy of self.
41		"""
42		output: PathFormat = deepcopy(self)
43		for key, value in other.__dict__.items():
44			if key == "path":
45				err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
46				raise ValueError(
47					err_msg,
48				)
49			if value is not None:
50				setattr(output, key, value)
51
52		return output

combine with other PathFormat object, overwriting attributes with non-None values.

returns a modified copy of self.

@dataclass
class StyledPath(PathFormat):
56@dataclass
57class StyledPath(PathFormat):
58	"a `StyledPath` is a `PathFormat` with a specific path"
59
60	path: CoordArray

a StyledPath is a PathFormat with a specific path

StyledPath( path: jaxtyping.Int8[ndarray, 'coord row_col=2'], *, label: str | None = None, fmt: str = 'o', color: str | None = None, cmap: str | None = None, line_width: float | None = None, quiver_kwargs: dict | None = None)
path: jaxtyping.Int8[ndarray, 'coord row_col=2']
DEFAULT_FORMATS: dict[str, PathFormat] = {'true': PathFormat(label='true path', fmt='--', color='red', cmap=None, line_width=2.5, quiver_kwargs=None), 'predicted': PathFormat(label=None, fmt=':', color=None, cmap=None, line_width=2, quiver_kwargs={'width': 0.015})}
def process_path_input( path: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | StyledPath, _default_key: str, path_fmt: PathFormat | None = None, **kwargs) -> StyledPath:
 81def process_path_input(
 82	path: CoordList | CoordArray | StyledPath,
 83	_default_key: str,
 84	path_fmt: PathFormat | None = None,
 85	**kwargs,
 86) -> StyledPath:
 87	"convert a path, which might be a list or array of coords, into a `StyledPath`"
 88	styled_path: StyledPath
 89	if isinstance(path, StyledPath):
 90		styled_path = path
 91	elif isinstance(path, np.ndarray):
 92		styled_path = StyledPath(path=path)
 93		# add default formatting
 94		styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
 95	elif isinstance(path, list):
 96		styled_path = StyledPath(path=np.array(path))
 97		# add default formatting
 98		styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
 99	else:
100		err_msg: str = (
101			f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}"
102		)
103		raise TypeError(
104			err_msg,
105		)
106
107	# add formatting from path_fmt
108	if path_fmt is not None:
109		styled_path = styled_path.combine(path_fmt)
110
111	# add formatting from kwargs
112	for key, value in kwargs.items():
113		setattr(styled_path, key, value)
114
115	return styled_path

convert a path, which might be a list or array of coords, into a StyledPath

DEFAULT_PREDICTED_PATH_COLORS: list[str] = ['tab:orange', 'tab:olive', 'sienna', 'mediumseagreen', 'tab:purple', 'slategrey']
class MazePlot:
128class MazePlot:
129	"""Class for displaying mazes and paths"""
130
131	def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
132		"""UNIT_LENGTH: Set ratio between node size and wall thickness in image.
133
134		Wall thickness is fixed to 1px
135		A "unit" consists of a single node and the right and lower connection/wall.
136		Example: ul = 14 yields 13:1 ratio between node size and wall thickness
137		"""
138		self.unit_length: int = unit_length
139		self.maze: LatticeMaze = maze
140		self.true_path: StyledPath | None = None
141		self.predicted_paths: list[StyledPath] = []
142		self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
143		self.custom_node_value_flag: bool = False
144		self.node_color_map: str = "Blues"
145		self.target_token_coord: Coord = None
146		self.preceding_tokens_coords: CoordArray = None
147		self.colormap_center: float | None = None
148		self.cbar_ax = None
149		self.marked_coords: list[tuple[Coord, dict]] = list()
150
151		self.marker_kwargs_current: dict = dict(
152			marker="s",
153			color="green",
154			ms=12,
155		)
156		self.marker_kwargs_next: dict = dict(
157			marker="P",
158			color="green",
159			ms=12,
160		)
161
162		if isinstance(maze, SolvedMaze):
163			self.add_true_path(maze.solution)
164		else:
165			if isinstance(maze, TargetedLatticeMaze):
166				self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
167
168	@property
169	def solved_maze(self) -> SolvedMaze:
170		"get the underlying `SolvedMaze` object"
171		if self.true_path is None:
172			raise ValueError(
173				"Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
174			)
175		return SolvedMaze.from_lattice_maze(
176			lattice_maze=self.maze,
177			solution=self.true_path.path,
178		)
179
180	def add_true_path(
181		self,
182		path: CoordList | CoordArray | StyledPath,
183		path_fmt: PathFormat | None = None,
184		**kwargs,
185	) -> MazePlot:
186		"add a true path to the maze with optional formatting"
187		self.true_path = process_path_input(
188			path=path,
189			_default_key="true",
190			path_fmt=path_fmt,
191			**kwargs,
192		)
193
194		return self
195
196	def add_predicted_path(
197		self,
198		path: CoordList | CoordArray | StyledPath,
199		path_fmt: PathFormat | None = None,
200		**kwargs,
201	) -> MazePlot:
202		"""Recieve predicted path and formatting preferences from input and save in predicted_path list.
203
204		Default formatting depends on nuber of paths already saved in predicted path list.
205		"""
206		styled_path: StyledPath = process_path_input(
207			path=path,
208			_default_key="predicted",
209			path_fmt=path_fmt,
210			**kwargs,
211		)
212
213		# set default label and color if not specified
214		if styled_path.label is None:
215			styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
216
217		if styled_path.color is None:
218			color_num: int = len(self.predicted_paths) % len(
219				DEFAULT_PREDICTED_PATH_COLORS,
220			)
221			styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
222
223		self.predicted_paths.append(styled_path)
224		return self
225
226	def add_multiple_paths(
227		self,
228		path_list: Sequence[CoordList | CoordArray | StyledPath],
229	) -> MazePlot:
230		"""Function for adding multiple paths to MazePlot at once.
231
232		> DOCS: what are the two ways?
233		This can be done in two ways:
234		1. Passing a list of
235		"""
236		for path in path_list:
237			self.add_predicted_path(path)
238		return self
239
240	def add_node_values(
241		self,
242		node_values: Float[np.ndarray, "grid_n grid_n"],
243		color_map: str = "Blues",
244		target_token_coord: Coord | None = None,
245		preceeding_tokens_coords: CoordArray = None,
246		colormap_center: float | None = None,
247		colormap_max: float | None = None,
248		hide_colorbar: bool = False,
249	) -> MazePlot:
250		"""add node values to the maze for visualization as a heatmap
251
252		> DOCS: what are these arguments?
253
254		# Parameters:
255		- `node_values : Float[np.ndarray, &quot;grid_n grid_n&quot;]`
256		- `color_map : str`
257			(defaults to `"Blues"`)
258		- `target_token_coord : Coord | None`
259			(defaults to `None`)
260		- `preceeding_tokens_coords : CoordArray`
261			(defaults to `None`)
262		- `colormap_center : float | None`
263			(defaults to `None`)
264		- `colormap_max : float | None`
265			(defaults to `None`)
266		- `hide_colorbar : bool`
267			(defaults to `False`)
268
269		# Returns:
270		- `MazePlot`
271		"""
272		assert node_values.shape == self.maze.grid_shape, (
273			"Please pass node values of the same sape as LatticeMaze.grid_shape"
274		)
275		# assert np.min(node_values) >= 0, "Please pass non-negative node values only."
276
277		self.node_values = node_values
278		# Set flag for choosing cmap while plotting maze
279		self.custom_node_value_flag = True
280		# Retrieve Max node value for plotting, +1e-10 to avoid division by zero
281		self.node_color_map = color_map
282		self.colormap_center = colormap_center
283		self.colormap_max = colormap_max
284		self.hide_colorbar = hide_colorbar
285
286		if target_token_coord is not None:
287			self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
288		if preceeding_tokens_coords is not None:
289			for coord in preceeding_tokens_coords:
290				self.marked_coords.append((coord, self.marker_kwargs_current))
291		return self
292
293	def plot(
294		self,
295		dpi: int = 100,
296		title: str = "",
297		fig_ax: tuple | None = None,
298		plain: bool = False,
299	) -> MazePlot:
300		"""Plot the maze and paths."""
301		# set up figure
302		if fig_ax is None:
303			self.fig = plt.figure(dpi=dpi)
304			self.ax = self.fig.add_subplot(1, 1, 1)
305		else:
306			self.fig, self.ax = fig_ax
307
308		# plot maze
309		self._plot_maze()
310
311		# Plot labels
312		if not plain:
313			tick_arr = np.arange(self.maze.grid_shape[0])
314			self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315			self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
316			self.ax.set_xlabel("col")
317			self.ax.set_ylabel("row")
318			self.ax.set_title(title)
319
320		# plot paths
321		if self.true_path is not None:
322			self._plot_path(self.true_path)
323		for path in self.predicted_paths:
324			self._plot_path(path)
325
326		# plot markers
327		for coord, kwargs in self.marked_coords:
328			self._place_marked_coords([coord], **kwargs)
329
330		return self
331
332	def _rowcol_to_coord(self, point: Coord) -> np.ndarray:
333		"""Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
334		point = np.array([point[1], point[0]])
335		return self.unit_length * (point + 0.5)
336
337	def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
338		"""Mark coordinates on the maze with a marker.
339
340		default marker is a blue "+":
341		`dict(marker="+", color="blue")`
342		"""
343		kwargs = {
344			**dict(marker="+", color="blue"),
345			**kwargs,
346		}
347		for coord in coords:
348			self.marked_coords.append((coord, kwargs))
349
350		return self
351
352	def _place_marked_coords(
353		self,
354		coords: CoordArray | list[Coord],
355		**kwargs,
356	) -> MazePlot:
357		coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords])
358		self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs)
359
360		return self
361
362	def _plot_maze(self) -> None:  # noqa: C901, PLR0912
363		"""Define Colormap and plot maze.
364
365		Colormap: x is -inf: black
366		else: use colormap
367		"""
368		img = self._lattice_maze_to_img()
369
370		# if no node_values have been passed (no colormap)
371		if self.custom_node_value_flag is False:
372			self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1)
373
374		else:
375			assert self.node_values is not None, "Please pass node values."
376			assert not np.isnan(self.node_values).any(), (
377				"Please pass node values, they cannot be nan."
378			)
379
380			vals_min: float = np.nanmin(self.node_values)
381			vals_max: float = np.nanmax(self.node_values)
382			# if both are negative or both are positive, set max/min to 0
383			if vals_max < 0.0:
384				vals_max = 0.0
385			elif vals_min > 0.0:
386				vals_min = 0.0
387
388			# adjust vals_max, in case you need consistent colorbar across multiple plots
389			vals_max = self.colormap_max or vals_max
390
391			# create colormap
392			cmap = mpl.colormaps[self.node_color_map]
393			# TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black
394			cmap.set_bad(color="black")
395
396			if self.colormap_center is not None:
397				if not (vals_min < self.colormap_center < vals_max):
398					if vals_min == self.colormap_center:
399						vals_min -= 1e-10
400					elif vals_max == self.colormap_center:
401						vals_max += 1e-10
402					else:
403						err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}"
404						raise ValueError(
405							err_msg,
406						)
407
408				norm = mpl.colors.TwoSlopeNorm(
409					vmin=vals_min,
410					vcenter=self.colormap_center,
411					vmax=vals_max,
412				)
413				_plotted = self.ax.imshow(img, cmap=cmap, norm=norm)
414			else:
415				_plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max)
416
417			# Add colorbar based on the condition of self.hide_colorbar
418			if not self.hide_colorbar:
419				ticks = np.linspace(vals_min, vals_max, 5)
420
421				if (vals_min < 0.0 < vals_max) and (0.0 not in ticks):
422					ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0)
423
424				if (
425					self.colormap_center is not None
426					and self.colormap_center not in ticks
427					and vals_min < self.colormap_center < vals_max
428				):
429					ticks = np.insert(
430						ticks,
431						np.searchsorted(ticks, self.colormap_center),
432						self.colormap_center,
433					)
434
435				cbar = plt.colorbar(
436					_plotted,
437					ticks=ticks,
438					ax=self.ax,
439					cax=self.cbar_ax,
440				)
441				self.cbar_ax = cbar.ax
442
443		# make the boundaries of the image thicker (walls look weird without this)
444		for axis in ["top", "bottom", "left", "right"]:
445			self.ax.spines[axis].set_linewidth(2)
446
447	def _lattice_maze_to_img(
448		self,
449		connection_val_scale: float = 0.93,
450	) -> Bool[np.ndarray, "row col"]:
451		"""Build an image to visualise the maze.
452
453		Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
454		- Nodes have area: (ul-1) * (ul-1) and value 1 by default
455			- take node_value if passed via .add_node_values()
456		- Walls have area: 1 * (ul-1) and value -1
457		- Connections have area: 1 * (ul-1); color and value 0.93 by default
458			- take node_value if passed via .add_node_values()
459
460		Axes definition:
461		(0,0)     col
462		----|----------->
463			|
464		row |
465			|
466			v
467
468		Returns a matrix of side length (ul) * n + 1 where n is the number of nodes.
469		"""
470		# TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored
471		node_bdry_hack: int
472		connection_list_processed: Float[np.ndarray, "dim row col"]
473		# Set node and connection values
474		if self.node_values is None:
475			scaled_node_values = np.ones(self.maze.grid_shape)
476			connection_values = scaled_node_values * connection_val_scale
477			node_bdry_hack = 0
478			# TODO: hack
479			# invert connection list
480			connection_list_processed = np.logical_not(self.maze.connection_list)
481		else:
482			# TODO: hack
483			scaled_node_values = self.node_values
484			# connection_values = scaled_node_values
485			connection_values = np.full_like(scaled_node_values, np.nan)
486			node_bdry_hack = 1
487			connection_list_processed = self.maze.connection_list
488
489		# Create background image (all pixels set to -1, walls everywhere)
490		img: Float[np.ndarray, "row col"] = -np.ones(
491			(
492				self.maze.grid_shape[0] * self.unit_length + 1,
493				self.maze.grid_shape[1] * self.unit_length + 1,
494			),
495			dtype=float,
496		)
497
498		# Draw nodes and connections by iterating through lattice
499		for row in range(self.maze.grid_shape[0]):
500			for col in range(self.maze.grid_shape[1]):
501				# Draw node
502				img[
503					row * self.unit_length + 1 : (row + 1) * self.unit_length
504					+ node_bdry_hack,
505					col * self.unit_length + 1 : (col + 1) * self.unit_length
506					+ node_bdry_hack,
507				] = scaled_node_values[row, col]
508
509				# Down connection
510				if not connection_list_processed[0, row, col]:
511					img[
512						(row + 1) * self.unit_length,
513						col * self.unit_length + 1 : (col + 1) * self.unit_length,
514					] = connection_values[row, col]
515
516				# Right connection
517				if not connection_list_processed[1, row, col]:
518					img[
519						row * self.unit_length + 1 : (row + 1) * self.unit_length,
520						(col + 1) * self.unit_length,
521					] = connection_values[row, col]
522
523		return img
524
525	def _plot_path(self, path_format: PathFormat) -> None:
526		if len(path_format.path) == 0:
527			warnings.warn(f"Empty path, skipping plotting\n{path_format = }")
528			return
529		p_transformed = np.array(
530			[self._rowcol_to_coord(coord) for coord in path_format.path],
531		)
532		if path_format.quiver_kwargs is not None:
533			try:
534				x: np.ndarray = p_transformed[:, 0]
535				y: np.ndarray = p_transformed[:, 1]
536			except Exception as e:
537				err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}"
538				raise ValueError(
539					err_msg,
540				) from e
541
542			# Generate colors from the colormap
543			if path_format.cmap is not None:
544				n = len(x) - 1  # Number of arrows
545				cmap = plt.get_cmap(path_format.cmap)
546				colors = [cmap(i / n) for i in range(n)]
547			else:
548				colors = path_format.color
549
550			self.ax.quiver(
551				x[:-1],
552				y[:-1],
553				x[1:] - x[:-1],
554				y[1:] - y[:-1],
555				scale_units="xy",
556				angles="xy",
557				scale=1,
558				color=colors,
559				**path_format.quiver_kwargs,
560			)
561		else:
562			self.ax.plot(
563				p_transformed[:, 0],
564				p_transformed[:, 1],
565				path_format.fmt,
566				lw=path_format.line_width,
567				color=path_format.color,
568				label=path_format.label,
569			)
570		# mark endpoints
571		self.ax.plot(
572			[p_transformed[0][0]],
573			[p_transformed[0][1]],
574			"o",
575			color=path_format.color,
576			ms=10,
577		)
578		self.ax.plot(
579			[p_transformed[-1][0]],
580			[p_transformed[-1][1]],
581			"x",
582			color=path_format.color,
583			ms=10,
584		)
585
586	def to_ascii(
587		self,
588		show_endpoints: bool = True,
589		show_solution: bool = True,
590	) -> str:
591		"wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
592		if self.true_path:
593			return self.solved_maze.as_ascii(
594				show_endpoints=show_endpoints,
595				show_solution=show_solution,
596			)
597		else:
598			return self.maze.as_ascii(show_endpoints=show_endpoints)

Class for displaying mazes and paths

MazePlot( maze: maze_dataset.LatticeMaze, unit_length: int = 14)
131	def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
132		"""UNIT_LENGTH: Set ratio between node size and wall thickness in image.
133
134		Wall thickness is fixed to 1px
135		A "unit" consists of a single node and the right and lower connection/wall.
136		Example: ul = 14 yields 13:1 ratio between node size and wall thickness
137		"""
138		self.unit_length: int = unit_length
139		self.maze: LatticeMaze = maze
140		self.true_path: StyledPath | None = None
141		self.predicted_paths: list[StyledPath] = []
142		self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
143		self.custom_node_value_flag: bool = False
144		self.node_color_map: str = "Blues"
145		self.target_token_coord: Coord = None
146		self.preceding_tokens_coords: CoordArray = None
147		self.colormap_center: float | None = None
148		self.cbar_ax = None
149		self.marked_coords: list[tuple[Coord, dict]] = list()
150
151		self.marker_kwargs_current: dict = dict(
152			marker="s",
153			color="green",
154			ms=12,
155		)
156		self.marker_kwargs_next: dict = dict(
157			marker="P",
158			color="green",
159			ms=12,
160		)
161
162		if isinstance(maze, SolvedMaze):
163			self.add_true_path(maze.solution)
164		else:
165			if isinstance(maze, TargetedLatticeMaze):
166				self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)

UNIT_LENGTH: Set ratio between node size and wall thickness in image.

Wall thickness is fixed to 1px A "unit" consists of a single node and the right and lower connection/wall. Example: ul = 14 yields 13:1 ratio between node size and wall thickness

unit_length: int
true_path: StyledPath | None
predicted_paths: list[StyledPath]
node_values: jaxtyping.Float[ndarray, 'grid_n grid_n']
custom_node_value_flag: bool
node_color_map: str
target_token_coord: jaxtyping.Int8[ndarray, 'row_col=2']
preceding_tokens_coords: jaxtyping.Int8[ndarray, 'coord row_col=2']
colormap_center: float | None
cbar_ax
marked_coords: list[tuple[jaxtyping.Int8[ndarray, 'row_col=2'], dict]]
marker_kwargs_current: dict
marker_kwargs_next: dict
solved_maze: maze_dataset.SolvedMaze
168	@property
169	def solved_maze(self) -> SolvedMaze:
170		"get the underlying `SolvedMaze` object"
171		if self.true_path is None:
172			raise ValueError(
173				"Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
174			)
175		return SolvedMaze.from_lattice_maze(
176			lattice_maze=self.maze,
177			solution=self.true_path.path,
178		)

get the underlying SolvedMaze object

def add_true_path( self, path: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | StyledPath, path_fmt: PathFormat | None = None, **kwargs) -> MazePlot:
180	def add_true_path(
181		self,
182		path: CoordList | CoordArray | StyledPath,
183		path_fmt: PathFormat | None = None,
184		**kwargs,
185	) -> MazePlot:
186		"add a true path to the maze with optional formatting"
187		self.true_path = process_path_input(
188			path=path,
189			_default_key="true",
190			path_fmt=path_fmt,
191			**kwargs,
192		)
193
194		return self

add a true path to the maze with optional formatting

def add_predicted_path( self, path: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | StyledPath, path_fmt: PathFormat | None = None, **kwargs) -> MazePlot:
196	def add_predicted_path(
197		self,
198		path: CoordList | CoordArray | StyledPath,
199		path_fmt: PathFormat | None = None,
200		**kwargs,
201	) -> MazePlot:
202		"""Recieve predicted path and formatting preferences from input and save in predicted_path list.
203
204		Default formatting depends on nuber of paths already saved in predicted path list.
205		"""
206		styled_path: StyledPath = process_path_input(
207			path=path,
208			_default_key="predicted",
209			path_fmt=path_fmt,
210			**kwargs,
211		)
212
213		# set default label and color if not specified
214		if styled_path.label is None:
215			styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
216
217		if styled_path.color is None:
218			color_num: int = len(self.predicted_paths) % len(
219				DEFAULT_PREDICTED_PATH_COLORS,
220			)
221			styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
222
223		self.predicted_paths.append(styled_path)
224		return self

Recieve predicted path and formatting preferences from input and save in predicted_path list.

Default formatting depends on nuber of paths already saved in predicted path list.

def add_multiple_paths( self, path_list: Sequence[list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | StyledPath]) -> MazePlot:
226	def add_multiple_paths(
227		self,
228		path_list: Sequence[CoordList | CoordArray | StyledPath],
229	) -> MazePlot:
230		"""Function for adding multiple paths to MazePlot at once.
231
232		> DOCS: what are the two ways?
233		This can be done in two ways:
234		1. Passing a list of
235		"""
236		for path in path_list:
237			self.add_predicted_path(path)
238		return self

Function for adding multiple paths to MazePlot at once.

DOCS: what are the two ways? This can be done in two ways:

  1. Passing a list of
def add_node_values( self, node_values: jaxtyping.Float[ndarray, 'grid_n grid_n'], color_map: str = 'Blues', target_token_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None, preceeding_tokens_coords: jaxtyping.Int8[ndarray, 'coord row_col=2'] = None, colormap_center: float | None = None, colormap_max: float | None = None, hide_colorbar: bool = False) -> MazePlot:
240	def add_node_values(
241		self,
242		node_values: Float[np.ndarray, "grid_n grid_n"],
243		color_map: str = "Blues",
244		target_token_coord: Coord | None = None,
245		preceeding_tokens_coords: CoordArray = None,
246		colormap_center: float | None = None,
247		colormap_max: float | None = None,
248		hide_colorbar: bool = False,
249	) -> MazePlot:
250		"""add node values to the maze for visualization as a heatmap
251
252		> DOCS: what are these arguments?
253
254		# Parameters:
255		- `node_values : Float[np.ndarray, &quot;grid_n grid_n&quot;]`
256		- `color_map : str`
257			(defaults to `"Blues"`)
258		- `target_token_coord : Coord | None`
259			(defaults to `None`)
260		- `preceeding_tokens_coords : CoordArray`
261			(defaults to `None`)
262		- `colormap_center : float | None`
263			(defaults to `None`)
264		- `colormap_max : float | None`
265			(defaults to `None`)
266		- `hide_colorbar : bool`
267			(defaults to `False`)
268
269		# Returns:
270		- `MazePlot`
271		"""
272		assert node_values.shape == self.maze.grid_shape, (
273			"Please pass node values of the same sape as LatticeMaze.grid_shape"
274		)
275		# assert np.min(node_values) >= 0, "Please pass non-negative node values only."
276
277		self.node_values = node_values
278		# Set flag for choosing cmap while plotting maze
279		self.custom_node_value_flag = True
280		# Retrieve Max node value for plotting, +1e-10 to avoid division by zero
281		self.node_color_map = color_map
282		self.colormap_center = colormap_center
283		self.colormap_max = colormap_max
284		self.hide_colorbar = hide_colorbar
285
286		if target_token_coord is not None:
287			self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
288		if preceeding_tokens_coords is not None:
289			for coord in preceeding_tokens_coords:
290				self.marked_coords.append((coord, self.marker_kwargs_current))
291		return self

add node values to the maze for visualization as a heatmap

DOCS: what are these arguments?

Parameters:

  • node_values : Float[np.ndarray, &quot;grid_n grid_n&quot;]
  • color_map : str (defaults to "Blues")
  • target_token_coord : Coord | None (defaults to None)
  • preceeding_tokens_coords : CoordArray (defaults to None)
  • colormap_center : float | None (defaults to None)
  • colormap_max : float | None (defaults to None)
  • hide_colorbar : bool (defaults to False)

Returns:

def plot( self, dpi: int = 100, title: str = '', fig_ax: tuple | None = None, plain: bool = False) -> MazePlot:
293	def plot(
294		self,
295		dpi: int = 100,
296		title: str = "",
297		fig_ax: tuple | None = None,
298		plain: bool = False,
299	) -> MazePlot:
300		"""Plot the maze and paths."""
301		# set up figure
302		if fig_ax is None:
303			self.fig = plt.figure(dpi=dpi)
304			self.ax = self.fig.add_subplot(1, 1, 1)
305		else:
306			self.fig, self.ax = fig_ax
307
308		# plot maze
309		self._plot_maze()
310
311		# Plot labels
312		if not plain:
313			tick_arr = np.arange(self.maze.grid_shape[0])
314			self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315			self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
316			self.ax.set_xlabel("col")
317			self.ax.set_ylabel("row")
318			self.ax.set_title(title)
319
320		# plot paths
321		if self.true_path is not None:
322			self._plot_path(self.true_path)
323		for path in self.predicted_paths:
324			self._plot_path(path)
325
326		# plot markers
327		for coord, kwargs in self.marked_coords:
328			self._place_marked_coords([coord], **kwargs)
329
330		return self

Plot the maze and paths.

def mark_coords( self, coords: jaxtyping.Int8[ndarray, 'coord row_col=2'] | list[jaxtyping.Int8[ndarray, 'row_col=2']], **kwargs) -> MazePlot:
337	def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
338		"""Mark coordinates on the maze with a marker.
339
340		default marker is a blue "+":
341		`dict(marker="+", color="blue")`
342		"""
343		kwargs = {
344			**dict(marker="+", color="blue"),
345			**kwargs,
346		}
347		for coord in coords:
348			self.marked_coords.append((coord, kwargs))
349
350		return self

Mark coordinates on the maze with a marker.

default marker is a blue "+": dict(marker="+", color="blue")

def to_ascii(self, show_endpoints: bool = True, show_solution: bool = True) -> str:
586	def to_ascii(
587		self,
588		show_endpoints: bool = True,
589		show_solution: bool = True,
590	) -> str:
591		"wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
592		if self.true_path:
593			return self.solved_maze.as_ascii(
594				show_endpoints=show_endpoints,
595				show_solution=show_solution,
596			)
597		else:
598			return self.maze.as_ascii(show_endpoints=show_endpoints)

wrapper for self.solved_maze.as_ascii(), shows the path if we have self.true_path