docs for muutils v0.8.10
View Source on GitHub

muutils.tensor_info


  1import numpy as np
  2from typing import Union, Any, Literal, List, Dict, overload
  3
  4# Global color definitions
  5COLORS: Dict[str, Dict[str, str]] = {
  6    "latex": {
  7        "range": r"\textcolor{purple}",
  8        "mean": r"\textcolor{teal}",
  9        "std": r"\textcolor{orange}",
 10        "median": r"\textcolor{green}",
 11        "warning": r"\textcolor{red}",
 12        "shape": r"\textcolor{magenta}",
 13        "dtype": r"\textcolor{gray}",
 14        "device": r"\textcolor{gray}",
 15        "requires_grad": r"\textcolor{gray}",
 16        "sparkline": r"\textcolor{blue}",
 17        "reset": "",
 18    },
 19    "terminal": {
 20        "range": "\033[35m",  # purple
 21        "mean": "\033[36m",  # cyan/teal
 22        "std": "\033[33m",  # yellow/orange
 23        "median": "\033[32m",  # green
 24        "warning": "\033[31m",  # red
 25        "shape": "\033[95m",  # bright magenta
 26        "dtype": "\033[90m",  # gray
 27        "device": "\033[90m",  # gray
 28        "requires_grad": "\033[90m",  # gray
 29        "sparkline": "\033[34m",  # blue
 30        "reset": "\033[0m",
 31    },
 32    "none": {
 33        "range": "",
 34        "mean": "",
 35        "std": "",
 36        "median": "",
 37        "warning": "",
 38        "shape": "",
 39        "dtype": "",
 40        "device": "",
 41        "requires_grad": "",
 42        "sparkline": "",
 43        "reset": "",
 44    },
 45}
 46
 47OutputFormat = Literal["unicode", "latex", "ascii"]
 48
 49SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
 50    "latex": {
 51        "range": r"\mathcal{R}",
 52        "mean": r"\mu",
 53        "std": r"\sigma",
 54        "median": r"\tilde{x}",
 55        "distribution": r"\mathbb{P}",
 56        "nan_values": r"\text{NANvals}",
 57        "warning": "!!!",
 58        "requires_grad": r"\nabla",
 59        "true": r"\checkmark",
 60        "false": r"\times",
 61    },
 62    "unicode": {
 63        "range": "R",
 64        "mean": "μ",
 65        "std": "σ",
 66        "median": "x̃",
 67        "distribution": "ℙ",
 68        "nan_values": "NANvals",
 69        "warning": "🚨",
 70        "requires_grad": "∇",
 71        "true": "✓",
 72        "false": "✗",
 73    },
 74    "ascii": {
 75        "range": "range",
 76        "mean": "mean",
 77        "std": "std",
 78        "median": "med",
 79        "distribution": "dist",
 80        "nan_values": "NANvals",
 81        "warning": "!!!",
 82        "requires_grad": "requires_grad",
 83        "true": "1",
 84        "false": "0",
 85    },
 86}
 87"Symbols for different formats"
 88
 89SPARK_CHARS: Dict[OutputFormat, List[str]] = {
 90    "unicode": list(" ▁▂▃▄▅▆▇█"),
 91    "ascii": list(" _.-~=#"),
 92    "latex": list(" ▁▂▃▄▅▆▇█"),
 93}
 94"characters for sparklines in different formats"
 95
 96
 97def array_info(
 98    A: Any,
 99    hist_bins: int = 5,
100) -> Dict[str, Any]:
101    """Extract statistical information from an array-like object.
102
103    # Parameters:
104     - `A : array-like`
105            Array to analyze (numpy array or torch tensor)
106
107    # Returns:
108     - `Dict[str, Any]`
109            Dictionary containing raw statistical information with numeric values
110    """
111    result: Dict[str, Any] = {
112        "is_tensor": None,
113        "device": None,
114        "requires_grad": None,
115        "shape": None,
116        "dtype": None,
117        "size": None,
118        "has_nans": None,
119        "nan_count": None,
120        "nan_percent": None,
121        "min": None,
122        "max": None,
123        "range": None,
124        "mean": None,
125        "std": None,
126        "median": None,
127        "histogram": None,
128        "bins": None,
129        "status": None,
130    }
131
132    # Check if it's a tensor by looking at its class name
133    # This avoids importing torch directly
134    A_type: str = type(A).__name__
135    result["is_tensor"] = A_type == "Tensor"
136
137    # Try to get device information if it's a tensor
138    if result["is_tensor"]:
139        try:
140            result["device"] = str(getattr(A, "device", None))
141        except:  # noqa: E722
142            pass
143
144    # Convert to numpy array for calculations
145    try:
146        # For PyTorch tensors
147        if result["is_tensor"]:
148            # Check if tensor is on GPU
149            is_cuda: bool = False
150            try:
151                is_cuda = bool(getattr(A, "is_cuda", False))
152            except:  # noqa: E722
153                pass
154
155            if is_cuda:
156                try:
157                    # Try to get CPU tensor first
158                    cpu_tensor = getattr(A, "cpu", lambda: A)()
159                except:  # noqa: E722
160                    A_np = np.array([])
161            else:
162                cpu_tensor = A
163            try:
164                # For CPU tensor, just detach and convert
165                detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
166                A_np = getattr(detached, "numpy", lambda: np.array([]))()
167            except:  # noqa: E722
168                A_np = np.array([])
169        else:
170            # For numpy arrays and other array-like objects
171            A_np = np.asarray(A)
172    except:  # noqa: E722
173        A_np = np.array([])
174
175    # Get basic information
176    try:
177        result["shape"] = A_np.shape
178        result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
179        result["size"] = A_np.size
180        result["requires_grad"] = getattr(A, "requires_grad", None)
181    except:  # noqa: E722
182        pass
183
184    # If array is empty, return early
185    if result["size"] == 0:
186        result["status"] = "empty array"
187        return result
188
189    # Flatten array for statistics if it's multi-dimensional
190    try:
191        if len(A_np.shape) > 1:
192            A_flat = A_np.flatten()
193        else:
194            A_flat = A_np
195    except:  # noqa: E722
196        A_flat = A_np
197
198    # Check for NaN values
199    try:
200        nan_mask = np.isnan(A_flat)
201        result["nan_count"] = np.sum(nan_mask)
202        result["has_nans"] = result["nan_count"] > 0
203        if result["size"] > 0:
204            result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
205    except:  # noqa: E722
206        pass
207
208    # If all values are NaN, return early
209    if result["has_nans"] and result["nan_count"] == result["size"]:
210        result["status"] = "all NaN"
211        return result
212
213    # Calculate statistics
214    try:
215        if result["has_nans"]:
216            result["min"] = float(np.nanmin(A_flat))
217            result["max"] = float(np.nanmax(A_flat))
218            result["mean"] = float(np.nanmean(A_flat))
219            result["std"] = float(np.nanstd(A_flat))
220            result["median"] = float(np.nanmedian(A_flat))
221            result["range"] = (result["min"], result["max"])
222
223            # Remove NaNs for histogram
224            A_hist = A_flat[~nan_mask]
225        else:
226            result["min"] = float(np.min(A_flat))
227            result["max"] = float(np.max(A_flat))
228            result["mean"] = float(np.mean(A_flat))
229            result["std"] = float(np.std(A_flat))
230            result["median"] = float(np.median(A_flat))
231            result["range"] = (result["min"], result["max"])
232
233            A_hist = A_flat
234
235        # Calculate histogram data for sparklines
236        if A_hist.size > 0:
237            try:
238                hist, bins = np.histogram(A_hist, bins=hist_bins)
239                result["histogram"] = hist
240                result["bins"] = bins
241            except:  # noqa: E722
242                pass
243
244        result["status"] = "ok"
245    except Exception as e:
246        result["status"] = f"error: {str(e)}"
247
248    return result
249
250
251def generate_sparkline(
252    histogram: np.ndarray,
253    format: Literal["unicode", "latex", "ascii"] = "unicode",
254    log_y: bool = False,
255) -> str:
256    """Generate a sparkline visualization of the histogram.
257
258    # Parameters:
259     - `histogram : np.ndarray`
260            Histogram data
261     - `format : Literal["unicode", "latex", "ascii"]`
262            Output format (defaults to `"unicode"`)
263     - `log_y : bool`
264            Whether to use logarithmic y-scale (defaults to `False`)
265
266    # Returns:
267     - `str`
268            Sparkline visualization
269    """
270    if histogram is None or len(histogram) == 0:
271        return ""
272
273    # Get the appropriate character set
274    if format in SPARK_CHARS:
275        chars = SPARK_CHARS[format]
276    else:
277        chars = SPARK_CHARS["ascii"]
278
279    # Handle log scale
280    if log_y:
281        # Add small value to avoid log(0)
282        hist_data = np.log1p(histogram)
283    else:
284        hist_data = histogram
285
286    # Normalize to character set range
287    if hist_data.max() > 0:
288        normalized = hist_data / hist_data.max() * (len(chars) - 1)
289    else:
290        normalized = np.zeros_like(hist_data)
291
292    # Convert to characters
293    spark = ""
294    for val in normalized:
295        idx = int(val)
296        spark += chars[idx]
297
298    return spark
299
300
301DEFAULT_SETTINGS: Dict[str, Any] = dict(
302    fmt="unicode",
303    precision=2,
304    stats=True,
305    shape=True,
306    dtype=True,
307    device=True,
308    requires_grad=True,
309    sparkline=False,
310    sparkline_bins=5,
311    sparkline_logy=False,
312    colored=False,
313    as_list=False,
314    eq_char="=",
315)
316
317
318class _UseDefaultType:
319    pass
320
321
322_USE_DEFAULT = _UseDefaultType()
323
324
325@overload
326def array_summary(
327    as_list: Literal[True],
328    **kwargs,
329) -> List[str]: ...
330@overload
331def array_summary(
332    as_list: Literal[False],
333    **kwargs,
334) -> str: ...
335def array_summary(  # type: ignore[misc]
336    array,
337    fmt: OutputFormat = _USE_DEFAULT,  # type: ignore[assignment]
338    precision: int = _USE_DEFAULT,  # type: ignore[assignment]
339    stats: bool = _USE_DEFAULT,  # type: ignore[assignment]
340    shape: bool = _USE_DEFAULT,  # type: ignore[assignment]
341    dtype: bool = _USE_DEFAULT,  # type: ignore[assignment]
342    device: bool = _USE_DEFAULT,  # type: ignore[assignment]
343    requires_grad: bool = _USE_DEFAULT,  # type: ignore[assignment]
344    sparkline: bool = _USE_DEFAULT,  # type: ignore[assignment]
345    sparkline_bins: int = _USE_DEFAULT,  # type: ignore[assignment]
346    sparkline_logy: bool = _USE_DEFAULT,  # type: ignore[assignment]
347    colored: bool = _USE_DEFAULT,  # type: ignore[assignment]
348    eq_char: str = _USE_DEFAULT,  # type: ignore[assignment]
349    as_list: bool = _USE_DEFAULT,  # type: ignore[assignment]
350) -> Union[str, List[str]]:
351    """Format array information into a readable summary.
352
353    # Parameters:
354     - `array`
355            array-like object (numpy array or torch tensor)
356     - `precision : int`
357            Decimal places (defaults to `2`)
358     - `format : Literal["unicode", "latex", "ascii"]`
359            Output format (defaults to `{default_fmt}`)
360     - `stats : bool`
361            Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
362     - `shape : bool`
363            Whether to include shape info (defaults to `True`)
364     - `dtype : bool`
365            Whether to include dtype info (defaults to `True`)
366     - `device : bool`
367            Whether to include device info for torch tensors (defaults to `True`)
368     - `requires_grad : bool`
369            Whether to include requires_grad info for torch tensors (defaults to `True`)
370     - `sparkline : bool`
371            Whether to include a sparkline visualization (defaults to `False`)
372     - `sparkline_width : int`
373            Width of the sparkline (defaults to `20`)
374     - `sparkline_logy : bool`
375            Whether to use logarithmic y-scale for sparkline (defaults to `False`)
376     - `colored : bool`
377            Whether to add color to output (defaults to `False`)
378     - `as_list : bool`
379            Whether to return as list of strings instead of joined string (defaults to `False`)
380
381    # Returns:
382     - `Union[str, List[str]]`
383            Formatted statistical summary, either as string or list of strings
384    """
385    if fmt is _USE_DEFAULT:
386        fmt = DEFAULT_SETTINGS["fmt"]
387    if precision is _USE_DEFAULT:
388        precision = DEFAULT_SETTINGS["precision"]
389    if stats is _USE_DEFAULT:
390        stats = DEFAULT_SETTINGS["stats"]
391    if shape is _USE_DEFAULT:
392        shape = DEFAULT_SETTINGS["shape"]
393    if dtype is _USE_DEFAULT:
394        dtype = DEFAULT_SETTINGS["dtype"]
395    if device is _USE_DEFAULT:
396        device = DEFAULT_SETTINGS["device"]
397    if requires_grad is _USE_DEFAULT:
398        requires_grad = DEFAULT_SETTINGS["requires_grad"]
399    if sparkline is _USE_DEFAULT:
400        sparkline = DEFAULT_SETTINGS["sparkline"]
401    if sparkline_bins is _USE_DEFAULT:
402        sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
403    if sparkline_logy is _USE_DEFAULT:
404        sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
405    if colored is _USE_DEFAULT:
406        colored = DEFAULT_SETTINGS["colored"]
407    if as_list is _USE_DEFAULT:
408        as_list = DEFAULT_SETTINGS["as_list"]
409    if eq_char is _USE_DEFAULT:
410        eq_char = DEFAULT_SETTINGS["eq_char"]
411
412    array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
413    result_parts: List[str] = []
414    using_tex: bool = fmt == "latex"
415
416    # Set color scheme based on format and colored flag
417    colors: Dict[str, str]
418    if colored:
419        colors = COLORS["latex"] if using_tex else COLORS["terminal"]
420    else:
421        colors = COLORS["none"]
422
423    # Get symbols for the current format
424    symbols: Dict[str, str] = SYMBOLS[fmt]
425
426    # Helper function to colorize text
427    def colorize(text: str, color_key: str) -> str:
428        if using_tex:
429            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
430        else:
431            return (
432                f"{colors[color_key]}{text}{colors['reset']}"
433                if colors[color_key]
434                else text
435            )
436
437    # Format string for numbers
438    float_fmt: str = f".{precision}f"
439
440    # Handle error status or empty array
441    if (
442        array_data["status"] in ["empty array", "all NaN", "unknown"]
443        or array_data["size"] == 0
444    ):
445        status = array_data["status"]
446        result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
447    else:
448        # Add NaN warning at the beginning if there are NaNs
449        if array_data["has_nans"]:
450            _percent: str = "\\%" if using_tex else "%"
451            nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
452            result_parts.append(colorize(nan_str, "warning"))
453
454        # Statistics
455        if stats:
456            for stat_key in ["mean", "std", "median"]:
457                if array_data[stat_key] is not None:
458                    stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
459                    stat_colored: str = colorize(stat_str, stat_key)
460                    result_parts.append(f"{symbols[stat_key]}={stat_colored}")
461
462            # Range (min, max)
463            if array_data["range"] is not None:
464                min_val, max_val = array_data["range"]
465                min_str: str = f"{min_val:{float_fmt}}"
466                max_str: str = f"{max_val:{float_fmt}}"
467                min_colored: str = colorize(min_str, "range")
468                max_colored: str = colorize(max_str, "range")
469                range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
470                result_parts.append(range_str)
471
472    # Add sparkline if requested
473    if sparkline and array_data["histogram"] is not None:
474        spark = generate_sparkline(
475            array_data["histogram"], format=fmt, log_y=sparkline_logy
476        )
477        if spark:
478            spark_colored = colorize(spark, "sparkline")
479            result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|")
480
481    # Add shape if requested
482    if shape and array_data["shape"]:
483        shape_val = array_data["shape"]
484        if len(shape_val) == 1:
485            shape_str = str(shape_val[0])
486        else:
487            shape_str = (
488                "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")"
489            )
490        result_parts.append(f"shape{eq_char}{shape_str}")
491
492    # Add dtype if requested
493    if dtype and array_data["dtype"]:
494        result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype"))
495
496    # Add device if requested and it's a tensor with device info
497    if device and array_data["is_tensor"] and array_data["device"]:
498        result_parts.append(
499            colorize(f"device{eq_char}{array_data['device']}", "device")
500        )
501
502    # Add gradient info
503    if requires_grad and array_data["is_tensor"]:
504        bool_req_grad_symb: str = (
505            symbols["true"] if array_data["requires_grad"] else symbols["false"]
506        )
507        result_parts.append(
508            colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
509        )
510
511    # Return as list if requested, otherwise join with spaces
512    if as_list:
513        return result_parts
514    else:
515        joinchar: str = r" \quad " if using_tex else " "
516        return joinchar.join(result_parts)

COLORS: Dict[str, Dict[str, str]] = {'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'reset': ''}}
OutputFormat = typing.Literal['unicode', 'latex', 'ascii']
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] = {'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'nan_values': '\\text{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla', 'true': '\\checkmark', 'false': '\\times'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇', 'true': '✓', 'false': '✗'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad', 'true': '1', 'false': '0'}}

Symbols for different formats

SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] = {'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}

characters for sparklines in different formats

def array_info(A: Any, hist_bins: int = 5) -> Dict[str, Any]:
 98def array_info(
 99    A: Any,
100    hist_bins: int = 5,
101) -> Dict[str, Any]:
102    """Extract statistical information from an array-like object.
103
104    # Parameters:
105     - `A : array-like`
106            Array to analyze (numpy array or torch tensor)
107
108    # Returns:
109     - `Dict[str, Any]`
110            Dictionary containing raw statistical information with numeric values
111    """
112    result: Dict[str, Any] = {
113        "is_tensor": None,
114        "device": None,
115        "requires_grad": None,
116        "shape": None,
117        "dtype": None,
118        "size": None,
119        "has_nans": None,
120        "nan_count": None,
121        "nan_percent": None,
122        "min": None,
123        "max": None,
124        "range": None,
125        "mean": None,
126        "std": None,
127        "median": None,
128        "histogram": None,
129        "bins": None,
130        "status": None,
131    }
132
133    # Check if it's a tensor by looking at its class name
134    # This avoids importing torch directly
135    A_type: str = type(A).__name__
136    result["is_tensor"] = A_type == "Tensor"
137
138    # Try to get device information if it's a tensor
139    if result["is_tensor"]:
140        try:
141            result["device"] = str(getattr(A, "device", None))
142        except:  # noqa: E722
143            pass
144
145    # Convert to numpy array for calculations
146    try:
147        # For PyTorch tensors
148        if result["is_tensor"]:
149            # Check if tensor is on GPU
150            is_cuda: bool = False
151            try:
152                is_cuda = bool(getattr(A, "is_cuda", False))
153            except:  # noqa: E722
154                pass
155
156            if is_cuda:
157                try:
158                    # Try to get CPU tensor first
159                    cpu_tensor = getattr(A, "cpu", lambda: A)()
160                except:  # noqa: E722
161                    A_np = np.array([])
162            else:
163                cpu_tensor = A
164            try:
165                # For CPU tensor, just detach and convert
166                detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
167                A_np = getattr(detached, "numpy", lambda: np.array([]))()
168            except:  # noqa: E722
169                A_np = np.array([])
170        else:
171            # For numpy arrays and other array-like objects
172            A_np = np.asarray(A)
173    except:  # noqa: E722
174        A_np = np.array([])
175
176    # Get basic information
177    try:
178        result["shape"] = A_np.shape
179        result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
180        result["size"] = A_np.size
181        result["requires_grad"] = getattr(A, "requires_grad", None)
182    except:  # noqa: E722
183        pass
184
185    # If array is empty, return early
186    if result["size"] == 0:
187        result["status"] = "empty array"
188        return result
189
190    # Flatten array for statistics if it's multi-dimensional
191    try:
192        if len(A_np.shape) > 1:
193            A_flat = A_np.flatten()
194        else:
195            A_flat = A_np
196    except:  # noqa: E722
197        A_flat = A_np
198
199    # Check for NaN values
200    try:
201        nan_mask = np.isnan(A_flat)
202        result["nan_count"] = np.sum(nan_mask)
203        result["has_nans"] = result["nan_count"] > 0
204        if result["size"] > 0:
205            result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
206    except:  # noqa: E722
207        pass
208
209    # If all values are NaN, return early
210    if result["has_nans"] and result["nan_count"] == result["size"]:
211        result["status"] = "all NaN"
212        return result
213
214    # Calculate statistics
215    try:
216        if result["has_nans"]:
217            result["min"] = float(np.nanmin(A_flat))
218            result["max"] = float(np.nanmax(A_flat))
219            result["mean"] = float(np.nanmean(A_flat))
220            result["std"] = float(np.nanstd(A_flat))
221            result["median"] = float(np.nanmedian(A_flat))
222            result["range"] = (result["min"], result["max"])
223
224            # Remove NaNs for histogram
225            A_hist = A_flat[~nan_mask]
226        else:
227            result["min"] = float(np.min(A_flat))
228            result["max"] = float(np.max(A_flat))
229            result["mean"] = float(np.mean(A_flat))
230            result["std"] = float(np.std(A_flat))
231            result["median"] = float(np.median(A_flat))
232            result["range"] = (result["min"], result["max"])
233
234            A_hist = A_flat
235
236        # Calculate histogram data for sparklines
237        if A_hist.size > 0:
238            try:
239                hist, bins = np.histogram(A_hist, bins=hist_bins)
240                result["histogram"] = hist
241                result["bins"] = bins
242            except:  # noqa: E722
243                pass
244
245        result["status"] = "ok"
246    except Exception as e:
247        result["status"] = f"error: {str(e)}"
248
249    return result

Extract statistical information from an array-like object.

Parameters:

  • A : array-like Array to analyze (numpy array or torch tensor)

Returns:

  • Dict[str, Any] Dictionary containing raw statistical information with numeric values
def generate_sparkline( histogram: numpy.ndarray, format: Literal['unicode', 'latex', 'ascii'] = 'unicode', log_y: bool = False) -> str:
252def generate_sparkline(
253    histogram: np.ndarray,
254    format: Literal["unicode", "latex", "ascii"] = "unicode",
255    log_y: bool = False,
256) -> str:
257    """Generate a sparkline visualization of the histogram.
258
259    # Parameters:
260     - `histogram : np.ndarray`
261            Histogram data
262     - `format : Literal["unicode", "latex", "ascii"]`
263            Output format (defaults to `"unicode"`)
264     - `log_y : bool`
265            Whether to use logarithmic y-scale (defaults to `False`)
266
267    # Returns:
268     - `str`
269            Sparkline visualization
270    """
271    if histogram is None or len(histogram) == 0:
272        return ""
273
274    # Get the appropriate character set
275    if format in SPARK_CHARS:
276        chars = SPARK_CHARS[format]
277    else:
278        chars = SPARK_CHARS["ascii"]
279
280    # Handle log scale
281    if log_y:
282        # Add small value to avoid log(0)
283        hist_data = np.log1p(histogram)
284    else:
285        hist_data = histogram
286
287    # Normalize to character set range
288    if hist_data.max() > 0:
289        normalized = hist_data / hist_data.max() * (len(chars) - 1)
290    else:
291        normalized = np.zeros_like(hist_data)
292
293    # Convert to characters
294    spark = ""
295    for val in normalized:
296        idx = int(val)
297        spark += chars[idx]
298
299    return spark

Generate a sparkline visualization of the histogram.

Parameters:

  • histogram : np.ndarray Histogram data
  • format : Literal["unicode", "latex", "ascii"] Output format (defaults to "unicode")
  • log_y : bool Whether to use logarithmic y-scale (defaults to False)

Returns:

  • str Sparkline visualization
DEFAULT_SETTINGS: Dict[str, Any] = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': False, 'colored': False, 'as_list': False, 'eq_char': '='}
def array_summary( array, fmt: Literal['unicode', 'latex', 'ascii'] = <muutils.tensor_info._UseDefaultType object>, precision: int = <muutils.tensor_info._UseDefaultType object>, stats: bool = <muutils.tensor_info._UseDefaultType object>, shape: bool = <muutils.tensor_info._UseDefaultType object>, dtype: bool = <muutils.tensor_info._UseDefaultType object>, device: bool = <muutils.tensor_info._UseDefaultType object>, requires_grad: bool = <muutils.tensor_info._UseDefaultType object>, sparkline: bool = <muutils.tensor_info._UseDefaultType object>, sparkline_bins: int = <muutils.tensor_info._UseDefaultType object>, sparkline_logy: bool = <muutils.tensor_info._UseDefaultType object>, colored: bool = <muutils.tensor_info._UseDefaultType object>, eq_char: str = <muutils.tensor_info._UseDefaultType object>, as_list: bool = <muutils.tensor_info._UseDefaultType object>) -> Union[str, List[str]]:
336def array_summary(  # type: ignore[misc]
337    array,
338    fmt: OutputFormat = _USE_DEFAULT,  # type: ignore[assignment]
339    precision: int = _USE_DEFAULT,  # type: ignore[assignment]
340    stats: bool = _USE_DEFAULT,  # type: ignore[assignment]
341    shape: bool = _USE_DEFAULT,  # type: ignore[assignment]
342    dtype: bool = _USE_DEFAULT,  # type: ignore[assignment]
343    device: bool = _USE_DEFAULT,  # type: ignore[assignment]
344    requires_grad: bool = _USE_DEFAULT,  # type: ignore[assignment]
345    sparkline: bool = _USE_DEFAULT,  # type: ignore[assignment]
346    sparkline_bins: int = _USE_DEFAULT,  # type: ignore[assignment]
347    sparkline_logy: bool = _USE_DEFAULT,  # type: ignore[assignment]
348    colored: bool = _USE_DEFAULT,  # type: ignore[assignment]
349    eq_char: str = _USE_DEFAULT,  # type: ignore[assignment]
350    as_list: bool = _USE_DEFAULT,  # type: ignore[assignment]
351) -> Union[str, List[str]]:
352    """Format array information into a readable summary.
353
354    # Parameters:
355     - `array`
356            array-like object (numpy array or torch tensor)
357     - `precision : int`
358            Decimal places (defaults to `2`)
359     - `format : Literal["unicode", "latex", "ascii"]`
360            Output format (defaults to `{default_fmt}`)
361     - `stats : bool`
362            Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
363     - `shape : bool`
364            Whether to include shape info (defaults to `True`)
365     - `dtype : bool`
366            Whether to include dtype info (defaults to `True`)
367     - `device : bool`
368            Whether to include device info for torch tensors (defaults to `True`)
369     - `requires_grad : bool`
370            Whether to include requires_grad info for torch tensors (defaults to `True`)
371     - `sparkline : bool`
372            Whether to include a sparkline visualization (defaults to `False`)
373     - `sparkline_width : int`
374            Width of the sparkline (defaults to `20`)
375     - `sparkline_logy : bool`
376            Whether to use logarithmic y-scale for sparkline (defaults to `False`)
377     - `colored : bool`
378            Whether to add color to output (defaults to `False`)
379     - `as_list : bool`
380            Whether to return as list of strings instead of joined string (defaults to `False`)
381
382    # Returns:
383     - `Union[str, List[str]]`
384            Formatted statistical summary, either as string or list of strings
385    """
386    if fmt is _USE_DEFAULT:
387        fmt = DEFAULT_SETTINGS["fmt"]
388    if precision is _USE_DEFAULT:
389        precision = DEFAULT_SETTINGS["precision"]
390    if stats is _USE_DEFAULT:
391        stats = DEFAULT_SETTINGS["stats"]
392    if shape is _USE_DEFAULT:
393        shape = DEFAULT_SETTINGS["shape"]
394    if dtype is _USE_DEFAULT:
395        dtype = DEFAULT_SETTINGS["dtype"]
396    if device is _USE_DEFAULT:
397        device = DEFAULT_SETTINGS["device"]
398    if requires_grad is _USE_DEFAULT:
399        requires_grad = DEFAULT_SETTINGS["requires_grad"]
400    if sparkline is _USE_DEFAULT:
401        sparkline = DEFAULT_SETTINGS["sparkline"]
402    if sparkline_bins is _USE_DEFAULT:
403        sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
404    if sparkline_logy is _USE_DEFAULT:
405        sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
406    if colored is _USE_DEFAULT:
407        colored = DEFAULT_SETTINGS["colored"]
408    if as_list is _USE_DEFAULT:
409        as_list = DEFAULT_SETTINGS["as_list"]
410    if eq_char is _USE_DEFAULT:
411        eq_char = DEFAULT_SETTINGS["eq_char"]
412
413    array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
414    result_parts: List[str] = []
415    using_tex: bool = fmt == "latex"
416
417    # Set color scheme based on format and colored flag
418    colors: Dict[str, str]
419    if colored:
420        colors = COLORS["latex"] if using_tex else COLORS["terminal"]
421    else:
422        colors = COLORS["none"]
423
424    # Get symbols for the current format
425    symbols: Dict[str, str] = SYMBOLS[fmt]
426
427    # Helper function to colorize text
428    def colorize(text: str, color_key: str) -> str:
429        if using_tex:
430            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
431        else:
432            return (
433                f"{colors[color_key]}{text}{colors['reset']}"
434                if colors[color_key]
435                else text
436            )
437
438    # Format string for numbers
439    float_fmt: str = f".{precision}f"
440
441    # Handle error status or empty array
442    if (
443        array_data["status"] in ["empty array", "all NaN", "unknown"]
444        or array_data["size"] == 0
445    ):
446        status = array_data["status"]
447        result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
448    else:
449        # Add NaN warning at the beginning if there are NaNs
450        if array_data["has_nans"]:
451            _percent: str = "\\%" if using_tex else "%"
452            nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
453            result_parts.append(colorize(nan_str, "warning"))
454
455        # Statistics
456        if stats:
457            for stat_key in ["mean", "std", "median"]:
458                if array_data[stat_key] is not None:
459                    stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
460                    stat_colored: str = colorize(stat_str, stat_key)
461                    result_parts.append(f"{symbols[stat_key]}={stat_colored}")
462
463            # Range (min, max)
464            if array_data["range"] is not None:
465                min_val, max_val = array_data["range"]
466                min_str: str = f"{min_val:{float_fmt}}"
467                max_str: str = f"{max_val:{float_fmt}}"
468                min_colored: str = colorize(min_str, "range")
469                max_colored: str = colorize(max_str, "range")
470                range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
471                result_parts.append(range_str)
472
473    # Add sparkline if requested
474    if sparkline and array_data["histogram"] is not None:
475        spark = generate_sparkline(
476            array_data["histogram"], format=fmt, log_y=sparkline_logy
477        )
478        if spark:
479            spark_colored = colorize(spark, "sparkline")
480            result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|")
481
482    # Add shape if requested
483    if shape and array_data["shape"]:
484        shape_val = array_data["shape"]
485        if len(shape_val) == 1:
486            shape_str = str(shape_val[0])
487        else:
488            shape_str = (
489                "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")"
490            )
491        result_parts.append(f"shape{eq_char}{shape_str}")
492
493    # Add dtype if requested
494    if dtype and array_data["dtype"]:
495        result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype"))
496
497    # Add device if requested and it's a tensor with device info
498    if device and array_data["is_tensor"] and array_data["device"]:
499        result_parts.append(
500            colorize(f"device{eq_char}{array_data['device']}", "device")
501        )
502
503    # Add gradient info
504    if requires_grad and array_data["is_tensor"]:
505        bool_req_grad_symb: str = (
506            symbols["true"] if array_data["requires_grad"] else symbols["false"]
507        )
508        result_parts.append(
509            colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
510        )
511
512    # Return as list if requested, otherwise join with spaces
513    if as_list:
514        return result_parts
515    else:
516        joinchar: str = r" \quad " if using_tex else " "
517        return joinchar.join(result_parts)

Format array information into a readable summary.

Parameters:

  • array array-like object (numpy array or torch tensor)
  • precision : int Decimal places (defaults to 2)
  • format : Literal["unicode", "latex", "ascii"] Output format (defaults to {default_fmt})
  • stats : bool Whether to include statistical info (μ, σ, x̃) (defaults to True)
  • shape : bool Whether to include shape info (defaults to True)
  • dtype : bool Whether to include dtype info (defaults to True)
  • device : bool Whether to include device info for torch tensors (defaults to True)
  • requires_grad : bool Whether to include requires_grad info for torch tensors (defaults to True)
  • sparkline : bool Whether to include a sparkline visualization (defaults to False)
  • sparkline_width : int Width of the sparkline (defaults to 20)
  • sparkline_logy : bool Whether to use logarithmic y-scale for sparkline (defaults to False)
  • colored : bool Whether to add color to output (defaults to False)
  • as_list : bool Whether to return as list of strings instead of joined string (defaults to False)

Returns:

  • Union[str, List[str]] Formatted statistical summary, either as string or list of strings