Coverage for muutils/dbg.py: 91%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1""" 

2 

3this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from 

4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 

5although it has been significantly modified 

6 

7licensed under MIT: 

8 

9Copyright (c) 2019 Tyler Wince 

10 

11Permission is hereby granted, free of charge, to any person obtaining a copy 

12of this software and associated documentation files (the "Software"), to deal 

13in the Software without restriction, including without limitation the rights 

14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

15copies of the Software, and to permit persons to whom the Software is 

16furnished to do so, subject to the following conditions: 

17 

18The above copyright notice and this permission notice shall be included in 

19all copies or substantial portions of the Software. 

20 

21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

27THE SOFTWARE. 

28 

29""" 

30 

31from __future__ import annotations 

32 

33import os 

34import inspect 

35import sys 

36import typing 

37from pathlib import Path 

38import functools 

39 

40# type defs 

41_ExpType = typing.TypeVar("_ExpType") 

42 

43 

44# Sentinel type for no expression passed 

45class _NoExpPassedSentinel: 

46 """Unique sentinel type used to indicate that no expression was passed.""" 

47 

48 pass 

49 

50 

51_NoExpPassed = _NoExpPassedSentinel() 

52 

53# global variables 

54_CWD: Path = Path.cwd().absolute() 

55_COUNTER: int = 0 

56 

57# configuration 

58PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 

59 

60 

61# path processing 

62def _process_path(path: Path) -> str: 

63 path_abs: Path = path.absolute() 

64 if PATH_MODE == "absolute": 

65 fname = path_abs.as_posix() 

66 elif PATH_MODE == "relative": 

67 try: 

68 fname = path_abs.relative_to( 

69 Path(os.path.commonpath([path_abs, _CWD])) 

70 ).as_posix() 

71 except ValueError: 

72 fname = path_abs.as_posix() 

73 else: 

74 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 

75 

76 return fname 

77 

78 

79# actual dbg function 

80@typing.overload 

81def dbg() -> _NoExpPassedSentinel: ... 

82@typing.overload 

83def dbg( 

84 exp: _NoExpPassedSentinel, 

85 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

86 val_joiner: str = " = ", 

87) -> _NoExpPassedSentinel: ... 

88@typing.overload 

89def dbg( 

90 exp: _ExpType, 

91 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

92 val_joiner: str = " = ", 

93) -> _ExpType: ... 

94def dbg( 

95 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 

96 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

97 val_joiner: str = " = ", 

98) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 

99 """Call dbg with any variable or expression. 

100 

101 Calling dbg will print to stderr the current filename and lineno, 

102 as well as the passed expression and what the expression evaluates to: 

103 

104 from muutils.dbg import dbg 

105 

106 a = 2 

107 b = 5 

108 

109 dbg(a+b) 

110 

111 def square(x: int) -> int: 

112 return x * x 

113 

114 dbg(square(a)) 

115 

116 """ 

117 global _COUNTER 

118 

119 # get the context 

120 fname: str = "unknown" 

121 line_exp: str = "unknown" 

122 for frame in inspect.stack(): 

123 if frame.code_context is None: 

124 continue 

125 line: str = frame.code_context[0] 

126 if "dbg" in line: 

127 start: int = line.find("(") + 1 

128 end: int = line.rfind(")") 

129 if end == -1: 

130 end = len(line) 

131 

132 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}" 

133 line_exp = line[start:end] 

134 

135 break 

136 

137 # assemble the message 

138 msg: str 

139 if exp is _NoExpPassed: 

140 # if no expression is passed, just show location and counter value 

141 msg = f"[ {fname} ] (dbg {_COUNTER})" 

142 _COUNTER += 1 

143 else: 

144 # if expression passed, format its value and show location, expr, and value 

145 exp_val: str = formatter(exp) if formatter else repr(exp) 

146 msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}" 

147 

148 # print the message 

149 print( 

150 msg, 

151 file=sys.stderr, 

152 ) 

153 

154 # return the expression itself 

155 return exp 

156 

157 

158# formatted `dbg_*` functions with their helpers 

159 

160DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = ( 

161 dict( 

162 fmt="unicode", 

163 precision=2, 

164 stats=True, 

165 shape=True, 

166 dtype=True, 

167 device=True, 

168 requires_grad=True, 

169 sparkline=True, 

170 sparkline_bins=7, 

171 sparkline_logy=False, 

172 colored=True, 

173 eq_char="=", 

174 ) 

175) 

176 

177 

178def tensor_info(tensor: typing.Any) -> str: 

179 from muutils.tensor_info import array_summary 

180 

181 return array_summary(tensor, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS) 

182 

183 

184dbg_tensor = functools.partial(dbg, formatter=tensor_info, val_joiner=": ")