docs for muutils v0.8.5
View Source on GitHub

muutils.tensor_info


  1import numpy as np
  2from typing import Union, Any, Literal, List, Dict, overload
  3
  4# Global color definitions
  5COLORS: Dict[str, Dict[str, str]] = {
  6    "latex": {
  7        "range": r"\textcolor{purple}",
  8        "mean": r"\textcolor{teal}",
  9        "std": r"\textcolor{orange}",
 10        "median": r"\textcolor{green}",
 11        "warning": r"\textcolor{red}",
 12        "shape": r"\textcolor{magenta}",
 13        "dtype": r"\textcolor{gray}",
 14        "device": r"\textcolor{gray}",
 15        "requires_grad": r"\textcolor{gray}",
 16        "sparkline": r"\textcolor{blue}",
 17        "reset": "",
 18    },
 19    "terminal": {
 20        "range": "\033[35m",  # purple
 21        "mean": "\033[36m",  # cyan/teal
 22        "std": "\033[33m",  # yellow/orange
 23        "median": "\033[32m",  # green
 24        "warning": "\033[31m",  # red
 25        "shape": "\033[95m",  # bright magenta
 26        "dtype": "\033[90m",  # gray
 27        "device": "\033[90m",  # gray
 28        "requires_grad": "\033[90m",  # gray
 29        "sparkline": "\033[34m",  # blue
 30        "reset": "\033[0m",
 31    },
 32    "none": {
 33        "range": "",
 34        "mean": "",
 35        "std": "",
 36        "median": "",
 37        "warning": "",
 38        "shape": "",
 39        "dtype": "",
 40        "device": "",
 41        "requires_grad": "",
 42        "sparkline": "",
 43        "reset": "",
 44    },
 45}
 46
 47OutputFormat = Literal["unicode", "latex", "ascii"]
 48
 49SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
 50    "latex": {
 51        "range": r"\mathcal{R}",
 52        "mean": r"\mu",
 53        "std": r"\sigma",
 54        "median": r"\tilde{x}",
 55        "distribution": r"\mathbb{P}",
 56        "nan_values": r"\texttt{NANvals}",
 57        "warning": "!!!",
 58        "requires_grad": r"\nabla",
 59    },
 60    "unicode": {
 61        "range": "R",
 62        "mean": "μ",
 63        "std": "σ",
 64        "median": "x̃",
 65        "distribution": "ℙ",
 66        "nan_values": "NANvals",
 67        "warning": "🚨",
 68        "requires_grad": "∇",
 69    },
 70    "ascii": {
 71        "range": "range",
 72        "mean": "mean",
 73        "std": "std",
 74        "median": "med",
 75        "distribution": "dist",
 76        "nan_values": "NANvals",
 77        "warning": "!!!",
 78        "requires_grad": "requires_grad",
 79    },
 80}
 81"Symbols for different formats"
 82
 83SPARK_CHARS: Dict[OutputFormat, List[str]] = {
 84    "unicode": list(" ▁▂▃▄▅▆▇█"),
 85    "ascii": list(" _.-~=#"),
 86    "latex": list(" ▁▂▃▄▅▆▇█"),
 87    # "latex": [r"\textbf{.}", r"\textbf{-}", r"\textbf{=}", r"\textbf{+}", r"\textbf{*}", r"\textbf{\\#}"],
 88}
 89"characters for sparklines in different formats"
 90
 91
 92def array_info(
 93    A: Any,
 94    hist_bins: int = 5,
 95) -> Dict[str, Any]:
 96    """Extract statistical information from an array-like object.
 97
 98    # Parameters:
 99     - `A : array-like`
100            Array to analyze (numpy array or torch tensor)
101
102    # Returns:
103     - `Dict[str, Any]`
104            Dictionary containing raw statistical information with numeric values
105    """
106    result: Dict[str, Any] = {
107        "is_tensor": None,
108        "device": None,
109        "requires_grad": None,
110        "shape": None,
111        "dtype": None,
112        "size": None,
113        "has_nans": None,
114        "nan_count": None,
115        "nan_percent": None,
116        "min": None,
117        "max": None,
118        "range": None,
119        "mean": None,
120        "std": None,
121        "median": None,
122        "histogram": None,
123        "bins": None,
124        "status": None,
125    }
126
127    # Check if it's a tensor by looking at its class name
128    # This avoids importing torch directly
129    A_type: str = type(A).__name__
130    result["is_tensor"] = A_type == "Tensor"
131
132    # Try to get device information if it's a tensor
133    if result["is_tensor"]:
134        try:
135            result["device"] = str(getattr(A, "device", None))
136        except:  # noqa: E722
137            pass
138
139    # Convert to numpy array for calculations
140    try:
141        # For PyTorch tensors
142        if result["is_tensor"]:
143            # Check if tensor is on GPU
144            is_cuda: bool = False
145            try:
146                is_cuda = bool(getattr(A, "is_cuda", False))
147            except:  # noqa: E722
148                pass
149
150            if is_cuda:
151                try:
152                    # Try to get CPU tensor first
153                    cpu_tensor = getattr(A, "cpu", lambda: A)()
154                except:  # noqa: E722
155                    A_np = np.array([])
156            else:
157                cpu_tensor = A
158            try:
159                # For CPU tensor, just detach and convert
160                detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
161                A_np = getattr(detached, "numpy", lambda: np.array([]))()
162            except:  # noqa: E722
163                A_np = np.array([])
164        else:
165            # For numpy arrays and other array-like objects
166            A_np = np.asarray(A)
167    except:  # noqa: E722
168        A_np = np.array([])
169
170    # Get basic information
171    try:
172        result["shape"] = A_np.shape
173        result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
174        result["size"] = A_np.size
175        result["requires_grad"] = getattr(A, "requires_grad", None)
176    except:  # noqa: E722
177        pass
178
179    # If array is empty, return early
180    if result["size"] == 0:
181        result["status"] = "empty array"
182        return result
183
184    # Flatten array for statistics if it's multi-dimensional
185    try:
186        if len(A_np.shape) > 1:
187            A_flat = A_np.flatten()
188        else:
189            A_flat = A_np
190    except:  # noqa: E722
191        A_flat = A_np
192
193    # Check for NaN values
194    try:
195        nan_mask = np.isnan(A_flat)
196        result["nan_count"] = np.sum(nan_mask)
197        result["has_nans"] = result["nan_count"] > 0
198        if result["size"] > 0:
199            result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
200    except:  # noqa: E722
201        pass
202
203    # If all values are NaN, return early
204    if result["has_nans"] and result["nan_count"] == result["size"]:
205        result["status"] = "all NaN"
206        return result
207
208    # Calculate statistics
209    try:
210        if result["has_nans"]:
211            result["min"] = float(np.nanmin(A_flat))
212            result["max"] = float(np.nanmax(A_flat))
213            result["mean"] = float(np.nanmean(A_flat))
214            result["std"] = float(np.nanstd(A_flat))
215            result["median"] = float(np.nanmedian(A_flat))
216            result["range"] = (result["min"], result["max"])
217
218            # Remove NaNs for histogram
219            A_hist = A_flat[~nan_mask]
220        else:
221            result["min"] = float(np.min(A_flat))
222            result["max"] = float(np.max(A_flat))
223            result["mean"] = float(np.mean(A_flat))
224            result["std"] = float(np.std(A_flat))
225            result["median"] = float(np.median(A_flat))
226            result["range"] = (result["min"], result["max"])
227
228            A_hist = A_flat
229
230        # Calculate histogram data for sparklines
231        if A_hist.size > 0:
232            try:
233                hist, bins = np.histogram(A_hist, bins=hist_bins)
234                result["histogram"] = hist
235                result["bins"] = bins
236            except:  # noqa: E722
237                pass
238
239        result["status"] = "ok"
240    except Exception as e:
241        result["status"] = f"error: {str(e)}"
242
243    return result
244
245
246def generate_sparkline(
247    histogram: np.ndarray,
248    format: Literal["unicode", "latex", "ascii"] = "unicode",
249    log_y: bool = False,
250) -> str:
251    """Generate a sparkline visualization of the histogram.
252
253    # Parameters:
254     - `histogram : np.ndarray`
255            Histogram data
256     - `format : Literal["unicode", "latex", "ascii"]`
257            Output format (defaults to `"unicode"`)
258     - `log_y : bool`
259            Whether to use logarithmic y-scale (defaults to `False`)
260
261    # Returns:
262     - `str`
263            Sparkline visualization
264    """
265    if histogram is None or len(histogram) == 0:
266        return ""
267
268    # Get the appropriate character set
269    if format in SPARK_CHARS:
270        chars = SPARK_CHARS[format]
271    else:
272        chars = SPARK_CHARS["ascii"]
273
274    # Handle log scale
275    if log_y:
276        # Add small value to avoid log(0)
277        hist_data = np.log1p(histogram)
278    else:
279        hist_data = histogram
280
281    # Normalize to character set range
282    if hist_data.max() > 0:
283        normalized = hist_data / hist_data.max() * (len(chars) - 1)
284    else:
285        normalized = np.zeros_like(hist_data)
286
287    # Convert to characters
288    spark = ""
289    for val in normalized:
290        idx = int(val)
291        spark += chars[idx]
292
293    return spark
294
295
296DEFAULT_SETTINGS: Dict[str, Any] = dict(
297    fmt="unicode",
298    precision=2,
299    stats=True,
300    shape=True,
301    dtype=True,
302    device=True,
303    requires_grad=True,
304    sparkline=False,
305    sparkline_bins=5,
306    sparkline_logy=False,
307    colored=False,
308    as_list=False,
309    eq_char="=",
310)
311
312
313class _UseDefaultType:
314    pass
315
316
317_USE_DEFAULT = _UseDefaultType()
318
319
320@overload
321def array_summary(
322    as_list: Literal[True],
323    **kwargs,
324) -> List[str]: ...
325@overload
326def array_summary(
327    as_list: Literal[False],
328    **kwargs,
329) -> str: ...
330def array_summary(  # type: ignore[misc]
331    array,
332    fmt: OutputFormat = _USE_DEFAULT,  # type: ignore[assignment]
333    precision: int = _USE_DEFAULT,  # type: ignore[assignment]
334    stats: bool = _USE_DEFAULT,  # type: ignore[assignment]
335    shape: bool = _USE_DEFAULT,  # type: ignore[assignment]
336    dtype: bool = _USE_DEFAULT,  # type: ignore[assignment]
337    device: bool = _USE_DEFAULT,  # type: ignore[assignment]
338    requires_grad: bool = _USE_DEFAULT,  # type: ignore[assignment]
339    sparkline: bool = _USE_DEFAULT,  # type: ignore[assignment]
340    sparkline_bins: int = _USE_DEFAULT,  # type: ignore[assignment]
341    sparkline_logy: bool = _USE_DEFAULT,  # type: ignore[assignment]
342    colored: bool = _USE_DEFAULT,  # type: ignore[assignment]
343    eq_char: str = _USE_DEFAULT,  # type: ignore[assignment]
344    as_list: bool = _USE_DEFAULT,  # type: ignore[assignment]
345) -> Union[str, List[str]]:
346    """Format array information into a readable summary.
347
348    # Parameters:
349     - `array`
350            array-like object (numpy array or torch tensor)
351     - `precision : int`
352            Decimal places (defaults to `2`)
353     - `format : Literal["unicode", "latex", "ascii"]`
354            Output format (defaults to `{default_fmt}`)
355     - `stats : bool`
356            Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
357     - `shape : bool`
358            Whether to include shape info (defaults to `True`)
359     - `dtype : bool`
360            Whether to include dtype info (defaults to `True`)
361     - `device : bool`
362            Whether to include device info for torch tensors (defaults to `True`)
363     - `requires_grad : bool`
364            Whether to include requires_grad info for torch tensors (defaults to `True`)
365     - `sparkline : bool`
366            Whether to include a sparkline visualization (defaults to `False`)
367     - `sparkline_width : int`
368            Width of the sparkline (defaults to `20`)
369     - `sparkline_logy : bool`
370            Whether to use logarithmic y-scale for sparkline (defaults to `False`)
371     - `colored : bool`
372            Whether to add color to output (defaults to `False`)
373     - `as_list : bool`
374            Whether to return as list of strings instead of joined string (defaults to `False`)
375
376    # Returns:
377     - `Union[str, List[str]]`
378            Formatted statistical summary, either as string or list of strings
379    """
380    if fmt is _USE_DEFAULT:
381        fmt = DEFAULT_SETTINGS["fmt"]
382    if precision is _USE_DEFAULT:
383        precision = DEFAULT_SETTINGS["precision"]
384    if stats is _USE_DEFAULT:
385        stats = DEFAULT_SETTINGS["stats"]
386    if shape is _USE_DEFAULT:
387        shape = DEFAULT_SETTINGS["shape"]
388    if dtype is _USE_DEFAULT:
389        dtype = DEFAULT_SETTINGS["dtype"]
390    if device is _USE_DEFAULT:
391        device = DEFAULT_SETTINGS["device"]
392    if requires_grad is _USE_DEFAULT:
393        requires_grad = DEFAULT_SETTINGS["requires_grad"]
394    if sparkline is _USE_DEFAULT:
395        sparkline = DEFAULT_SETTINGS["sparkline"]
396    if sparkline_bins is _USE_DEFAULT:
397        sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
398    if sparkline_logy is _USE_DEFAULT:
399        sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
400    if colored is _USE_DEFAULT:
401        colored = DEFAULT_SETTINGS["colored"]
402    if as_list is _USE_DEFAULT:
403        as_list = DEFAULT_SETTINGS["as_list"]
404    if eq_char is _USE_DEFAULT:
405        eq_char = DEFAULT_SETTINGS["eq_char"]
406
407    array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
408    result_parts: List[str] = []
409    using_tex: bool = fmt == "latex"
410
411    # Set color scheme based on format and colored flag
412    colors: Dict[str, str]
413    if colored:
414        colors = COLORS["latex"] if using_tex else COLORS["terminal"]
415    else:
416        colors = COLORS["none"]
417
418    # Get symbols for the current format
419    symbols: Dict[str, str] = SYMBOLS[fmt]
420
421    # Helper function to colorize text
422    def colorize(text: str, color_key: str) -> str:
423        if using_tex:
424            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
425        else:
426            return (
427                f"{colors[color_key]}{text}{colors['reset']}"
428                if colors[color_key]
429                else text
430            )
431
432    # Format string for numbers
433    float_fmt: str = f".{precision}f"
434
435    # Handle error status or empty array
436    if (
437        array_data["status"] in ["empty array", "all NaN", "unknown"]
438        or array_data["size"] == 0
439    ):
440        status = array_data["status"]
441        result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
442    else:
443        # Add NaN warning at the beginning if there are NaNs
444        if array_data["has_nans"]:
445            _percent: str = "\\%" if using_tex else "%"
446            nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
447            result_parts.append(colorize(nan_str, "warning"))
448
449        # Statistics
450        if stats:
451            for stat_key in ["mean", "std", "median"]:
452                if array_data[stat_key] is not None:
453                    stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
454                    stat_colored: str = colorize(stat_str, stat_key)
455                    result_parts.append(f"{symbols[stat_key]}={stat_colored}")
456
457            # Range (min, max)
458            if array_data["range"] is not None:
459                min_val, max_val = array_data["range"]
460                min_str: str = f"{min_val:{float_fmt}}"
461                max_str: str = f"{max_val:{float_fmt}}"
462                min_colored: str = colorize(min_str, "range")
463                max_colored: str = colorize(max_str, "range")
464                range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
465                result_parts.append(range_str)
466
467    # Add sparkline if requested
468    if sparkline and array_data["histogram"] is not None:
469        print(array_data["histogram"])
470        print(array_data["bins"])
471        spark = generate_sparkline(
472            array_data["histogram"], format=fmt, log_y=sparkline_logy
473        )
474        if spark:
475            spark_colored = colorize(spark, "sparkline")
476            result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|")
477
478    # Add shape if requested
479    if shape and array_data["shape"]:
480        shape_val = array_data["shape"]
481        if len(shape_val) == 1:
482            shape_str = str(shape_val[0])
483        else:
484            shape_str = (
485                "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")"
486            )
487        result_parts.append(f"shape{eq_char}{shape_str}")
488
489    # Add dtype if requested
490    if dtype and array_data["dtype"]:
491        result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype"))
492
493    # Add device if requested and it's a tensor with device info
494    if device and array_data["is_tensor"] and array_data["device"]:
495        result_parts.append(
496            colorize(f"device{eq_char}{array_data['device']}", "device")
497        )
498
499    # Add gradient info
500    if requires_grad and array_data["is_tensor"] and array_data["requires_grad"]:
501        result_parts.append(colorize(symbols["requires_grad"], "requires_grad"))
502
503    # Return as list if requested, otherwise join with spaces
504    if as_list:
505        return result_parts
506    else:
507        joinchar: str = r" \quad " if using_tex else " "
508        return joinchar.join(result_parts)

COLORS: Dict[str, Dict[str, str]] = {'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'reset': ''}}
OutputFormat = typing.Literal['unicode', 'latex', 'ascii']
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] = {'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'nan_values': '\\texttt{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad'}}

Symbols for different formats

SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] = {'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}

characters for sparklines in different formats

def array_info(A: Any, hist_bins: int = 5) -> Dict[str, Any]:
 93def array_info(
 94    A: Any,
 95    hist_bins: int = 5,
 96) -> Dict[str, Any]:
 97    """Extract statistical information from an array-like object.
 98
 99    # Parameters:
100     - `A : array-like`
101            Array to analyze (numpy array or torch tensor)
102
103    # Returns:
104     - `Dict[str, Any]`
105            Dictionary containing raw statistical information with numeric values
106    """
107    result: Dict[str, Any] = {
108        "is_tensor": None,
109        "device": None,
110        "requires_grad": None,
111        "shape": None,
112        "dtype": None,
113        "size": None,
114        "has_nans": None,
115        "nan_count": None,
116        "nan_percent": None,
117        "min": None,
118        "max": None,
119        "range": None,
120        "mean": None,
121        "std": None,
122        "median": None,
123        "histogram": None,
124        "bins": None,
125        "status": None,
126    }
127
128    # Check if it's a tensor by looking at its class name
129    # This avoids importing torch directly
130    A_type: str = type(A).__name__
131    result["is_tensor"] = A_type == "Tensor"
132
133    # Try to get device information if it's a tensor
134    if result["is_tensor"]:
135        try:
136            result["device"] = str(getattr(A, "device", None))
137        except:  # noqa: E722
138            pass
139
140    # Convert to numpy array for calculations
141    try:
142        # For PyTorch tensors
143        if result["is_tensor"]:
144            # Check if tensor is on GPU
145            is_cuda: bool = False
146            try:
147                is_cuda = bool(getattr(A, "is_cuda", False))
148            except:  # noqa: E722
149                pass
150
151            if is_cuda:
152                try:
153                    # Try to get CPU tensor first
154                    cpu_tensor = getattr(A, "cpu", lambda: A)()
155                except:  # noqa: E722
156                    A_np = np.array([])
157            else:
158                cpu_tensor = A
159            try:
160                # For CPU tensor, just detach and convert
161                detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
162                A_np = getattr(detached, "numpy", lambda: np.array([]))()
163            except:  # noqa: E722
164                A_np = np.array([])
165        else:
166            # For numpy arrays and other array-like objects
167            A_np = np.asarray(A)
168    except:  # noqa: E722
169        A_np = np.array([])
170
171    # Get basic information
172    try:
173        result["shape"] = A_np.shape
174        result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
175        result["size"] = A_np.size
176        result["requires_grad"] = getattr(A, "requires_grad", None)
177    except:  # noqa: E722
178        pass
179
180    # If array is empty, return early
181    if result["size"] == 0:
182        result["status"] = "empty array"
183        return result
184
185    # Flatten array for statistics if it's multi-dimensional
186    try:
187        if len(A_np.shape) > 1:
188            A_flat = A_np.flatten()
189        else:
190            A_flat = A_np
191    except:  # noqa: E722
192        A_flat = A_np
193
194    # Check for NaN values
195    try:
196        nan_mask = np.isnan(A_flat)
197        result["nan_count"] = np.sum(nan_mask)
198        result["has_nans"] = result["nan_count"] > 0
199        if result["size"] > 0:
200            result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
201    except:  # noqa: E722
202        pass
203
204    # If all values are NaN, return early
205    if result["has_nans"] and result["nan_count"] == result["size"]:
206        result["status"] = "all NaN"
207        return result
208
209    # Calculate statistics
210    try:
211        if result["has_nans"]:
212            result["min"] = float(np.nanmin(A_flat))
213            result["max"] = float(np.nanmax(A_flat))
214            result["mean"] = float(np.nanmean(A_flat))
215            result["std"] = float(np.nanstd(A_flat))
216            result["median"] = float(np.nanmedian(A_flat))
217            result["range"] = (result["min"], result["max"])
218
219            # Remove NaNs for histogram
220            A_hist = A_flat[~nan_mask]
221        else:
222            result["min"] = float(np.min(A_flat))
223            result["max"] = float(np.max(A_flat))
224            result["mean"] = float(np.mean(A_flat))
225            result["std"] = float(np.std(A_flat))
226            result["median"] = float(np.median(A_flat))
227            result["range"] = (result["min"], result["max"])
228
229            A_hist = A_flat
230
231        # Calculate histogram data for sparklines
232        if A_hist.size > 0:
233            try:
234                hist, bins = np.histogram(A_hist, bins=hist_bins)
235                result["histogram"] = hist
236                result["bins"] = bins
237            except:  # noqa: E722
238                pass
239
240        result["status"] = "ok"
241    except Exception as e:
242        result["status"] = f"error: {str(e)}"
243
244    return result

Extract statistical information from an array-like object.

Parameters:

  • A : array-like Array to analyze (numpy array or torch tensor)

Returns:

  • Dict[str, Any] Dictionary containing raw statistical information with numeric values
def generate_sparkline( histogram: numpy.ndarray, format: Literal['unicode', 'latex', 'ascii'] = 'unicode', log_y: bool = False) -> str:
247def generate_sparkline(
248    histogram: np.ndarray,
249    format: Literal["unicode", "latex", "ascii"] = "unicode",
250    log_y: bool = False,
251) -> str:
252    """Generate a sparkline visualization of the histogram.
253
254    # Parameters:
255     - `histogram : np.ndarray`
256            Histogram data
257     - `format : Literal["unicode", "latex", "ascii"]`
258            Output format (defaults to `"unicode"`)
259     - `log_y : bool`
260            Whether to use logarithmic y-scale (defaults to `False`)
261
262    # Returns:
263     - `str`
264            Sparkline visualization
265    """
266    if histogram is None or len(histogram) == 0:
267        return ""
268
269    # Get the appropriate character set
270    if format in SPARK_CHARS:
271        chars = SPARK_CHARS[format]
272    else:
273        chars = SPARK_CHARS["ascii"]
274
275    # Handle log scale
276    if log_y:
277        # Add small value to avoid log(0)
278        hist_data = np.log1p(histogram)
279    else:
280        hist_data = histogram
281
282    # Normalize to character set range
283    if hist_data.max() > 0:
284        normalized = hist_data / hist_data.max() * (len(chars) - 1)
285    else:
286        normalized = np.zeros_like(hist_data)
287
288    # Convert to characters
289    spark = ""
290    for val in normalized:
291        idx = int(val)
292        spark += chars[idx]
293
294    return spark

Generate a sparkline visualization of the histogram.

Parameters:

  • histogram : np.ndarray Histogram data
  • format : Literal["unicode", "latex", "ascii"] Output format (defaults to "unicode")
  • log_y : bool Whether to use logarithmic y-scale (defaults to False)

Returns:

  • str Sparkline visualization
DEFAULT_SETTINGS: Dict[str, Any] = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': False, 'colored': False, 'as_list': False, 'eq_char': '='}
def array_summary( array, fmt: Literal['unicode', 'latex', 'ascii'] = <muutils.tensor_info._UseDefaultType object>, precision: int = <muutils.tensor_info._UseDefaultType object>, stats: bool = <muutils.tensor_info._UseDefaultType object>, shape: bool = <muutils.tensor_info._UseDefaultType object>, dtype: bool = <muutils.tensor_info._UseDefaultType object>, device: bool = <muutils.tensor_info._UseDefaultType object>, requires_grad: bool = <muutils.tensor_info._UseDefaultType object>, sparkline: bool = <muutils.tensor_info._UseDefaultType object>, sparkline_bins: int = <muutils.tensor_info._UseDefaultType object>, sparkline_logy: bool = <muutils.tensor_info._UseDefaultType object>, colored: bool = <muutils.tensor_info._UseDefaultType object>, eq_char: str = <muutils.tensor_info._UseDefaultType object>, as_list: bool = <muutils.tensor_info._UseDefaultType object>) -> Union[str, List[str]]:
331def array_summary(  # type: ignore[misc]
332    array,
333    fmt: OutputFormat = _USE_DEFAULT,  # type: ignore[assignment]
334    precision: int = _USE_DEFAULT,  # type: ignore[assignment]
335    stats: bool = _USE_DEFAULT,  # type: ignore[assignment]
336    shape: bool = _USE_DEFAULT,  # type: ignore[assignment]
337    dtype: bool = _USE_DEFAULT,  # type: ignore[assignment]
338    device: bool = _USE_DEFAULT,  # type: ignore[assignment]
339    requires_grad: bool = _USE_DEFAULT,  # type: ignore[assignment]
340    sparkline: bool = _USE_DEFAULT,  # type: ignore[assignment]
341    sparkline_bins: int = _USE_DEFAULT,  # type: ignore[assignment]
342    sparkline_logy: bool = _USE_DEFAULT,  # type: ignore[assignment]
343    colored: bool = _USE_DEFAULT,  # type: ignore[assignment]
344    eq_char: str = _USE_DEFAULT,  # type: ignore[assignment]
345    as_list: bool = _USE_DEFAULT,  # type: ignore[assignment]
346) -> Union[str, List[str]]:
347    """Format array information into a readable summary.
348
349    # Parameters:
350     - `array`
351            array-like object (numpy array or torch tensor)
352     - `precision : int`
353            Decimal places (defaults to `2`)
354     - `format : Literal["unicode", "latex", "ascii"]`
355            Output format (defaults to `{default_fmt}`)
356     - `stats : bool`
357            Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
358     - `shape : bool`
359            Whether to include shape info (defaults to `True`)
360     - `dtype : bool`
361            Whether to include dtype info (defaults to `True`)
362     - `device : bool`
363            Whether to include device info for torch tensors (defaults to `True`)
364     - `requires_grad : bool`
365            Whether to include requires_grad info for torch tensors (defaults to `True`)
366     - `sparkline : bool`
367            Whether to include a sparkline visualization (defaults to `False`)
368     - `sparkline_width : int`
369            Width of the sparkline (defaults to `20`)
370     - `sparkline_logy : bool`
371            Whether to use logarithmic y-scale for sparkline (defaults to `False`)
372     - `colored : bool`
373            Whether to add color to output (defaults to `False`)
374     - `as_list : bool`
375            Whether to return as list of strings instead of joined string (defaults to `False`)
376
377    # Returns:
378     - `Union[str, List[str]]`
379            Formatted statistical summary, either as string or list of strings
380    """
381    if fmt is _USE_DEFAULT:
382        fmt = DEFAULT_SETTINGS["fmt"]
383    if precision is _USE_DEFAULT:
384        precision = DEFAULT_SETTINGS["precision"]
385    if stats is _USE_DEFAULT:
386        stats = DEFAULT_SETTINGS["stats"]
387    if shape is _USE_DEFAULT:
388        shape = DEFAULT_SETTINGS["shape"]
389    if dtype is _USE_DEFAULT:
390        dtype = DEFAULT_SETTINGS["dtype"]
391    if device is _USE_DEFAULT:
392        device = DEFAULT_SETTINGS["device"]
393    if requires_grad is _USE_DEFAULT:
394        requires_grad = DEFAULT_SETTINGS["requires_grad"]
395    if sparkline is _USE_DEFAULT:
396        sparkline = DEFAULT_SETTINGS["sparkline"]
397    if sparkline_bins is _USE_DEFAULT:
398        sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
399    if sparkline_logy is _USE_DEFAULT:
400        sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
401    if colored is _USE_DEFAULT:
402        colored = DEFAULT_SETTINGS["colored"]
403    if as_list is _USE_DEFAULT:
404        as_list = DEFAULT_SETTINGS["as_list"]
405    if eq_char is _USE_DEFAULT:
406        eq_char = DEFAULT_SETTINGS["eq_char"]
407
408    array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
409    result_parts: List[str] = []
410    using_tex: bool = fmt == "latex"
411
412    # Set color scheme based on format and colored flag
413    colors: Dict[str, str]
414    if colored:
415        colors = COLORS["latex"] if using_tex else COLORS["terminal"]
416    else:
417        colors = COLORS["none"]
418
419    # Get symbols for the current format
420    symbols: Dict[str, str] = SYMBOLS[fmt]
421
422    # Helper function to colorize text
423    def colorize(text: str, color_key: str) -> str:
424        if using_tex:
425            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
426        else:
427            return (
428                f"{colors[color_key]}{text}{colors['reset']}"
429                if colors[color_key]
430                else text
431            )
432
433    # Format string for numbers
434    float_fmt: str = f".{precision}f"
435
436    # Handle error status or empty array
437    if (
438        array_data["status"] in ["empty array", "all NaN", "unknown"]
439        or array_data["size"] == 0
440    ):
441        status = array_data["status"]
442        result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
443    else:
444        # Add NaN warning at the beginning if there are NaNs
445        if array_data["has_nans"]:
446            _percent: str = "\\%" if using_tex else "%"
447            nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
448            result_parts.append(colorize(nan_str, "warning"))
449
450        # Statistics
451        if stats:
452            for stat_key in ["mean", "std", "median"]:
453                if array_data[stat_key] is not None:
454                    stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
455                    stat_colored: str = colorize(stat_str, stat_key)
456                    result_parts.append(f"{symbols[stat_key]}={stat_colored}")
457
458            # Range (min, max)
459            if array_data["range"] is not None:
460                min_val, max_val = array_data["range"]
461                min_str: str = f"{min_val:{float_fmt}}"
462                max_str: str = f"{max_val:{float_fmt}}"
463                min_colored: str = colorize(min_str, "range")
464                max_colored: str = colorize(max_str, "range")
465                range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
466                result_parts.append(range_str)
467
468    # Add sparkline if requested
469    if sparkline and array_data["histogram"] is not None:
470        print(array_data["histogram"])
471        print(array_data["bins"])
472        spark = generate_sparkline(
473            array_data["histogram"], format=fmt, log_y=sparkline_logy
474        )
475        if spark:
476            spark_colored = colorize(spark, "sparkline")
477            result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|")
478
479    # Add shape if requested
480    if shape and array_data["shape"]:
481        shape_val = array_data["shape"]
482        if len(shape_val) == 1:
483            shape_str = str(shape_val[0])
484        else:
485            shape_str = (
486                "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")"
487            )
488        result_parts.append(f"shape{eq_char}{shape_str}")
489
490    # Add dtype if requested
491    if dtype and array_data["dtype"]:
492        result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype"))
493
494    # Add device if requested and it's a tensor with device info
495    if device and array_data["is_tensor"] and array_data["device"]:
496        result_parts.append(
497            colorize(f"device{eq_char}{array_data['device']}", "device")
498        )
499
500    # Add gradient info
501    if requires_grad and array_data["is_tensor"] and array_data["requires_grad"]:
502        result_parts.append(colorize(symbols["requires_grad"], "requires_grad"))
503
504    # Return as list if requested, otherwise join with spaces
505    if as_list:
506        return result_parts
507    else:
508        joinchar: str = r" \quad " if using_tex else " "
509        return joinchar.join(result_parts)

Format array information into a readable summary.

Parameters:

  • array array-like object (numpy array or torch tensor)
  • precision : int Decimal places (defaults to 2)
  • format : Literal["unicode", "latex", "ascii"] Output format (defaults to {default_fmt})
  • stats : bool Whether to include statistical info (μ, σ, x̃) (defaults to True)
  • shape : bool Whether to include shape info (defaults to True)
  • dtype : bool Whether to include dtype info (defaults to True)
  • device : bool Whether to include device info for torch tensors (defaults to True)
  • requires_grad : bool Whether to include requires_grad info for torch tensors (defaults to True)
  • sparkline : bool Whether to include a sparkline visualization (defaults to False)
  • sparkline_width : int Width of the sparkline (defaults to 20)
  • sparkline_logy : bool Whether to use logarithmic y-scale for sparkline (defaults to False)
  • colored : bool Whether to add color to output (defaults to False)
  • as_list : bool Whether to return as list of strings instead of joined string (defaults to False)

Returns:

  • Union[str, List[str]] Formatted statistical summary, either as string or list of strings