muutils.json_serialize
submodule for serializing things to json in a recoverable way
you can throw any object into muutils.json_serialize.json_serialize
and it will return a JSONitem
, meaning a bool, int, float, str, None, list of JSONitem
s, or a dict mappting to JSONitem
.
The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize
and it will just work. If you want to do so in a recoverable way, check out ZANJ
.
it will do so by looking in DEFAULT_HANDLERS
, which will keep it as-is if its already valid, then try to find a .serialize()
method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer
object and passing a sequence of them to handlers_pre
additionally, SerializeableDataclass
is a special kind of dataclass where you specify how to serialize each field, and a .serialize()
method is automatically added to the class. This is done by using the serializable_dataclass
decorator, inheriting from SerializeableDataclass
, and serializable_field
in place of dataclasses.field
when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ
library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
1"""submodule for serializing things to json in a recoverable way 2 3you can throw *any* object into `muutils.json_serialize.json_serialize` 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`. 5 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ). 7 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre` 9 10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields. 11 12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes. 13 14""" 15 16from __future__ import annotations 17 18from muutils.json_serialize.array import arr_metadata, load_array 19from muutils.json_serialize.json_serialize import ( 20 BASE_HANDLERS, 21 JsonSerializer, 22 json_serialize, 23) 24from muutils.json_serialize.serializable_dataclass import ( 25 SerializableDataclass, 26 serializable_dataclass, 27 serializable_field, 28) 29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq 30 31__all__ = [ 32 # submodules 33 "array", 34 "json_serialize", 35 "serializable_dataclass", 36 "serializable_field", 37 "util", 38 # imports 39 "arr_metadata", 40 "load_array", 41 "BASE_HANDLERS", 42 "JSONitem", 43 "JsonSerializer", 44 "json_serialize", 45 "try_catch", 46 "JSONitem", 47 "dc_eq", 48 "serializable_dataclass", 49 "serializable_field", 50 "SerializableDataclass", 51]
332def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem: 333 """serialize object to json-serializable object with default config""" 334 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)
serialize object to json-serializable object with default config
559@dataclass_transform( 560 field_specifiers=(serializable_field, SerializableField), 561) 562def serializable_dataclass( 563 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 564 _cls=None, # type: ignore 565 *, 566 init: bool = True, 567 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 568 eq: bool = True, 569 order: bool = False, 570 unsafe_hash: bool = False, 571 frozen: bool = False, 572 properties_to_serialize: Optional[list[str]] = None, 573 register_handler: bool = True, 574 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 575 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 576 methods_no_override: list[str] | None = None, 577 **kwargs, 578): 579 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 580 581 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 582 583 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 584 585 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 586 587 Examines PEP 526 `__annotations__` to determine fields. 588 589 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 590 591 ```python 592 @serializable_dataclass(kw_only=True) 593 class Myclass(SerializableDataclass): 594 a: int 595 b: str 596 ``` 597 ```python 598 >>> Myclass(a=1, b="q").serialize() 599 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 600 ``` 601 602 # Parameters: 603 604 - `_cls : _type_` 605 class to decorate. don't pass this arg, just use this as a decorator 606 (defaults to `None`) 607 - `init : bool` 608 whether to add an `__init__` method 609 *(passed to dataclasses.dataclass)* 610 (defaults to `True`) 611 - `repr : bool` 612 whether to add a `__repr__` method 613 *(passed to dataclasses.dataclass)* 614 (defaults to `True`) 615 - `order : bool` 616 whether to add rich comparison methods 617 *(passed to dataclasses.dataclass)* 618 (defaults to `False`) 619 - `unsafe_hash : bool` 620 whether to add a `__hash__` method 621 *(passed to dataclasses.dataclass)* 622 (defaults to `False`) 623 - `frozen : bool` 624 whether to make the class frozen 625 *(passed to dataclasses.dataclass)* 626 (defaults to `False`) 627 - `properties_to_serialize : Optional[list[str]]` 628 which properties to add to the serialized data dict 629 **SerializableDataclass only** 630 (defaults to `None`) 631 - `register_handler : bool` 632 if true, register the class with ZANJ for loading 633 **SerializableDataclass only** 634 (defaults to `True`) 635 - `on_typecheck_error : ErrorMode` 636 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 637 **SerializableDataclass only** 638 - `on_typecheck_mismatch : ErrorMode` 639 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 640 **SerializableDataclass only** 641 - `methods_no_override : list[str]|None` 642 list of methods that should not be overridden by the decorator 643 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 644 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 645 **SerializableDataclass only** 646 (defaults to `None`) 647 - `**kwargs` 648 *(passed to dataclasses.dataclass)* 649 650 # Returns: 651 652 - `_type_` 653 the decorated class 654 655 # Raises: 656 657 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 658 - `NotSerializableFieldException` : if a field is not a `SerializableField` 659 - `FieldSerializationError` : if there is an error serializing a field 660 - `AttributeError` : if a property is not found on the class 661 - `FieldLoadingError` : if there is an error loading a field 662 """ 663 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 664 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 665 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 666 667 if properties_to_serialize is None: 668 _properties_to_serialize: list = list() 669 else: 670 _properties_to_serialize = properties_to_serialize 671 672 def wrap(cls: Type[T]) -> Type[T]: 673 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 674 for field_name, field_type in cls.__annotations__.items(): 675 field_value = getattr(cls, field_name, None) 676 if not isinstance(field_value, SerializableField): 677 if isinstance(field_value, dataclasses.Field): 678 # Convert the field to a SerializableField while preserving properties 679 field_value = SerializableField.from_Field(field_value) 680 else: 681 # Create a new SerializableField 682 field_value = serializable_field() 683 setattr(cls, field_name, field_value) 684 685 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 686 if sys.version_info < (3, 10): 687 if "kw_only" in kwargs: 688 if kwargs["kw_only"] == True: # noqa: E712 689 raise KWOnlyError( 690 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 691 ) 692 else: 693 del kwargs["kw_only"] 694 695 # call `dataclasses.dataclass` to set some stuff up 696 cls = dataclasses.dataclass( # type: ignore[call-overload] 697 cls, 698 init=init, 699 repr=repr, 700 eq=eq, 701 order=order, 702 unsafe_hash=unsafe_hash, 703 frozen=frozen, 704 **kwargs, 705 ) 706 707 # copy these to the class 708 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 709 710 # ====================================================================== 711 # define `serialize` func 712 # done locally since it depends on args to the decorator 713 # ====================================================================== 714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result 771 772 # ====================================================================== 773 # define `load` func 774 # done locally since it depends on args to the decorator 775 # ====================================================================== 776 # mypy thinks this isnt a classmethod 777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output 864 865 _methods_no_override: set[str] 866 if methods_no_override is None: 867 _methods_no_override = set() 868 else: 869 _methods_no_override = set(methods_no_override) 870 871 if _methods_no_override - { 872 "__eq__", 873 "serialize", 874 "load", 875 "validate_fields_types", 876 }: 877 warnings.warn( 878 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 879 ) 880 881 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 882 if "serialize" not in _methods_no_override: 883 # type is `Callable[[T], dict]` 884 cls.serialize = serialize # type: ignore[attr-defined] 885 if "load" not in _methods_no_override: 886 # type is `Callable[[dict], T]` 887 cls.load = load # type: ignore[attr-defined] 888 889 if "validate_field_type" not in _methods_no_override: 890 # type is `Callable[[T, ErrorMode], bool]` 891 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 892 893 if "__eq__" not in _methods_no_override: 894 # type is `Callable[[T, T], bool]` 895 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 896 897 # Register the class with ZANJ 898 if register_handler: 899 zanj_register_loader_serializable_dataclass(cls) 900 901 return cls 902 903 if _cls is None: 904 return wrap 905 else: 906 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass
!!
types will be validated (like pydantic) unless on_typecheck_mismatch
is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass
, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine fields.
If init is true, an __init__()
method is added to the class. If repr is true, a __repr__()
method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__()
method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_
class to decorate. don't pass this arg, just use this as a decorator (defaults toNone
)init : bool
whether to add an__init__
method (passed to dataclasses.dataclass) (defaults toTrue
)repr : bool
whether to add a__repr__
method (passed to dataclasses.dataclass) (defaults toTrue
)order : bool
whether to add rich comparison methods (passed to dataclasses.dataclass) (defaults toFalse
)unsafe_hash : bool
whether to add a__hash__
method (passed to dataclasses.dataclass) (defaults toFalse
)frozen : bool
whether to make the class frozen (passed to dataclasses.dataclass) (defaults toFalse
)properties_to_serialize : Optional[list[str]]
which properties to add to the serialized data dict SerializableDataclass only (defaults toNone
)register_handler : bool
if true, register the class with ZANJ for loading SerializableDataclass only (defaults toTrue
)on_typecheck_error : ErrorMode
what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, type validation will still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorMode
what to do if a type mismatch is found (except, warn, ignore). Ifignore
, type validation will returnTrue
SerializableDataclass onlymethods_no_override : list[str]|None
list of methods that should not be overridden by the decorator by default,__eq__
,serialize
,load
, andvalidate_fields_types
are overridden by this function, but you can disable this if you'd rather write your own.dataclasses.dataclass
might still overwrite these, and those options take precedence SerializableDataclass only (defaults toNone
)**kwargs
(passed to dataclasses.dataclass)
Returns:
_type_
the decorated class
Raises:
KWOnlyError
: only raised ifkw_only
isTrue
and python version is <3.9, sincedataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not aSerializableField
FieldSerializationError
: if there is an error serializing a fieldAttributeError
: if a property is not found on the classFieldLoadingError
: if there is an error loading a field
188def serializable_field( 189 *_args, 190 default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 191 default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 192 init: bool = True, 193 repr: bool = True, 194 hash: Optional[bool] = None, 195 compare: bool = True, 196 metadata: Optional[types.MappingProxyType] = None, 197 kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 198 serialize: bool = True, 199 serialization_fn: Optional[Callable[[Any], Any]] = None, 200 deserialize_fn: Optional[Callable[[Any], Any]] = None, 201 assert_type: bool = True, 202 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 203 **kwargs: Any, 204) -> Any: 205 """Create a new `SerializableField` 206 207 ``` 208 default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING, 209 default_factory: Callable[[], Sfield_T] 210 | dataclasses._MISSING_TYPE = dataclasses.MISSING, 211 init: bool = True, 212 repr: bool = True, 213 hash: Optional[bool] = None, 214 compare: bool = True, 215 metadata: types.MappingProxyType | None = None, 216 kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, 217 # ---------------------------------------------------------------------- 218 # new in `SerializableField`, not in `dataclasses.Field` 219 serialize: bool = True, 220 serialization_fn: Optional[Callable[[Any], Any]] = None, 221 loading_fn: Optional[Callable[[Any], Any]] = None, 222 deserialize_fn: Optional[Callable[[Any], Any]] = None, 223 assert_type: bool = True, 224 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 225 ``` 226 227 # new Parameters: 228 - `serialize`: whether to serialize this field when serializing the class' 229 - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize` 230 - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is. 231 - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised. 232 - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field. 233 - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking. 234 235 # Gotchas: 236 - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write: 237 238 ```python 239 class MyClass: 240 my_field: int = serializable_field( 241 serialization_fn=lambda x: str(x), 242 loading_fn=lambda x["my_field"]: int(x) 243 ) 244 ``` 245 246 using `deserialize_fn` instead: 247 248 ```python 249 class MyClass: 250 my_field: int = serializable_field( 251 serialization_fn=lambda x: str(x), 252 deserialize_fn=lambda x: int(x) 253 ) 254 ``` 255 256 In the above code, `my_field` is an int but will be serialized as a string. 257 258 note that if not using ZANJ, and you have a class inside a container, you MUST provide 259 `serialization_fn` and `loading_fn` to serialize and load the container. 260 ZANJ will automatically do this for you. 261 262 # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test 263 """ 264 assert len(_args) == 0, f"unexpected positional arguments: {_args}" 265 return SerializableField( 266 default=default, 267 default_factory=default_factory, 268 init=init, 269 repr=repr, 270 hash=hash, 271 compare=compare, 272 metadata=metadata, 273 kw_only=kw_only, 274 serialize=serialize, 275 serialization_fn=serialization_fn, 276 deserialize_fn=deserialize_fn, 277 assert_type=assert_type, 278 custom_typecheck_fn=custom_typecheck_fn, 279 **kwargs, 280 )
Create a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
new Parameters:
serialize
: whether to serialize this field when serializing the class'serialization_fn
: function taking the instance of the field and returning a serializable object. If not provided, will iterate through theSerializerHandler
s defined inmuutils.json_serialize.json_serialize
loading_fn
: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.deserialize_fn
: new alternative toloading_fn
. takes only the field's value, not the whole class. if bothloading_fn
anddeserialize_fn
are provided, an error will be raised.assert_type
: whether to assert the type of the field when loading. ifFalse
, will not check the type of the field.custom_typecheck_fn
: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
Gotchas:
loading_fn
takes the dict of the class, not the field. if you wanted aloading_fn
that does nothing, you'd write:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)
using deserialize_fn
instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)
In the above code, my_field
is an int but will be serialized as a string.
note that if not using ZANJ, and you have a class inside a container, you MUST provide
serialization_fn
and loading_fn
to serialize and load the container.
ZANJ will automatically do this for you.
TODO: custom_value_check_fn
: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
49def arr_metadata(arr) -> dict[str, list[int] | str | int]: 50 """get metadata for a numpy array""" 51 return { 52 "shape": list(arr.shape), 53 "dtype": ( 54 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) 55 ), 56 "n_elements": array_n_elements(arr), 57 }
get metadata for a numpy array
168def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: 169 """load a json-serialized array, infer the mode if not specified""" 170 # return arr if its already a numpy array 171 if isinstance(arr, np.ndarray) and array_mode is None: 172 return arr 173 174 # try to infer the array_mode 175 array_mode_inferred: ArrayMode = infer_array_mode(arr) 176 if array_mode is None: 177 array_mode = array_mode_inferred 178 elif array_mode != array_mode_inferred: 179 warnings.warn( 180 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 181 ) 182 183 # actually load the array 184 if array_mode == "array_list_meta": 185 assert isinstance( 186 arr, typing.Mapping 187 ), f"invalid list format: {type(arr) = }\n{arr = }" 188 data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore 189 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 190 raise ValueError(f"invalid shape: {arr}") 191 return data 192 193 elif array_mode == "array_hex_meta": 194 assert isinstance( 195 arr, typing.Mapping 196 ), f"invalid list format: {type(arr) = }\n{arr = }" 197 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore 198 return data.reshape(arr["shape"]) # type: ignore 199 200 elif array_mode == "array_b64_meta": 201 assert isinstance( 202 arr, typing.Mapping 203 ), f"invalid list format: {type(arr) = }\n{arr = }" 204 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore 205 return data.reshape(arr["shape"]) # type: ignore 206 207 elif array_mode == "list": 208 assert isinstance( 209 arr, typing.Sequence 210 ), f"invalid list format: {type(arr) = }\n{arr = }" 211 return np.array(arr) # type: ignore 212 elif array_mode == "external": 213 # assume ZANJ has taken care of it 214 assert isinstance(arr, typing.Mapping) 215 if "data" not in arr: 216 raise KeyError( 217 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 218 ) 219 return arr["data"] 220 elif array_mode == "zero_dim": 221 assert isinstance(arr, typing.Mapping) 222 data = np.array(arr["data"]) 223 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 224 raise ValueError(f"invalid shape: {arr}") 225 return data 226 else: 227 raise ValueError(f"invalid array_mode: {array_mode}")
load a json-serialized array, infer the mode if not specified
236class JsonSerializer: 237 """Json serialization class (holds configs) 238 239 # Parameters: 240 - `array_mode : ArrayMode` 241 how to write arrays 242 (defaults to `"array_list_meta"`) 243 - `error_mode : ErrorMode` 244 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") 245 (defaults to `"except"`) 246 - `handlers_pre : MonoTuple[SerializerHandler]` 247 handlers to use before the default handlers 248 (defaults to `tuple()`) 249 - `handlers_default : MonoTuple[SerializerHandler]` 250 default handlers to use 251 (defaults to `DEFAULT_HANDLERS`) 252 - `write_only_format : bool` 253 changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) 254 (defaults to `False`) 255 256 # Raises: 257 - `ValueError`: on init, if `args` is not empty 258 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` 259 260 """ 261 262 def __init__( 263 self, 264 *args, 265 array_mode: ArrayMode = "array_list_meta", 266 error_mode: ErrorMode = ErrorMode.EXCEPT, 267 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 268 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 269 write_only_format: bool = False, 270 ): 271 if len(args) > 0: 272 raise ValueError( 273 f"JsonSerializer takes no positional arguments!\n{args = }" 274 ) 275 276 self.array_mode: ArrayMode = array_mode 277 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 278 self.write_only_format: bool = write_only_format 279 # join up the handlers 280 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 281 handlers_default 282 ) 283 284 def json_serialize( 285 self, 286 obj: Any, 287 path: ObjectPath = tuple(), 288 ) -> JSONitem: 289 try: 290 for handler in self.handlers: 291 if handler.check(self, obj, path): 292 output: JSONitem = handler.serialize_func(self, obj, path) 293 if self.write_only_format: 294 if isinstance(output, dict) and _FORMAT_KEY in output: 295 new_fmt: JSONitem = output.pop(_FORMAT_KEY) 296 output["__write_format__"] = new_fmt 297 return output 298 299 raise ValueError(f"no handler found for object with {type(obj) = }") 300 301 except Exception as e: 302 if self.error_mode == "except": 303 obj_str: str = repr(obj) 304 if len(obj_str) > 1000: 305 obj_str = obj_str[:1000] + "..." 306 raise SerializationException( 307 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 308 ) from e 309 elif self.error_mode == "warn": 310 warnings.warn( 311 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 312 ) 313 314 return repr(obj) 315 316 def hashify( 317 self, 318 obj: Any, 319 path: ObjectPath = tuple(), 320 force: bool = True, 321 ) -> Hashableitem: 322 """try to turn any object into something hashable""" 323 data = self.json_serialize(obj, path=path) 324 325 # recursive hashify, turning dicts and lists into tuples 326 return _recursive_hashify(data, force=force)
Json serialization class (holds configs)
Parameters:
array_mode : ArrayMode
how to write arrays (defaults to"array_list_meta"
)error_mode : ErrorMode
what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to"except"
)handlers_pre : MonoTuple[SerializerHandler]
handlers to use before the default handlers (defaults totuple()
)handlers_default : MonoTuple[SerializerHandler]
default handlers to use (defaults toDEFAULT_HANDLERS
)write_only_format : bool
changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults toFalse
)
Raises:
ValueError
: on init, ifargs
is not emptySerializationException
: onjson_serialize()
, if any error occurs when trying to serialize an object anderror_mode
is set toErrorMode.EXCEPT"
262 def __init__( 263 self, 264 *args, 265 array_mode: ArrayMode = "array_list_meta", 266 error_mode: ErrorMode = ErrorMode.EXCEPT, 267 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 268 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 269 write_only_format: bool = False, 270 ): 271 if len(args) > 0: 272 raise ValueError( 273 f"JsonSerializer takes no positional arguments!\n{args = }" 274 ) 275 276 self.array_mode: ArrayMode = array_mode 277 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 278 self.write_only_format: bool = write_only_format 279 # join up the handlers 280 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 281 handlers_default 282 )
284 def json_serialize( 285 self, 286 obj: Any, 287 path: ObjectPath = tuple(), 288 ) -> JSONitem: 289 try: 290 for handler in self.handlers: 291 if handler.check(self, obj, path): 292 output: JSONitem = handler.serialize_func(self, obj, path) 293 if self.write_only_format: 294 if isinstance(output, dict) and _FORMAT_KEY in output: 295 new_fmt: JSONitem = output.pop(_FORMAT_KEY) 296 output["__write_format__"] = new_fmt 297 return output 298 299 raise ValueError(f"no handler found for object with {type(obj) = }") 300 301 except Exception as e: 302 if self.error_mode == "except": 303 obj_str: str = repr(obj) 304 if len(obj_str) > 1000: 305 obj_str = obj_str[:1000] + "..." 306 raise SerializationException( 307 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 308 ) from e 309 elif self.error_mode == "warn": 310 warnings.warn( 311 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 312 ) 313 314 return repr(obj)
316 def hashify( 317 self, 318 obj: Any, 319 path: ObjectPath = tuple(), 320 force: bool = True, 321 ) -> Hashableitem: 322 """try to turn any object into something hashable""" 323 data = self.json_serialize(obj, path=path) 324 325 # recursive hashify, turning dicts and lists into tuples 326 return _recursive_hashify(data, force=force)
try to turn any object into something hashable
99def try_catch(func: Callable): 100 """wraps the function to catch exceptions, returns serialized error message on exception 101 102 returned func will return normal result on success, or error message on exception 103 """ 104 105 @functools.wraps(func) 106 def newfunc(*args, **kwargs): 107 try: 108 return func(*args, **kwargs) 109 except Exception as e: 110 return f"{e.__class__.__name__}: {e}" 111 112 return newfunc
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
193def dc_eq( 194 dc1, 195 dc2, 196 except_when_class_mismatch: bool = False, 197 false_when_class_mismatch: bool = True, 198 except_when_field_mismatch: bool = False, 199) -> bool: 200 """ 201 checks if two dataclasses which (might) hold numpy arrays are equal 202 203 # Parameters: 204 205 - `dc1`: the first dataclass 206 - `dc2`: the second dataclass 207 - `except_when_class_mismatch: bool` 208 if `True`, will throw `TypeError` if the classes are different. 209 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 210 (default: `False`) 211 - `false_when_class_mismatch: bool` 212 only relevant if `except_when_class_mismatch` is `False`. 213 if `True`, will return `False` if the classes are different. 214 if `False`, will attempt to compare the fields. 215 - `except_when_field_mismatch: bool` 216 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 217 if `True`, will throw `TypeError` if the fields are different. 218 (default: `True`) 219 220 # Returns: 221 - `bool`: True if the dataclasses are equal, False otherwise 222 223 # Raises: 224 - `TypeError`: if the dataclasses are of different classes 225 - `AttributeError`: if the dataclasses have different fields 226 227 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"? 228 ``` 229 [START] 230 ▼ 231 ┌───────────┐ ┌─────────┐ 232 │dc1 is dc2?├─►│ classes │ 233 └──┬────────┘No│ match? │ 234 ──── │ ├─────────┤ 235 (True)◄──┘Yes │No │Yes 236 ──── ▼ ▼ 237 ┌────────────────┐ ┌────────────┐ 238 │ except when │ │ fields keys│ 239 │ class mismatch?│ │ match? │ 240 ├───────────┬────┘ ├───────┬────┘ 241 │Yes │No │No │Yes 242 ▼ ▼ ▼ ▼ 243 ─────────── ┌──────────┐ ┌────────┐ 244 { raise } │ except │ │ field │ 245 { TypeError } │ when │ │ values │ 246 ─────────── │ field │ │ match? │ 247 │ mismatch?│ ├────┬───┘ 248 ├───────┬──┘ │ │Yes 249 │Yes │No │No ▼ 250 ▼ ▼ │ ──── 251 ─────────────── ───── │ (True) 252 { raise } (False)◄┘ ──── 253 { AttributeError} ───── 254 ─────────────── 255 ``` 256 257 """ 258 if dc1 is dc2: 259 return True 260 261 if dc1.__class__ is not dc2.__class__: 262 if except_when_class_mismatch: 263 # if the classes don't match, raise an error 264 raise TypeError( 265 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 266 ) 267 if except_when_field_mismatch: 268 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 269 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 270 fields_match: bool = set(dc1_fields) == set(dc2_fields) 271 if not fields_match: 272 # if the fields match, keep going 273 raise AttributeError( 274 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 275 ) 276 return False 277 278 return all( 279 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 280 for fld in dataclasses.fields(dc1) 281 if fld.compare 282 )
checks if two dataclasses which (might) hold numpy arrays are equal
Parameters:
dc1
: the first dataclassdc2
: the second dataclassexcept_when_class_mismatch: bool
ifTrue
, will throwTypeError
if the classes are different. if not, will return false by default or attempt to compare the fields iffalse_when_class_mismatch
isFalse
(default:False
)false_when_class_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
. ifTrue
, will returnFalse
if the classes are different. ifFalse
, will attempt to compare the fields.except_when_field_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
andfalse_when_class_mismatch
isFalse
. ifTrue
, will throwTypeError
if the fields are different. (default:True
)
Returns:
bool
: True if the dataclasses are equal, False otherwise
Raises:
TypeError
: if the dataclasses are of different classesAttributeError
: if the dataclasses have different fields
TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
[START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
295@dataclass_transform( 296 field_specifiers=(serializable_field, SerializableField), 297) 298class SerializableDataclass(abc.ABC): 299 """Base class for serializable dataclasses 300 301 only for linting and type checking, still need to call `serializable_dataclass` decorator 302 303 # Usage: 304 305 ```python 306 @serializable_dataclass 307 class MyClass(SerializableDataclass): 308 a: int 309 b: str 310 ``` 311 312 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 313 314 >>> my_obj = MyClass(a=1, b="q") 315 >>> s = json.dumps(my_obj.serialize()) 316 >>> s 317 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 318 >>> read_obj = MyClass.load(json.loads(s)) 319 >>> read_obj == my_obj 320 True 321 322 This isn't too impressive on its own, but it gets more useful when you have nested classses, 323 or fields that are not json-serializable by default: 324 325 ```python 326 @serializable_dataclass 327 class NestedClass(SerializableDataclass): 328 x: str 329 y: MyClass 330 act_fun: torch.nn.Module = serializable_field( 331 default=torch.nn.ReLU(), 332 serialization_fn=lambda x: str(x), 333 deserialize_fn=lambda x: getattr(torch.nn, x)(), 334 ) 335 ``` 336 337 which gives us: 338 339 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 340 >>> s = json.dumps(nc.serialize()) 341 >>> s 342 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 343 >>> read_nc = NestedClass.load(json.loads(s)) 344 >>> read_nc == nc 345 True 346 """ 347 348 def serialize(self) -> dict[str, Any]: 349 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 350 raise NotImplementedError( 351 f"decorate {self.__class__ = } with `@serializable_dataclass`" 352 ) 353 354 @classmethod 355 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 356 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 357 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 358 359 def validate_fields_types( 360 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 361 ) -> bool: 362 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 363 return SerializableDataclass__validate_fields_types( 364 self, on_typecheck_error=on_typecheck_error 365 ) 366 367 def validate_field_type( 368 self, 369 field: "SerializableField|str", 370 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 371 ) -> bool: 372 """given a dataclass, check the field matches the type hint""" 373 return SerializableDataclass__validate_field_type( 374 self, field, on_typecheck_error=on_typecheck_error 375 ) 376 377 def __eq__(self, other: Any) -> bool: 378 return dc_eq(self, other) 379 380 def __hash__(self) -> int: 381 "hashes the json-serialized representation of the class" 382 return hash(json.dumps(self.serialize())) 383 384 def diff( 385 self, other: "SerializableDataclass", of_serialized: bool = False 386 ) -> dict[str, Any]: 387 """get a rich and recursive diff between two instances of a serializable dataclass 388 389 ```python 390 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 391 {'b': {'self': 2, 'other': 3}} 392 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 393 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 394 ``` 395 396 # Parameters: 397 - `other : SerializableDataclass` 398 other instance to compare against 399 - `of_serialized : bool` 400 if true, compare serialized data and not raw values 401 (defaults to `False`) 402 403 # Returns: 404 - `dict[str, Any]` 405 406 407 # Raises: 408 - `ValueError` : if the instances are not of the same type 409 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 410 """ 411 # match types 412 if type(self) is not type(other): 413 raise ValueError( 414 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 415 ) 416 417 # initialize the diff result 418 diff_result: dict = {} 419 420 # if they are the same, return the empty diff 421 try: 422 if self == other: 423 return diff_result 424 except Exception: 425 pass 426 427 # if we are working with serialized data, serialize the instances 428 if of_serialized: 429 ser_self: dict = self.serialize() 430 ser_other: dict = other.serialize() 431 432 # for each field in the class 433 for field in dataclasses.fields(self): # type: ignore[arg-type] 434 # skip fields that are not for comparison 435 if not field.compare: 436 continue 437 438 # get values 439 field_name: str = field.name 440 self_value = getattr(self, field_name) 441 other_value = getattr(other, field_name) 442 443 # if the values are both serializable dataclasses, recurse 444 if isinstance(self_value, SerializableDataclass) and isinstance( 445 other_value, SerializableDataclass 446 ): 447 nested_diff: dict = self_value.diff( 448 other_value, of_serialized=of_serialized 449 ) 450 if nested_diff: 451 diff_result[field_name] = nested_diff 452 # only support serializable dataclasses 453 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 454 other_value 455 ): 456 raise ValueError("Non-serializable dataclass is not supported") 457 else: 458 # get the values of either the serialized or the actual values 459 self_value_s = ser_self[field_name] if of_serialized else self_value 460 other_value_s = ser_other[field_name] if of_serialized else other_value 461 # compare the values 462 if not array_safe_eq(self_value_s, other_value_s): 463 diff_result[field_name] = {"self": self_value, "other": other_value} 464 465 # return the diff result 466 return diff_result 467 468 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 469 """update the instance from a nested dict, useful for configuration from command line args 470 471 # Parameters: 472 - `nested_dict : dict[str, Any]` 473 nested dict to update the instance with 474 """ 475 for field in dataclasses.fields(self): # type: ignore[arg-type] 476 field_name: str = field.name 477 self_value = getattr(self, field_name) 478 479 if field_name in nested_dict: 480 if isinstance(self_value, SerializableDataclass): 481 self_value.update_from_nested_dict(nested_dict[field_name]) 482 else: 483 setattr(self, field_name, nested_dict[field_name]) 484 485 def __copy__(self) -> "SerializableDataclass": 486 "deep copy by serializing and loading the instance to json" 487 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 488 489 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 490 "deep copy by serializing and loading the instance to json" 491 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass
decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
348 def serialize(self) -> dict[str, Any]: 349 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 350 raise NotImplementedError( 351 f"decorate {self.__class__ = } with `@serializable_dataclass`" 352 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
354 @classmethod 355 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 356 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 357 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
359 def validate_fields_types( 360 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 361 ) -> bool: 362 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 363 return SerializableDataclass__validate_fields_types( 364 self, on_typecheck_error=on_typecheck_error 365 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
367 def validate_field_type( 368 self, 369 field: "SerializableField|str", 370 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 371 ) -> bool: 372 """given a dataclass, check the field matches the type hint""" 373 return SerializableDataclass__validate_field_type( 374 self, field, on_typecheck_error=on_typecheck_error 375 )
given a dataclass, check the field matches the type hint
384 def diff( 385 self, other: "SerializableDataclass", of_serialized: bool = False 386 ) -> dict[str, Any]: 387 """get a rich and recursive diff between two instances of a serializable dataclass 388 389 ```python 390 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 391 {'b': {'self': 2, 'other': 3}} 392 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 393 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 394 ``` 395 396 # Parameters: 397 - `other : SerializableDataclass` 398 other instance to compare against 399 - `of_serialized : bool` 400 if true, compare serialized data and not raw values 401 (defaults to `False`) 402 403 # Returns: 404 - `dict[str, Any]` 405 406 407 # Raises: 408 - `ValueError` : if the instances are not of the same type 409 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 410 """ 411 # match types 412 if type(self) is not type(other): 413 raise ValueError( 414 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 415 ) 416 417 # initialize the diff result 418 diff_result: dict = {} 419 420 # if they are the same, return the empty diff 421 try: 422 if self == other: 423 return diff_result 424 except Exception: 425 pass 426 427 # if we are working with serialized data, serialize the instances 428 if of_serialized: 429 ser_self: dict = self.serialize() 430 ser_other: dict = other.serialize() 431 432 # for each field in the class 433 for field in dataclasses.fields(self): # type: ignore[arg-type] 434 # skip fields that are not for comparison 435 if not field.compare: 436 continue 437 438 # get values 439 field_name: str = field.name 440 self_value = getattr(self, field_name) 441 other_value = getattr(other, field_name) 442 443 # if the values are both serializable dataclasses, recurse 444 if isinstance(self_value, SerializableDataclass) and isinstance( 445 other_value, SerializableDataclass 446 ): 447 nested_diff: dict = self_value.diff( 448 other_value, of_serialized=of_serialized 449 ) 450 if nested_diff: 451 diff_result[field_name] = nested_diff 452 # only support serializable dataclasses 453 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 454 other_value 455 ): 456 raise ValueError("Non-serializable dataclass is not supported") 457 else: 458 # get the values of either the serialized or the actual values 459 self_value_s = ser_self[field_name] if of_serialized else self_value 460 other_value_s = ser_other[field_name] if of_serialized else other_value 461 # compare the values 462 if not array_safe_eq(self_value_s, other_value_s): 463 diff_result[field_name] = {"self": self_value, "other": other_value} 464 465 # return the diff result 466 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclass
other instance to compare againstof_serialized : bool
if true, compare serialized data and not raw values (defaults toFalse
)
Returns:
dict[str, Any]
Raises:
ValueError
: if the instances are not of the same typeValueError
: if the instances aredataclasses.dataclass
but notSerializableDataclass
468 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 469 """update the instance from a nested dict, useful for configuration from command line args 470 471 # Parameters: 472 - `nested_dict : dict[str, Any]` 473 nested dict to update the instance with 474 """ 475 for field in dataclasses.fields(self): # type: ignore[arg-type] 476 field_name: str = field.name 477 self_value = getattr(self, field_name) 478 479 if field_name in nested_dict: 480 if isinstance(self_value, SerializableDataclass): 481 self_value.update_from_nested_dict(nested_dict[field_name]) 482 else: 483 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with