Coverage for muutils/tensor_info.py: 88%

200 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:35 -0600

1import numpy as np 

2from typing import Union, Any, Literal, List, Dict, overload 

3 

4# Global color definitions 

5COLORS: Dict[str, Dict[str, str]] = { 

6 "latex": { 

7 "range": r"\textcolor{purple}", 

8 "mean": r"\textcolor{teal}", 

9 "std": r"\textcolor{orange}", 

10 "median": r"\textcolor{green}", 

11 "warning": r"\textcolor{red}", 

12 "shape": r"\textcolor{magenta}", 

13 "dtype": r"\textcolor{gray}", 

14 "device": r"\textcolor{gray}", 

15 "requires_grad": r"\textcolor{gray}", 

16 "sparkline": r"\textcolor{blue}", 

17 "reset": "", 

18 }, 

19 "terminal": { 

20 "range": "\033[35m", # purple 

21 "mean": "\033[36m", # cyan/teal 

22 "std": "\033[33m", # yellow/orange 

23 "median": "\033[32m", # green 

24 "warning": "\033[31m", # red 

25 "shape": "\033[95m", # bright magenta 

26 "dtype": "\033[90m", # gray 

27 "device": "\033[90m", # gray 

28 "requires_grad": "\033[90m", # gray 

29 "sparkline": "\033[34m", # blue 

30 "reset": "\033[0m", 

31 }, 

32 "none": { 

33 "range": "", 

34 "mean": "", 

35 "std": "", 

36 "median": "", 

37 "warning": "", 

38 "shape": "", 

39 "dtype": "", 

40 "device": "", 

41 "requires_grad": "", 

42 "sparkline": "", 

43 "reset": "", 

44 }, 

45} 

46 

47OutputFormat = Literal["unicode", "latex", "ascii"] 

48 

49SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { 

50 "latex": { 

51 "range": r"\mathcal{R}", 

52 "mean": r"\mu", 

53 "std": r"\sigma", 

54 "median": r"\tilde{x}", 

55 "distribution": r"\mathbb{P}", 

56 "nan_values": r"\texttt{NANvals}", 

57 "warning": "!!!", 

58 "requires_grad": r"\nabla", 

59 }, 

60 "unicode": { 

61 "range": "R", 

62 "mean": "μ", 

63 "std": "σ", 

64 "median": "x̃", 

65 "distribution": "ℙ", 

66 "nan_values": "NANvals", 

67 "warning": "🚨", 

68 "requires_grad": "∇", 

69 }, 

70 "ascii": { 

71 "range": "range", 

72 "mean": "mean", 

73 "std": "std", 

74 "median": "med", 

75 "distribution": "dist", 

76 "nan_values": "NANvals", 

77 "warning": "!!!", 

78 "requires_grad": "requires_grad", 

79 }, 

80} 

81"Symbols for different formats" 

82 

83SPARK_CHARS: Dict[OutputFormat, List[str]] = { 

84 "unicode": list(" ▁▂▃▄▅▆▇█"), 

85 "ascii": list(" _.-~=#"), 

86 "latex": list(" ▁▂▃▄▅▆▇█"), 

87 # "latex": [r"\textbf{.}", r"\textbf{-}", r"\textbf{=}", r"\textbf{+}", r"\textbf{*}", r"\textbf{\\#}"], 

88} 

89"characters for sparklines in different formats" 

90 

91 

92def array_info( 

93 A: Any, 

94 hist_bins: int = 5, 

95) -> Dict[str, Any]: 

96 """Extract statistical information from an array-like object. 

97 

98 # Parameters: 

99 - `A : array-like` 

100 Array to analyze (numpy array or torch tensor) 

101 

102 # Returns: 

103 - `Dict[str, Any]` 

104 Dictionary containing raw statistical information with numeric values 

105 """ 

106 result: Dict[str, Any] = { 

107 "is_tensor": None, 

108 "device": None, 

109 "requires_grad": None, 

110 "shape": None, 

111 "dtype": None, 

112 "size": None, 

113 "has_nans": None, 

114 "nan_count": None, 

115 "nan_percent": None, 

116 "min": None, 

117 "max": None, 

118 "range": None, 

119 "mean": None, 

120 "std": None, 

121 "median": None, 

122 "histogram": None, 

123 "bins": None, 

124 "status": None, 

125 } 

126 

127 # Check if it's a tensor by looking at its class name 

128 # This avoids importing torch directly 

129 A_type: str = type(A).__name__ 

130 result["is_tensor"] = A_type == "Tensor" 

131 

132 # Try to get device information if it's a tensor 

133 if result["is_tensor"]: 

134 try: 

135 result["device"] = str(getattr(A, "device", None)) 

136 except: # noqa: E722 

137 pass 

138 

139 # Convert to numpy array for calculations 

140 try: 

141 # For PyTorch tensors 

142 if result["is_tensor"]: 

143 # Check if tensor is on GPU 

144 is_cuda: bool = False 

145 try: 

146 is_cuda = bool(getattr(A, "is_cuda", False)) 

147 except: # noqa: E722 

148 pass 

149 

150 if is_cuda: 

151 try: 

152 # Try to get CPU tensor first 

153 cpu_tensor = getattr(A, "cpu", lambda: A)() 

154 except: # noqa: E722 

155 A_np = np.array([]) 

156 else: 

157 cpu_tensor = A 

158 try: 

159 # For CPU tensor, just detach and convert 

160 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 

161 A_np = getattr(detached, "numpy", lambda: np.array([]))() 

162 except: # noqa: E722 

163 A_np = np.array([]) 

164 else: 

165 # For numpy arrays and other array-like objects 

166 A_np = np.asarray(A) 

167 except: # noqa: E722 

168 A_np = np.array([]) 

169 

170 # Get basic information 

171 try: 

172 result["shape"] = A_np.shape 

173 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 

174 result["size"] = A_np.size 

175 result["requires_grad"] = getattr(A, "requires_grad", None) 

176 except: # noqa: E722 

177 pass 

178 

179 # If array is empty, return early 

180 if result["size"] == 0: 

181 result["status"] = "empty array" 

182 return result 

183 

184 # Flatten array for statistics if it's multi-dimensional 

185 try: 

186 if len(A_np.shape) > 1: 

187 A_flat = A_np.flatten() 

188 else: 

189 A_flat = A_np 

190 except: # noqa: E722 

191 A_flat = A_np 

192 

193 # Check for NaN values 

194 try: 

195 nan_mask = np.isnan(A_flat) 

196 result["nan_count"] = np.sum(nan_mask) 

197 result["has_nans"] = result["nan_count"] > 0 

198 if result["size"] > 0: 

199 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 

200 except: # noqa: E722 

201 pass 

202 

203 # If all values are NaN, return early 

204 if result["has_nans"] and result["nan_count"] == result["size"]: 

205 result["status"] = "all NaN" 

206 return result 

207 

208 # Calculate statistics 

209 try: 

210 if result["has_nans"]: 

211 result["min"] = float(np.nanmin(A_flat)) 

212 result["max"] = float(np.nanmax(A_flat)) 

213 result["mean"] = float(np.nanmean(A_flat)) 

214 result["std"] = float(np.nanstd(A_flat)) 

215 result["median"] = float(np.nanmedian(A_flat)) 

216 result["range"] = (result["min"], result["max"]) 

217 

218 # Remove NaNs for histogram 

219 A_hist = A_flat[~nan_mask] 

220 else: 

221 result["min"] = float(np.min(A_flat)) 

222 result["max"] = float(np.max(A_flat)) 

223 result["mean"] = float(np.mean(A_flat)) 

224 result["std"] = float(np.std(A_flat)) 

225 result["median"] = float(np.median(A_flat)) 

226 result["range"] = (result["min"], result["max"]) 

227 

228 A_hist = A_flat 

229 

230 # Calculate histogram data for sparklines 

231 if A_hist.size > 0: 

232 try: 

233 hist, bins = np.histogram(A_hist, bins=hist_bins) 

234 result["histogram"] = hist 

235 result["bins"] = bins 

236 except: # noqa: E722 

237 pass 

238 

239 result["status"] = "ok" 

240 except Exception as e: 

241 result["status"] = f"error: {str(e)}" 

242 

243 return result 

244 

245 

246def generate_sparkline( 

247 histogram: np.ndarray, 

248 format: Literal["unicode", "latex", "ascii"] = "unicode", 

249 log_y: bool = False, 

250) -> str: 

251 """Generate a sparkline visualization of the histogram. 

252 

253 # Parameters: 

254 - `histogram : np.ndarray` 

255 Histogram data 

256 - `format : Literal["unicode", "latex", "ascii"]` 

257 Output format (defaults to `"unicode"`) 

258 - `log_y : bool` 

259 Whether to use logarithmic y-scale (defaults to `False`) 

260 

261 # Returns: 

262 - `str` 

263 Sparkline visualization 

264 """ 

265 if histogram is None or len(histogram) == 0: 

266 return "" 

267 

268 # Get the appropriate character set 

269 if format in SPARK_CHARS: 

270 chars = SPARK_CHARS[format] 

271 else: 

272 chars = SPARK_CHARS["ascii"] 

273 

274 # Handle log scale 

275 if log_y: 

276 # Add small value to avoid log(0) 

277 hist_data = np.log1p(histogram) 

278 else: 

279 hist_data = histogram 

280 

281 # Normalize to character set range 

282 if hist_data.max() > 0: 

283 normalized = hist_data / hist_data.max() * (len(chars) - 1) 

284 else: 

285 normalized = np.zeros_like(hist_data) 

286 

287 # Convert to characters 

288 spark = "" 

289 for val in normalized: 

290 idx = int(val) 

291 spark += chars[idx] 

292 

293 return spark 

294 

295 

296DEFAULT_SETTINGS: Dict[str, Any] = dict( 

297 fmt="unicode", 

298 precision=2, 

299 stats=True, 

300 shape=True, 

301 dtype=True, 

302 device=True, 

303 requires_grad=True, 

304 sparkline=False, 

305 sparkline_bins=5, 

306 sparkline_logy=False, 

307 colored=False, 

308 as_list=False, 

309 eq_char="=", 

310) 

311 

312 

313class _UseDefaultType: 

314 pass 

315 

316 

317_USE_DEFAULT = _UseDefaultType() 

318 

319 

320@overload 

321def array_summary( 

322 as_list: Literal[True], 

323 **kwargs, 

324) -> List[str]: ... 

325@overload 

326def array_summary( 

327 as_list: Literal[False], 

328 **kwargs, 

329) -> str: ... 

330def array_summary( # type: ignore[misc] 

331 array, 

332 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 

333 precision: int = _USE_DEFAULT, # type: ignore[assignment] 

334 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 

335 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 

336 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 

337 device: bool = _USE_DEFAULT, # type: ignore[assignment] 

338 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 

339 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 

340 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 

341 sparkline_logy: bool = _USE_DEFAULT, # type: ignore[assignment] 

342 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 

343 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 

344 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 

345) -> Union[str, List[str]]: 

346 """Format array information into a readable summary. 

347 

348 # Parameters: 

349 - `array` 

350 array-like object (numpy array or torch tensor) 

351 - `precision : int` 

352 Decimal places (defaults to `2`) 

353 - `format : Literal["unicode", "latex", "ascii"]` 

354 Output format (defaults to `{default_fmt}`) 

355 - `stats : bool` 

356 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 

357 - `shape : bool` 

358 Whether to include shape info (defaults to `True`) 

359 - `dtype : bool` 

360 Whether to include dtype info (defaults to `True`) 

361 - `device : bool` 

362 Whether to include device info for torch tensors (defaults to `True`) 

363 - `requires_grad : bool` 

364 Whether to include requires_grad info for torch tensors (defaults to `True`) 

365 - `sparkline : bool` 

366 Whether to include a sparkline visualization (defaults to `False`) 

367 - `sparkline_width : int` 

368 Width of the sparkline (defaults to `20`) 

369 - `sparkline_logy : bool` 

370 Whether to use logarithmic y-scale for sparkline (defaults to `False`) 

371 - `colored : bool` 

372 Whether to add color to output (defaults to `False`) 

373 - `as_list : bool` 

374 Whether to return as list of strings instead of joined string (defaults to `False`) 

375 

376 # Returns: 

377 - `Union[str, List[str]]` 

378 Formatted statistical summary, either as string or list of strings 

379 """ 

380 if fmt is _USE_DEFAULT: 

381 fmt = DEFAULT_SETTINGS["fmt"] 

382 if precision is _USE_DEFAULT: 

383 precision = DEFAULT_SETTINGS["precision"] 

384 if stats is _USE_DEFAULT: 

385 stats = DEFAULT_SETTINGS["stats"] 

386 if shape is _USE_DEFAULT: 

387 shape = DEFAULT_SETTINGS["shape"] 

388 if dtype is _USE_DEFAULT: 

389 dtype = DEFAULT_SETTINGS["dtype"] 

390 if device is _USE_DEFAULT: 

391 device = DEFAULT_SETTINGS["device"] 

392 if requires_grad is _USE_DEFAULT: 

393 requires_grad = DEFAULT_SETTINGS["requires_grad"] 

394 if sparkline is _USE_DEFAULT: 

395 sparkline = DEFAULT_SETTINGS["sparkline"] 

396 if sparkline_bins is _USE_DEFAULT: 

397 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 

398 if sparkline_logy is _USE_DEFAULT: 

399 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 

400 if colored is _USE_DEFAULT: 

401 colored = DEFAULT_SETTINGS["colored"] 

402 if as_list is _USE_DEFAULT: 

403 as_list = DEFAULT_SETTINGS["as_list"] 

404 if eq_char is _USE_DEFAULT: 

405 eq_char = DEFAULT_SETTINGS["eq_char"] 

406 

407 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 

408 result_parts: List[str] = [] 

409 using_tex: bool = fmt == "latex" 

410 

411 # Set color scheme based on format and colored flag 

412 colors: Dict[str, str] 

413 if colored: 

414 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 

415 else: 

416 colors = COLORS["none"] 

417 

418 # Get symbols for the current format 

419 symbols: Dict[str, str] = SYMBOLS[fmt] 

420 

421 # Helper function to colorize text 

422 def colorize(text: str, color_key: str) -> str: 

423 if using_tex: 

424 return f"{colors[color_key]}{ {text}} " if colors[color_key] else text 

425 else: 

426 return ( 

427 f"{colors[color_key]}{text}{colors['reset']}" 

428 if colors[color_key] 

429 else text 

430 ) 

431 

432 # Format string for numbers 

433 float_fmt: str = f".{precision}f" 

434 

435 # Handle error status or empty array 

436 if ( 

437 array_data["status"] in ["empty array", "all NaN", "unknown"] 

438 or array_data["size"] == 0 

439 ): 

440 status = array_data["status"] 

441 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 

442 else: 

443 # Add NaN warning at the beginning if there are NaNs 

444 if array_data["has_nans"]: 

445 _percent: str = "\\%" if using_tex else "%" 

446 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 

447 result_parts.append(colorize(nan_str, "warning")) 

448 

449 # Statistics 

450 if stats: 

451 for stat_key in ["mean", "std", "median"]: 

452 if array_data[stat_key] is not None: 

453 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 

454 stat_colored: str = colorize(stat_str, stat_key) 

455 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 

456 

457 # Range (min, max) 

458 if array_data["range"] is not None: 

459 min_val, max_val = array_data["range"] 

460 min_str: str = f"{min_val:{float_fmt}}" 

461 max_str: str = f"{max_val:{float_fmt}}" 

462 min_colored: str = colorize(min_str, "range") 

463 max_colored: str = colorize(max_str, "range") 

464 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 

465 result_parts.append(range_str) 

466 

467 # Add sparkline if requested 

468 if sparkline and array_data["histogram"] is not None: 

469 print(array_data["histogram"]) 

470 print(array_data["bins"]) 

471 spark = generate_sparkline( 

472 array_data["histogram"], format=fmt, log_y=sparkline_logy 

473 ) 

474 if spark: 

475 spark_colored = colorize(spark, "sparkline") 

476 result_parts.append(f"{symbols['distribution']}{eq_char}|{spark_colored}|") 

477 

478 # Add shape if requested 

479 if shape and array_data["shape"]: 

480 shape_val = array_data["shape"] 

481 if len(shape_val) == 1: 

482 shape_str = str(shape_val[0]) 

483 else: 

484 shape_str = ( 

485 "(" + ",".join(colorize(str(dim), "shape") for dim in shape_val) + ")" 

486 ) 

487 result_parts.append(f"shape{eq_char}{shape_str}") 

488 

489 # Add dtype if requested 

490 if dtype and array_data["dtype"]: 

491 result_parts.append(colorize(f"dtype={array_data['dtype']}", "dtype")) 

492 

493 # Add device if requested and it's a tensor with device info 

494 if device and array_data["is_tensor"] and array_data["device"]: 

495 result_parts.append( 

496 colorize(f"device{eq_char}{array_data['device']}", "device") 

497 ) 

498 

499 # Add gradient info 

500 if requires_grad and array_data["is_tensor"] and array_data["requires_grad"]: 

501 result_parts.append(colorize(symbols["requires_grad"], "requires_grad")) 

502 

503 # Return as list if requested, otherwise join with spaces 

504 if as_list: 

505 return result_parts 

506 else: 

507 joinchar: str = r" \quad " if using_tex else " " 

508 return joinchar.join(result_parts)