docs for muutils v0.8.10
View Source on GitHub

muutils.dbg

this code is based on an implementation of the Rust builtin dbg! for Python, originally from https://github.com/tylerwince/pydbg/blob/master/pydbg.py although it has been significantly modified

licensed under MIT:

Copyright (c) 2019 Tyler Wince

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


  1"""
  2
  3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from
  4https://github.com/tylerwince/pydbg/blob/master/pydbg.py
  5although it has been significantly modified
  6
  7licensed under MIT:
  8
  9Copyright (c) 2019 Tyler Wince
 10
 11Permission is hereby granted, free of charge, to any person obtaining a copy
 12of this software and associated documentation files (the "Software"), to deal
 13in the Software without restriction, including without limitation the rights
 14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 15copies of the Software, and to permit persons to whom the Software is
 16furnished to do so, subject to the following conditions:
 17
 18The above copyright notice and this permission notice shall be included in
 19all copies or substantial portions of the Software.
 20
 21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 27THE SOFTWARE.
 28
 29"""
 30
 31from __future__ import annotations
 32
 33import inspect
 34import sys
 35import typing
 36from pathlib import Path
 37import functools
 38
 39# type defs
 40_ExpType = typing.TypeVar("_ExpType")
 41
 42
 43# Sentinel type for no expression passed
 44class _NoExpPassedSentinel:
 45    """Unique sentinel type used to indicate that no expression was passed."""
 46
 47    pass
 48
 49
 50_NoExpPassed = _NoExpPassedSentinel()
 51
 52# global variables
 53_CWD: Path = Path.cwd().absolute()
 54_COUNTER: int = 0
 55
 56# configuration
 57PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
 58DEFAULT_VAL_JOINER: str = " = "
 59
 60
 61# path processing
 62def _process_path(path: Path) -> str:
 63    path_abs: Path = path.absolute()
 64    fname: Path
 65    if PATH_MODE == "absolute":
 66        fname = path_abs
 67    elif PATH_MODE == "relative":
 68        try:
 69            # if it's inside the cwd, print the relative path
 70            fname = path.relative_to(_CWD)
 71        except ValueError:
 72            # if its not in the subpath, use the absolute path
 73            fname = path_abs
 74    else:
 75        raise ValueError("PATH_MODE must be either 'relative' or 'absolute")
 76
 77    return fname.as_posix()
 78
 79
 80# actual dbg function
 81@typing.overload
 82def dbg() -> _NoExpPassedSentinel: ...
 83@typing.overload
 84def dbg(
 85    exp: _NoExpPassedSentinel,
 86    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
 87    val_joiner: str = DEFAULT_VAL_JOINER,
 88) -> _NoExpPassedSentinel: ...
 89@typing.overload
 90def dbg(
 91    exp: _ExpType,
 92    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
 93    val_joiner: str = DEFAULT_VAL_JOINER,
 94) -> _ExpType: ...
 95def dbg(
 96    exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
 97    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
 98    val_joiner: str = DEFAULT_VAL_JOINER,
 99) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
100    """Call dbg with any variable or expression.
101
102    Calling dbg will print to stderr the current filename and lineno,
103    as well as the passed expression and what the expression evaluates to:
104
105            from muutils.dbg import dbg
106
107            a = 2
108            b = 5
109
110            dbg(a+b)
111
112            def square(x: int) -> int:
113                    return x * x
114
115            dbg(square(a))
116
117    """
118    global _COUNTER
119
120    # get the context
121    line_exp: str = "unknown"
122    current_file: str = "unknown"
123    dbg_frame: typing.Optional[inspect.FrameInfo] = None
124    for frame in inspect.stack():
125        if frame.code_context is None:
126            continue
127        line: str = frame.code_context[0]
128        if "dbg" in line:
129            current_file = _process_path(Path(frame.filename))
130            dbg_frame = frame
131            start: int = line.find("(") + 1
132            end: int = line.rfind(")")
133            if end == -1:
134                end = len(line)
135            line_exp = line[start:end]
136            break
137
138    fname: str = "unknown"
139    if current_file.startswith("/tmp/ipykernel_"):
140        stack: list[inspect.FrameInfo] = inspect.stack()
141        filtered_functions: list[str] = []
142        # this loop will find, in this order:
143        # - the dbg function call
144        # - the functions we care about displaying
145        # - `<module>`
146        # - a bunch of jupyter internals we don't care about
147        for frame_info in stack:
148            if _process_path(Path(frame_info.filename)) != current_file:
149                continue
150            if frame_info.function == "<module>":
151                break
152            if frame_info.function.startswith("dbg"):
153                continue
154            filtered_functions.append(frame_info.function)
155        if dbg_frame is not None:
156            filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}")
157        else:
158            filtered_functions.append(current_file)
159        filtered_functions.reverse()
160        fname = " -> ".join(filtered_functions)
161    elif dbg_frame is not None:
162        fname = f"{current_file}:{dbg_frame.lineno}"
163
164    # assemble the message
165    msg: str
166    if exp is _NoExpPassed:
167        # if no expression is passed, just show location and counter value
168        msg = f"[ {fname} ] <dbg {_COUNTER}>"
169        _COUNTER += 1
170    else:
171        # if expression passed, format its value and show location, expr, and value
172        exp_val: str = formatter(exp) if formatter else repr(exp)
173        msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
174
175    # print the message
176    print(
177        msg,
178        file=sys.stderr,
179    )
180
181    # return the expression itself
182    return exp
183
184
185# formatted `dbg_*` functions with their helpers
186
187DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = (
188    dict(
189        fmt="unicode",
190        precision=2,
191        stats=True,
192        shape=True,
193        dtype=True,
194        device=True,
195        requires_grad=True,
196        sparkline=True,
197        sparkline_bins=7,
198        sparkline_logy=False,
199        colored=True,
200        eq_char="=",
201    )
202)
203
204
205DBG_TENSOR_VAL_JOINER: str = ": "
206
207
208def tensor_info(tensor: typing.Any) -> str:
209    from muutils.tensor_info import array_summary
210
211    return array_summary(tensor, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS)
212
213
214dbg_tensor = functools.partial(
215    dbg, formatter=tensor_info, val_joiner=DBG_TENSOR_VAL_JOINER
216)

PATH_MODE: Literal['relative', 'absolute'] = 'relative'
DEFAULT_VAL_JOINER: str = ' = '
def dbg( exp: Union[~_ExpType, muutils.dbg._NoExpPassedSentinel] = <muutils.dbg._NoExpPassedSentinel object>, formatter: Optional[Callable[[Any], str]] = None, val_joiner: str = ' = ') -> Union[~_ExpType, muutils.dbg._NoExpPassedSentinel]:
 96def dbg(
 97    exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
 98    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
 99    val_joiner: str = DEFAULT_VAL_JOINER,
100) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
101    """Call dbg with any variable or expression.
102
103    Calling dbg will print to stderr the current filename and lineno,
104    as well as the passed expression and what the expression evaluates to:
105
106            from muutils.dbg import dbg
107
108            a = 2
109            b = 5
110
111            dbg(a+b)
112
113            def square(x: int) -> int:
114                    return x * x
115
116            dbg(square(a))
117
118    """
119    global _COUNTER
120
121    # get the context
122    line_exp: str = "unknown"
123    current_file: str = "unknown"
124    dbg_frame: typing.Optional[inspect.FrameInfo] = None
125    for frame in inspect.stack():
126        if frame.code_context is None:
127            continue
128        line: str = frame.code_context[0]
129        if "dbg" in line:
130            current_file = _process_path(Path(frame.filename))
131            dbg_frame = frame
132            start: int = line.find("(") + 1
133            end: int = line.rfind(")")
134            if end == -1:
135                end = len(line)
136            line_exp = line[start:end]
137            break
138
139    fname: str = "unknown"
140    if current_file.startswith("/tmp/ipykernel_"):
141        stack: list[inspect.FrameInfo] = inspect.stack()
142        filtered_functions: list[str] = []
143        # this loop will find, in this order:
144        # - the dbg function call
145        # - the functions we care about displaying
146        # - `<module>`
147        # - a bunch of jupyter internals we don't care about
148        for frame_info in stack:
149            if _process_path(Path(frame_info.filename)) != current_file:
150                continue
151            if frame_info.function == "<module>":
152                break
153            if frame_info.function.startswith("dbg"):
154                continue
155            filtered_functions.append(frame_info.function)
156        if dbg_frame is not None:
157            filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}")
158        else:
159            filtered_functions.append(current_file)
160        filtered_functions.reverse()
161        fname = " -> ".join(filtered_functions)
162    elif dbg_frame is not None:
163        fname = f"{current_file}:{dbg_frame.lineno}"
164
165    # assemble the message
166    msg: str
167    if exp is _NoExpPassed:
168        # if no expression is passed, just show location and counter value
169        msg = f"[ {fname} ] <dbg {_COUNTER}>"
170        _COUNTER += 1
171    else:
172        # if expression passed, format its value and show location, expr, and value
173        exp_val: str = formatter(exp) if formatter else repr(exp)
174        msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"
175
176    # print the message
177    print(
178        msg,
179        file=sys.stderr,
180    )
181
182    # return the expression itself
183    return exp

Call dbg with any variable or expression.

Calling dbg will print to stderr the current filename and lineno, as well as the passed expression and what the expression evaluates to:

    from muutils.dbg import dbg

    a = 2
    b = 5

    dbg(a+b)

    def square(x: int) -> int:
            return x * x

    dbg(square(a))
DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: Dict[str, Union[str, int, bool]] = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': True, 'sparkline_bins': 7, 'sparkline_logy': False, 'colored': True, 'eq_char': '='}
DBG_TENSOR_VAL_JOINER: str = ': '
def tensor_info(tensor: Any) -> str:
209def tensor_info(tensor: typing.Any) -> str:
210    from muutils.tensor_info import array_summary
211
212    return array_summary(tensor, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS)
dbg_tensor = functools.partial(<function dbg>, formatter=<function tensor_info>, val_joiner=': ')