muutils.tensor_info
1import numpy as np 2from typing import Union, Any, Literal, List, Dict, overload 3 4# Global color definitions 5COLORS: Dict[str, Dict[str, str]] = { 6 "latex": { 7 "range": r"\textcolor{purple}", 8 "mean": r"\textcolor{teal}", 9 "std": r"\textcolor{orange}", 10 "median": r"\textcolor{green}", 11 "warning": r"\textcolor{red}", 12 "shape": r"\textcolor{magenta}", 13 "dtype": r"\textcolor{gray}", 14 "device": r"\textcolor{gray}", 15 "requires_grad": r"\textcolor{gray}", 16 "sparkline": r"\textcolor{blue}", 17 "reset": "", 18 }, 19 "terminal": { 20 "range": "\033[35m", # purple 21 "mean": "\033[36m", # cyan/teal 22 "std": "\033[33m", # yellow/orange 23 "median": "\033[32m", # green 24 "warning": "\033[31m", # red 25 "shape": "\033[95m", # bright magenta 26 "dtype": "\033[90m", # gray 27 "device": "\033[90m", # gray 28 "requires_grad": "\033[90m", # gray 29 "sparkline": "\033[34m", # blue 30 "reset": "\033[0m", 31 }, 32 "none": { 33 "range": "", 34 "mean": "", 35 "std": "", 36 "median": "", 37 "warning": "", 38 "shape": "", 39 "dtype": "", 40 "device": "", 41 "requires_grad": "", 42 "sparkline": "", 43 "reset": "", 44 }, 45} 46 47OutputFormat = Literal["unicode", "latex", "ascii"] 48 49SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { 50 "latex": { 51 "range": r"\mathcal{R}", 52 "mean": r"\mu", 53 "std": r"\sigma", 54 "median": r"\tilde{x}", 55 "distribution": r"\mathbb{P}", 56 "nan_values": r"\texttt{NANvals}", 57 "warning": "!!!", 58 "requires_grad": r"\nabla", 59 }, 60 "unicode": { 61 "range": "R", 62 "mean": "μ", 63 "std": "σ", 64 "median": "x̃", 65 "distribution": "ℙ", 66 "nan_values": "NANvals", 67 "warning": "🚨", 68 "requires_grad": "∇", 69 }, 70 "ascii": { 71 "range": "range", 72 "mean": "mean", 73 "std": "std", 74 "median": "med", 75 "distribution": "dist", 76 "nan_values": "NANvals", 77 "warning": "!!!", 78 "requires_grad": "requires_grad", 79 }, 80} 81"Symbols for different formats" 82 83SPARK_CHARS: Dict[OutputFormat, List[str]] = { 84 "unicode": list(" ▁▂▃▄▅▆▇█"), 85 "ascii": list(" _.-~=#"), 86 "latex": list(" ▁▂▃▄▅▆▇█"), 87 # "latex": [r"\textbf{.}", r"\textbf{-}", r"\textbf{=}", r"\textbf{+}", r"\textbf{*}", r"\textbf{\\#}"], 88} 89"characters for sparklines in different formats" 90 91 92def array_info( 93 A: Any, 94 hist_bins: int = 5, 95) -> Dict[str, Any]: 96 """Extract statistical information from an array-like object. 97 98 # Parameters: 99 - `A : array-like` 100 Array to analyze (numpy array or torch tensor) 101 102 # Returns: 103 - `Dict[str, Any]` 104 Dictionary containing raw statistical information with numeric values 105 """ 106 result: Dict[str, Any] = { 107 "is_tensor": None, 108 "device": None, 109 "requires_grad": None, 110 "shape": None, 111 "dtype": None, 112 "size": None, 113 "has_nans": None, 114 "nan_count": None, 115 "nan_percent": None, 116 "min": None, 117 "max": None, 118 "range": None, 119 "mean": None, 120 "std": None, 121 "median": None, 122 "histogram": None, 123 "bins": None, 124 "status": None, 125 } 126 127 # Check if it's a tensor by looking at its class name 128 # This avoids importing torch directly 129 A_type: str = type(A).__name__ 130 result["is_tensor"] = A_type == "Tensor" 131 132 # Try to get device information if it's a tensor 133 if result["is_tensor"]: 134 try: 135 result["device"] = str(getattr(A, "device", None)) 136 except: # noqa: E722 137 pass 138 139 # Convert to numpy array for calculations 140 try: 141 # For PyTorch tensors 142 if result["is_tensor"]: 143 # Check if tensor is on GPU 144 is_cuda: bool = False 145 try: 146 is_cuda = bool(getattr(A, "is_cuda", False)) 147 except: # noqa: E722 148 pass 149 150 if is_cuda: 151 try: 152 # Try to get CPU tensor first 153 cpu_tensor = getattr(A, "cpu", lambda: A)() 154 except: # noqa: E722 155 A_np = np.array([]) 156 else: 157 cpu_tensor = A 158 try: 159 # For CPU tensor, just detach and convert 160 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 161 A_np = getattr(detached, "numpy", lambda: np.array([]))() 162 except: # noqa: E722 163 A_np = np.array([]) 164 else: 165 # For numpy arrays and other array-like objects 166 A_np = np.asarray(A) 167 except: # noqa: E722 168 A_np = np.array([]) 169 170 # Get basic information 171 try: 172 result["shape"] = A_np.shape 173 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 174 result["size"] = A_np.size 175 result["requires_grad"] = getattr(A, "requires_grad", None) 176 except: # noqa: E722 177 pass 178 179 # If array is empty, return early 180 if result["size"] == 0: 181 result["status"] = "empty array" 182 return result 183 184 # Flatten array for statistics if it's multi-dimensional 185 try: 186 if len(A_np.shape) > 1: 187 A_flat = A_np.flatten() 188 else: 189 A_flat = A_np 190 except: # noqa: E722 191 A_flat = A_np 192 193 # Check for NaN values 194 try: 195 nan_mask = np.isnan(A_flat) 196 result["nan_count"] = np.sum(nan_mask) 197 result["has_nans"] = result["nan_count"] > 0 198 if result["size"] > 0: 199 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 200 except: # noqa: E722 201 pass 202 203 # If all values are NaN, return early 204 if result["has_nans"] and result["nan_count"] == result["size"]: 205 result["status"] = "all NaN" 206 return result 207 208 # Calculate statistics 209 try: 210 if result["has_nans"]: 211 result["min"] = float(np.nanmin(A_flat)) 212 result["max"] = float(np.nanmax(A_flat)) 213 result["mean"] = float(np.nanmean(A_flat)) 214 result["std"] = float(np.nanstd(A_flat)) 215 result["median"] = float(np.nanmedian(A_flat)) 216 result["range"] = (result["min"], result["max"]) 217 218 # Remove NaNs for histogram 219 A_hist = A_flat[~nan_mask] 220 else: 221 result["min"] = float(np.min(A_flat)) 222 result["max"] = float(np.max(A_flat)) 223 result["mean"] = float(np.mean(A_flat)) 224 result["std"] = float(np.std(A_flat)) 225 result["median"] = float(np.median(A_flat)) 226 result["range"] = (result["min"], result["max"]) 227 228 A_hist = A_flat 229 230 # Calculate histogram data for sparklines 231 if A_hist.size > 0: 232 try: 233 hist, bins = np.histogram(A_hist, bins=hist_bins) 234 result["histogram"] = hist 235 result["bins"] = bins 236 except: # noqa: E722 237 pass 238 239 result["status"] = "ok" 240 except Exception as e: 241 result["status"] = f"error: {str(e)}" 242 243 return result 244 245 246def generate_sparkline( 247 histogram: np.ndarray, 248 format: Literal["unicode", "latex", "ascii"] = "unicode", 249 log_y: bool = False, 250) -> str: 251 """Generate a sparkline visualization of the histogram. 252 253 # Parameters: 254 - `histogram : np.ndarray` 255 Histogram data 256 - `format : Literal["unicode", "latex", "ascii"]` 257 Output format (defaults to `"unicode"`) 258 - `log_y : bool` 259 Whether to use logarithmic y-scale (defaults to `False`) 260 261 # Returns: 262 - `str` 263 Sparkline visualization 264 """ 265 if histogram is None or len(histogram) == 0: 266 return "" 267 268 # Get the appropriate character set 269 if format in SPARK_CHARS: 270 chars = SPARK_CHARS[format] 271 else: 272 chars = SPARK_CHARS["ascii"] 273 274 # Handle log scale 275 if log_y: 276 # Add small value to avoid log(0) 277 hist_data = np.log1p(histogram) 278 else: 279 hist_data = histogram 280 281 # Normalize to character set range 282 if hist_data.max() > 0: 283 normalized = hist_data / hist_data.max() * (len(chars) - 1) 284 else: 285 normalized = np.zeros_like(hist_data) 286 287 # Convert to characters 288 spark = "" 289 for val in normalized: 290 idx = int(val) 291 spark += chars[idx] 292 293 return spark 294 295 296DEFAULT_SETTINGS: Dict[str, Any] = dict( 297 fmt="unicode", 298 precision=2, 299 stats=True, 300 shape=True, 301 dtype=True, 302 device=True, 303 requires_grad=True, 304 sparkline=False, 305 sparkline_bins=5, 306 sparkline_logy=False, 307 colored=False, 308 as_list=False, 309 eq_char="=", 310) 311 312 313class _UseDefaultType: 314 pass 315 316 317_USE_DEFAULT = _UseDefaultType() 318 319 320@overload 321def array_summary( 322 as_list: Literal[True], 323 **kwargs, 324) -> List[str]: ... 325@overload 326def array_summary( 327 as_list: Literal[False], 328 **kwargs, 329) -> str: ... 330def array_summary( # type: ignore[misc] 331 array, 332 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 333 precision: int = _USE_DEFAULT, # type: ignore[assignment] 334 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 335 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 336 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 337 device: bool = _USE_DEFAULT, # type: ignore[assignment] 338 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 339 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 340 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 341 sparkline_logy: bool = _USE_DEFAULT, # type: ignore[assignment] 342 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 343 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 344 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 345) -> Union[str, List[str]]: 346 """Format array information into a readable summary. 347 348 # Parameters: 349 - `array` 350 array-like object (numpy array or torch tensor) 351 - `precision : int` 352 Decimal places (defaults to `2`) 353 - `format : Literal["unicode", "latex", "ascii"]` 354 Output format (defaults to `{default_fmt}`) 355 - `stats : bool` 356 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 357 - `shape : bool` 358 Whether to include shape info (defaults to `True`) 359 - `dtype : bool` 360 Whether to include dtype info (defaults to `True`) 361 - `device : bool` 362 Whether to include device info for torch tensors (defaults to `True`) 363 - `requires_grad : bool` 364 Whether to include requires_grad info for torch tensors (defaults to `True`) 365 - `sparkline : bool` 366 Whether to include a sparkline visualization (defaults to `False`) 367 - `sparkline_width : int` 368 Width of the sparkline (defaults to `20`) 369 - `sparkline_logy : bool` 370 Whether to use logarithmic y-scale for sparkline (defaults to `False`) 371 - `colored : bool` 372 Whether to add color to output (defaults to `False`) 373 - `as_list : bool` 374 Whether to return as list of strings instead of joined string (defaults to `False`) 375 376 # Returns: 377 - `Union[str, List[str]]` 378 Formatted statistical summary, either as string or list of strings 379 """ 380 if fmt is _USE_DEFAULT: 381 fmt = DEFAULT_SETTINGS["fmt"] 382 if precision is _USE_DEFAULT: 383 precision = DEFAULT_SETTINGS["precision"] 384 if stats is _USE_DEFAULT: 385 stats = DEFAULT_SETTINGS["stats"] 386 if shape is _USE_DEFAULT: 387 shape = DEFAULT_SETTINGS["shape"] 388 if dtype is _USE_DEFAULT: 389 dtype = DEFAULT_SETTINGS["dtype"] 390 if device is _USE_DEFAULT: 391 device = DEFAULT_SETTINGS["device"] 392 if requires_grad is _USE_DEFAULT: 393 requires_grad = DEFAULT_SETTINGS["requires_grad"] 394 if sparkline is _USE_DEFAULT: 395 sparkline = DEFAULT_SETTINGS["sparkline"] 396 if sparkline_bins is _USE_DEFAULT: 397 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 398 if sparkline_logy is _USE_DEFAULT: 399 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 400 if colored is _USE_DEFAULT: 401 colored = DEFAULT_SETTINGS["colored"] 402 if as_list is _USE_DEFAULT: 403 as_list = DEFAULT_SETTINGS["as_list"] 404 if eq_char is _USE_DEFAULT: 405 eq_char = DEFAULT_SETTINGS["eq_char"] 406 407 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 408 result_parts: List[str] = [] 409 using_tex: bool = fmt == "latex" 410 411 # Set color scheme based on format and colored flag 412 colors: Dict[str, str] 413 if colored: 414 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 415 else: 416 colors = COLORS["none"] 417 418 # Get symbols for the current format 419 symbols: Dict[str, str] = SYMBOLS[fmt] 420 421 # Helper function to colorize text 422 def colorize(text: str, color_key: str) -> str: 423 if using_tex: 424 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 425 else: 426 return ( 427 f"{colors[color_key]}{text}{colors['reset']}" 428 if colors[color_key] 429 else text 430 ) 431 432 # Format string for numbers 433 float_fmt: str = f".{precision}f" 434 435 # Handle error status or empty array 436 if ( 437 array_data["status"] in ["empty array", "all NaN", "unknown"] 438 or array_data["size"] == 0 439 ): 440 status = array_data["status"] 441 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 442 else: 443 # Add NaN warning at the beginning if there are NaNs 444 if array_data["has_nans"]: 445 _percent: str = "\\%" if using_tex else "%" 446 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 447 result_parts.append(colorize(nan_str, "warning")) 448 449 # Statistics 450 if stats: 451 for stat_key in ["mean", "std", "median"]: 452 if array_data[stat_key] is not None: 453 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 454 stat_colored: str = colorize(stat_str, stat_key) 455 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 456 457 # Range (min, max) 458 if array_data["range"] is not None: 459 min_val, max_val = array_data["range"] 460 min_str: str = f"{min_val:{float_fmt}}" 461 max_str: str = f"{max_val:{float_fmt}}" 462 min_colored: str = colorize(min_str, "range") 463 max_colored: str = colorize(max_str, "range") 464 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 465 result_parts.append(range_str) 466 467 # Add sparkline if requested 468 if sparkline and array_data["histogram"] is not None: 469 print(array_data["histogram"]) 470 print(array_data["bins"]) 471 spark = generate_sparkline( 472 array_data["histogram"], format=fmt, log_y=sparkline_logy 473 ) 474 if spark: 475 spark_colored = colorize(spark, "sparkline") 476 result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|") 477 478 # Add shape if requested 479 if shape and array_data["shape"]: 480 shape_val = array_data["shape"] 481 if len(shape_val) == 1: 482 shape_str = str(shape_val[0]) 483 else: 484 shape_str = ( 485 "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")" 486 ) 487 result_parts.append(f"shape{eq_char}{shape_str}") 488 489 # Add dtype if requested 490 if dtype and array_data["dtype"]: 491 result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype")) 492 493 # Add device if requested and it's a tensor with device info 494 if device and array_data["is_tensor"] and array_data["device"]: 495 result_parts.append( 496 colorize(f"device{eq_char}{array_data['device']}", "device") 497 ) 498 499 # Add gradient info 500 if requires_grad and array_data["is_tensor"] and array_data["requires_grad"]: 501 result_parts.append(colorize(symbols["requires_grad"], "requires_grad")) 502 503 # Return as list if requested, otherwise join with spaces 504 if as_list: 505 return result_parts 506 else: 507 joinchar: str = r" \quad " if using_tex else " " 508 return joinchar.join(result_parts)
COLORS: Dict[str, Dict[str, str]] =
{'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'reset': ''}}
OutputFormat =
typing.Literal['unicode', 'latex', 'ascii']
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] =
{'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'nan_values': '\\texttt{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad'}}
Symbols for different formats
SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] =
{'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}
characters for sparklines in different formats
def
array_info(A: Any, hist_bins: int = 5) -> Dict[str, Any]:
93def array_info( 94 A: Any, 95 hist_bins: int = 5, 96) -> Dict[str, Any]: 97 """Extract statistical information from an array-like object. 98 99 # Parameters: 100 - `A : array-like` 101 Array to analyze (numpy array or torch tensor) 102 103 # Returns: 104 - `Dict[str, Any]` 105 Dictionary containing raw statistical information with numeric values 106 """ 107 result: Dict[str, Any] = { 108 "is_tensor": None, 109 "device": None, 110 "requires_grad": None, 111 "shape": None, 112 "dtype": None, 113 "size": None, 114 "has_nans": None, 115 "nan_count": None, 116 "nan_percent": None, 117 "min": None, 118 "max": None, 119 "range": None, 120 "mean": None, 121 "std": None, 122 "median": None, 123 "histogram": None, 124 "bins": None, 125 "status": None, 126 } 127 128 # Check if it's a tensor by looking at its class name 129 # This avoids importing torch directly 130 A_type: str = type(A).__name__ 131 result["is_tensor"] = A_type == "Tensor" 132 133 # Try to get device information if it's a tensor 134 if result["is_tensor"]: 135 try: 136 result["device"] = str(getattr(A, "device", None)) 137 except: # noqa: E722 138 pass 139 140 # Convert to numpy array for calculations 141 try: 142 # For PyTorch tensors 143 if result["is_tensor"]: 144 # Check if tensor is on GPU 145 is_cuda: bool = False 146 try: 147 is_cuda = bool(getattr(A, "is_cuda", False)) 148 except: # noqa: E722 149 pass 150 151 if is_cuda: 152 try: 153 # Try to get CPU tensor first 154 cpu_tensor = getattr(A, "cpu", lambda: A)() 155 except: # noqa: E722 156 A_np = np.array([]) 157 else: 158 cpu_tensor = A 159 try: 160 # For CPU tensor, just detach and convert 161 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 162 A_np = getattr(detached, "numpy", lambda: np.array([]))() 163 except: # noqa: E722 164 A_np = np.array([]) 165 else: 166 # For numpy arrays and other array-like objects 167 A_np = np.asarray(A) 168 except: # noqa: E722 169 A_np = np.array([]) 170 171 # Get basic information 172 try: 173 result["shape"] = A_np.shape 174 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 175 result["size"] = A_np.size 176 result["requires_grad"] = getattr(A, "requires_grad", None) 177 except: # noqa: E722 178 pass 179 180 # If array is empty, return early 181 if result["size"] == 0: 182 result["status"] = "empty array" 183 return result 184 185 # Flatten array for statistics if it's multi-dimensional 186 try: 187 if len(A_np.shape) > 1: 188 A_flat = A_np.flatten() 189 else: 190 A_flat = A_np 191 except: # noqa: E722 192 A_flat = A_np 193 194 # Check for NaN values 195 try: 196 nan_mask = np.isnan(A_flat) 197 result["nan_count"] = np.sum(nan_mask) 198 result["has_nans"] = result["nan_count"] > 0 199 if result["size"] > 0: 200 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 201 except: # noqa: E722 202 pass 203 204 # If all values are NaN, return early 205 if result["has_nans"] and result["nan_count"] == result["size"]: 206 result["status"] = "all NaN" 207 return result 208 209 # Calculate statistics 210 try: 211 if result["has_nans"]: 212 result["min"] = float(np.nanmin(A_flat)) 213 result["max"] = float(np.nanmax(A_flat)) 214 result["mean"] = float(np.nanmean(A_flat)) 215 result["std"] = float(np.nanstd(A_flat)) 216 result["median"] = float(np.nanmedian(A_flat)) 217 result["range"] = (result["min"], result["max"]) 218 219 # Remove NaNs for histogram 220 A_hist = A_flat[~nan_mask] 221 else: 222 result["min"] = float(np.min(A_flat)) 223 result["max"] = float(np.max(A_flat)) 224 result["mean"] = float(np.mean(A_flat)) 225 result["std"] = float(np.std(A_flat)) 226 result["median"] = float(np.median(A_flat)) 227 result["range"] = (result["min"], result["max"]) 228 229 A_hist = A_flat 230 231 # Calculate histogram data for sparklines 232 if A_hist.size > 0: 233 try: 234 hist, bins = np.histogram(A_hist, bins=hist_bins) 235 result["histogram"] = hist 236 result["bins"] = bins 237 except: # noqa: E722 238 pass 239 240 result["status"] = "ok" 241 except Exception as e: 242 result["status"] = f"error: {str(e)}" 243 244 return result
Extract statistical information from an array-like object.
Parameters:
A : array-like
Array to analyze (numpy array or torch tensor)
Returns:
Dict[str, Any]
Dictionary containing raw statistical information with numeric values
def
generate_sparkline( histogram: numpy.ndarray, format: Literal['unicode', 'latex', 'ascii'] = 'unicode', log_y: bool = False) -> str:
247def generate_sparkline( 248 histogram: np.ndarray, 249 format: Literal["unicode", "latex", "ascii"] = "unicode", 250 log_y: bool = False, 251) -> str: 252 """Generate a sparkline visualization of the histogram. 253 254 # Parameters: 255 - `histogram : np.ndarray` 256 Histogram data 257 - `format : Literal["unicode", "latex", "ascii"]` 258 Output format (defaults to `"unicode"`) 259 - `log_y : bool` 260 Whether to use logarithmic y-scale (defaults to `False`) 261 262 # Returns: 263 - `str` 264 Sparkline visualization 265 """ 266 if histogram is None or len(histogram) == 0: 267 return "" 268 269 # Get the appropriate character set 270 if format in SPARK_CHARS: 271 chars = SPARK_CHARS[format] 272 else: 273 chars = SPARK_CHARS["ascii"] 274 275 # Handle log scale 276 if log_y: 277 # Add small value to avoid log(0) 278 hist_data = np.log1p(histogram) 279 else: 280 hist_data = histogram 281 282 # Normalize to character set range 283 if hist_data.max() > 0: 284 normalized = hist_data / hist_data.max() * (len(chars) - 1) 285 else: 286 normalized = np.zeros_like(hist_data) 287 288 # Convert to characters 289 spark = "" 290 for val in normalized: 291 idx = int(val) 292 spark += chars[idx] 293 294 return spark
Generate a sparkline visualization of the histogram.
Parameters:
histogram : np.ndarray
Histogram dataformat : Literal["unicode", "latex", "ascii"]
Output format (defaults to"unicode"
)log_y : bool
Whether to use logarithmic y-scale (defaults toFalse
)
Returns:
str
Sparkline visualization
DEFAULT_SETTINGS: Dict[str, Any] =
{'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': False, 'colored': False, 'as_list': False, 'eq_char': '='}
def
array_summary( array, fmt: Literal['unicode', 'latex', 'ascii'] = <muutils.tensor_info._UseDefaultType object>, precision: int = <muutils.tensor_info._UseDefaultType object>, stats: bool = <muutils.tensor_info._UseDefaultType object>, shape: bool = <muutils.tensor_info._UseDefaultType object>, dtype: bool = <muutils.tensor_info._UseDefaultType object>, device: bool = <muutils.tensor_info._UseDefaultType object>, requires_grad: bool = <muutils.tensor_info._UseDefaultType object>, sparkline: bool = <muutils.tensor_info._UseDefaultType object>, sparkline_bins: int = <muutils.tensor_info._UseDefaultType object>, sparkline_logy: bool = <muutils.tensor_info._UseDefaultType object>, colored: bool = <muutils.tensor_info._UseDefaultType object>, eq_char: str = <muutils.tensor_info._UseDefaultType object>, as_list: bool = <muutils.tensor_info._UseDefaultType object>) -> Union[str, List[str]]:
331def array_summary( # type: ignore[misc] 332 array, 333 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 334 precision: int = _USE_DEFAULT, # type: ignore[assignment] 335 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 336 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 337 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 338 device: bool = _USE_DEFAULT, # type: ignore[assignment] 339 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 340 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 341 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 342 sparkline_logy: bool = _USE_DEFAULT, # type: ignore[assignment] 343 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 344 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 345 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 346) -> Union[str, List[str]]: 347 """Format array information into a readable summary. 348 349 # Parameters: 350 - `array` 351 array-like object (numpy array or torch tensor) 352 - `precision : int` 353 Decimal places (defaults to `2`) 354 - `format : Literal["unicode", "latex", "ascii"]` 355 Output format (defaults to `{default_fmt}`) 356 - `stats : bool` 357 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 358 - `shape : bool` 359 Whether to include shape info (defaults to `True`) 360 - `dtype : bool` 361 Whether to include dtype info (defaults to `True`) 362 - `device : bool` 363 Whether to include device info for torch tensors (defaults to `True`) 364 - `requires_grad : bool` 365 Whether to include requires_grad info for torch tensors (defaults to `True`) 366 - `sparkline : bool` 367 Whether to include a sparkline visualization (defaults to `False`) 368 - `sparkline_width : int` 369 Width of the sparkline (defaults to `20`) 370 - `sparkline_logy : bool` 371 Whether to use logarithmic y-scale for sparkline (defaults to `False`) 372 - `colored : bool` 373 Whether to add color to output (defaults to `False`) 374 - `as_list : bool` 375 Whether to return as list of strings instead of joined string (defaults to `False`) 376 377 # Returns: 378 - `Union[str, List[str]]` 379 Formatted statistical summary, either as string or list of strings 380 """ 381 if fmt is _USE_DEFAULT: 382 fmt = DEFAULT_SETTINGS["fmt"] 383 if precision is _USE_DEFAULT: 384 precision = DEFAULT_SETTINGS["precision"] 385 if stats is _USE_DEFAULT: 386 stats = DEFAULT_SETTINGS["stats"] 387 if shape is _USE_DEFAULT: 388 shape = DEFAULT_SETTINGS["shape"] 389 if dtype is _USE_DEFAULT: 390 dtype = DEFAULT_SETTINGS["dtype"] 391 if device is _USE_DEFAULT: 392 device = DEFAULT_SETTINGS["device"] 393 if requires_grad is _USE_DEFAULT: 394 requires_grad = DEFAULT_SETTINGS["requires_grad"] 395 if sparkline is _USE_DEFAULT: 396 sparkline = DEFAULT_SETTINGS["sparkline"] 397 if sparkline_bins is _USE_DEFAULT: 398 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 399 if sparkline_logy is _USE_DEFAULT: 400 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 401 if colored is _USE_DEFAULT: 402 colored = DEFAULT_SETTINGS["colored"] 403 if as_list is _USE_DEFAULT: 404 as_list = DEFAULT_SETTINGS["as_list"] 405 if eq_char is _USE_DEFAULT: 406 eq_char = DEFAULT_SETTINGS["eq_char"] 407 408 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 409 result_parts: List[str] = [] 410 using_tex: bool = fmt == "latex" 411 412 # Set color scheme based on format and colored flag 413 colors: Dict[str, str] 414 if colored: 415 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 416 else: 417 colors = COLORS["none"] 418 419 # Get symbols for the current format 420 symbols: Dict[str, str] = SYMBOLS[fmt] 421 422 # Helper function to colorize text 423 def colorize(text: str, color_key: str) -> str: 424 if using_tex: 425 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 426 else: 427 return ( 428 f"{colors[color_key]}{text}{colors['reset']}" 429 if colors[color_key] 430 else text 431 ) 432 433 # Format string for numbers 434 float_fmt: str = f".{precision}f" 435 436 # Handle error status or empty array 437 if ( 438 array_data["status"] in ["empty array", "all NaN", "unknown"] 439 or array_data["size"] == 0 440 ): 441 status = array_data["status"] 442 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 443 else: 444 # Add NaN warning at the beginning if there are NaNs 445 if array_data["has_nans"]: 446 _percent: str = "\\%" if using_tex else "%" 447 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 448 result_parts.append(colorize(nan_str, "warning")) 449 450 # Statistics 451 if stats: 452 for stat_key in ["mean", "std", "median"]: 453 if array_data[stat_key] is not None: 454 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 455 stat_colored: str = colorize(stat_str, stat_key) 456 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 457 458 # Range (min, max) 459 if array_data["range"] is not None: 460 min_val, max_val = array_data["range"] 461 min_str: str = f"{min_val:{float_fmt}}" 462 max_str: str = f"{max_val:{float_fmt}}" 463 min_colored: str = colorize(min_str, "range") 464 max_colored: str = colorize(max_str, "range") 465 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 466 result_parts.append(range_str) 467 468 # Add sparkline if requested 469 if sparkline and array_data["histogram"] is not None: 470 print(array_data["histogram"]) 471 print(array_data["bins"]) 472 spark = generate_sparkline( 473 array_data["histogram"], format=fmt, log_y=sparkline_logy 474 ) 475 if spark: 476 spark_colored = colorize(spark, "sparkline") 477 result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|") 478 479 # Add shape if requested 480 if shape and array_data["shape"]: 481 shape_val = array_data["shape"] 482 if len(shape_val) == 1: 483 shape_str = str(shape_val[0]) 484 else: 485 shape_str = ( 486 "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")" 487 ) 488 result_parts.append(f"shape{eq_char}{shape_str}") 489 490 # Add dtype if requested 491 if dtype and array_data["dtype"]: 492 result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype")) 493 494 # Add device if requested and it's a tensor with device info 495 if device and array_data["is_tensor"] and array_data["device"]: 496 result_parts.append( 497 colorize(f"device{eq_char}{array_data['device']}", "device") 498 ) 499 500 # Add gradient info 501 if requires_grad and array_data["is_tensor"] and array_data["requires_grad"]: 502 result_parts.append(colorize(symbols["requires_grad"], "requires_grad")) 503 504 # Return as list if requested, otherwise join with spaces 505 if as_list: 506 return result_parts 507 else: 508 joinchar: str = r" \quad " if using_tex else " " 509 return joinchar.join(result_parts)
Format array information into a readable summary.
Parameters:
array
array-like object (numpy array or torch tensor)precision : int
Decimal places (defaults to2
)format : Literal["unicode", "latex", "ascii"]
Output format (defaults to{default_fmt}
)stats : bool
Whether to include statistical info (μ, σ, x̃) (defaults toTrue
)shape : bool
Whether to include shape info (defaults toTrue
)dtype : bool
Whether to include dtype info (defaults toTrue
)device : bool
Whether to include device info for torch tensors (defaults toTrue
)requires_grad : bool
Whether to include requires_grad info for torch tensors (defaults toTrue
)sparkline : bool
Whether to include a sparkline visualization (defaults toFalse
)sparkline_width : int
Width of the sparkline (defaults to20
)sparkline_logy : bool
Whether to use logarithmic y-scale for sparkline (defaults toFalse
)colored : bool
Whether to add color to output (defaults toFalse
)as_list : bool
Whether to return as list of strings instead of joined string (defaults toFalse
)
Returns:
Union[str, List[str]]
Formatted statistical summary, either as string or list of strings