Compare commits

...

65 Commits

Author SHA1 Message Date
c0d8211567 Update
[ghstack-poisoned]
2025-11-04 21:47:44 +08:00
35b0f1979a Update
[ghstack-poisoned]
2025-11-04 21:33:30 +08:00
7f254deda5 Update
[ghstack-poisoned]
2025-11-04 20:18:39 +08:00
89461902e7 Update
[ghstack-poisoned]
2025-11-04 17:07:28 +08:00
42b7121579 Update (base update)
[ghstack-poisoned]
2025-11-04 17:07:28 +08:00
8b6ea87350 Update
[ghstack-poisoned]
2025-11-04 13:02:07 +08:00
7e75d9c1af Update (base update)
[ghstack-poisoned]
2025-11-04 13:02:07 +08:00
b276759fa9 Update
[ghstack-poisoned]
2025-11-04 12:53:24 +08:00
77fdb8fc0b Update (base update)
[ghstack-poisoned]
2025-11-04 12:53:24 +08:00
982c3afc1a Update
[ghstack-poisoned]
2025-10-29 21:23:29 +08:00
98795824a3 Update (base update)
[ghstack-poisoned]
2025-10-29 21:23:29 +08:00
f5dc9ed4a7 Update
[ghstack-poisoned]
2025-10-29 21:02:52 +08:00
ca97452084 Update (base update)
[ghstack-poisoned]
2025-10-29 21:02:52 +08:00
642532d99a Update
[ghstack-poisoned]
2025-10-11 21:35:29 +08:00
13cff139b1 Update (base update)
[ghstack-poisoned]
2025-10-11 21:35:29 +08:00
6905e60009 Update
[ghstack-poisoned]
2025-10-08 22:52:12 +08:00
2a709d3c5c Update (base update)
[ghstack-poisoned]
2025-10-08 22:39:22 +08:00
146f74745c Update
[ghstack-poisoned]
2025-10-08 22:39:22 +08:00
3dcb697423 Update (base update)
[ghstack-poisoned]
2025-09-19 18:03:44 +08:00
8fc312080c Update
[ghstack-poisoned]
2025-09-19 18:03:44 +08:00
137bb17393 Update (base update)
[ghstack-poisoned]
2025-09-06 11:34:32 +08:00
506c102be6 Update
[ghstack-poisoned]
2025-09-06 11:34:32 +08:00
7bd0e7b13a Update (base update)
[ghstack-poisoned]
2025-08-17 16:23:37 +08:00
3f6d5c7882 Update
[ghstack-poisoned]
2025-08-17 16:23:37 +08:00
0a45a80a1e Update
[ghstack-poisoned]
2025-08-09 02:53:28 +08:00
20a2bf1848 Update (base update)
[ghstack-poisoned]
2025-08-09 02:51:17 +08:00
d2d7e2e007 Update
[ghstack-poisoned]
2025-08-09 02:51:17 +08:00
0c3b32cdfe Update (base update)
[ghstack-poisoned]
2025-07-31 15:19:07 +08:00
b831219577 Update
[ghstack-poisoned]
2025-07-31 15:19:07 +08:00
7a5415fa49 Update (base update)
[ghstack-poisoned]
2025-07-25 20:00:29 +08:00
3fcdb0331a Update
[ghstack-poisoned]
2025-07-25 20:00:29 +08:00
7c0318f29e Update (base update)
[ghstack-poisoned]
2025-07-17 15:02:01 +08:00
870921f43d Update
[ghstack-poisoned]
2025-07-17 15:02:01 +08:00
e094178954 Update (base update)
[ghstack-poisoned]
2025-07-09 19:01:32 +08:00
8629531760 Update
[ghstack-poisoned]
2025-07-09 19:01:32 +08:00
7834da3d0c Update (base update)
[ghstack-poisoned]
2025-07-03 16:24:21 +08:00
444425428a Update
[ghstack-poisoned]
2025-07-03 16:24:21 +08:00
51a17f064b Update (base update)
[ghstack-poisoned]
2025-06-28 20:59:43 +08:00
191402dffc Update
[ghstack-poisoned]
2025-06-28 20:59:43 +08:00
8775c579e1 Update (base update)
[ghstack-poisoned]
2025-06-27 21:27:42 +08:00
f4af4f029a Update
[ghstack-poisoned]
2025-06-27 21:27:42 +08:00
2e5aa1b8bc Update (base update)
[ghstack-poisoned]
2025-06-23 22:51:19 +08:00
654ec31cf8 Update
[ghstack-poisoned]
2025-06-23 22:51:19 +08:00
1b3f9fb167 Update (base update)
[ghstack-poisoned]
2025-06-18 23:17:45 +08:00
6d5ff9e966 Update
[ghstack-poisoned]
2025-06-18 23:17:45 +08:00
faaa219abc Update (base update)
[ghstack-poisoned]
2025-06-06 19:50:49 +08:00
7da62e7e6e Update
[ghstack-poisoned]
2025-06-06 19:50:49 +08:00
007973b088 Update (base update)
[ghstack-poisoned]
2025-05-31 21:59:56 +08:00
1f5c39f11f Update
[ghstack-poisoned]
2025-05-31 21:59:56 +08:00
0fca020022 Update (base update)
[ghstack-poisoned]
2025-05-28 20:43:30 +08:00
8263baabb1 Update
[ghstack-poisoned]
2025-05-28 20:43:30 +08:00
f0b966f925 Update (base update)
[ghstack-poisoned]
2025-05-16 11:37:30 +08:00
3ca2fb4e55 Update
[ghstack-poisoned]
2025-05-16 11:37:30 +08:00
ec2da971dc Update (base update)
[ghstack-poisoned]
2025-05-14 20:34:58 +08:00
e4ee616def Update
[ghstack-poisoned]
2025-05-14 20:34:58 +08:00
7f31805e72 Update (base update)
[ghstack-poisoned]
2025-05-08 21:19:05 +08:00
e3e16adff0 Update
[ghstack-poisoned]
2025-05-08 21:19:05 +08:00
c202d5b1fc Update
[ghstack-poisoned]
2025-05-03 02:34:21 +08:00
7e9405da0b Update
[ghstack-poisoned]
2025-05-03 01:14:42 +08:00
79768226c9 Update
[ghstack-poisoned]
2025-05-03 00:44:58 +08:00
08cf8c67c2 Update (base update)
[ghstack-poisoned]
2025-05-03 00:40:32 +08:00
2fac214138 Update
[ghstack-poisoned]
2025-05-03 00:40:32 +08:00
e830571b16 Update
[ghstack-poisoned]
2025-05-02 02:30:01 +08:00
8ca16467fc Update (base update)
[ghstack-poisoned]
2025-05-02 02:24:57 +08:00
864d750662 Update
[ghstack-poisoned]
2025-05-02 02:24:57 +08:00
4 changed files with 125 additions and 55 deletions

View File

@ -9,6 +9,7 @@ import sys
import time
import unittest
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import auto
from typing import Any, NamedTuple, Optional
@ -22,6 +23,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
@ -1472,8 +1474,23 @@ class TestCxxPytree(TestCase):
if IS_FBCODE:
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
def assertEqualSpecs(
self,
spec1,
spec2,
msg: str | Callable[[str], str] | None = None,
):
if TEST_WITH_TORCHDYNAMO:
# The Dynamo polyfill returns a pure Python class for PyTreeSpec.
# So we compare the type names and reprs instead because the types
# themselves won't be equal.
self.assertEqual(type(spec1).__name__, type(spec2).__name__, msg=msg)
self.assertEqual(repr(spec1), repr(spec2), msg=msg)
else:
self.assertEqual(spec1, spec2, msg=msg)
def test_treespec_equality(self):
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
self.assertEqualSpecs(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
def test_treespec_repr(self):
# Check that it looks sane
@ -1503,16 +1520,11 @@ class TestCxxPytree(TestCase):
],
)
def test_pytree_serialize(self, spec):
self.assertEqual(
spec,
cxx_pytree.tree_structure(
cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec)
),
)
serialized_spec = cxx_pytree.treespec_dumps(spec)
self.assertIsInstance(serialized_spec, str)
self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqualSpecs(roundtrip_spec, spec)
def test_pytree_serialize_namedtuple(self):
python_pytree._register_namedtuple(
@ -1538,7 +1550,7 @@ class TestCxxPytree(TestCase):
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
self.assertEqualSpecs(roundtrip_spec, spec)
class LocalDummyType:
def __init__(self, x, y):
@ -1554,7 +1566,7 @@ class TestCxxPytree(TestCase):
spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
self.assertEqualSpecs(roundtrip_spec, spec)
instantiate_parametrized_tests(TestGenericPytree)

View File

@ -346,8 +346,10 @@ if python_pytree._cxx_pytree_dynamo_traceable:
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
def _is_pytreespec_instance(
obj: Any, /
) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]:
return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec))
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_leaf,
@ -550,7 +552,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec.unflatten(leaves)

View File

@ -13,6 +13,7 @@ collection support for PyTorch APIs.
"""
import functools
import sys
import types
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Optional, overload, TypeVar, Union
@ -266,8 +267,21 @@ def _private_register_pytree_node(
)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec)
def _is_pytreespec_instance(
obj: Any,
/,
) -> TypeIs[Union[TreeSpec, python_pytree.PyTreeSpec]]:
if isinstance(obj, (TreeSpec, python_pytree.PyTreeSpec)):
return True
if "torch._dynamo.polyfills.pytree" in sys.modules:
# The PyTorch Dynamo pytree module is not always available, so we check if it is loaded.
# If the PyTorch Dynamo pytree module is loaded, we can check if the treespec
# is an instance of the PyTorch Dynamo TreeSpec class.
import torch._dynamo.polyfills.pytree as dynamo_pytree
if isinstance(obj, dynamo_pytree.PyTreeSpec):
return True
return False
def treespec_leaf() -> TreeSpec:
@ -394,7 +408,12 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec.unflatten(leaves)
def tree_iter(
@ -973,7 +992,7 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
"""Serialize a treespec to a JSON string."""
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
f"Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)

View File

@ -20,6 +20,7 @@ import functools
import importlib
import importlib.metadata
import json
import sys
import threading
import types
import warnings
@ -36,14 +37,19 @@ from typing import (
Optional,
overload,
Protocol,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import deprecated, NamedTuple, Self, TypeAlias
from typing_extensions import deprecated, NamedTuple, Self, TypeAlias, TypeIs
from torch.torch_version import TorchVersion as _TorchVersion
if TYPE_CHECKING:
import torch.utils._cxx_pytree as cxx_pytree
__all__ = [
"PyTree",
"Context",
@ -249,9 +255,9 @@ def register_pytree_node(
return
if _cxx_pytree_imported:
from . import _cxx_pytree as cxx
import torch.utils._cxx_pytree as cxx_pytree
cxx._private_register_pytree_node(
cxx_pytree._private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
@ -1363,6 +1369,45 @@ def treespec_dict(
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
def _is_pytreespec_instance(
obj: Any,
) -> TypeIs[Union[TreeSpec, "cxx_pytree.PyTreeSpec"]]:
if isinstance(obj, TreeSpec):
return True
if "torch.utils._cxx_pytree" in sys.modules:
# The C++ pytree module is not always available, so we check if it is loaded.
# If the C++ pytree module is loaded, we can check if the treespec
# is an instance of the C++ TreeSpec class.
import torch.utils._cxx_pytree as cxx_pytree
if isinstance(obj, cxx_pytree.PyTreeSpec):
return True
if "torch._dynamo.polyfills.pytree" in sys.modules:
# The PyTorch Dynamo pytree module is not always available, so we check if it is loaded.
# If the PyTorch Dynamo pytree module is loaded, we can check if the treespec
# is an instance of the PyTorch Dynamo TreeSpec class.
import torch._dynamo.polyfills.pytree as dynamo_pytree
if isinstance(obj, dynamo_pytree.PyTreeSpec):
return True
return False
def _ensure_python_treespec_instance(
treespec: Union[TreeSpec, "cxx_pytree.PyTreeSpec"],
) -> TreeSpec:
if isinstance(treespec, TreeSpec):
return treespec
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
dummy_tree = treespec.unflatten([0] * treespec.num_leaves)
return tree_structure(dummy_tree)
def tree_flatten(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1393,10 +1438,10 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
"""Given a list of values and a TreeSpec, builds a pytree.
This is the inverse operation of `tree_flatten`.
"""
if not isinstance(treespec, TreeSpec):
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
f"instance of TreeSpec but got item of type {type(treespec)}.",
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec.unflatten(leaves)
@ -1827,34 +1872,30 @@ def _broadcast_to_and_flatten(
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[list[Any]]:
if not isinstance(treespec, TreeSpec):
raise AssertionError("treespec must be a TreeSpec")
def broadcast_prefix(
prefix_tree: PyTree,
full_tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> list[Any]:
result: list[Any] = []
if tree_is_leaf(tree, is_leaf=is_leaf):
return [tree] * treespec.num_leaves
if treespec.is_leaf():
def add_leaves(x: Any, subtree: PyTree) -> None:
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
result.extend([x] * subtreespec.num_leaves)
tree_map_(
add_leaves,
prefix_tree,
full_tree,
is_leaf=is_leaf,
)
return result
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
except ValueError:
return None
node_type = _get_node_type(tree)
if node_type != treespec.type:
return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)
# Check if the Node is different from the spec
if len(child_pytrees) != treespec.num_children or context != treespec._context:
return None
# Recursively flatten the children
result: list[Any] = []
for child, child_spec in zip(child_pytrees, treespec._children, strict=True):
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
if flat is not None:
result += flat
else:
return None
return result
@dataclasses.dataclass
@ -1968,11 +2009,7 @@ _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
if not isinstance(treespec, TreeSpec):
raise TypeError(
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}.",
)
treespec = _ensure_python_treespec_instance(treespec)
if protocol is None:
protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL