mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
If optree is less than the minimum version, we should pretend it doesn't exist. The problem right now is: - Install optree==0.12.1 - `import torch._dynamo` - This raise an error "min optree version is 0.13.0" The fix is to pretend optree doesn't exist if it is less than the min version. There are ways to clean up this PR more (e.g. have a single source of truth for the version, some of the variables are redundant), but I am trying to reduce the risk as much as possible for this to go into 2.7. Test Plan: I verified the above problem was fixed. Also tried some other things, like the following, which now gives the expected behavior. ```py >>> import torch >>> import optree >>> optree.__version__ '0.12.1' >>> import torch._dynamo >>> import torch._dynamo.polyfills.pytree >>> import torch.utils._pytree >>> import torch.utils._cxx_pytree ImportError: torch.utils._cxx_pytree depends on optree, which is an optional dependency of PyTorch. To u se it, please upgrade your optree package to >= 0.13.0 ``` I also audited all non-test callsites of optree and torch.utils._cxx_pytree. Follow along with me: optree imports - torch.utils._cxx_pytree. This is fine. - [guarded by check]f76b7ef33c/torch/_dynamo/polyfills/pytree.py (L29-L31)
_cxx_pytree imports - [guarded by check] torch.utils._pytree (changed in this PR) - [guarded by check] torch/_dynamo/polyfills/pytree.py (changed in this PR) - [guarded by try-catch]f76b7ef33c/torch/distributed/_functional_collectives.py (L17)
- [guarded by try-catch]f76b7ef33c/torch/distributed/tensor/_op_schema.py (L15)
- [guarded by try-catch]f76b7ef33c/torch/distributed/tensor/_dispatch.py (L35)
- [guarded by try-catch]f76b7ef33c/torch/_dynamo/variables/user_defined.py (L94)
- [guarded by try-catch]f76b7ef33c/torch/distributed/tensor/experimental/_func_map.py (L14)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151257 Approved by: https://github.com/malfet, https://github.com/XuehaiPan
420 lines
15 KiB
Python
420 lines
15 KiB
Python
"""
|
|
Python polyfills for torch.utils.pytree
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections import deque
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Literal, TYPE_CHECKING
|
|
from typing_extensions import TypeIs
|
|
|
|
import torch.utils._pytree as python_pytree
|
|
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
|
|
|
|
from ..decorators import substitute_in_graph
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import builtins
|
|
from collections.abc import Iterable
|
|
from typing_extensions import Self
|
|
|
|
|
|
__all__: list[str] = []
|
|
|
|
|
|
if python_pytree._cxx_pytree_dynamo_traceable:
|
|
import optree
|
|
import optree._C
|
|
|
|
import torch.utils._cxx_pytree as cxx_pytree
|
|
|
|
if TYPE_CHECKING:
|
|
from torch.utils._cxx_pytree import PyTree
|
|
|
|
@substitute_in_graph(
|
|
optree._C.is_dict_insertion_ordered,
|
|
can_constant_fold_through=True,
|
|
)
|
|
def _(*args: Any, **kwargs: Any) -> bool:
|
|
# In namespace 'torch', the dictionary is always traversed in insertion order.
|
|
# This function returns True.
|
|
raise ValueError(
|
|
"Should not be called directly "
|
|
"because the original function will be called in the constant fold path."
|
|
)
|
|
|
|
__name = ""
|
|
for __name in (
|
|
"is_namedtuple",
|
|
"is_namedtuple_class",
|
|
"is_namedtuple_instance",
|
|
"is_structseq",
|
|
"is_structseq_class",
|
|
"is_structseq_instance",
|
|
"namedtuple_fields",
|
|
"structseq_fields",
|
|
):
|
|
__func = getattr(optree, __name)
|
|
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
|
|
__func.__python_implementation__
|
|
)
|
|
__all__ += [__name] # noqa: PLE0604
|
|
del __func
|
|
del __name
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
|
|
def tree_is_leaf(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> bool:
|
|
if tree is None or (is_leaf is not None and is_leaf(tree)):
|
|
return True
|
|
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
|
|
return True
|
|
return False
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
|
|
def tree_iter(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> Iterable[Any]:
|
|
stack = [tree]
|
|
while stack:
|
|
node = stack.pop()
|
|
if tree_is_leaf(node, is_leaf=is_leaf):
|
|
yield node
|
|
continue
|
|
|
|
children, *_ = optree.tree_flatten_one_level(
|
|
node,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=True,
|
|
namespace="torch",
|
|
)
|
|
stack.extend(reversed(children))
|
|
|
|
__all__ += ["tree_iter"]
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
|
|
def tree_leaves(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> list[Any]:
|
|
return list(tree_iter(tree, is_leaf=is_leaf))
|
|
|
|
__all__ += ["tree_leaves"]
|
|
|
|
class _Asterisk(str):
|
|
__slots__ = ()
|
|
|
|
def __new__(cls) -> Self:
|
|
return super().__new__(cls, "*")
|
|
|
|
def __repr__(self) -> str:
|
|
return "*" # no quotes
|
|
|
|
_asterisk = _Asterisk()
|
|
del _Asterisk
|
|
|
|
@dataclass(frozen=True)
|
|
class PyTreeSpec:
|
|
"""Analog for :class:`optree.PyTreeSpec` in Python."""
|
|
|
|
_children: tuple[PyTreeSpec, ...]
|
|
_type: builtins.type | None
|
|
_metadata: Any
|
|
_entries: tuple[Any, ...]
|
|
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
|
|
|
|
num_nodes: int = field(init=False)
|
|
num_leaves: int = field(init=False)
|
|
num_children: int = field(init=False)
|
|
none_is_leaf: Literal[True] = field(init=False)
|
|
namespace: Literal["torch"] = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
if self._type is None:
|
|
assert len(self._children) == 0
|
|
assert self._metadata is None
|
|
assert self._entries == ()
|
|
assert self._unflatten_func is None
|
|
num_nodes = 1
|
|
num_leaves = 1
|
|
num_children = 0
|
|
else:
|
|
assert callable(self._unflatten_func)
|
|
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
|
|
num_leaves = sum(spec.num_leaves for spec in self._children)
|
|
num_children = len(self._children)
|
|
|
|
object.__setattr__(self, "num_nodes", num_nodes)
|
|
object.__setattr__(self, "num_leaves", num_leaves)
|
|
object.__setattr__(self, "num_children", num_children)
|
|
object.__setattr__(self, "none_is_leaf", True)
|
|
object.__setattr__(self, "namespace", "torch")
|
|
|
|
def __repr__(self) -> str:
|
|
def helper(treespec: PyTreeSpec) -> str:
|
|
if treespec.is_leaf():
|
|
assert treespec.type is None
|
|
return _asterisk
|
|
|
|
assert treespec.type is not None
|
|
assert callable(treespec._unflatten_func)
|
|
children_representations = [
|
|
helper(subspec) for subspec in treespec._children
|
|
]
|
|
if (
|
|
treespec.type in BUILTIN_TYPES
|
|
or optree.is_namedtuple_class(treespec.type)
|
|
or optree.is_structseq_class(treespec.type)
|
|
):
|
|
return treespec._unflatten_func(
|
|
treespec._metadata,
|
|
children_representations,
|
|
)
|
|
return (
|
|
f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
|
|
f"[{', '.join(children_representations)}])"
|
|
)
|
|
|
|
return (
|
|
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
|
|
)
|
|
|
|
def __len__(self) -> int:
|
|
return self.num_leaves
|
|
|
|
@property
|
|
def type(self) -> builtins.type | None:
|
|
return self._type
|
|
|
|
def is_leaf(self) -> bool:
|
|
return self.num_nodes == 1 and self.num_leaves == 1
|
|
|
|
def children(self) -> list[PyTreeSpec]:
|
|
return list(self._children)
|
|
|
|
def child(self, index: int) -> PyTreeSpec:
|
|
return self._children[index]
|
|
|
|
def entries(self) -> list[Any]:
|
|
return list(self._entries)
|
|
|
|
def entry(self, index: int) -> Any:
|
|
return self._entries[index]
|
|
|
|
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
|
def helper(
|
|
treespec: PyTreeSpec,
|
|
node: PyTree,
|
|
subtrees: list[PyTree],
|
|
) -> None:
|
|
if treespec.is_leaf():
|
|
subtrees.append(node)
|
|
return
|
|
|
|
node_type = type(node)
|
|
if treespec.type not in BUILTIN_TYPES:
|
|
# Always require custom node types to match exactly
|
|
if node_type != treespec.type:
|
|
raise ValueError(
|
|
f"Type mismatch; "
|
|
f"expected {treespec.type!r}, but got {node_type!r}.",
|
|
)
|
|
|
|
children, metadata, *_ = optree.tree_flatten_one_level(
|
|
node,
|
|
none_is_leaf=True,
|
|
namespace="torch",
|
|
)
|
|
if len(children) != treespec.num_children:
|
|
raise ValueError(
|
|
f"Node arity mismatch; "
|
|
f"expected {treespec.num_children}, but got {len(children)}.",
|
|
)
|
|
if metadata != treespec._metadata:
|
|
raise ValueError(
|
|
f"Node context mismatch for custom node type {treespec.type!r}.",
|
|
)
|
|
else:
|
|
# For builtin dictionary types, we allow some flexibility
|
|
# Otherwise, we require exact matches
|
|
both_standard_dict = (
|
|
treespec.type in STANDARD_DICT_TYPES
|
|
and node_type in STANDARD_DICT_TYPES
|
|
)
|
|
if not both_standard_dict and node_type != treespec.type:
|
|
raise ValueError(
|
|
f"Node type mismatch; "
|
|
f"expected {treespec.type!r}, but got {node_type!r}.",
|
|
)
|
|
if len(node) != treespec.num_children:
|
|
raise ValueError(
|
|
f"Node arity mismatch; "
|
|
f"expected {treespec.num_children}, but got {len(node)}.",
|
|
)
|
|
|
|
if both_standard_dict:
|
|
# dictionary types are compatible with each other
|
|
expected_keys = treespec.entries()
|
|
got_key_set = set(node)
|
|
expected_key_set = set(expected_keys)
|
|
if got_key_set != expected_key_set:
|
|
missing_keys = expected_key_set.difference(got_key_set)
|
|
extra_keys = got_key_set.difference(expected_key_set)
|
|
message = ""
|
|
if missing_keys:
|
|
message += f"; missing key(s): {missing_keys}"
|
|
if extra_keys:
|
|
message += f"; extra key(s): {extra_keys}"
|
|
raise ValueError(f"Node keys mismatch{message}.")
|
|
children = [node[key] for key in expected_keys]
|
|
else:
|
|
# node_type is treespec.type
|
|
children, metadata, *_ = optree.tree_flatten_one_level(
|
|
node,
|
|
none_is_leaf=True,
|
|
namespace="torch",
|
|
)
|
|
if (
|
|
node_type
|
|
is not deque # ignore mismatch of `maxlen` for deque
|
|
) and metadata != treespec._metadata:
|
|
raise ValueError(
|
|
f"Node metadata mismatch for node type {treespec.type!r}; "
|
|
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
|
|
)
|
|
|
|
for subtree, subspec in zip(children, treespec._children):
|
|
helper(subspec, subtree, subtrees)
|
|
|
|
subtrees: list[PyTree] = []
|
|
helper(self, tree, subtrees)
|
|
return subtrees
|
|
|
|
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
|
|
if not isinstance(leaves, (list, tuple)):
|
|
leaves = list(leaves)
|
|
if len(leaves) != self.num_leaves:
|
|
raise ValueError(
|
|
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
|
|
f"but the spec refers to a pytree that holds {self.num_leaves} "
|
|
f"items ({self}).",
|
|
)
|
|
if self.is_leaf():
|
|
return leaves[0]
|
|
|
|
# Recursively unflatten the children
|
|
start = 0
|
|
end = 0
|
|
subtrees = []
|
|
for subspec in self._children:
|
|
end += subspec.num_leaves
|
|
subtrees.append(subspec.unflatten(leaves[start:end]))
|
|
start = end
|
|
|
|
assert callable(self._unflatten_func)
|
|
return self._unflatten_func(self._metadata, subtrees)
|
|
|
|
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
|
|
|
|
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
|
return isinstance(obj, PyTreeSpec)
|
|
|
|
@substitute_in_graph( # type: ignore[arg-type]
|
|
cxx_pytree.tree_flatten,
|
|
# We need to disable constant folding here because we want the function to reference the
|
|
# PyTreeSpec class defined above, not the one in the C++ module.
|
|
can_constant_fold_through=False,
|
|
)
|
|
def tree_flatten(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> tuple[list[Any], PyTreeSpec]:
|
|
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
|
|
if tree_is_leaf(node, is_leaf=is_leaf):
|
|
leaves.append(node)
|
|
return _LEAF_SPEC
|
|
|
|
(
|
|
children,
|
|
metadata,
|
|
entries,
|
|
unflatten_func,
|
|
) = optree.tree_flatten_one_level(
|
|
node,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=True,
|
|
namespace="torch",
|
|
)
|
|
|
|
# Recursively flatten the children
|
|
subspecs = tuple(helper(child, leaves) for child in children)
|
|
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
|
|
|
|
leaves: list[Any] = []
|
|
treespec = helper(tree, leaves)
|
|
return leaves, treespec
|
|
|
|
__all__ += ["tree_flatten"]
|
|
|
|
@substitute_in_graph( # type: ignore[arg-type]
|
|
cxx_pytree.tree_structure,
|
|
# We need to disable constant folding here because we want the function to reference the
|
|
# PyTreeSpec class defined above, not the one in the C++ module.
|
|
can_constant_fold_through=False,
|
|
)
|
|
def tree_structure(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> PyTreeSpec:
|
|
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
|
|
|
|
__all__ += ["tree_structure"]
|
|
|
|
@substitute_in_graph( # type: ignore[arg-type]
|
|
cxx_pytree.tree_unflatten,
|
|
# We need to disable constant folding here because we want the function to reference the
|
|
# PyTreeSpec class defined above, not the one in the C++ module.
|
|
can_constant_fold_through=False,
|
|
)
|
|
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
|
if not _is_pytreespec_instance(treespec):
|
|
raise TypeError(
|
|
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
|
f"PyTreeSpec but got item of type {type(treespec)}."
|
|
)
|
|
return treespec.unflatten(leaves)
|
|
|
|
__all__ += ["tree_unflatten"]
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
|
|
def tree_map(
|
|
func: Callable[..., Any],
|
|
tree: PyTree,
|
|
*rests: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> PyTree:
|
|
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
return treespec.unflatten(map(func, *flat_args))
|
|
|
|
__all__ += ["tree_map"]
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
|
|
def tree_map_(
|
|
func: Callable[..., Any],
|
|
tree: PyTree,
|
|
*rests: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> PyTree:
|
|
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
|
return tree
|
|
|
|
__all__ += ["tree_map_"]
|