Coverage for muutils/dbg.py: 91%
58 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"""
3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from
4https://github.com/tylerwince/pydbg/blob/master/pydbg.py
5although it has been significantly modified
7licensed under MIT:
9Copyright (c) 2019 Tyler Wince
11Permission is hereby granted, free of charge, to any person obtaining a copy
12of this software and associated documentation files (the "Software"), to deal
13in the Software without restriction, including without limitation the rights
14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15copies of the Software, and to permit persons to whom the Software is
16furnished to do so, subject to the following conditions:
18The above copyright notice and this permission notice shall be included in
19all copies or substantial portions of the Software.
21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27THE SOFTWARE.
29"""
31from __future__ import annotations
33import os
34import inspect
35import sys
36import typing
37from pathlib import Path
38import functools
40# type defs
41_ExpType = typing.TypeVar("_ExpType")
44# Sentinel type for no expression passed
45class _NoExpPassedSentinel:
46 """Unique sentinel type used to indicate that no expression was passed."""
48 pass
51_NoExpPassed = _NoExpPassedSentinel()
53# global variables
54_CWD: Path = Path.cwd().absolute()
55_COUNTER: int = 0
57# configuration
58PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
61# path processing
62def _process_path(path: Path) -> str:
63 path_abs: Path = path.absolute()
64 if PATH_MODE == "absolute":
65 fname = path_abs.as_posix()
66 elif PATH_MODE == "relative":
67 try:
68 fname = path_abs.relative_to(
69 Path(os.path.commonpath([path_abs, _CWD]))
70 ).as_posix()
71 except ValueError:
72 fname = path_abs.as_posix()
73 else:
74 raise ValueError("PATH_MODE must be either 'relative' or 'absolute")
76 return fname
79# actual dbg function
80@typing.overload
81def dbg() -> _NoExpPassedSentinel: ...
82@typing.overload
83def dbg(
84 exp: _NoExpPassedSentinel,
85 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
86 val_joiner: str = " = ",
87) -> _NoExpPassedSentinel: ...
88@typing.overload
89def dbg(
90 exp: _ExpType,
91 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
92 val_joiner: str = " = ",
93) -> _ExpType: ...
94def dbg(
95 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
96 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
97 val_joiner: str = " = ",
98) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
99 """Call dbg with any variable or expression.
101 Calling dbg will print to stderr the current filename and lineno,
102 as well as the passed expression and what the expression evaluates to:
104 from muutils.dbg import dbg
106 a = 2
107 b = 5
109 dbg(a+b)
111 def square(x: int) -> int:
112 return x * x
114 dbg(square(a))
116 """
117 global _COUNTER
119 # get the context
120 fname: str = "unknown"
121 line_exp: str = "unknown"
122 for frame in inspect.stack():
123 if frame.code_context is None:
124 continue
125 line: str = frame.code_context[0]
126 if "dbg" in line:
127 start: int = line.find("(") + 1
128 end: int = line.rfind(")")
129 if end == -1:
130 end = len(line)
132 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}"
133 line_exp = line[start:end]
135 break
137 # assemble the message
138 msg: str
139 if exp is _NoExpPassed:
140 # if no expression is passed, just show location and counter value
141 msg = f"[ {fname} ] (dbg {_COUNTER})"
142 _COUNTER += 1
143 else:
144 # if expression passed, format its value and show location, expr, and value
145 exp_val: str = formatter(exp) if formatter else repr(exp)
146 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
148 # print the message
149 print(
150 msg,
151 file=sys.stderr,
152 )
154 # return the expression itself
155 return exp
158# formatted `dbg_*` functions with their helpers
160DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = (
161 dict(
162 fmt="unicode",
163 precision=2,
164 stats=True,
165 shape=True,
166 dtype=True,
167 device=True,
168 requires_grad=True,
169 sparkline=True,
170 sparkline_bins=7,
171 sparkline_logy=False,
172 colored=True,
173 eq_char="=",
174 )
175)
178def tensor_info(tensor: typing.Any) -> str:
179 from muutils.tensor_info import array_summary
181 return array_summary(tensor, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS)
184dbg_tensor = functools.partial(dbg, formatter=tensor_info, val_joiner=": ")