[pytree] Register normal class to register_dataclass (#147752)

Fixes https://github.com/pytorch/pytorch/pull/147532#discussion_r1964365330

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147752
Approved by: https://github.com/zou3519
This commit is contained in:
angelayi
2025-04-01 23:28:20 +00:00
committed by PyTorch MergeBot
parent 203a27e0ce
commit 60fe0922f6
4 changed files with 138 additions and 27 deletions

View File

@ -3859,7 +3859,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
if node.op == "placeholder":
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
@testing.expectedFailureRetraceability
def test_dynamic_shapes_builder_pytree(self):
torch.export.register_dataclass(
Inp1,
@ -5097,7 +5096,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
):
self.assertTrue("source_fn_stack" in node.meta)
@testing.expectedFailureRetraceability
def test_dynamic_shapes_dataclass(self):
torch.export.register_dataclass(
Inp2,
@ -7144,7 +7142,6 @@ def forward(self, b_a_buffer, x):
ep = export(m, ())
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
@testing.expectedFailureRetraceability
def test_preserve_shape_dynamism_for_unused_inputs(self):
torch.export.register_dataclass(
Inp3,

View File

@ -9,9 +9,9 @@ import sys
import time
import unittest
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import auto
from typing import Any, NamedTuple
from typing import Any, NamedTuple, Optional
import torch
import torch.utils._pytree as py_pytree
@ -1297,16 +1297,55 @@ if "optree" in sys.modules:
def test_dataclass(self):
@dataclass
class Point:
x: torch.Tensor
y: torch.Tensor
class Data:
a: torch.Tensor
b: str = "moo"
c: Optional[str] = None
d: str = field(init=False, default="")
py_pytree.register_dataclass(Point)
py_pytree.register_dataclass(Data)
old_data = Data(torch.tensor(3), "b", "c")
old_data.d = "d"
new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data))
self.assertEqual(new_data.a, torch.tensor(3))
self.assertEqual(new_data.b, "b")
self.assertEqual(new_data.c, "c")
self.assertEqual(new_data.d, "")
py_pytree._deregister_pytree_node(Data)
point = Point(torch.tensor(0), torch.tensor(1))
point = py_pytree.tree_map(lambda x: x + 1, point)
self.assertEqual(point.x, torch.tensor(1))
self.assertEqual(point.y, torch.tensor(2))
with self.assertRaisesRegex(ValueError, "Missing fields"):
py_pytree.register_dataclass(Data, field_names=["a", "b"])
with self.assertRaisesRegex(ValueError, "Unexpected fields"):
py_pytree.register_dataclass(Data, field_names=["a", "b", "e"])
with self.assertRaisesRegex(ValueError, "Unexpected fields"):
py_pytree.register_dataclass(Data, field_names=["a", "b", "c", "d"])
py_pytree.register_dataclass(
Data, field_names=["a"], drop_field_names=["b", "c"]
)
old_data = Data(torch.tensor(3), "b", "c")
new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data))
self.assertEqual(new_data.a, torch.tensor(3))
self.assertEqual(new_data.b, "moo")
self.assertEqual(new_data.c, None)
py_pytree._deregister_pytree_node(Data)
def test_register_dataclass_class(self):
class CustomClass:
def __init__(self, x, y):
self.x = x
self.y = y
with self.assertRaisesRegex(ValueError, "field_names must be specified"):
py_pytree.register_dataclass(CustomClass)
py_pytree.register_dataclass(CustomClass, field_names=["x", "y"])
c = CustomClass(torch.tensor(0), torch.tensor(1))
mapped = py_pytree.tree_map(lambda x: x + 1, c)
self.assertEqual(mapped.x, torch.tensor(1))
self.assertEqual(mapped.y, torch.tensor(2))
def test_constant(self):
# Either use `frozen=True` or `unsafe_hash=True` so we have a

View File

@ -523,9 +523,4 @@ def register_dataclass(
print(ep)
"""
from torch._export.utils import register_dataclass_as_pytree_node
return register_dataclass_as_pytree_node(
cls, serialized_type_name=serialized_type_name
)
pytree.register_dataclass(cls, serialized_type_name=serialized_type_name)

View File

@ -205,6 +205,10 @@ def register_pytree_node(
) -> None:
"""Register a container-like type as pytree node.
Note:
:func:`register_dataclass` is a simpler way of registering a container-like
type as a pytree node.
Args:
cls: the type to register
flatten_fn: A callable that takes a pytree and returns a flattened
@ -265,14 +269,34 @@ def register_pytree_node(
_cxx_pytree_pending_imports.append((args, kwargs))
def register_dataclass(cls: type[Any]) -> None:
"""Registers a ``dataclasses.dataclass`` type as a pytree node.
def register_dataclass(
cls: type[Any],
*,
field_names: Optional[list[str]] = None,
drop_field_names: Optional[list[str]] = None,
serialized_type_name: Optional[str] = None,
) -> None:
"""
Registers a type that has the semantics of a ``dataclasses.dataclass`` type
as a pytree node.
This is a simpler API than :func:`register_pytree_node` for registering
a dataclass.
a dataclass or a custom class with the semantics of a dataclass.
Args:
cls: the dataclass type to register
cls: The python type to register. The class must have the semantics of a
dataclass; in particular, it must be constructed by passing the fields
in.
field_names (Optional[List[str]]): A list of field names that correspond
to the **non-constant data** in this class. This list must contain
all the fields that are used to initialize the class. This argument
is optional if ``cls`` is a dataclass, in which case the fields will
be taken from ``dataclasses.fields()``.
drop_field_names (Optional[List[str]]): A list of field names that
should not be included in the pytree.
serialized_type_name: A keyword argument used to specify the fully
qualified name used when serializing the tree spec. This is only
needed for serializing the treespec in torch.export.
Example:
@ -293,11 +317,67 @@ def register_dataclass(cls: type[Any]) -> None:
>>> assert torch.allclose(point.y, torch.tensor(2))
"""
import torch.export
drop_field_names = drop_field_names or []
# Eventually we should move the export code here. It is not specific to export,
# aside from the serialization pieces.
torch.export.register_dataclass(cls)
if not dataclasses.is_dataclass(cls):
if field_names is None:
raise ValueError(
"field_names must be specified with a list of all fields used to "
f"initialize {cls}, as it is not a dataclass."
)
elif field_names is None:
field_names = [f.name for f in dataclasses.fields(cls) if f.init]
else:
dataclass_init_fields = {f.name for f in dataclasses.fields(cls) if f.init}
dataclass_init_fields.difference_update(drop_field_names)
if dataclass_init_fields != set(field_names):
error_msg = "field_names does not include all dataclass fields.\n"
if missing := dataclass_init_fields - set(field_names):
error_msg += (
f"Missing fields in `field_names`: {missing}. If you want "
"to include these fields in the pytree, please add them "
"to `field_names`, otherwise please add them to "
"`drop_field_names`.\n"
)
if unexpected := set(field_names) - dataclass_init_fields:
error_msg += (
f"Unexpected fields in `field_names`: {unexpected}. "
"Please remove these fields, or add them to `drop_field_names`.\n"
)
raise ValueError(error_msg)
def _flatten_fn(obj: Any) -> tuple[list[Any], Context]:
flattened = []
flat_names = []
none_names = []
for name in field_names:
val = getattr(obj, name)
if val is not None:
flattened.append(val)
flat_names.append(name)
else:
none_names.append(name)
return flattened, [flat_names, none_names]
def _unflatten_fn(values: Iterable[Any], context: Context) -> Any:
flat_names, none_names = context
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc]
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
_private_register_pytree_node(
cls,
_flatten_fn,
_unflatten_fn,
serialized_type_name=serialized_type_name,
flatten_with_keys_fn=_flatten_fn_with_keys,
)
CONSTANT_NODES: set[type] = set()