mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR: 1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence. 2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types. 3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class. Resolves #75982. New tests are included in this PR. - #75982 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
48e9ffc873
commit
a10b765bf1
@ -1397,7 +1397,7 @@ class AOTInductorModelCache:
|
||||
# see https://github.com/pytorch/pytorch/issues/113029
|
||||
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
|
||||
|
||||
if pytree._is_namedtuple_instance(example_outputs):
|
||||
if pytree.is_namedtuple_instance(example_outputs):
|
||||
typ = type(example_outputs)
|
||||
pytree._register_namedtuple(
|
||||
typ,
|
||||
|
@ -6,6 +6,7 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
|
||||
from dataclasses import dataclass
|
||||
@ -731,6 +732,133 @@ class TestGenericPytree(TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
pytree_impl.treespec_dumps("random_blurb")
|
||||
|
||||
@parametrize(
|
||||
"pytree",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_is_namedtuple(self, pytree):
|
||||
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
|
||||
|
||||
class DirectNamedTuple2(NamedTuple):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
class IndirectNamedTuple1(DirectNamedTuple1):
|
||||
pass
|
||||
|
||||
class IndirectNamedTuple2(DirectNamedTuple2):
|
||||
pass
|
||||
|
||||
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1)))
|
||||
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1)))
|
||||
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1)))
|
||||
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1)))
|
||||
self.assertFalse(pytree.is_namedtuple(time.gmtime()))
|
||||
self.assertFalse(pytree.is_namedtuple((0, 1)))
|
||||
self.assertFalse(pytree.is_namedtuple([0, 1]))
|
||||
self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2}))
|
||||
self.assertFalse(pytree.is_namedtuple({0, 1}))
|
||||
self.assertFalse(pytree.is_namedtuple(1))
|
||||
|
||||
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1))
|
||||
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2))
|
||||
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1))
|
||||
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2))
|
||||
self.assertFalse(pytree.is_namedtuple(time.struct_time))
|
||||
self.assertFalse(pytree.is_namedtuple(tuple))
|
||||
self.assertFalse(pytree.is_namedtuple(list))
|
||||
|
||||
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1))
|
||||
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2))
|
||||
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1))
|
||||
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2))
|
||||
self.assertFalse(pytree.is_namedtuple_class(time.struct_time))
|
||||
self.assertFalse(pytree.is_namedtuple_class(tuple))
|
||||
self.assertFalse(pytree.is_namedtuple_class(list))
|
||||
|
||||
@parametrize(
|
||||
"pytree",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_is_structseq(self, pytree):
|
||||
class FakeStructSeq(tuple):
|
||||
n_fields = 2
|
||||
n_sequence_fields = 2
|
||||
n_unnamed_fields = 0
|
||||
|
||||
__slots__ = ()
|
||||
__match_args__ = ("x", "y")
|
||||
|
||||
def __new__(cls, sequence):
|
||||
return super().__new__(cls, sequence)
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self[1]
|
||||
|
||||
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
|
||||
|
||||
class DirectNamedTuple2(NamedTuple):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1))))
|
||||
self.assertTrue(pytree.is_structseq(time.gmtime()))
|
||||
self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1)))
|
||||
self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1)))
|
||||
self.assertFalse(pytree.is_structseq((0, 1)))
|
||||
self.assertFalse(pytree.is_structseq([0, 1]))
|
||||
self.assertFalse(pytree.is_structseq({0: 1, 1: 2}))
|
||||
self.assertFalse(pytree.is_structseq({0, 1}))
|
||||
self.assertFalse(pytree.is_structseq(1))
|
||||
|
||||
self.assertFalse(pytree.is_structseq(FakeStructSeq))
|
||||
self.assertTrue(pytree.is_structseq(time.struct_time))
|
||||
self.assertFalse(pytree.is_structseq(DirectNamedTuple1))
|
||||
self.assertFalse(pytree.is_structseq(DirectNamedTuple2))
|
||||
self.assertFalse(pytree.is_structseq(tuple))
|
||||
self.assertFalse(pytree.is_structseq(list))
|
||||
|
||||
self.assertFalse(pytree.is_structseq_class(FakeStructSeq))
|
||||
self.assertTrue(
|
||||
pytree.is_structseq_class(time.struct_time),
|
||||
)
|
||||
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1))
|
||||
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2))
|
||||
self.assertFalse(pytree.is_structseq_class(tuple))
|
||||
self.assertFalse(pytree.is_structseq_class(list))
|
||||
|
||||
# torch.return_types.* are all PyStructSequence types
|
||||
for cls in vars(torch.return_types).values():
|
||||
if isinstance(cls, type) and issubclass(cls, tuple):
|
||||
self.assertTrue(pytree.is_structseq(cls))
|
||||
self.assertTrue(pytree.is_structseq_class(cls))
|
||||
self.assertFalse(pytree.is_namedtuple(cls))
|
||||
self.assertFalse(pytree.is_namedtuple_class(cls))
|
||||
|
||||
inst = cls(range(cls.n_sequence_fields))
|
||||
self.assertTrue(pytree.is_structseq(inst))
|
||||
self.assertTrue(pytree.is_structseq(type(inst)))
|
||||
self.assertFalse(pytree.is_structseq_class(inst))
|
||||
self.assertTrue(pytree.is_structseq_class(type(inst)))
|
||||
self.assertFalse(pytree.is_namedtuple(inst))
|
||||
self.assertFalse(pytree.is_namedtuple_class(inst))
|
||||
else:
|
||||
self.assertFalse(pytree.is_structseq(cls))
|
||||
self.assertFalse(pytree.is_structseq_class(cls))
|
||||
self.assertFalse(pytree.is_namedtuple(cls))
|
||||
self.assertFalse(pytree.is_namedtuple_class(cls))
|
||||
|
||||
|
||||
class TestPythonPytree(TestCase):
|
||||
def test_deprecated_register_pytree_node(self):
|
||||
@ -975,9 +1103,8 @@ if "optree" in sys.modules:
|
||||
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
|
||||
)
|
||||
|
||||
spec = py_pytree.TreeSpec(
|
||||
namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
spec = py_pytree.tree_structure(Point1(1, 2))
|
||||
self.assertIs(spec.type, namedtuple)
|
||||
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
|
||||
self.assertEqual(spec, roundtrip_spec)
|
||||
|
||||
@ -990,18 +1117,28 @@ if "optree" in sys.modules:
|
||||
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
|
||||
)
|
||||
|
||||
spec = py_pytree.TreeSpec(
|
||||
namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
spec = py_pytree.tree_structure(Point2(1, 2))
|
||||
self.assertIs(spec.type, namedtuple)
|
||||
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
|
||||
self.assertEqual(spec, roundtrip_spec)
|
||||
|
||||
class Point3(Point2):
|
||||
pass
|
||||
|
||||
py_pytree._register_namedtuple(
|
||||
Point3,
|
||||
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3",
|
||||
)
|
||||
|
||||
spec = py_pytree.tree_structure(Point3(1, 2))
|
||||
self.assertIs(spec.type, namedtuple)
|
||||
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
|
||||
self.assertEqual(spec, roundtrip_spec)
|
||||
|
||||
def test_pytree_serialize_namedtuple_bad(self):
|
||||
DummyType = namedtuple("DummyType", ["x", "y"])
|
||||
|
||||
spec = py_pytree.TreeSpec(
|
||||
namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
spec = py_pytree.tree_structure(DummyType(1, 2))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, "Please register using `_register_namedtuple`"
|
||||
@ -1020,9 +1157,7 @@ if "optree" in sys.modules:
|
||||
lambda xs, _: DummyType(*xs),
|
||||
)
|
||||
|
||||
spec = py_pytree.TreeSpec(
|
||||
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
spec = py_pytree.tree_structure(DummyType(1, 2))
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, "No registered serialization name"
|
||||
):
|
||||
@ -1042,9 +1177,7 @@ if "optree" in sys.modules:
|
||||
to_dumpable_context=lambda context: "moo",
|
||||
from_dumpable_context=lambda dumpable_context: None,
|
||||
)
|
||||
spec = py_pytree.TreeSpec(
|
||||
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
spec = py_pytree.tree_structure(DummyType(1, 2))
|
||||
serialized_spec = py_pytree.treespec_dumps(spec, 1)
|
||||
self.assertIn("moo", serialized_spec)
|
||||
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
|
||||
@ -1082,9 +1215,7 @@ if "optree" in sys.modules:
|
||||
from_dumpable_context=lambda dumpable_context: None,
|
||||
)
|
||||
|
||||
spec = py_pytree.TreeSpec(
|
||||
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
spec = py_pytree.tree_structure(DummyType(1, 2))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Object of type type is not JSON serializable"
|
||||
@ -1095,9 +1226,7 @@ if "optree" in sys.modules:
|
||||
import json
|
||||
|
||||
Point = namedtuple("Point", ["x", "y"])
|
||||
spec = py_pytree.TreeSpec(
|
||||
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
|
||||
)
|
||||
spec = py_pytree.tree_structure(Point(1, 2))
|
||||
py_pytree._register_namedtuple(
|
||||
Point,
|
||||
serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
|
||||
|
@ -56,9 +56,10 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
"structseq_fields",
|
||||
):
|
||||
__func = getattr(optree, __name)
|
||||
substitute_in_graph(__func, can_constant_fold_through=True)(
|
||||
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
|
||||
__func.__python_implementation__
|
||||
)
|
||||
__all__ += [__name] # noqa: PLE0604
|
||||
del __func
|
||||
del __name
|
||||
|
||||
|
@ -1243,7 +1243,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
def store_namedtuple_fields(ts):
|
||||
if ts.type is None:
|
||||
return
|
||||
if ts.type == namedtuple:
|
||||
if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type):
|
||||
serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ts.context].serialized_type_name
|
||||
if serialized_type_name in self.treespec_namedtuple_fields:
|
||||
field_names = self.treespec_namedtuple_fields[serialized_type_name].field_names
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from typing import Any
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -129,16 +128,15 @@ def make_dual(tensor, tangent, *, level=None):
|
||||
return torch._VF._make_dual(tensor, tangent, level=level)
|
||||
|
||||
|
||||
_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"])
|
||||
|
||||
|
||||
class UnpackedDualTensor(_UnpackedDualTensor):
|
||||
class UnpackedDualTensor(NamedTuple):
|
||||
r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
|
||||
|
||||
See :func:`unpack_dual` for more details.
|
||||
|
||||
"""
|
||||
|
||||
primal: torch.Tensor
|
||||
tangent: Optional[torch.Tensor]
|
||||
|
||||
|
||||
def unpack_dual(tensor, *, level=None):
|
||||
r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient.
|
||||
|
@ -552,8 +552,16 @@ def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None,
|
||||
|
||||
expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
|
||||
expected = tree_map(fwAD.unpack_dual, expected)
|
||||
expected_primals = tree_map(lambda x: x.primal, expected)
|
||||
expected_tangents = tree_map(lambda x: x.tangent, expected)
|
||||
expected_primals = tree_map(
|
||||
lambda x: x.primal,
|
||||
expected,
|
||||
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
|
||||
)
|
||||
expected_tangents = tree_map(
|
||||
lambda x: x.tangent,
|
||||
expected,
|
||||
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
|
||||
)
|
||||
|
||||
# Permutations of arg and kwargs in CCT.
|
||||
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
|
||||
@ -586,7 +594,15 @@ def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None,
|
||||
return e.elem if isinstance(e, CCT) else e
|
||||
|
||||
actual = tree_map(fwAD.unpack_dual, actual)
|
||||
actual_primals = tree_map(lambda x: unwrap(x.primal), actual)
|
||||
actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual)
|
||||
actual_primals = tree_map(
|
||||
lambda x: unwrap(x.primal),
|
||||
actual,
|
||||
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
|
||||
)
|
||||
actual_tangents = tree_map(
|
||||
lambda x: unwrap(x.tangent),
|
||||
actual,
|
||||
is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
|
||||
)
|
||||
assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
|
||||
assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)
|
||||
|
@ -23,7 +23,15 @@ import optree
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.utils._pytree import KeyEntry as KeyEntry
|
||||
from torch.utils._pytree import (
|
||||
is_namedtuple as is_namedtuple,
|
||||
is_namedtuple_class as is_namedtuple_class,
|
||||
is_namedtuple_instance as is_namedtuple_instance,
|
||||
is_structseq as is_structseq,
|
||||
is_structseq_class as is_structseq_class,
|
||||
is_structseq_instance as is_structseq_instance,
|
||||
KeyEntry as KeyEntry,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -39,6 +47,7 @@ __all__ = [
|
||||
"keystr",
|
||||
"key_get",
|
||||
"register_pytree_node",
|
||||
"tree_is_leaf",
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
"tree_unflatten",
|
||||
@ -58,6 +67,12 @@ __all__ = [
|
||||
"treespec_dumps",
|
||||
"treespec_loads",
|
||||
"treespec_pprint",
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
]
|
||||
|
||||
|
||||
|
@ -31,14 +31,17 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
ClassVar,
|
||||
Final,
|
||||
Generic,
|
||||
NoReturn,
|
||||
Optional,
|
||||
overload,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import deprecated, NamedTuple
|
||||
from typing_extensions import deprecated, NamedTuple, Self
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -54,6 +57,7 @@ __all__ = [
|
||||
"keystr",
|
||||
"key_get",
|
||||
"register_pytree_node",
|
||||
"tree_is_leaf",
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
"tree_unflatten",
|
||||
@ -73,6 +77,12 @@ __all__ = [
|
||||
"treespec_dumps",
|
||||
"treespec_loads",
|
||||
"treespec_pprint",
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
]
|
||||
|
||||
|
||||
@ -573,6 +583,90 @@ class GetAttrKey:
|
||||
return getattr(obj, self.name)
|
||||
|
||||
|
||||
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
|
||||
def is_namedtuple(obj: Union[object, type]) -> bool:
|
||||
"""Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
|
||||
cls = obj if isinstance(obj, type) else type(obj)
|
||||
return is_namedtuple_class(cls)
|
||||
|
||||
|
||||
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
|
||||
def is_namedtuple_class(cls: type) -> bool:
|
||||
"""Return whether the class is a subclass of namedtuple."""
|
||||
return (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, tuple)
|
||||
and isinstance(getattr(cls, "_fields", None), tuple)
|
||||
and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined]
|
||||
and callable(getattr(cls, "_make", None))
|
||||
and callable(getattr(cls, "_asdict", None))
|
||||
)
|
||||
|
||||
|
||||
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
|
||||
def is_namedtuple_instance(obj: object) -> bool:
|
||||
"""Return whether the object is an instance of namedtuple."""
|
||||
return is_namedtuple_class(type(obj))
|
||||
|
||||
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
|
||||
|
||||
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
|
||||
class structseq(tuple[_T_co, ...]):
|
||||
"""A generic type stub for CPython's ``PyStructSequence`` type."""
|
||||
|
||||
__slots__: ClassVar[tuple[()]] = ()
|
||||
|
||||
n_fields: Final[int] # type: ignore[misc]
|
||||
n_sequence_fields: Final[int] # type: ignore[misc]
|
||||
n_unnamed_fields: Final[int] # type: ignore[misc]
|
||||
|
||||
def __init_subclass__(cls) -> NoReturn:
|
||||
"""Prohibit subclassing."""
|
||||
raise TypeError("type 'structseq' is not an acceptable base type")
|
||||
|
||||
def __new__(
|
||||
cls: type[Self],
|
||||
sequence: Iterable[_T_co],
|
||||
dict: dict[str, Any] = ...,
|
||||
) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
|
||||
def is_structseq(obj: Union[object, type]) -> bool:
|
||||
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
|
||||
cls = obj if isinstance(obj, type) else type(obj)
|
||||
return is_structseq_class(cls)
|
||||
|
||||
|
||||
# Set if the type allows subclassing (see CPython's Include/object.h)
|
||||
Py_TPFLAGS_BASETYPE: int = 1 << 10
|
||||
|
||||
|
||||
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
|
||||
def is_structseq_class(cls: type) -> bool:
|
||||
"""Return whether the class is a class of PyStructSequence."""
|
||||
return (
|
||||
isinstance(cls, type)
|
||||
# Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
|
||||
and cls.__bases__ == (tuple,)
|
||||
# Check PyStructSequence members
|
||||
and isinstance(getattr(cls, "n_fields", None), int)
|
||||
and isinstance(getattr(cls, "n_sequence_fields", None), int)
|
||||
and isinstance(getattr(cls, "n_unnamed_fields", None), int)
|
||||
# Check the type does not allow subclassing
|
||||
and not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # only works for CPython
|
||||
)
|
||||
|
||||
|
||||
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
|
||||
def is_structseq_instance(obj: object) -> bool:
|
||||
"""Return whether the object is an instance of PyStructSequence."""
|
||||
return is_structseq_class(type(obj))
|
||||
|
||||
|
||||
def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]:
|
||||
return list(d), None
|
||||
|
||||
@ -807,37 +901,72 @@ _private_register_pytree_node(
|
||||
)
|
||||
|
||||
|
||||
STANDARD_DICT_TYPES: frozenset[type] = frozenset(
|
||||
{dict, OrderedDict, defaultdict},
|
||||
)
|
||||
STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict})
|
||||
BUILTIN_TYPES: frozenset[type] = frozenset(
|
||||
{tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type]
|
||||
{
|
||||
tuple,
|
||||
list,
|
||||
dict,
|
||||
namedtuple, # type: ignore[arg-type]
|
||||
OrderedDict,
|
||||
defaultdict,
|
||||
deque,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
|
||||
@deprecated(
|
||||
"torch.utils._pytree._is_namedtuple_instance is private and will be removed in a future release. "
|
||||
"Please use torch.utils._pytree.is_namedtuple_instance instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def _is_namedtuple_instance(tree: Any) -> bool:
|
||||
typ = type(tree)
|
||||
bases = typ.__bases__
|
||||
if len(bases) != 1 or bases[0] != tuple:
|
||||
return False
|
||||
fields = getattr(typ, "_fields", None)
|
||||
if not isinstance(fields, tuple):
|
||||
return False
|
||||
return all(type(entry) == str for entry in fields)
|
||||
return is_namedtuple_instance(tree)
|
||||
|
||||
|
||||
def _get_node_type(tree: Any) -> Any:
|
||||
if _is_namedtuple_instance(tree):
|
||||
node_type = type(tree)
|
||||
# All namedtuple types are implicitly registered as pytree nodes.
|
||||
# XXX: Other parts of the codebase expect namedtuple types always return
|
||||
# `namedtuple` instead of the actual namedtuple type. Even if the type
|
||||
# is explicitly registered.
|
||||
if is_namedtuple_class(node_type):
|
||||
return namedtuple
|
||||
return type(tree)
|
||||
return node_type
|
||||
|
||||
|
||||
# A leaf is defined as anything that is not a Node.
|
||||
def tree_is_leaf(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
"""Check if a pytree is a leaf.
|
||||
|
||||
>>> tree_is_leaf(1)
|
||||
True
|
||||
>>> tree_is_leaf(None)
|
||||
True
|
||||
>>> tree_is_leaf([1, 2, 3])
|
||||
False
|
||||
>>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple))
|
||||
True
|
||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3})
|
||||
False
|
||||
>>> tree_is_leaf({'a': 1, 'b': 2, 'c': None})
|
||||
False
|
||||
"""
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
return True
|
||||
return _get_node_type(tree) not in SUPPORTED_NODES
|
||||
|
||||
|
||||
@deprecated(
|
||||
"torch.utils._pytree._is_leaf is private and will be removed in a future release. "
|
||||
"Please use torch.utils._pytree.tree_is_leaf instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool:
|
||||
return (is_leaf is not None and is_leaf(tree)) or _get_node_type(
|
||||
tree
|
||||
) not in SUPPORTED_NODES
|
||||
return tree_is_leaf(tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
# A TreeSpec represents the structure of a pytree. It holds:
|
||||
@ -1040,7 +1169,7 @@ def tree_flatten(
|
||||
"""
|
||||
|
||||
def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
|
||||
if _is_leaf(node, is_leaf=is_leaf):
|
||||
if tree_is_leaf(node, is_leaf=is_leaf):
|
||||
leaves.append(node)
|
||||
return _LEAF_SPEC
|
||||
|
||||
@ -1074,7 +1203,7 @@ def tree_iter(
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Iterable[Any]:
|
||||
"""Get an iterator over the leaves of a pytree."""
|
||||
if _is_leaf(tree, is_leaf=is_leaf):
|
||||
if tree_is_leaf(tree, is_leaf=is_leaf):
|
||||
yield tree
|
||||
else:
|
||||
node_type = _get_node_type(tree)
|
||||
@ -1520,7 +1649,7 @@ def _broadcast_to_and_flatten(
|
||||
) -> Optional[list[Any]]:
|
||||
assert isinstance(treespec, TreeSpec)
|
||||
|
||||
if _is_leaf(tree, is_leaf=is_leaf):
|
||||
if tree_is_leaf(tree, is_leaf=is_leaf):
|
||||
return [tree] * treespec.num_leaves
|
||||
if treespec.is_leaf():
|
||||
return None
|
||||
|
Reference in New Issue
Block a user