docs for muutils v0.8.5
View Source on GitHub

muutils.dictmagic

making working with dictionaries easier

  • DefaulterDict: like a defaultdict, but default_factory is passed the key as an argument
  • various methods for working wit dotlist-nested dicts, converting to and from them
  • condense_nested_dicts: condense a nested dict, by condensing numeric or matching keys with matching values to ranges
  • condense_tensor_dict: convert a dictionary of tensors to a dictionary of shapes
  • kwargs_to_nested_dict: given kwargs from fire, convert them to a nested dict

  1"""making working with dictionaries easier
  2
  3- `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument
  4- various methods for working wit dotlist-nested dicts, converting to and from them
  5- `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges
  6- `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes
  7- `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict
  8"""
  9
 10from __future__ import annotations
 11
 12import typing
 13import warnings
 14from collections import defaultdict
 15from typing import (
 16    Any,
 17    Callable,
 18    Generic,
 19    Hashable,
 20    Iterable,
 21    Literal,
 22    Optional,
 23    TypeVar,
 24    Union,
 25)
 26
 27from muutils.errormode import ErrorMode
 28
 29_KT = TypeVar("_KT")
 30_VT = TypeVar("_VT")
 31
 32
 33class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]):
 34    """like a defaultdict, but default_factory is passed the key as an argument"""
 35
 36    def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs):
 37        if args:
 38            raise TypeError(
 39                f"DefaulterDict does not support positional arguments: *args = {args}"
 40            )
 41        super().__init__(**kwargs)
 42        self.default_factory: Callable[[_KT], _VT] = default_factory
 43
 44    def __getitem__(self, k: _KT) -> _VT:
 45        if k in self:
 46            return dict.__getitem__(self, k)
 47        else:
 48            v: _VT = self.default_factory(k)
 49            dict.__setitem__(self, k, v)
 50            return v
 51
 52
 53def _recursive_defaultdict_ctor() -> defaultdict:
 54    return defaultdict(_recursive_defaultdict_ctor)
 55
 56
 57def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict:
 58    """Convert a defaultdict or DefaulterDict to a normal dict, recursively"""
 59    return {
 60        key: (
 61            defaultdict_to_dict_recursive(value)
 62            if isinstance(value, (defaultdict, DefaulterDict))
 63            else value
 64        )
 65        for key, value in dd.items()
 66    }
 67
 68
 69def dotlist_to_nested_dict(
 70    dot_dict: typing.Dict[str, Any], sep: str = "."
 71) -> typing.Dict[str, Any]:
 72    """Convert a dict with dot-separated keys to a nested dict
 73
 74    Example:
 75
 76        >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
 77        {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
 78    """
 79    nested_dict: defaultdict = _recursive_defaultdict_ctor()
 80    for key, value in dot_dict.items():
 81        if not isinstance(key, str):
 82            raise TypeError(f"key must be a string, got {type(key)}")
 83        keys: list[str] = key.split(sep)
 84        current: defaultdict = nested_dict
 85        # iterate over the keys except the last one
 86        for sub_key in keys[:-1]:
 87            current = current[sub_key]
 88        current[keys[-1]] = value
 89    return defaultdict_to_dict_recursive(nested_dict)
 90
 91
 92def nested_dict_to_dotlist(
 93    nested_dict: typing.Dict[str, Any],
 94    sep: str = ".",
 95    allow_lists: bool = False,
 96) -> dict[str, Any]:
 97    def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]:
 98        items: dict = dict()
 99
100        new_key: str
101        if isinstance(current, dict):
102            # dict case
103            if not current and parent_key:
104                items[parent_key] = current
105            else:
106                for k, v in current.items():
107                    new_key = f"{parent_key}{sep}{k}" if parent_key else k
108                    items.update(_recurse(v, new_key))
109
110        elif allow_lists and isinstance(current, list):
111            # list case
112            for i, item in enumerate(current):
113                new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
114                items.update(_recurse(item, new_key))
115
116        else:
117            # anything else (write value)
118            items[parent_key] = current
119
120        return items
121
122    return _recurse(nested_dict)
123
124
125def update_with_nested_dict(
126    original: dict[str, Any],
127    update: dict[str, Any],
128) -> dict[str, Any]:
129    """Update a dict with a nested dict
130
131    Example:
132    >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
133    {'a': {'b': 2}, 'c': -1}
134
135    # Arguments
136    - `original: dict[str, Any]`
137        the dict to update (will be modified in-place)
138    - `update: dict[str, Any]`
139        the dict to update with
140
141    # Returns
142    - `dict`
143        the updated dict
144    """
145    for key, value in update.items():
146        if key in original:
147            if isinstance(original[key], dict) and isinstance(value, dict):
148                update_with_nested_dict(original[key], value)
149            else:
150                original[key] = value
151        else:
152            original[key] = value
153
154    return original
155
156
157def kwargs_to_nested_dict(
158    kwargs_dict: dict[str, Any],
159    sep: str = ".",
160    strip_prefix: Optional[str] = None,
161    when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN,
162    transform_key: Optional[Callable[[str], str]] = None,
163) -> dict[str, Any]:
164    """given kwargs from fire, convert them to a nested dict
165
166    if strip_prefix is not None, then all keys must start with the prefix. by default,
167    will warn if an unknown prefix is found, but can be set to raise an error or ignore it:
168    `when_unknown_prefix: ErrorMode`
169
170    Example:
171    ```python
172    def main(**kwargs):
173        print(kwargs_to_nested_dict(kwargs))
174    fire.Fire(main)
175    ```
176    running the above script will give:
177    ```bash
178    $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
179    {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
180    ```
181
182    # Arguments
183    - `kwargs_dict: dict[str, Any]`
184        the kwargs dict to convert
185    - `sep: str = "."`
186        the separator to use for nested keys
187    - `strip_prefix: Optional[str] = None`
188        if not None, then all keys must start with this prefix
189    - `when_unknown_prefix: ErrorMode = ErrorMode.WARN`
190        what to do when an unknown prefix is found
191    - `transform_key: Callable[[str], str] | None = None`
192        a function to apply to each key before adding it to the dict (applied after stripping the prefix)
193    """
194    when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix)
195    filtered_kwargs: dict[str, Any] = dict()
196    for key, value in kwargs_dict.items():
197        if strip_prefix is not None:
198            if not key.startswith(strip_prefix):
199                when_unknown_prefix_.process(
200                    f"key '{key}' does not start with '{strip_prefix}'",
201                    except_cls=ValueError,
202                )
203            else:
204                key = key[len(strip_prefix) :]
205
206        if transform_key is not None:
207            key = transform_key(key)
208
209        filtered_kwargs[key] = value
210
211    return dotlist_to_nested_dict(filtered_kwargs, sep=sep)
212
213
214def is_numeric_consecutive(lst: list[str]) -> bool:
215    """Check if the list of keys is numeric and consecutive."""
216    try:
217        numbers: list[int] = [int(x) for x in lst]
218        return sorted(numbers) == list(range(min(numbers), max(numbers) + 1))
219    except ValueError:
220        return False
221
222
223def condense_nested_dicts_numeric_keys(
224    data: dict[str, Any],
225) -> dict[str, Any]:
226    """condense a nested dict, by condensing numeric keys with matching values to ranges
227
228    # Examples:
229    ```python
230    >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
231    {'[1-3]': 1, '[4-6]': 2}
232    >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
233    {"1": {"[1-2]": "a"}, "2": "b"}
234    ```
235    """
236
237    if not isinstance(data, dict):
238        return data
239
240    # Process each sub-dictionary
241    for key, value in list(data.items()):
242        data[key] = condense_nested_dicts_numeric_keys(value)
243
244    # Find all numeric, consecutive keys
245    if is_numeric_consecutive(list(data.keys())):
246        keys: list[str] = sorted(data.keys(), key=lambda x: int(x))
247    else:
248        return data
249
250    # output dict
251    condensed_data: dict[str, Any] = {}
252
253    # Identify ranges of identical values and condense
254    i: int = 0
255    while i < len(keys):
256        j: int = i
257        while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]:
258            j += 1
259        if j > i:  # Found consecutive keys with identical values
260            condensed_key: str = f"[{keys[i]}-{keys[j]}]"
261            condensed_data[condensed_key] = data[keys[i]]
262            i = j + 1
263        else:
264            condensed_data[keys[i]] = data[keys[i]]
265            i += 1
266
267    return condensed_data
268
269
270def condense_nested_dicts_matching_values(
271    data: dict[str, Any],
272    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
273) -> dict[str, Any]:
274    """condense a nested dict, by condensing keys with matching values
275
276    # Examples: TODO
277
278    # Parameters:
279     - `data : dict[str, Any]`
280        data to process
281     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
282        a function to apply to each value before adding it to the dict (if it's not hashable)
283        (defaults to `None`)
284
285    """
286
287    if isinstance(data, dict):
288        data = {
289            key: condense_nested_dicts_matching_values(
290                value, val_condense_fallback_mapping
291            )
292            for key, value in data.items()
293        }
294    else:
295        return data
296
297    # Find all identical values and condense by stitching together keys
298    values_grouped: defaultdict[Any, list[str]] = defaultdict(list)
299    data_persist: dict[str, Any] = dict()
300    for key, value in data.items():
301        if not isinstance(value, dict):
302            try:
303                values_grouped[value].append(key)
304            except TypeError:
305                # If the value is unhashable, use a fallback mapping to find a hashable representation
306                if val_condense_fallback_mapping is not None:
307                    values_grouped[val_condense_fallback_mapping(value)].append(key)
308                else:
309                    data_persist[key] = value
310        else:
311            data_persist[key] = value
312
313    condensed_data = data_persist
314    for value, keys in values_grouped.items():
315        if len(keys) > 1:
316            merged_key = f"[{', '.join(keys)}]"  # Choose an appropriate method to represent merged keys
317            condensed_data[merged_key] = value
318        else:
319            condensed_data[keys[0]] = value
320
321    return condensed_data
322
323
324def condense_nested_dicts(
325    data: dict[str, Any],
326    condense_numeric_keys: bool = True,
327    condense_matching_values: bool = True,
328    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
329) -> dict[str, Any]:
330    """condense a nested dict, by condensing numeric or matching keys with matching values to ranges
331
332    combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()`
333
334    # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes
335    it's not reversible because types are lost to make the printing pretty
336
337    # Parameters:
338     - `data : dict[str, Any]`
339        data to process
340     - `condense_numeric_keys : bool`
341        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]")
342       (defaults to `True`)
343     - `condense_matching_values : bool`
344        whether to condense keys with matching values
345       (defaults to `True`)
346     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
347        a function to apply to each value before adding it to the dict (if it's not hashable)
348       (defaults to `None`)
349
350    """
351
352    condensed_data: dict = data
353    if condense_numeric_keys:
354        condensed_data = condense_nested_dicts_numeric_keys(condensed_data)
355    if condense_matching_values:
356        condensed_data = condense_nested_dicts_matching_values(
357            condensed_data, val_condense_fallback_mapping
358        )
359    return condensed_data
360
361
362def tuple_dims_replace(
363    t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None
364) -> tuple[Union[int, str], ...]:
365    if dims_names_map is None:
366        return t
367    else:
368        return tuple(dims_names_map.get(x, x) for x in t)
369
370
371TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"]  # type: ignore[name-defined] # noqa: F821
372TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]]  # type: ignore[name-defined] # noqa: F821
373TensorDictFormats = Literal["dict", "json", "yaml", "yml"]
374
375
376def _default_shapes_convert(x: tuple) -> str:
377    return str(x).replace('"', "").replace("'", "")
378
379
380def condense_tensor_dict(
381    data: TensorDict | TensorIterable,
382    fmt: TensorDictFormats = "dict",
383    *args,
384    shapes_convert: Callable[[tuple], Any] = _default_shapes_convert,
385    drop_batch_dims: int = 0,
386    sep: str = ".",
387    dims_names_map: Optional[dict[int, str]] = None,
388    condense_numeric_keys: bool = True,
389    condense_matching_values: bool = True,
390    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
391    return_format: Optional[TensorDictFormats] = None,
392) -> Union[str, dict[str, str | tuple[int, ...]]]:
393    """Convert a dictionary of tensors to a dictionary of shapes.
394
395    by default, values are converted to strings of their shapes (for nice printing).
396    If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`.
397
398    # Parameters:
399     - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]`
400        a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` )
401     - `fmt : TensorDictFormats`
402        format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed.
403        (defaults to `'dict'`)
404     - `shapes_convert : Callable[[tuple], Any]`
405        conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes)
406        (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`)
407     - `drop_batch_dims : int`
408        number of leading dimensions to drop from the shape
409        (defaults to `0`)
410     - `sep : str`
411        separator to use for nested keys
412        (defaults to `'.'`)
413     - `dims_names_map : dict[int, str] | None`
414        convert certain dimension values in shape. not perfect, can be buggy
415        (defaults to `None`)
416     - `condense_numeric_keys : bool`
417        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts`
418        (defaults to `True`)
419     - `condense_matching_values : bool`
420        whether to condense keys with matching values, passed on to `condense_nested_dicts`
421        (defaults to `True`)
422     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
423        a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts`
424        (defaults to `None`)
425     - `return_format : TensorDictFormats | None`
426        legacy alias for `fmt` kwarg
427
428    # Returns:
429     - `str|dict[str, str|tuple[int, ...]]`
430        dict if `return_format='dict'`, a string for `json` or `yaml` output
431
432    # Examples:
433    ```python
434    >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
435    >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
436    ```
437    ```yaml
438    embed:
439      W_E: (50257, 768)
440    pos_embed:
441      W_pos: (1024, 768)
442    blocks:
443      '[0-11]':
444        attn:
445          '[W_Q, W_K, W_V]': (12, 768, 64)
446          W_O: (12, 64, 768)
447          '[b_Q, b_K, b_V]': (12, 64)
448          b_O: (768,)
449        mlp:
450          W_in: (768, 3072)
451          b_in: (3072,)
452          W_out: (3072, 768)
453          b_out: (768,)
454    unembed:
455      W_U: (768, 50257)
456      b_U: (50257,)
457    ```
458
459    # Raises:
460     - `ValueError` :  if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed
461    """
462
463    # handle arg processing:
464    # ----------------------------------------------------------------------
465    # make all args except data and format keyword-only
466    assert len(args) == 0, f"unexpected positional args: {args}"
467    # handle legacy return_format
468    if return_format is not None:
469        warnings.warn(
470            "return_format is deprecated, use fmt instead",
471            DeprecationWarning,
472        )
473        fmt = return_format
474
475    # identity function for shapes_convert if not provided
476    if shapes_convert is None:
477        shapes_convert = lambda x: x  # noqa: E731
478
479    # convert to iterable
480    data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = (  # type: ignore # noqa: F821
481        data.items() if hasattr(data, "items") and callable(data.items) else data  # type: ignore
482    )
483
484    # get shapes
485    data_shapes: dict[str, Union[str, tuple[int, ...]]] = {
486        k: shapes_convert(
487            tuple_dims_replace(
488                tuple(v.shape)[drop_batch_dims:],
489                dims_names_map,
490            )
491        )
492        for k, v in data_items
493    }
494
495    # nest the dict
496    data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep)
497
498    # condense the nested dict
499    data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts(
500        data=data_nested,
501        condense_numeric_keys=condense_numeric_keys,
502        condense_matching_values=condense_matching_values,
503        val_condense_fallback_mapping=val_condense_fallback_mapping,
504    )
505
506    # return in the specified format
507    fmt_lower: str = fmt.lower()
508    if fmt_lower == "dict":
509        return data_condensed
510    elif fmt_lower == "json":
511        import json
512
513        return json.dumps(data_condensed, indent=2)
514    elif fmt_lower in ["yaml", "yml"]:
515        try:
516            import yaml  # type: ignore[import-untyped]
517
518            return yaml.dump(data_condensed, sort_keys=False)
519        except ImportError as e:
520            raise ValueError("PyYAML is required for YAML output") from e
521    else:
522        raise ValueError(f"Invalid return format: {fmt}")

class DefaulterDict(typing.Dict[~_KT, ~_VT], typing.Generic[~_KT, ~_VT]):
34class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]):
35    """like a defaultdict, but default_factory is passed the key as an argument"""
36
37    def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs):
38        if args:
39            raise TypeError(
40                f"DefaulterDict does not support positional arguments: *args = {args}"
41            )
42        super().__init__(**kwargs)
43        self.default_factory: Callable[[_KT], _VT] = default_factory
44
45    def __getitem__(self, k: _KT) -> _VT:
46        if k in self:
47            return dict.__getitem__(self, k)
48        else:
49            v: _VT = self.default_factory(k)
50            dict.__setitem__(self, k, v)
51            return v

like a defaultdict, but default_factory is passed the key as an argument

default_factory: Callable[[~_KT], ~_VT]
Inherited Members
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
def defaultdict_to_dict_recursive( dd: Union[collections.defaultdict, DefaulterDict]) -> dict:
58def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict:
59    """Convert a defaultdict or DefaulterDict to a normal dict, recursively"""
60    return {
61        key: (
62            defaultdict_to_dict_recursive(value)
63            if isinstance(value, (defaultdict, DefaulterDict))
64            else value
65        )
66        for key, value in dd.items()
67    }

Convert a defaultdict or DefaulterDict to a normal dict, recursively

def dotlist_to_nested_dict(dot_dict: Dict[str, Any], sep: str = '.') -> Dict[str, Any]:
70def dotlist_to_nested_dict(
71    dot_dict: typing.Dict[str, Any], sep: str = "."
72) -> typing.Dict[str, Any]:
73    """Convert a dict with dot-separated keys to a nested dict
74
75    Example:
76
77        >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
78        {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
79    """
80    nested_dict: defaultdict = _recursive_defaultdict_ctor()
81    for key, value in dot_dict.items():
82        if not isinstance(key, str):
83            raise TypeError(f"key must be a string, got {type(key)}")
84        keys: list[str] = key.split(sep)
85        current: defaultdict = nested_dict
86        # iterate over the keys except the last one
87        for sub_key in keys[:-1]:
88            current = current[sub_key]
89        current[keys[-1]] = value
90    return defaultdict_to_dict_recursive(nested_dict)

Convert a dict with dot-separated keys to a nested dict

Example:

>>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
def nested_dict_to_dotlist( nested_dict: Dict[str, Any], sep: str = '.', allow_lists: bool = False) -> dict[str, typing.Any]:
 93def nested_dict_to_dotlist(
 94    nested_dict: typing.Dict[str, Any],
 95    sep: str = ".",
 96    allow_lists: bool = False,
 97) -> dict[str, Any]:
 98    def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]:
 99        items: dict = dict()
100
101        new_key: str
102        if isinstance(current, dict):
103            # dict case
104            if not current and parent_key:
105                items[parent_key] = current
106            else:
107                for k, v in current.items():
108                    new_key = f"{parent_key}{sep}{k}" if parent_key else k
109                    items.update(_recurse(v, new_key))
110
111        elif allow_lists and isinstance(current, list):
112            # list case
113            for i, item in enumerate(current):
114                new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
115                items.update(_recurse(item, new_key))
116
117        else:
118            # anything else (write value)
119            items[parent_key] = current
120
121        return items
122
123    return _recurse(nested_dict)
def update_with_nested_dict( original: dict[str, typing.Any], update: dict[str, typing.Any]) -> dict[str, typing.Any]:
126def update_with_nested_dict(
127    original: dict[str, Any],
128    update: dict[str, Any],
129) -> dict[str, Any]:
130    """Update a dict with a nested dict
131
132    Example:
133    >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
134    {'a': {'b': 2}, 'c': -1}
135
136    # Arguments
137    - `original: dict[str, Any]`
138        the dict to update (will be modified in-place)
139    - `update: dict[str, Any]`
140        the dict to update with
141
142    # Returns
143    - `dict`
144        the updated dict
145    """
146    for key, value in update.items():
147        if key in original:
148            if isinstance(original[key], dict) and isinstance(value, dict):
149                update_with_nested_dict(original[key], value)
150            else:
151                original[key] = value
152        else:
153            original[key] = value
154
155    return original

Update a dict with a nested dict

Example:

>>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
{'a': {'b': 2}, 'c': -1}

Arguments

  • original: dict[str, Any] the dict to update (will be modified in-place)
  • update: dict[str, Any] the dict to update with

Returns

  • dict the updated dict
def kwargs_to_nested_dict( kwargs_dict: dict[str, typing.Any], sep: str = '.', strip_prefix: Optional[str] = None, when_unknown_prefix: Union[muutils.errormode.ErrorMode, str] = ErrorMode.Warn, transform_key: Optional[Callable[[str], str]] = None) -> dict[str, typing.Any]:
158def kwargs_to_nested_dict(
159    kwargs_dict: dict[str, Any],
160    sep: str = ".",
161    strip_prefix: Optional[str] = None,
162    when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN,
163    transform_key: Optional[Callable[[str], str]] = None,
164) -> dict[str, Any]:
165    """given kwargs from fire, convert them to a nested dict
166
167    if strip_prefix is not None, then all keys must start with the prefix. by default,
168    will warn if an unknown prefix is found, but can be set to raise an error or ignore it:
169    `when_unknown_prefix: ErrorMode`
170
171    Example:
172    ```python
173    def main(**kwargs):
174        print(kwargs_to_nested_dict(kwargs))
175    fire.Fire(main)
176    ```
177    running the above script will give:
178    ```bash
179    $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
180    {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
181    ```
182
183    # Arguments
184    - `kwargs_dict: dict[str, Any]`
185        the kwargs dict to convert
186    - `sep: str = "."`
187        the separator to use for nested keys
188    - `strip_prefix: Optional[str] = None`
189        if not None, then all keys must start with this prefix
190    - `when_unknown_prefix: ErrorMode = ErrorMode.WARN`
191        what to do when an unknown prefix is found
192    - `transform_key: Callable[[str], str] | None = None`
193        a function to apply to each key before adding it to the dict (applied after stripping the prefix)
194    """
195    when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix)
196    filtered_kwargs: dict[str, Any] = dict()
197    for key, value in kwargs_dict.items():
198        if strip_prefix is not None:
199            if not key.startswith(strip_prefix):
200                when_unknown_prefix_.process(
201                    f"key '{key}' does not start with '{strip_prefix}'",
202                    except_cls=ValueError,
203                )
204            else:
205                key = key[len(strip_prefix) :]
206
207        if transform_key is not None:
208            key = transform_key(key)
209
210        filtered_kwargs[key] = value
211
212    return dotlist_to_nested_dict(filtered_kwargs, sep=sep)

given kwargs from fire, convert them to a nested dict

if strip_prefix is not None, then all keys must start with the prefix. by default, will warn if an unknown prefix is found, but can be set to raise an error or ignore it: when_unknown_prefix: ErrorMode

Example:

def main(**kwargs):
    print(kwargs_to_nested_dict(kwargs))
fire.Fire(main)

running the above script will give:

$ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}

Arguments

  • kwargs_dict: dict[str, Any] the kwargs dict to convert
  • sep: str = "." the separator to use for nested keys
  • strip_prefix: Optional[str] = None if not None, then all keys must start with this prefix
  • when_unknown_prefix: ErrorMode = ErrorMode.WARN what to do when an unknown prefix is found
  • transform_key: Callable[[str], str] | None = None a function to apply to each key before adding it to the dict (applied after stripping the prefix)
def is_numeric_consecutive(lst: list[str]) -> bool:
215def is_numeric_consecutive(lst: list[str]) -> bool:
216    """Check if the list of keys is numeric and consecutive."""
217    try:
218        numbers: list[int] = [int(x) for x in lst]
219        return sorted(numbers) == list(range(min(numbers), max(numbers) + 1))
220    except ValueError:
221        return False

Check if the list of keys is numeric and consecutive.

def condense_nested_dicts_numeric_keys(data: dict[str, typing.Any]) -> dict[str, typing.Any]:
224def condense_nested_dicts_numeric_keys(
225    data: dict[str, Any],
226) -> dict[str, Any]:
227    """condense a nested dict, by condensing numeric keys with matching values to ranges
228
229    # Examples:
230    ```python
231    >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
232    {'[1-3]': 1, '[4-6]': 2}
233    >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
234    {"1": {"[1-2]": "a"}, "2": "b"}
235    ```
236    """
237
238    if not isinstance(data, dict):
239        return data
240
241    # Process each sub-dictionary
242    for key, value in list(data.items()):
243        data[key] = condense_nested_dicts_numeric_keys(value)
244
245    # Find all numeric, consecutive keys
246    if is_numeric_consecutive(list(data.keys())):
247        keys: list[str] = sorted(data.keys(), key=lambda x: int(x))
248    else:
249        return data
250
251    # output dict
252    condensed_data: dict[str, Any] = {}
253
254    # Identify ranges of identical values and condense
255    i: int = 0
256    while i < len(keys):
257        j: int = i
258        while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]:
259            j += 1
260        if j > i:  # Found consecutive keys with identical values
261            condensed_key: str = f"[{keys[i]}-{keys[j]}]"
262            condensed_data[condensed_key] = data[keys[i]]
263            i = j + 1
264        else:
265            condensed_data[keys[i]] = data[keys[i]]
266            i += 1
267
268    return condensed_data

condense a nested dict, by condensing numeric keys with matching values to ranges

Examples:

>>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
{'[1-3]': 1, '[4-6]': 2}
>>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
{"1": {"[1-2]": "a"}, "2": "b"}
def condense_nested_dicts_matching_values( data: dict[str, typing.Any], val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None) -> dict[str, typing.Any]:
271def condense_nested_dicts_matching_values(
272    data: dict[str, Any],
273    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
274) -> dict[str, Any]:
275    """condense a nested dict, by condensing keys with matching values
276
277    # Examples: TODO
278
279    # Parameters:
280     - `data : dict[str, Any]`
281        data to process
282     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
283        a function to apply to each value before adding it to the dict (if it's not hashable)
284        (defaults to `None`)
285
286    """
287
288    if isinstance(data, dict):
289        data = {
290            key: condense_nested_dicts_matching_values(
291                value, val_condense_fallback_mapping
292            )
293            for key, value in data.items()
294        }
295    else:
296        return data
297
298    # Find all identical values and condense by stitching together keys
299    values_grouped: defaultdict[Any, list[str]] = defaultdict(list)
300    data_persist: dict[str, Any] = dict()
301    for key, value in data.items():
302        if not isinstance(value, dict):
303            try:
304                values_grouped[value].append(key)
305            except TypeError:
306                # If the value is unhashable, use a fallback mapping to find a hashable representation
307                if val_condense_fallback_mapping is not None:
308                    values_grouped[val_condense_fallback_mapping(value)].append(key)
309                else:
310                    data_persist[key] = value
311        else:
312            data_persist[key] = value
313
314    condensed_data = data_persist
315    for value, keys in values_grouped.items():
316        if len(keys) > 1:
317            merged_key = f"[{', '.join(keys)}]"  # Choose an appropriate method to represent merged keys
318            condensed_data[merged_key] = value
319        else:
320            condensed_data[keys[0]] = value
321
322    return condensed_data

condense a nested dict, by condensing keys with matching values

Examples: TODO

Parameters:

  • data : dict[str, Any] data to process
  • val_condense_fallback_mapping : Callable[[Any], Hashable] | None a function to apply to each value before adding it to the dict (if it's not hashable) (defaults to None)
def condense_nested_dicts( data: dict[str, typing.Any], condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None) -> dict[str, typing.Any]:
325def condense_nested_dicts(
326    data: dict[str, Any],
327    condense_numeric_keys: bool = True,
328    condense_matching_values: bool = True,
329    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
330) -> dict[str, Any]:
331    """condense a nested dict, by condensing numeric or matching keys with matching values to ranges
332
333    combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()`
334
335    # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes
336    it's not reversible because types are lost to make the printing pretty
337
338    # Parameters:
339     - `data : dict[str, Any]`
340        data to process
341     - `condense_numeric_keys : bool`
342        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]")
343       (defaults to `True`)
344     - `condense_matching_values : bool`
345        whether to condense keys with matching values
346       (defaults to `True`)
347     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
348        a function to apply to each value before adding it to the dict (if it's not hashable)
349       (defaults to `None`)
350
351    """
352
353    condensed_data: dict = data
354    if condense_numeric_keys:
355        condensed_data = condense_nested_dicts_numeric_keys(condensed_data)
356    if condense_matching_values:
357        condensed_data = condense_nested_dicts_matching_values(
358            condensed_data, val_condense_fallback_mapping
359        )
360    return condensed_data

condense a nested dict, by condensing numeric or matching keys with matching values to ranges

combines the functionality of condense_nested_dicts_numeric_keys() and condense_nested_dicts_matching_values()

NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes

it's not reversible because types are lost to make the printing pretty

Parameters:

  • data : dict[str, Any] data to process
  • condense_numeric_keys : bool whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") (defaults to True)
  • condense_matching_values : bool whether to condense keys with matching values (defaults to True)
  • val_condense_fallback_mapping : Callable[[Any], Hashable] | None a function to apply to each value before adding it to the dict (if it's not hashable) (defaults to None)
def tuple_dims_replace( t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None) -> tuple[typing.Union[int, str], ...]:
363def tuple_dims_replace(
364    t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None
365) -> tuple[Union[int, str], ...]:
366    if dims_names_map is None:
367        return t
368    else:
369        return tuple(dims_names_map.get(x, x) for x in t)
TensorDict = typing.Dict[str, ForwardRef('torch.Tensor|np.ndarray')]
TensorIterable = typing.Iterable[typing.Tuple[str, ForwardRef('torch.Tensor|np.ndarray')]]
TensorDictFormats = typing.Literal['dict', 'json', 'yaml', 'yml']
def condense_tensor_dict( data: 'TensorDict | TensorIterable', fmt: Literal['dict', 'json', 'yaml', 'yml'] = 'dict', *args, shapes_convert: Callable[[tuple], Any] = <function _default_shapes_convert>, drop_batch_dims: int = 0, sep: str = '.', dims_names_map: Optional[dict[int, str]] = None, condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, return_format: Optional[Literal['dict', 'json', 'yaml', 'yml']] = None) -> Union[str, dict[str, str | tuple[int, ...]]]:
381def condense_tensor_dict(
382    data: TensorDict | TensorIterable,
383    fmt: TensorDictFormats = "dict",
384    *args,
385    shapes_convert: Callable[[tuple], Any] = _default_shapes_convert,
386    drop_batch_dims: int = 0,
387    sep: str = ".",
388    dims_names_map: Optional[dict[int, str]] = None,
389    condense_numeric_keys: bool = True,
390    condense_matching_values: bool = True,
391    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
392    return_format: Optional[TensorDictFormats] = None,
393) -> Union[str, dict[str, str | tuple[int, ...]]]:
394    """Convert a dictionary of tensors to a dictionary of shapes.
395
396    by default, values are converted to strings of their shapes (for nice printing).
397    If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`.
398
399    # Parameters:
400     - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]`
401        a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` )
402     - `fmt : TensorDictFormats`
403        format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed.
404        (defaults to `'dict'`)
405     - `shapes_convert : Callable[[tuple], Any]`
406        conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes)
407        (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`)
408     - `drop_batch_dims : int`
409        number of leading dimensions to drop from the shape
410        (defaults to `0`)
411     - `sep : str`
412        separator to use for nested keys
413        (defaults to `'.'`)
414     - `dims_names_map : dict[int, str] | None`
415        convert certain dimension values in shape. not perfect, can be buggy
416        (defaults to `None`)
417     - `condense_numeric_keys : bool`
418        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts`
419        (defaults to `True`)
420     - `condense_matching_values : bool`
421        whether to condense keys with matching values, passed on to `condense_nested_dicts`
422        (defaults to `True`)
423     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
424        a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts`
425        (defaults to `None`)
426     - `return_format : TensorDictFormats | None`
427        legacy alias for `fmt` kwarg
428
429    # Returns:
430     - `str|dict[str, str|tuple[int, ...]]`
431        dict if `return_format='dict'`, a string for `json` or `yaml` output
432
433    # Examples:
434    ```python
435    >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
436    >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
437    ```
438    ```yaml
439    embed:
440      W_E: (50257, 768)
441    pos_embed:
442      W_pos: (1024, 768)
443    blocks:
444      '[0-11]':
445        attn:
446          '[W_Q, W_K, W_V]': (12, 768, 64)
447          W_O: (12, 64, 768)
448          '[b_Q, b_K, b_V]': (12, 64)
449          b_O: (768,)
450        mlp:
451          W_in: (768, 3072)
452          b_in: (3072,)
453          W_out: (3072, 768)
454          b_out: (768,)
455    unembed:
456      W_U: (768, 50257)
457      b_U: (50257,)
458    ```
459
460    # Raises:
461     - `ValueError` :  if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed
462    """
463
464    # handle arg processing:
465    # ----------------------------------------------------------------------
466    # make all args except data and format keyword-only
467    assert len(args) == 0, f"unexpected positional args: {args}"
468    # handle legacy return_format
469    if return_format is not None:
470        warnings.warn(
471            "return_format is deprecated, use fmt instead",
472            DeprecationWarning,
473        )
474        fmt = return_format
475
476    # identity function for shapes_convert if not provided
477    if shapes_convert is None:
478        shapes_convert = lambda x: x  # noqa: E731
479
480    # convert to iterable
481    data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = (  # type: ignore # noqa: F821
482        data.items() if hasattr(data, "items") and callable(data.items) else data  # type: ignore
483    )
484
485    # get shapes
486    data_shapes: dict[str, Union[str, tuple[int, ...]]] = {
487        k: shapes_convert(
488            tuple_dims_replace(
489                tuple(v.shape)[drop_batch_dims:],
490                dims_names_map,
491            )
492        )
493        for k, v in data_items
494    }
495
496    # nest the dict
497    data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep)
498
499    # condense the nested dict
500    data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts(
501        data=data_nested,
502        condense_numeric_keys=condense_numeric_keys,
503        condense_matching_values=condense_matching_values,
504        val_condense_fallback_mapping=val_condense_fallback_mapping,
505    )
506
507    # return in the specified format
508    fmt_lower: str = fmt.lower()
509    if fmt_lower == "dict":
510        return data_condensed
511    elif fmt_lower == "json":
512        import json
513
514        return json.dumps(data_condensed, indent=2)
515    elif fmt_lower in ["yaml", "yml"]:
516        try:
517            import yaml  # type: ignore[import-untyped]
518
519            return yaml.dump(data_condensed, sort_keys=False)
520        except ImportError as e:
521            raise ValueError("PyYAML is required for YAML output") from e
522    else:
523        raise ValueError(f"Invalid return format: {fmt}")

Convert a dictionary of tensors to a dictionary of shapes.

by default, values are converted to strings of their shapes (for nice printing). If you want the actual shapes, set shapes_convert = lambda x: x or shapes_convert = None.

Parameters:

  • data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]] a either a TensorDict dict from strings to tensors, or an TensorIterable iterable of (key, tensor) pairs (like you might get from a dict().items()) )
  • fmt : TensorDictFormats format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. (defaults to 'dict')
  • shapes_convert : Callable[[tuple], Any] conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes) (defaults to lambdax:str(x).replace('"', '').replace("'", ''))
  • drop_batch_dims : int number of leading dimensions to drop from the shape (defaults to 0)
  • sep : str separator to use for nested keys (defaults to '.')
  • dims_names_map : dict[int, str] | None convert certain dimension values in shape. not perfect, can be buggy (defaults to None)
  • condense_numeric_keys : bool whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to condense_nested_dicts (defaults to True)
  • condense_matching_values : bool whether to condense keys with matching values, passed on to condense_nested_dicts (defaults to True)
  • val_condense_fallback_mapping : Callable[[Any], Hashable] | None a function to apply to each value before adding it to the dict (if it's not hashable), passed on to condense_nested_dicts (defaults to None)
  • return_format : TensorDictFormats | None legacy alias for fmt kwarg

Returns:

  • str|dict[str, str|tuple[int, ...]] dict if return_format='dict', a string for json or yaml output

Examples:

>>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
>>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
embed:
  W_E: (50257, 768)
pos_embed:
  W_pos: (1024, 768)
blocks:
  '[0-11]':
    attn:
      '[W_Q, W_K, W_V]': (12, 768, 64)
      W_O: (12, 64, 768)
      '[b_Q, b_K, b_V]': (12, 64)
      b_O: (768,)
    mlp:
      W_in: (768, 3072)
      b_in: (3072,)
      W_out: (3072, 768)
      b_out: (768,)
unembed:
  W_U: (768, 50257)
  b_U: (50257,)

Raises:

  • ValueError : if return_format is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed