diff --git a/test/export/test_export.py b/test/export/test_export.py index 4507bf93b9d6..4fc5515c7665 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3859,7 +3859,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): if node.op == "placeholder": self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") - @testing.expectedFailureRetraceability def test_dynamic_shapes_builder_pytree(self): torch.export.register_dataclass( Inp1, @@ -5097,7 +5096,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): ): self.assertTrue("source_fn_stack" in node.meta) - @testing.expectedFailureRetraceability def test_dynamic_shapes_dataclass(self): torch.export.register_dataclass( Inp2, @@ -7144,7 +7142,6 @@ def forward(self, b_a_buffer, x): ep = export(m, ()) self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"]) - @testing.expectedFailureRetraceability def test_preserve_shape_dynamism_for_unused_inputs(self): torch.export.register_dataclass( Inp3, diff --git a/test/test_pytree.py b/test/test_pytree.py index 99dfba3969ea..82665854c2b1 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -9,9 +9,9 @@ import sys import time import unittest from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import auto -from typing import Any, NamedTuple +from typing import Any, NamedTuple, Optional import torch import torch.utils._pytree as py_pytree @@ -1297,16 +1297,55 @@ if "optree" in sys.modules: def test_dataclass(self): @dataclass - class Point: - x: torch.Tensor - y: torch.Tensor + class Data: + a: torch.Tensor + b: str = "moo" + c: Optional[str] = None + d: str = field(init=False, default="") - py_pytree.register_dataclass(Point) + py_pytree.register_dataclass(Data) + old_data = Data(torch.tensor(3), "b", "c") + old_data.d = "d" + new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) + self.assertEqual(new_data.a, torch.tensor(3)) + self.assertEqual(new_data.b, "b") + self.assertEqual(new_data.c, "c") + self.assertEqual(new_data.d, "") + py_pytree._deregister_pytree_node(Data) - point = Point(torch.tensor(0), torch.tensor(1)) - point = py_pytree.tree_map(lambda x: x + 1, point) - self.assertEqual(point.x, torch.tensor(1)) - self.assertEqual(point.y, torch.tensor(2)) + with self.assertRaisesRegex(ValueError, "Missing fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b"]) + + with self.assertRaisesRegex(ValueError, "Unexpected fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b", "e"]) + + with self.assertRaisesRegex(ValueError, "Unexpected fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b", "c", "d"]) + + py_pytree.register_dataclass( + Data, field_names=["a"], drop_field_names=["b", "c"] + ) + old_data = Data(torch.tensor(3), "b", "c") + new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) + self.assertEqual(new_data.a, torch.tensor(3)) + self.assertEqual(new_data.b, "moo") + self.assertEqual(new_data.c, None) + py_pytree._deregister_pytree_node(Data) + + def test_register_dataclass_class(self): + class CustomClass: + def __init__(self, x, y): + self.x = x + self.y = y + + with self.assertRaisesRegex(ValueError, "field_names must be specified"): + py_pytree.register_dataclass(CustomClass) + + py_pytree.register_dataclass(CustomClass, field_names=["x", "y"]) + c = CustomClass(torch.tensor(0), torch.tensor(1)) + mapped = py_pytree.tree_map(lambda x: x + 1, c) + self.assertEqual(mapped.x, torch.tensor(1)) + self.assertEqual(mapped.y, torch.tensor(2)) def test_constant(self): # Either use `frozen=True` or `unsafe_hash=True` so we have a diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 41b9421be641..e95ac3f3a1df 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -523,9 +523,4 @@ def register_dataclass( print(ep) """ - - from torch._export.utils import register_dataclass_as_pytree_node - - return register_dataclass_as_pytree_node( - cls, serialized_type_name=serialized_type_name - ) + pytree.register_dataclass(cls, serialized_type_name=serialized_type_name) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 27941c68066b..9b5d472321e5 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -205,6 +205,10 @@ def register_pytree_node( ) -> None: """Register a container-like type as pytree node. + Note: + :func:`register_dataclass` is a simpler way of registering a container-like + type as a pytree node. + Args: cls: the type to register flatten_fn: A callable that takes a pytree and returns a flattened @@ -265,14 +269,34 @@ def register_pytree_node( _cxx_pytree_pending_imports.append((args, kwargs)) -def register_dataclass(cls: type[Any]) -> None: - """Registers a ``dataclasses.dataclass`` type as a pytree node. +def register_dataclass( + cls: type[Any], + *, + field_names: Optional[list[str]] = None, + drop_field_names: Optional[list[str]] = None, + serialized_type_name: Optional[str] = None, +) -> None: + """ + Registers a type that has the semantics of a ``dataclasses.dataclass`` type + as a pytree node. This is a simpler API than :func:`register_pytree_node` for registering - a dataclass. + a dataclass or a custom class with the semantics of a dataclass. Args: - cls: the dataclass type to register + cls: The python type to register. The class must have the semantics of a + dataclass; in particular, it must be constructed by passing the fields + in. + field_names (Optional[List[str]]): A list of field names that correspond + to the **non-constant data** in this class. This list must contain + all the fields that are used to initialize the class. This argument + is optional if ``cls`` is a dataclass, in which case the fields will + be taken from ``dataclasses.fields()``. + drop_field_names (Optional[List[str]]): A list of field names that + should not be included in the pytree. + serialized_type_name: A keyword argument used to specify the fully + qualified name used when serializing the tree spec. This is only + needed for serializing the treespec in torch.export. Example: @@ -293,11 +317,67 @@ def register_dataclass(cls: type[Any]) -> None: >>> assert torch.allclose(point.y, torch.tensor(2)) """ - import torch.export + drop_field_names = drop_field_names or [] - # Eventually we should move the export code here. It is not specific to export, - # aside from the serialization pieces. - torch.export.register_dataclass(cls) + if not dataclasses.is_dataclass(cls): + if field_names is None: + raise ValueError( + "field_names must be specified with a list of all fields used to " + f"initialize {cls}, as it is not a dataclass." + ) + elif field_names is None: + field_names = [f.name for f in dataclasses.fields(cls) if f.init] + else: + dataclass_init_fields = {f.name for f in dataclasses.fields(cls) if f.init} + dataclass_init_fields.difference_update(drop_field_names) + + if dataclass_init_fields != set(field_names): + error_msg = "field_names does not include all dataclass fields.\n" + + if missing := dataclass_init_fields - set(field_names): + error_msg += ( + f"Missing fields in `field_names`: {missing}. If you want " + "to include these fields in the pytree, please add them " + "to `field_names`, otherwise please add them to " + "`drop_field_names`.\n" + ) + + if unexpected := set(field_names) - dataclass_init_fields: + error_msg += ( + f"Unexpected fields in `field_names`: {unexpected}. " + "Please remove these fields, or add them to `drop_field_names`.\n" + ) + + raise ValueError(error_msg) + + def _flatten_fn(obj: Any) -> tuple[list[Any], Context]: + flattened = [] + flat_names = [] + none_names = [] + for name in field_names: + val = getattr(obj, name) + if val is not None: + flattened.append(val) + flat_names.append(name) + else: + none_names.append(name) + return flattened, [flat_names, none_names] + + def _unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, none_names = context + return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) + + def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: + flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc] + return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + + _private_register_pytree_node( + cls, + _flatten_fn, + _unflatten_fn, + serialized_type_name=serialized_type_name, + flatten_with_keys_fn=_flatten_fn_with_keys, + ) CONSTANT_NODES: set[type] = set()