mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][pytree][3/N] make CXX pytree traceable: tree_map
/ tree_map_
(#137399)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137399 Approved by: https://github.com/jansel ghstack dependencies: #137398
This commit is contained in:
committed by
PyTorch MergeBot
parent
7edeb1005a
commit
d47a80246a
@ -10153,6 +10153,8 @@ def ___make_guard_fn():
|
||||
|
||||
def test_pytree_tree_map(self):
|
||||
implemtations = [("python", python_pytree)]
|
||||
if cxx_pytree is not None:
|
||||
implemtations.append(("cxx", cxx_pytree))
|
||||
|
||||
for name, module in implemtations:
|
||||
with self.subTest(f"pytree implement: {name}"):
|
||||
|
@ -4,12 +4,13 @@ Python polyfills for torch.utils.pytree
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Iterable, Literal, TYPE_CHECKING
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.utils._pytree import BUILTIN_TYPES
|
||||
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
|
||||
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
@ -200,6 +201,95 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
def entry(self, index: int) -> Any:
|
||||
return self._entries[index]
|
||||
|
||||
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
||||
def helper(
|
||||
treespec: PyTreeSpec,
|
||||
node: PyTree,
|
||||
subtrees: list[PyTree],
|
||||
) -> None:
|
||||
if treespec.is_leaf():
|
||||
subtrees.append(node)
|
||||
return
|
||||
|
||||
node_type = type(node)
|
||||
if treespec.type not in BUILTIN_TYPES:
|
||||
# Always require custom node types to match exactly
|
||||
if node_type != treespec.type:
|
||||
raise ValueError(
|
||||
f"Type mismatch; "
|
||||
f"expected {treespec.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
|
||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
if len(children) != treespec.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
f"expected {treespec.num_children}, but got {len(children)}.",
|
||||
)
|
||||
if metadata != treespec._metadata:
|
||||
raise ValueError(
|
||||
f"Node context mismatch for custom node type {treespec.type!r}.",
|
||||
)
|
||||
else:
|
||||
# For builtin dictionary types, we allow some flexibility
|
||||
# Otherwise, we require exact matches
|
||||
both_standard_dict = (
|
||||
treespec.type in STANDARD_DICT_TYPES
|
||||
and node_type in STANDARD_DICT_TYPES
|
||||
)
|
||||
if not both_standard_dict and node_type != treespec.type:
|
||||
raise ValueError(
|
||||
f"Node type mismatch; "
|
||||
f"expected {treespec.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
if len(node) != treespec.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
f"expected {treespec.num_children}, but got {len(node)}.",
|
||||
)
|
||||
|
||||
if both_standard_dict:
|
||||
# dictionary types are compatible with each other
|
||||
expected_keys = treespec.entries()
|
||||
got_key_set = set(node)
|
||||
expected_key_set = set(expected_keys)
|
||||
if got_key_set != expected_key_set:
|
||||
missing_keys = expected_key_set.difference(got_key_set)
|
||||
extra_keys = got_key_set.difference(expected_key_set)
|
||||
message = ""
|
||||
if missing_keys:
|
||||
message += f"; missing key(s): {missing_keys}"
|
||||
if extra_keys:
|
||||
message += f"; extra key(s): {extra_keys}"
|
||||
raise ValueError(f"Node keys mismatch{message}.")
|
||||
children = [node[key] for key in expected_keys]
|
||||
else:
|
||||
# node_type is treespec.type
|
||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
if (
|
||||
node_type
|
||||
is not deque # ignore mismatch of `maxlen` for deque
|
||||
) and metadata != treespec._metadata:
|
||||
raise ValueError(
|
||||
f"Node metadata mismatch for node type {treespec.type!r}; "
|
||||
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
|
||||
)
|
||||
|
||||
for subtree, subspec in zip(children, treespec._children):
|
||||
helper(subspec, subtree, subtrees)
|
||||
|
||||
subtrees: list[PyTree] = []
|
||||
helper(self, tree, subtrees)
|
||||
return subtrees
|
||||
|
||||
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
|
||||
if not isinstance(leaves, (list, tuple)):
|
||||
leaves = list(leaves)
|
||||
@ -295,3 +385,30 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
__all__ += ["tree_unflatten"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
|
||||
def tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
return treespec.unflatten(map(func, *flat_args))
|
||||
|
||||
__all__ += ["tree_map"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
|
||||
def tree_map_(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
return tree
|
||||
|
||||
__all__ += ["tree_map_"]
|
||||
|
@ -293,7 +293,7 @@ def tree_flatten(
|
||||
The flattening order (i.e., the order of elements in the output list) is deterministic,
|
||||
corresponding to a left-to-right depth-first tree traversal.
|
||||
|
||||
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> tree_flatten(tree)
|
||||
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
|
||||
>>> tree_flatten(1)
|
||||
@ -306,7 +306,7 @@ def tree_flatten(
|
||||
if you want to keep the keys in the insertion order.
|
||||
|
||||
>>> from collections import OrderedDict
|
||||
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
|
||||
>>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
|
||||
>>> tree_flatten(tree)
|
||||
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
|
||||
|
||||
@ -335,7 +335,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
|
||||
The inverse of :func:`tree_flatten`.
|
||||
|
||||
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> leaves, treespec = tree_flatten(tree)
|
||||
>>> tree == tree_unflatten(leaves, treespec)
|
||||
True
|
||||
@ -365,7 +365,7 @@ def tree_iter(
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
|
||||
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> list(tree_iter(tree))
|
||||
[1, 2, 3, 4, None, 5]
|
||||
>>> list(tree_iter(1))
|
||||
@ -400,7 +400,7 @@ def tree_leaves(
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
|
||||
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> tree_leaves(tree)
|
||||
[1, 2, 3, 4, None, 5]
|
||||
>>> tree_leaves(1)
|
||||
@ -435,7 +435,7 @@ def tree_structure(
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
|
||||
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> tree_structure(tree)
|
||||
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
|
||||
>>> tree_structure(1)
|
||||
@ -472,9 +472,9 @@ def tree_map(
|
||||
|
||||
See also :func:`tree_map_`.
|
||||
|
||||
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
|
||||
>>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
|
||||
{'x': 8, 'y': (43, 65)}
|
||||
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
|
||||
>>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
|
||||
{'x': False, 'y': (False, False), 'z': True}
|
||||
|
||||
If multiple inputs are given, the structure of the tree is taken from the first input;
|
||||
@ -572,7 +572,9 @@ def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
def map_only(
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@ -588,12 +590,14 @@ def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
|
||||
def map_only(
|
||||
__type_or_types_or_pred: Callable[[Any], bool],
|
||||
) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
|
||||
|
||||
def map_only(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
) -> MapOnlyFn[FnAny[Any]]:
|
||||
"""
|
||||
Suppose you are writing a tree_map over tensors, leaving everything
|
||||
@ -858,7 +862,7 @@ def broadcast_prefix(
|
||||
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
|
||||
>>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
|
||||
[1, 2, 3, 3]
|
||||
>>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
|
||||
>>> broadcast_prefix([1, 2, 3], [1, 2, {"a": 3, "b": 4, "c": (None, 5)}])
|
||||
[1, 2, 3, 3, 3, 3]
|
||||
|
||||
Args:
|
||||
@ -873,13 +877,19 @@ def broadcast_prefix(
|
||||
Returns:
|
||||
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
|
||||
"""
|
||||
return optree.broadcast_prefix(
|
||||
result: List[Any] = []
|
||||
|
||||
def add_leaves(x: Any, subtree: PyTree) -> None:
|
||||
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
|
||||
result.extend([x] * subtreespec.num_leaves)
|
||||
|
||||
tree_map_(
|
||||
add_leaves,
|
||||
prefix_tree,
|
||||
full_tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
|
||||
|
@ -744,85 +744,87 @@ class TreeSpec:
|
||||
def is_leaf(self) -> bool:
|
||||
return self.num_nodes == 1 and self.num_leaves == 1
|
||||
|
||||
def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None:
|
||||
if self.is_leaf():
|
||||
subtrees.append(tree)
|
||||
return
|
||||
def flatten_up_to(self, tree: PyTree) -> List[PyTree]:
|
||||
def helper(treespec: TreeSpec, tree: PyTree, subtrees: List[PyTree]) -> None:
|
||||
if treespec.is_leaf():
|
||||
subtrees.append(tree)
|
||||
return
|
||||
|
||||
node_type = _get_node_type(tree)
|
||||
if self.type not in BUILTIN_TYPES:
|
||||
# Always require custom node types to match exactly
|
||||
if node_type != self.type:
|
||||
raise ValueError(
|
||||
f"Type mismatch; "
|
||||
f"expected {self.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(tree)
|
||||
if len(child_pytrees) != self.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
f"expected {self.num_children}, but got {len(child_pytrees)}.",
|
||||
)
|
||||
if context != self.context:
|
||||
raise ValueError(
|
||||
f"Node context mismatch for custom node type {self.type!r}.",
|
||||
)
|
||||
else:
|
||||
# For builtin dictionary types, we allow some flexibility
|
||||
# Otherwise, we require exact matches
|
||||
both_standard_dict = (
|
||||
self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES
|
||||
)
|
||||
if node_type != self.type and not both_standard_dict:
|
||||
raise ValueError(
|
||||
f"Node type mismatch; "
|
||||
f"expected {self.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
if len(tree) != self.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
f"expected {self.num_children}, but got {len(tree)}.",
|
||||
)
|
||||
|
||||
if both_standard_dict: # dictionary types are compatible with each other
|
||||
dict_context = (
|
||||
self.context
|
||||
if self.type is not defaultdict
|
||||
# ignore mismatch of `default_factory` for defaultdict
|
||||
else self.context[1]
|
||||
)
|
||||
expected_keys = dict_context
|
||||
got_key_set = set(tree)
|
||||
expected_key_set = set(expected_keys)
|
||||
if got_key_set != expected_key_set:
|
||||
missing_keys = expected_key_set.difference(got_key_set)
|
||||
extra_keys = got_key_set.difference(expected_key_set)
|
||||
message = ""
|
||||
if missing_keys:
|
||||
message += f"; missing key(s): {missing_keys}"
|
||||
if extra_keys:
|
||||
message += f"; extra key(s): {extra_keys}"
|
||||
raise ValueError(f"Node keys mismatch{message}.")
|
||||
child_pytrees = [tree[key] for key in expected_keys]
|
||||
else:
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(tree)
|
||||
if (
|
||||
context != self.context
|
||||
and self.type is not deque # ignore mismatch of `maxlen` for deque
|
||||
):
|
||||
node_type = _get_node_type(tree)
|
||||
if treespec.type not in BUILTIN_TYPES:
|
||||
# Always require custom node types to match exactly
|
||||
if node_type != treespec.type:
|
||||
raise ValueError(
|
||||
f"Node context mismatch for node type {self.type!r}; "
|
||||
f"expected {self.context!r}, but got {context!r}.", # namedtuple type mismatch
|
||||
f"Type mismatch; "
|
||||
f"expected {treespec.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
children, context = flatten_fn(tree)
|
||||
if len(children) != treespec.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
f"expected {treespec.num_children}, but got {len(children)}.",
|
||||
)
|
||||
if context != treespec.context:
|
||||
raise ValueError(
|
||||
f"Node context mismatch for custom node type {treespec.type!r}.",
|
||||
)
|
||||
else:
|
||||
# For builtin dictionary types, we allow some flexibility
|
||||
# Otherwise, we require exact matches
|
||||
both_standard_dict = (
|
||||
treespec.type in STANDARD_DICT_TYPES
|
||||
and node_type in STANDARD_DICT_TYPES
|
||||
)
|
||||
if not both_standard_dict and node_type != treespec.type:
|
||||
raise ValueError(
|
||||
f"Node type mismatch; "
|
||||
f"expected {treespec.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
if len(tree) != treespec.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
f"expected {treespec.num_children}, but got {len(tree)}.",
|
||||
)
|
||||
|
||||
for child_pytree, child_spec in zip(child_pytrees, self.children_specs):
|
||||
child_spec._flatten_up_to_helper(child_pytree, subtrees)
|
||||
if both_standard_dict:
|
||||
# dictionary types are compatible with each other
|
||||
dict_context = (
|
||||
treespec.context
|
||||
if treespec.type is not defaultdict
|
||||
# ignore mismatch of `default_factory` for defaultdict
|
||||
else treespec.context[1]
|
||||
)
|
||||
expected_keys = dict_context
|
||||
got_key_set = set(tree)
|
||||
expected_key_set = set(expected_keys)
|
||||
if got_key_set != expected_key_set:
|
||||
missing_keys = expected_key_set.difference(got_key_set)
|
||||
extra_keys = got_key_set.difference(expected_key_set)
|
||||
message = ""
|
||||
if missing_keys:
|
||||
message += f"; missing key(s): {missing_keys}"
|
||||
if extra_keys:
|
||||
message += f"; extra key(s): {extra_keys}"
|
||||
raise ValueError(f"Node keys mismatch{message}.")
|
||||
children = [tree[key] for key in expected_keys]
|
||||
else:
|
||||
# node_type is treespec.type
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
children, context = flatten_fn(tree)
|
||||
if (
|
||||
node_type is not deque # ignore mismatch of `maxlen` for deque
|
||||
) and context != treespec.context:
|
||||
raise ValueError(
|
||||
f"Node context mismatch for node type {treespec.type!r}; "
|
||||
f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
|
||||
)
|
||||
|
||||
for subtree, subspec in zip(children, treespec.children_specs):
|
||||
helper(subspec, subtree, subtrees)
|
||||
|
||||
def flatten_up_to(self, tree: PyTree) -> List[PyTree]:
|
||||
subtrees: List[PyTree] = []
|
||||
self._flatten_up_to_helper(tree, subtrees)
|
||||
helper(self, tree, subtrees)
|
||||
return subtrees
|
||||
|
||||
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
|
||||
|
Reference in New Issue
Block a user