[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)

Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2025-04-01 02:18:46 +08:00
committed by PyTorch MergeBot
parent 48e9ffc873
commit a10b765bf1
8 changed files with 345 additions and 57 deletions

View File

@ -1397,7 +1397,7 @@ class AOTInductorModelCache:
# see https://github.com/pytorch/pytorch/issues/113029
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
if pytree._is_namedtuple_instance(example_outputs):
if pytree.is_namedtuple_instance(example_outputs):
typ = type(example_outputs)
pytree._register_namedtuple(
typ,

View File

@ -6,6 +6,7 @@ import os
import re
import subprocess
import sys
import time
import unittest
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass
@ -731,6 +732,133 @@ class TestGenericPytree(TestCase):
with self.assertRaises(TypeError):
pytree_impl.treespec_dumps("random_blurb")
@parametrize(
"pytree",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_is_namedtuple(self, pytree):
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
class DirectNamedTuple2(NamedTuple):
x: int
y: int
class IndirectNamedTuple1(DirectNamedTuple1):
pass
class IndirectNamedTuple2(DirectNamedTuple2):
pass
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1)))
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1)))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1)))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1)))
self.assertFalse(pytree.is_namedtuple(time.gmtime()))
self.assertFalse(pytree.is_namedtuple((0, 1)))
self.assertFalse(pytree.is_namedtuple([0, 1]))
self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2}))
self.assertFalse(pytree.is_namedtuple({0, 1}))
self.assertFalse(pytree.is_namedtuple(1))
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2))
self.assertFalse(pytree.is_namedtuple(time.struct_time))
self.assertFalse(pytree.is_namedtuple(tuple))
self.assertFalse(pytree.is_namedtuple(list))
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2))
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2))
self.assertFalse(pytree.is_namedtuple_class(time.struct_time))
self.assertFalse(pytree.is_namedtuple_class(tuple))
self.assertFalse(pytree.is_namedtuple_class(list))
@parametrize(
"pytree",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_is_structseq(self, pytree):
class FakeStructSeq(tuple):
n_fields = 2
n_sequence_fields = 2
n_unnamed_fields = 0
__slots__ = ()
__match_args__ = ("x", "y")
def __new__(cls, sequence):
return super().__new__(cls, sequence)
@property
def x(self):
return self[0]
@property
def y(self):
return self[1]
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
class DirectNamedTuple2(NamedTuple):
x: int
y: int
self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1))))
self.assertTrue(pytree.is_structseq(time.gmtime()))
self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1)))
self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1)))
self.assertFalse(pytree.is_structseq((0, 1)))
self.assertFalse(pytree.is_structseq([0, 1]))
self.assertFalse(pytree.is_structseq({0: 1, 1: 2}))
self.assertFalse(pytree.is_structseq({0, 1}))
self.assertFalse(pytree.is_structseq(1))
self.assertFalse(pytree.is_structseq(FakeStructSeq))
self.assertTrue(pytree.is_structseq(time.struct_time))
self.assertFalse(pytree.is_structseq(DirectNamedTuple1))
self.assertFalse(pytree.is_structseq(DirectNamedTuple2))
self.assertFalse(pytree.is_structseq(tuple))
self.assertFalse(pytree.is_structseq(list))
self.assertFalse(pytree.is_structseq_class(FakeStructSeq))
self.assertTrue(
pytree.is_structseq_class(time.struct_time),
)
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1))
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2))
self.assertFalse(pytree.is_structseq_class(tuple))
self.assertFalse(pytree.is_structseq_class(list))
# torch.return_types.* are all PyStructSequence types
for cls in vars(torch.return_types).values():
if isinstance(cls, type) and issubclass(cls, tuple):
self.assertTrue(pytree.is_structseq(cls))
self.assertTrue(pytree.is_structseq_class(cls))
self.assertFalse(pytree.is_namedtuple(cls))
self.assertFalse(pytree.is_namedtuple_class(cls))
inst = cls(range(cls.n_sequence_fields))
self.assertTrue(pytree.is_structseq(inst))
self.assertTrue(pytree.is_structseq(type(inst)))
self.assertFalse(pytree.is_structseq_class(inst))
self.assertTrue(pytree.is_structseq_class(type(inst)))
self.assertFalse(pytree.is_namedtuple(inst))
self.assertFalse(pytree.is_namedtuple_class(inst))
else:
self.assertFalse(pytree.is_structseq(cls))
self.assertFalse(pytree.is_structseq_class(cls))
self.assertFalse(pytree.is_namedtuple(cls))
self.assertFalse(pytree.is_namedtuple_class(cls))
class TestPythonPytree(TestCase):
def test_deprecated_register_pytree_node(self):
@ -975,9 +1103,8 @@ if "optree" in sys.modules:
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
)
spec = py_pytree.TreeSpec(
namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(Point1(1, 2))
self.assertIs(spec.type, namedtuple)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
self.assertEqual(spec, roundtrip_spec)
@ -990,18 +1117,28 @@ if "optree" in sys.modules:
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
)
spec = py_pytree.TreeSpec(
namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
spec = py_pytree.tree_structure(Point2(1, 2))
self.assertIs(spec.type, namedtuple)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
self.assertEqual(spec, roundtrip_spec)
class Point3(Point2):
pass
py_pytree._register_namedtuple(
Point3,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3",
)
spec = py_pytree.tree_structure(Point3(1, 2))
self.assertIs(spec.type, namedtuple)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
self.assertEqual(spec, roundtrip_spec)
def test_pytree_serialize_namedtuple_bad(self):
DummyType = namedtuple("DummyType", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(DummyType(1, 2))
with self.assertRaisesRegex(
NotImplementedError, "Please register using `_register_namedtuple`"
@ -1020,9 +1157,7 @@ if "optree" in sys.modules:
lambda xs, _: DummyType(*xs),
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(DummyType(1, 2))
with self.assertRaisesRegex(
NotImplementedError, "No registered serialization name"
):
@ -1042,9 +1177,7 @@ if "optree" in sys.modules:
to_dumpable_context=lambda context: "moo",
from_dumpable_context=lambda dumpable_context: None,
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(DummyType(1, 2))
serialized_spec = py_pytree.treespec_dumps(spec, 1)
self.assertIn("moo", serialized_spec)
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
@ -1082,9 +1215,7 @@ if "optree" in sys.modules:
from_dumpable_context=lambda dumpable_context: None,
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(DummyType(1, 2))
with self.assertRaisesRegex(
TypeError, "Object of type type is not JSON serializable"
@ -1095,9 +1226,7 @@ if "optree" in sys.modules:
import json
Point = namedtuple("Point", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
spec = py_pytree.tree_structure(Point(1, 2))
py_pytree._register_namedtuple(
Point,
serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",

View File

@ -56,9 +56,10 @@ if python_pytree._cxx_pytree_dynamo_traceable:
"structseq_fields",
):
__func = getattr(optree, __name)
substitute_in_graph(__func, can_constant_fold_through=True)(
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
__func.__python_implementation__
)
__all__ += [__name] # noqa: PLE0604
del __func
del __name

View File

@ -1243,7 +1243,7 @@ class GraphModuleSerializer(metaclass=Final):
def store_namedtuple_fields(ts):
if ts.type is None:
return
if ts.type == namedtuple:
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ts.context].serialized_type_name
if serialized_type_name in self.treespec_namedtuple_fields:
field_names = self.treespec_namedtuple_fields[serialized_type_name].field_names

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import os
from collections import namedtuple
from typing import Any
from typing import Any, NamedTuple, Optional
import torch
@ -129,16 +128,15 @@ def make_dual(tensor, tangent, *, level=None):
return torch._VF._make_dual(tensor, tangent, level=level)
_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"])
class UnpackedDualTensor(_UnpackedDualTensor):
class UnpackedDualTensor(NamedTuple):
r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
See :func:`unpack_dual` for more details.
"""
primal: torch.Tensor
tangent: Optional[torch.Tensor]
def unpack_dual(tensor, *, level=None):
r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient.

View File

@ -552,8 +552,16 @@ def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None,
expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
expected = tree_map(fwAD.unpack_dual, expected)
expected_primals = tree_map(lambda x: x.primal, expected)
expected_tangents = tree_map(lambda x: x.tangent, expected)
expected_primals = tree_map(
lambda x: x.primal,
expected,
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
)
expected_tangents = tree_map(
lambda x: x.tangent,
expected,
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
)
# Permutations of arg and kwargs in CCT.
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
@ -586,7 +594,15 @@ def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None,
return e.elem if isinstance(e, CCT) else e
actual = tree_map(fwAD.unpack_dual, actual)
actual_primals = tree_map(lambda x: unwrap(x.primal), actual)
actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual)
actual_primals = tree_map(
lambda x: unwrap(x.primal),
actual,
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
)
actual_tangents = tree_map(
lambda x: unwrap(x.tangent),
actual,
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
)
assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)

View File

@ -23,7 +23,15 @@ import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
import torch.utils._pytree as python_pytree
from torch.utils._pytree import KeyEntry as KeyEntry
from torch.utils._pytree import (
is_namedtuple as is_namedtuple,
is_namedtuple_class as is_namedtuple_class,
is_namedtuple_instance as is_namedtuple_instance,
is_structseq as is_structseq,
is_structseq_class as is_structseq_class,
is_structseq_instance as is_structseq_instance,
KeyEntry as KeyEntry,
)
__all__ = [
@ -39,6 +47,7 @@ __all__ = [
"keystr",
"key_get",
"register_pytree_node",
"tree_is_leaf",
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
@ -58,6 +67,12 @@ __all__ = [
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
]

View File

@ -31,14 +31,17 @@ from typing import (
Any,
Callable,
cast,
ClassVar,
Final,
Generic,
NoReturn,
Optional,
overload,
Protocol,
TypeVar,
Union,
)
from typing_extensions import deprecated, NamedTuple
from typing_extensions import deprecated, NamedTuple, Self
__all__ = [
@ -54,6 +57,7 @@ __all__ = [
"keystr",
"key_get",
"register_pytree_node",
"tree_is_leaf",
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
@ -73,6 +77,12 @@ __all__ = [
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
]
@ -573,6 +583,90 @@ class GetAttrKey:
return getattr(obj, self.name)
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
def is_namedtuple(obj: Union[object, type]) -> bool:
"""Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
cls = obj if isinstance(obj, type) else type(obj)
return is_namedtuple_class(cls)
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
def is_namedtuple_class(cls: type) -> bool:
"""Return whether the class is a subclass of namedtuple."""
return (
isinstance(cls, type)
and issubclass(cls, tuple)
and isinstance(getattr(cls, "_fields", None), tuple)
and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined]
and callable(getattr(cls, "_make", None))
and callable(getattr(cls, "_asdict", None))
)
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
def is_namedtuple_instance(obj: object) -> bool:
"""Return whether the object is an instance of namedtuple."""
return is_namedtuple_class(type(obj))
_T_co = TypeVar("_T_co", covariant=True)
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
class structseq(tuple[_T_co, ...]):
"""A generic type stub for CPython's ``PyStructSequence`` type."""
__slots__: ClassVar[tuple[()]] = ()
n_fields: Final[int] # type: ignore[misc]
n_sequence_fields: Final[int] # type: ignore[misc]
n_unnamed_fields: Final[int] # type: ignore[misc]
def __init_subclass__(cls) -> NoReturn:
"""Prohibit subclassing."""
raise TypeError("type 'structseq' is not an acceptable base type")
def __new__(
cls: type[Self],
sequence: Iterable[_T_co],
dict: dict[str, Any] = ...,
) -> Self:
raise NotImplementedError
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
def is_structseq(obj: Union[object, type]) -> bool:
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
cls = obj if isinstance(obj, type) else type(obj)
return is_structseq_class(cls)
# Set if the type allows subclassing (see CPython's Include/object.h)
Py_TPFLAGS_BASETYPE: int = 1 << 10
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
def is_structseq_class(cls: type) -> bool:
"""Return whether the class is a class of PyStructSequence."""
return (
isinstance(cls, type)
# Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
and cls.__bases__ == (tuple,)
# Check PyStructSequence members
and isinstance(getattr(cls, "n_fields", None), int)
and isinstance(getattr(cls, "n_sequence_fields", None), int)
and isinstance(getattr(cls, "n_unnamed_fields", None), int)
# Check the type does not allow subclassing
and not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # only works for CPython
)
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
def is_structseq_instance(obj: object) -> bool:
"""Return whether the object is an instance of PyStructSequence."""
return is_structseq_class(type(obj))
def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]:
return list(d), None
@ -807,37 +901,72 @@ _private_register_pytree_node(
)
STANDARD_DICT_TYPES: frozenset[type] = frozenset(
{dict, OrderedDict, defaultdict},
)
STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict})
BUILTIN_TYPES: frozenset[type] = frozenset(
{tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type]
{
tuple,
list,
dict,
namedtuple, # type: ignore[arg-type]
OrderedDict,
defaultdict,
deque,
},
)
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
@deprecated(
"torch.utils._pytree._is_namedtuple_instance is private and will be removed in a future release. "
"Please use torch.utils._pytree.is_namedtuple_instance instead.",
category=FutureWarning,
)
def _is_namedtuple_instance(tree: Any) -> bool:
typ = type(tree)
bases = typ.__bases__
if len(bases) != 1 or bases[0] != tuple:
return False
fields = getattr(typ, "_fields", None)
if not isinstance(fields, tuple):
return False
return all(type(entry) == str for entry in fields)
return is_namedtuple_instance(tree)
def _get_node_type(tree: Any) -> Any:
if _is_namedtuple_instance(tree):
node_type = type(tree)
# All namedtuple types are implicitly registered as pytree nodes.
# XXX: Other parts of the codebase expect namedtuple types always return
# `namedtuple` instead of the actual namedtuple type. Even if the type
# is explicitly registered.
if is_namedtuple_class(node_type):
return namedtuple
return type(tree)
return node_type
# A leaf is defined as anything that is not a Node.
def tree_is_leaf(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
"""Check if a pytree is a leaf.
>>> tree_is_leaf(1)
True
>>> tree_is_leaf(None)
True
>>> tree_is_leaf([1, 2, 3])
False
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
True
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3})
False
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None})
False
"""
if is_leaf is not None and is_leaf(tree):
return True
return _get_node_type(tree) not in SUPPORTED_NODES
@deprecated(
"torch.utils._pytree._is_leaf is private and will be removed in a future release. "
"Please use torch.utils._pytree.tree_is_leaf instead.",
category=FutureWarning,
)
def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool:
return (is_leaf is not None and is_leaf(tree)) or _get_node_type(
tree
) not in SUPPORTED_NODES
return tree_is_leaf(tree, is_leaf=is_leaf)
# A TreeSpec represents the structure of a pytree. It holds:
@ -1040,7 +1169,7 @@ def tree_flatten(
"""
def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
if _is_leaf(node, is_leaf=is_leaf):
if tree_is_leaf(node, is_leaf=is_leaf):
leaves.append(node)
return _LEAF_SPEC
@ -1074,7 +1203,7 @@ def tree_iter(
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree."""
if _is_leaf(tree, is_leaf=is_leaf):
if tree_is_leaf(tree, is_leaf=is_leaf):
yield tree
else:
node_type = _get_node_type(tree)
@ -1520,7 +1649,7 @@ def _broadcast_to_and_flatten(
) -> Optional[list[Any]]:
assert isinstance(treespec, TreeSpec)
if _is_leaf(tree, is_leaf=is_leaf):
if tree_is_leaf(tree, is_leaf=is_leaf):
return [tree] * treespec.num_leaves
if treespec.is_leaf():
return None