Coverage for muutils/tensor_info.py: 88%
200 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:35 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:35 -0600
1import numpy as np
2from typing import Union, Any, Literal, List, Dict, overload
4# Global color definitions
5COLORS: Dict[str, Dict[str, str]] = {
6 "latex": {
7 "range": r"\textcolor{purple}",
8 "mean": r"\textcolor{teal}",
9 "std": r"\textcolor{orange}",
10 "median": r"\textcolor{green}",
11 "warning": r"\textcolor{red}",
12 "shape": r"\textcolor{magenta}",
13 "dtype": r"\textcolor{gray}",
14 "device": r"\textcolor{gray}",
15 "requires_grad": r"\textcolor{gray}",
16 "sparkline": r"\textcolor{blue}",
17 "reset": "",
18 },
19 "terminal": {
20 "range": "\033[35m", # purple
21 "mean": "\033[36m", # cyan/teal
22 "std": "\033[33m", # yellow/orange
23 "median": "\033[32m", # green
24 "warning": "\033[31m", # red
25 "shape": "\033[95m", # bright magenta
26 "dtype": "\033[90m", # gray
27 "device": "\033[90m", # gray
28 "requires_grad": "\033[90m", # gray
29 "sparkline": "\033[34m", # blue
30 "reset": "\033[0m",
31 },
32 "none": {
33 "range": "",
34 "mean": "",
35 "std": "",
36 "median": "",
37 "warning": "",
38 "shape": "",
39 "dtype": "",
40 "device": "",
41 "requires_grad": "",
42 "sparkline": "",
43 "reset": "",
44 },
45}
47OutputFormat = Literal["unicode", "latex", "ascii"]
49SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
50 "latex": {
51 "range": r"\mathcal{R}",
52 "mean": r"\mu",
53 "std": r"\sigma",
54 "median": r"\tilde{x}",
55 "distribution": r"\mathbb{P}",
56 "nan_values": r"\texttt{NANvals}",
57 "warning": "!!!",
58 "requires_grad": r"\nabla",
59 },
60 "unicode": {
61 "range": "R",
62 "mean": "μ",
63 "std": "σ",
64 "median": "x̃",
65 "distribution": "ℙ",
66 "nan_values": "NANvals",
67 "warning": "🚨",
68 "requires_grad": "∇",
69 },
70 "ascii": {
71 "range": "range",
72 "mean": "mean",
73 "std": "std",
74 "median": "med",
75 "distribution": "dist",
76 "nan_values": "NANvals",
77 "warning": "!!!",
78 "requires_grad": "requires_grad",
79 },
80}
81"Symbols for different formats"
83SPARK_CHARS: Dict[OutputFormat, List[str]] = {
84 "unicode": list(" ▁▂▃▄▅▆▇█"),
85 "ascii": list(" _.-~=#"),
86 "latex": list(" ▁▂▃▄▅▆▇█"),
87 # "latex": [r"\textbf{.}", r"\textbf{-}", r"\textbf{=}", r"\textbf{+}", r"\textbf{*}", r"\textbf{\\#}"],
88}
89"characters for sparklines in different formats"
92def array_info(
93 A: Any,
94 hist_bins: int = 5,
95) -> Dict[str, Any]:
96 """Extract statistical information from an array-like object.
98 # Parameters:
99 - `A : array-like`
100 Array to analyze (numpy array or torch tensor)
102 # Returns:
103 - `Dict[str, Any]`
104 Dictionary containing raw statistical information with numeric values
105 """
106 result: Dict[str, Any] = {
107 "is_tensor": None,
108 "device": None,
109 "requires_grad": None,
110 "shape": None,
111 "dtype": None,
112 "size": None,
113 "has_nans": None,
114 "nan_count": None,
115 "nan_percent": None,
116 "min": None,
117 "max": None,
118 "range": None,
119 "mean": None,
120 "std": None,
121 "median": None,
122 "histogram": None,
123 "bins": None,
124 "status": None,
125 }
127 # Check if it's a tensor by looking at its class name
128 # This avoids importing torch directly
129 A_type: str = type(A).__name__
130 result["is_tensor"] = A_type == "Tensor"
132 # Try to get device information if it's a tensor
133 if result["is_tensor"]:
134 try:
135 result["device"] = str(getattr(A, "device", None))
136 except: # noqa: E722
137 pass
139 # Convert to numpy array for calculations
140 try:
141 # For PyTorch tensors
142 if result["is_tensor"]:
143 # Check if tensor is on GPU
144 is_cuda: bool = False
145 try:
146 is_cuda = bool(getattr(A, "is_cuda", False))
147 except: # noqa: E722
148 pass
150 if is_cuda:
151 try:
152 # Try to get CPU tensor first
153 cpu_tensor = getattr(A, "cpu", lambda: A)()
154 except: # noqa: E722
155 A_np = np.array([])
156 else:
157 cpu_tensor = A
158 try:
159 # For CPU tensor, just detach and convert
160 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
161 A_np = getattr(detached, "numpy", lambda: np.array([]))()
162 except: # noqa: E722
163 A_np = np.array([])
164 else:
165 # For numpy arrays and other array-like objects
166 A_np = np.asarray(A)
167 except: # noqa: E722
168 A_np = np.array([])
170 # Get basic information
171 try:
172 result["shape"] = A_np.shape
173 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
174 result["size"] = A_np.size
175 result["requires_grad"] = getattr(A, "requires_grad", None)
176 except: # noqa: E722
177 pass
179 # If array is empty, return early
180 if result["size"] == 0:
181 result["status"] = "empty array"
182 return result
184 # Flatten array for statistics if it's multi-dimensional
185 try:
186 if len(A_np.shape) > 1:
187 A_flat = A_np.flatten()
188 else:
189 A_flat = A_np
190 except: # noqa: E722
191 A_flat = A_np
193 # Check for NaN values
194 try:
195 nan_mask = np.isnan(A_flat)
196 result["nan_count"] = np.sum(nan_mask)
197 result["has_nans"] = result["nan_count"] > 0
198 if result["size"] > 0:
199 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
200 except: # noqa: E722
201 pass
203 # If all values are NaN, return early
204 if result["has_nans"] and result["nan_count"] == result["size"]:
205 result["status"] = "all NaN"
206 return result
208 # Calculate statistics
209 try:
210 if result["has_nans"]:
211 result["min"] = float(np.nanmin(A_flat))
212 result["max"] = float(np.nanmax(A_flat))
213 result["mean"] = float(np.nanmean(A_flat))
214 result["std"] = float(np.nanstd(A_flat))
215 result["median"] = float(np.nanmedian(A_flat))
216 result["range"] = (result["min"], result["max"])
218 # Remove NaNs for histogram
219 A_hist = A_flat[~nan_mask]
220 else:
221 result["min"] = float(np.min(A_flat))
222 result["max"] = float(np.max(A_flat))
223 result["mean"] = float(np.mean(A_flat))
224 result["std"] = float(np.std(A_flat))
225 result["median"] = float(np.median(A_flat))
226 result["range"] = (result["min"], result["max"])
228 A_hist = A_flat
230 # Calculate histogram data for sparklines
231 if A_hist.size > 0:
232 try:
233 hist, bins = np.histogram(A_hist, bins=hist_bins)
234 result["histogram"] = hist
235 result["bins"] = bins
236 except: # noqa: E722
237 pass
239 result["status"] = "ok"
240 except Exception as e:
241 result["status"] = f"error: {str(e)}"
243 return result
246def generate_sparkline(
247 histogram: np.ndarray,
248 format: Literal["unicode", "latex", "ascii"] = "unicode",
249 log_y: bool = False,
250) -> str:
251 """Generate a sparkline visualization of the histogram.
253 # Parameters:
254 - `histogram : np.ndarray`
255 Histogram data
256 - `format : Literal["unicode", "latex", "ascii"]`
257 Output format (defaults to `"unicode"`)
258 - `log_y : bool`
259 Whether to use logarithmic y-scale (defaults to `False`)
261 # Returns:
262 - `str`
263 Sparkline visualization
264 """
265 if histogram is None or len(histogram) == 0:
266 return ""
268 # Get the appropriate character set
269 if format in SPARK_CHARS:
270 chars = SPARK_CHARS[format]
271 else:
272 chars = SPARK_CHARS["ascii"]
274 # Handle log scale
275 if log_y:
276 # Add small value to avoid log(0)
277 hist_data = np.log1p(histogram)
278 else:
279 hist_data = histogram
281 # Normalize to character set range
282 if hist_data.max() > 0:
283 normalized = hist_data / hist_data.max() * (len(chars) - 1)
284 else:
285 normalized = np.zeros_like(hist_data)
287 # Convert to characters
288 spark = ""
289 for val in normalized:
290 idx = int(val)
291 spark += chars[idx]
293 return spark
296DEFAULT_SETTINGS: Dict[str, Any] = dict(
297 fmt="unicode",
298 precision=2,
299 stats=True,
300 shape=True,
301 dtype=True,
302 device=True,
303 requires_grad=True,
304 sparkline=False,
305 sparkline_bins=5,
306 sparkline_logy=False,
307 colored=False,
308 as_list=False,
309 eq_char="=",
310)
313class _UseDefaultType:
314 pass
317_USE_DEFAULT = _UseDefaultType()
320@overload
321def array_summary(
322 as_list: Literal[True],
323 **kwargs,
324) -> List[str]: ...
325@overload
326def array_summary(
327 as_list: Literal[False],
328 **kwargs,
329) -> str: ...
330def array_summary( # type: ignore[misc]
331 array,
332 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment]
333 precision: int = _USE_DEFAULT, # type: ignore[assignment]
334 stats: bool = _USE_DEFAULT, # type: ignore[assignment]
335 shape: bool = _USE_DEFAULT, # type: ignore[assignment]
336 dtype: bool = _USE_DEFAULT, # type: ignore[assignment]
337 device: bool = _USE_DEFAULT, # type: ignore[assignment]
338 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment]
339 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment]
340 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment]
341 sparkline_logy: bool = _USE_DEFAULT, # type: ignore[assignment]
342 colored: bool = _USE_DEFAULT, # type: ignore[assignment]
343 eq_char: str = _USE_DEFAULT, # type: ignore[assignment]
344 as_list: bool = _USE_DEFAULT, # type: ignore[assignment]
345) -> Union[str, List[str]]:
346 """Format array information into a readable summary.
348 # Parameters:
349 - `array`
350 array-like object (numpy array or torch tensor)
351 - `precision : int`
352 Decimal places (defaults to `2`)
353 - `format : Literal["unicode", "latex", "ascii"]`
354 Output format (defaults to `{default_fmt}`)
355 - `stats : bool`
356 Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
357 - `shape : bool`
358 Whether to include shape info (defaults to `True`)
359 - `dtype : bool`
360 Whether to include dtype info (defaults to `True`)
361 - `device : bool`
362 Whether to include device info for torch tensors (defaults to `True`)
363 - `requires_grad : bool`
364 Whether to include requires_grad info for torch tensors (defaults to `True`)
365 - `sparkline : bool`
366 Whether to include a sparkline visualization (defaults to `False`)
367 - `sparkline_width : int`
368 Width of the sparkline (defaults to `20`)
369 - `sparkline_logy : bool`
370 Whether to use logarithmic y-scale for sparkline (defaults to `False`)
371 - `colored : bool`
372 Whether to add color to output (defaults to `False`)
373 - `as_list : bool`
374 Whether to return as list of strings instead of joined string (defaults to `False`)
376 # Returns:
377 - `Union[str, List[str]]`
378 Formatted statistical summary, either as string or list of strings
379 """
380 if fmt is _USE_DEFAULT:
381 fmt = DEFAULT_SETTINGS["fmt"]
382 if precision is _USE_DEFAULT:
383 precision = DEFAULT_SETTINGS["precision"]
384 if stats is _USE_DEFAULT:
385 stats = DEFAULT_SETTINGS["stats"]
386 if shape is _USE_DEFAULT:
387 shape = DEFAULT_SETTINGS["shape"]
388 if dtype is _USE_DEFAULT:
389 dtype = DEFAULT_SETTINGS["dtype"]
390 if device is _USE_DEFAULT:
391 device = DEFAULT_SETTINGS["device"]
392 if requires_grad is _USE_DEFAULT:
393 requires_grad = DEFAULT_SETTINGS["requires_grad"]
394 if sparkline is _USE_DEFAULT:
395 sparkline = DEFAULT_SETTINGS["sparkline"]
396 if sparkline_bins is _USE_DEFAULT:
397 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
398 if sparkline_logy is _USE_DEFAULT:
399 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
400 if colored is _USE_DEFAULT:
401 colored = DEFAULT_SETTINGS["colored"]
402 if as_list is _USE_DEFAULT:
403 as_list = DEFAULT_SETTINGS["as_list"]
404 if eq_char is _USE_DEFAULT:
405 eq_char = DEFAULT_SETTINGS["eq_char"]
407 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
408 result_parts: List[str] = []
409 using_tex: bool = fmt == "latex"
411 # Set color scheme based on format and colored flag
412 colors: Dict[str, str]
413 if colored:
414 colors = COLORS["latex"] if using_tex else COLORS["terminal"]
415 else:
416 colors = COLORS["none"]
418 # Get symbols for the current format
419 symbols: Dict[str, str] = SYMBOLS[fmt]
421 # Helper function to colorize text
422 def colorize(text: str, color_key: str) -> str:
423 if using_tex:
424 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text
425 else:
426 return (
427 f"{colors[color_key]}{text}{colors['reset']}"
428 if colors[color_key]
429 else text
430 )
432 # Format string for numbers
433 float_fmt: str = f".{precision}f"
435 # Handle error status or empty array
436 if (
437 array_data["status"] in ["empty array", "all NaN", "unknown"]
438 or array_data["size"] == 0
439 ):
440 status = array_data["status"]
441 result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
442 else:
443 # Add NaN warning at the beginning if there are NaNs
444 if array_data["has_nans"]:
445 _percent: str = "\\%" if using_tex else "%"
446 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
447 result_parts.append(colorize(nan_str, "warning"))
449 # Statistics
450 if stats:
451 for stat_key in ["mean", "std", "median"]:
452 if array_data[stat_key] is not None:
453 stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
454 stat_colored: str = colorize(stat_str, stat_key)
455 result_parts.append(f"{symbols[stat_key]}={stat_colored}")
457 # Range (min, max)
458 if array_data["range"] is not None:
459 min_val, max_val = array_data["range"]
460 min_str: str = f"{min_val:{float_fmt}}"
461 max_str: str = f"{max_val:{float_fmt}}"
462 min_colored: str = colorize(min_str, "range")
463 max_colored: str = colorize(max_str, "range")
464 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
465 result_parts.append(range_str)
467 # Add sparkline if requested
468 if sparkline and array_data["histogram"] is not None:
469 print(array_data["histogram"])
470 print(array_data["bins"])
471 spark = generate_sparkline(
472 array_data["histogram"], format=fmt, log_y=sparkline_logy
473 )
474 if spark:
475 spark_colored = colorize(spark, "sparkline")
476 result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|")
478 # Add shape if requested
479 if shape and array_data["shape"]:
480 shape_val = array_data["shape"]
481 if len(shape_val) == 1:
482 shape_str = str(shape_val[0])
483 else:
484 shape_str = (
485 "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")"
486 )
487 result_parts.append(f"shape{eq_char}{shape_str}")
489 # Add dtype if requested
490 if dtype and array_data["dtype"]:
491 result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype"))
493 # Add device if requested and it's a tensor with device info
494 if device and array_data["is_tensor"] and array_data["device"]:
495 result_parts.append(
496 colorize(f"device{eq_char}{array_data['device']}", "device")
497 )
499 # Add gradient info
500 if requires_grad and array_data["is_tensor"] and array_data["requires_grad"]:
501 result_parts.append(colorize(symbols["requires_grad"], "requires_grad"))
503 # Return as list if requested, otherwise join with spaces
504 if as_list:
505 return result_parts
506 else:
507 joinchar: str = r" \quad " if using_tex else " "
508 return joinchar.join(result_parts)