mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] Register normal class to register_dataclass (#147752)
Fixes https://github.com/pytorch/pytorch/pull/147532#discussion_r1964365330 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147752 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
203a27e0ce
commit
60fe0922f6
@ -3859,7 +3859,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
if node.op == "placeholder":
|
||||
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
|
||||
|
||||
@testing.expectedFailureRetraceability
|
||||
def test_dynamic_shapes_builder_pytree(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp1,
|
||||
@ -5097,7 +5096,6 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
):
|
||||
self.assertTrue("source_fn_stack" in node.meta)
|
||||
|
||||
@testing.expectedFailureRetraceability
|
||||
def test_dynamic_shapes_dataclass(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp2,
|
||||
@ -7144,7 +7142,6 @@ def forward(self, b_a_buffer, x):
|
||||
ep = export(m, ())
|
||||
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
|
||||
|
||||
@testing.expectedFailureRetraceability
|
||||
def test_preserve_shape_dynamism_for_unused_inputs(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp3,
|
||||
|
@ -9,9 +9,9 @@ import sys
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from enum import auto
|
||||
from typing import Any, NamedTuple
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as py_pytree
|
||||
@ -1297,16 +1297,55 @@ if "optree" in sys.modules:
|
||||
|
||||
def test_dataclass(self):
|
||||
@dataclass
|
||||
class Point:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
class Data:
|
||||
a: torch.Tensor
|
||||
b: str = "moo"
|
||||
c: Optional[str] = None
|
||||
d: str = field(init=False, default="")
|
||||
|
||||
py_pytree.register_dataclass(Point)
|
||||
py_pytree.register_dataclass(Data)
|
||||
old_data = Data(torch.tensor(3), "b", "c")
|
||||
old_data.d = "d"
|
||||
new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data))
|
||||
self.assertEqual(new_data.a, torch.tensor(3))
|
||||
self.assertEqual(new_data.b, "b")
|
||||
self.assertEqual(new_data.c, "c")
|
||||
self.assertEqual(new_data.d, "")
|
||||
py_pytree._deregister_pytree_node(Data)
|
||||
|
||||
point = Point(torch.tensor(0), torch.tensor(1))
|
||||
point = py_pytree.tree_map(lambda x: x + 1, point)
|
||||
self.assertEqual(point.x, torch.tensor(1))
|
||||
self.assertEqual(point.y, torch.tensor(2))
|
||||
with self.assertRaisesRegex(ValueError, "Missing fields"):
|
||||
py_pytree.register_dataclass(Data, field_names=["a", "b"])
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Unexpected fields"):
|
||||
py_pytree.register_dataclass(Data, field_names=["a", "b", "e"])
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Unexpected fields"):
|
||||
py_pytree.register_dataclass(Data, field_names=["a", "b", "c", "d"])
|
||||
|
||||
py_pytree.register_dataclass(
|
||||
Data, field_names=["a"], drop_field_names=["b", "c"]
|
||||
)
|
||||
old_data = Data(torch.tensor(3), "b", "c")
|
||||
new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data))
|
||||
self.assertEqual(new_data.a, torch.tensor(3))
|
||||
self.assertEqual(new_data.b, "moo")
|
||||
self.assertEqual(new_data.c, None)
|
||||
py_pytree._deregister_pytree_node(Data)
|
||||
|
||||
def test_register_dataclass_class(self):
|
||||
class CustomClass:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "field_names must be specified"):
|
||||
py_pytree.register_dataclass(CustomClass)
|
||||
|
||||
py_pytree.register_dataclass(CustomClass, field_names=["x", "y"])
|
||||
c = CustomClass(torch.tensor(0), torch.tensor(1))
|
||||
mapped = py_pytree.tree_map(lambda x: x + 1, c)
|
||||
self.assertEqual(mapped.x, torch.tensor(1))
|
||||
self.assertEqual(mapped.y, torch.tensor(2))
|
||||
|
||||
def test_constant(self):
|
||||
# Either use `frozen=True` or `unsafe_hash=True` so we have a
|
||||
|
@ -523,9 +523,4 @@ def register_dataclass(
|
||||
print(ep)
|
||||
|
||||
"""
|
||||
|
||||
from torch._export.utils import register_dataclass_as_pytree_node
|
||||
|
||||
return register_dataclass_as_pytree_node(
|
||||
cls, serialized_type_name=serialized_type_name
|
||||
)
|
||||
pytree.register_dataclass(cls, serialized_type_name=serialized_type_name)
|
||||
|
@ -205,6 +205,10 @@ def register_pytree_node(
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node.
|
||||
|
||||
Note:
|
||||
:func:`register_dataclass` is a simpler way of registering a container-like
|
||||
type as a pytree node.
|
||||
|
||||
Args:
|
||||
cls: the type to register
|
||||
flatten_fn: A callable that takes a pytree and returns a flattened
|
||||
@ -265,14 +269,34 @@ def register_pytree_node(
|
||||
_cxx_pytree_pending_imports.append((args, kwargs))
|
||||
|
||||
|
||||
def register_dataclass(cls: type[Any]) -> None:
|
||||
"""Registers a ``dataclasses.dataclass`` type as a pytree node.
|
||||
def register_dataclass(
|
||||
cls: type[Any],
|
||||
*,
|
||||
field_names: Optional[list[str]] = None,
|
||||
drop_field_names: Optional[list[str]] = None,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Registers a type that has the semantics of a ``dataclasses.dataclass`` type
|
||||
as a pytree node.
|
||||
|
||||
This is a simpler API than :func:`register_pytree_node` for registering
|
||||
a dataclass.
|
||||
a dataclass or a custom class with the semantics of a dataclass.
|
||||
|
||||
Args:
|
||||
cls: the dataclass type to register
|
||||
cls: The python type to register. The class must have the semantics of a
|
||||
dataclass; in particular, it must be constructed by passing the fields
|
||||
in.
|
||||
field_names (Optional[List[str]]): A list of field names that correspond
|
||||
to the **non-constant data** in this class. This list must contain
|
||||
all the fields that are used to initialize the class. This argument
|
||||
is optional if ``cls`` is a dataclass, in which case the fields will
|
||||
be taken from ``dataclasses.fields()``.
|
||||
drop_field_names (Optional[List[str]]): A list of field names that
|
||||
should not be included in the pytree.
|
||||
serialized_type_name: A keyword argument used to specify the fully
|
||||
qualified name used when serializing the tree spec. This is only
|
||||
needed for serializing the treespec in torch.export.
|
||||
|
||||
Example:
|
||||
|
||||
@ -293,11 +317,67 @@ def register_dataclass(cls: type[Any]) -> None:
|
||||
>>> assert torch.allclose(point.y, torch.tensor(2))
|
||||
|
||||
"""
|
||||
import torch.export
|
||||
drop_field_names = drop_field_names or []
|
||||
|
||||
# Eventually we should move the export code here. It is not specific to export,
|
||||
# aside from the serialization pieces.
|
||||
torch.export.register_dataclass(cls)
|
||||
if not dataclasses.is_dataclass(cls):
|
||||
if field_names is None:
|
||||
raise ValueError(
|
||||
"field_names must be specified with a list of all fields used to "
|
||||
f"initialize {cls}, as it is not a dataclass."
|
||||
)
|
||||
elif field_names is None:
|
||||
field_names = [f.name for f in dataclasses.fields(cls) if f.init]
|
||||
else:
|
||||
dataclass_init_fields = {f.name for f in dataclasses.fields(cls) if f.init}
|
||||
dataclass_init_fields.difference_update(drop_field_names)
|
||||
|
||||
if dataclass_init_fields != set(field_names):
|
||||
error_msg = "field_names does not include all dataclass fields.\n"
|
||||
|
||||
if missing := dataclass_init_fields - set(field_names):
|
||||
error_msg += (
|
||||
f"Missing fields in `field_names`: {missing}. If you want "
|
||||
"to include these fields in the pytree, please add them "
|
||||
"to `field_names`, otherwise please add them to "
|
||||
"`drop_field_names`.\n"
|
||||
)
|
||||
|
||||
if unexpected := set(field_names) - dataclass_init_fields:
|
||||
error_msg += (
|
||||
f"Unexpected fields in `field_names`: {unexpected}. "
|
||||
"Please remove these fields, or add them to `drop_field_names`.\n"
|
||||
)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _flatten_fn(obj: Any) -> tuple[list[Any], Context]:
|
||||
flattened = []
|
||||
flat_names = []
|
||||
none_names = []
|
||||
for name in field_names:
|
||||
val = getattr(obj, name)
|
||||
if val is not None:
|
||||
flattened.append(val)
|
||||
flat_names.append(name)
|
||||
else:
|
||||
none_names.append(name)
|
||||
return flattened, [flat_names, none_names]
|
||||
|
||||
def _unflatten_fn(values: Iterable[Any], context: Context) -> Any:
|
||||
flat_names, none_names = context
|
||||
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
|
||||
|
||||
def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
|
||||
flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc]
|
||||
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
|
||||
|
||||
_private_register_pytree_node(
|
||||
cls,
|
||||
_flatten_fn,
|
||||
_unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
flatten_with_keys_fn=_flatten_fn_with_keys,
|
||||
)
|
||||
|
||||
|
||||
CONSTANT_NODES: set[type] = set()
|
||||
|
Reference in New Issue
Block a user