Files
pytorch/torch/_dynamo/polyfills/pytree.py
rzou f1f18c75c9 Gracefully handle optree less than minimum version, part 2 (#151257)
If optree is less than the minimum version, we should pretend it doesn't
exist.

The problem right now is:
- Install optree==0.12.1
- `import torch._dynamo`
- This raise an error "min optree version is 0.13.0"

The fix is to pretend optree doesn't exist if it is less than the min
version.

There are ways to clean up this PR more (e.g. have a single source of
truth for the version, some of the variables are redundant), but I am
trying to reduce the risk as much as possible for this to go into 2.7.

Test Plan:

I verified the above problem was fixed. Also tried some other things,
like the following, which now gives the expected behavior.
```py
>>> import torch
>>> import optree
>>> optree.__version__
'0.12.1'
>>> import torch._dynamo
>>> import torch._dynamo.polyfills.pytree
>>> import torch.utils._pytree
>>> import torch.utils._cxx_pytree
ImportError: torch.utils._cxx_pytree depends on optree, which is
an optional dependency of PyTorch. To u
se it, please upgrade your optree package to >= 0.13.0
```

I also audited all non-test callsites of optree and torch.utils._cxx_pytree.
Follow along with me:

optree imports
- torch.utils._cxx_pytree. This is fine.
- [guarded by check] f76b7ef33c/torch/_dynamo/polyfills/pytree.py (L29-L31)

_cxx_pytree imports
- [guarded by check] torch.utils._pytree (changed in this PR)
- [guarded by check] torch/_dynamo/polyfills/pytree.py (changed in this PR)
- [guarded by try-catch] f76b7ef33c/torch/distributed/_functional_collectives.py (L17)
- [guarded by try-catch] f76b7ef33c/torch/distributed/tensor/_op_schema.py (L15)
- [guarded by try-catch] f76b7ef33c/torch/distributed/tensor/_dispatch.py (L35)
- [guarded by try-catch] f76b7ef33c/torch/_dynamo/variables/user_defined.py (L94)
- [guarded by try-catch] f76b7ef33c/torch/distributed/tensor/experimental/_func_map.py (L14)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151257
Approved by: https://github.com/malfet, https://github.com/XuehaiPan
2025-04-15 13:08:26 +00:00

420 lines
15 KiB
Python

"""
Python polyfills for torch.utils.pytree
"""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, TYPE_CHECKING
from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
from ..decorators import substitute_in_graph
if TYPE_CHECKING:
import builtins
from collections.abc import Iterable
from typing_extensions import Self
__all__: list[str] = []
if python_pytree._cxx_pytree_dynamo_traceable:
import optree
import optree._C
import torch.utils._cxx_pytree as cxx_pytree
if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree
@substitute_in_graph(
optree._C.is_dict_insertion_ordered,
can_constant_fold_through=True,
)
def _(*args: Any, **kwargs: Any) -> bool:
# In namespace 'torch', the dictionary is always traversed in insertion order.
# This function returns True.
raise ValueError(
"Should not be called directly "
"because the original function will be called in the constant fold path."
)
__name = ""
for __name in (
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
"namedtuple_fields",
"structseq_fields",
):
__func = getattr(optree, __name)
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
__func.__python_implementation__
)
__all__ += [__name] # noqa: PLE0604
del __func
del __name
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
def tree_is_leaf(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
if tree is None or (is_leaf is not None and is_leaf(tree)):
return True
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
return True
return False
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
def tree_iter(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> Iterable[Any]:
stack = [tree]
while stack:
node = stack.pop()
if tree_is_leaf(node, is_leaf=is_leaf):
yield node
continue
children, *_ = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
stack.extend(reversed(children))
__all__ += ["tree_iter"]
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
def tree_leaves(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any]:
return list(tree_iter(tree, is_leaf=is_leaf))
__all__ += ["tree_leaves"]
class _Asterisk(str):
__slots__ = ()
def __new__(cls) -> Self:
return super().__new__(cls, "*")
def __repr__(self) -> str:
return "*" # no quotes
_asterisk = _Asterisk()
del _Asterisk
@dataclass(frozen=True)
class PyTreeSpec:
"""Analog for :class:`optree.PyTreeSpec` in Python."""
_children: tuple[PyTreeSpec, ...]
_type: builtins.type | None
_metadata: Any
_entries: tuple[Any, ...]
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
num_nodes: int = field(init=False)
num_leaves: int = field(init=False)
num_children: int = field(init=False)
none_is_leaf: Literal[True] = field(init=False)
namespace: Literal["torch"] = field(init=False)
def __post_init__(self) -> None:
if self._type is None:
assert len(self._children) == 0
assert self._metadata is None
assert self._entries == ()
assert self._unflatten_func is None
num_nodes = 1
num_leaves = 1
num_children = 0
else:
assert callable(self._unflatten_func)
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
num_leaves = sum(spec.num_leaves for spec in self._children)
num_children = len(self._children)
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
object.__setattr__(self, "none_is_leaf", True)
object.__setattr__(self, "namespace", "torch")
def __repr__(self) -> str:
def helper(treespec: PyTreeSpec) -> str:
if treespec.is_leaf():
assert treespec.type is None
return _asterisk
assert treespec.type is not None
assert callable(treespec._unflatten_func)
children_representations = [
helper(subspec) for subspec in treespec._children
]
if (
treespec.type in BUILTIN_TYPES
or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type)
):
return treespec._unflatten_func(
treespec._metadata,
children_representations,
)
return (
f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
f"[{', '.join(children_representations)}])"
)
return (
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
)
def __len__(self) -> int:
return self.num_leaves
@property
def type(self) -> builtins.type | None:
return self._type
def is_leaf(self) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1
def children(self) -> list[PyTreeSpec]:
return list(self._children)
def child(self, index: int) -> PyTreeSpec:
return self._children[index]
def entries(self) -> list[Any]:
return list(self._entries)
def entry(self, index: int) -> Any:
return self._entries[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(
treespec: PyTreeSpec,
node: PyTree,
subtrees: list[PyTree],
) -> None:
if treespec.is_leaf():
subtrees.append(node)
return
node_type = type(node)
if treespec.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != treespec.type:
raise ValueError(
f"Type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
)
if len(children) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(children)}.",
)
if metadata != treespec._metadata:
raise ValueError(
f"Node context mismatch for custom node type {treespec.type!r}.",
)
else:
# For builtin dictionary types, we allow some flexibility
# Otherwise, we require exact matches
both_standard_dict = (
treespec.type in STANDARD_DICT_TYPES
and node_type in STANDARD_DICT_TYPES
)
if not both_standard_dict and node_type != treespec.type:
raise ValueError(
f"Node type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
if len(node) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(node)}.",
)
if both_standard_dict:
# dictionary types are compatible with each other
expected_keys = treespec.entries()
got_key_set = set(node)
expected_key_set = set(expected_keys)
if got_key_set != expected_key_set:
missing_keys = expected_key_set.difference(got_key_set)
extra_keys = got_key_set.difference(expected_key_set)
message = ""
if missing_keys:
message += f"; missing key(s): {missing_keys}"
if extra_keys:
message += f"; extra key(s): {extra_keys}"
raise ValueError(f"Node keys mismatch{message}.")
children = [node[key] for key in expected_keys]
else:
# node_type is treespec.type
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
)
if (
node_type
is not deque # ignore mismatch of `maxlen` for deque
) and metadata != treespec._metadata:
raise ValueError(
f"Node metadata mismatch for node type {treespec.type!r}; "
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec._children):
helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = []
helper(self, tree, subtrees)
return subtrees
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
if len(leaves) != self.num_leaves:
raise ValueError(
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
f"but the spec refers to a pytree that holds {self.num_leaves} "
f"items ({self}).",
)
if self.is_leaf():
return leaves[0]
# Recursively unflatten the children
start = 0
end = 0
subtrees = []
for subspec in self._children:
end += subspec.num_leaves
subtrees.append(subspec.unflatten(leaves[start:end]))
start = end
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> tuple[list[Any], PyTreeSpec]:
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
if tree_is_leaf(node, is_leaf=is_leaf):
leaves.append(node)
return _LEAF_SPEC
(
children,
metadata,
entries,
unflatten_func,
) = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
# Recursively flatten the children
subspecs = tuple(helper(child, leaves) for child in children)
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
leaves: list[Any] = []
treespec = helper(tree, leaves)
return leaves, treespec
__all__ += ["tree_flatten"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_structure,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_structure(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTreeSpec:
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
__all__ += ["tree_structure"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec.unflatten(leaves)
__all__ += ["tree_unflatten"]
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
def tree_map(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args))
__all__ += ["tree_map"]
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree
__all__ += ["tree_map_"]