[dynamo][pytree][3/N] make CXX pytree traceable: tree_map / tree_map_ (#137399)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137399
Approved by: https://github.com/jansel
ghstack dependencies: #137398
This commit is contained in:
Xuehai Pan
2024-12-09 18:24:39 +08:00
committed by PyTorch MergeBot
parent 7edeb1005a
commit d47a80246a
4 changed files with 220 additions and 89 deletions

View File

@ -10153,6 +10153,8 @@ def ___make_guard_fn():
def test_pytree_tree_map(self):
implemtations = [("python", python_pytree)]
if cxx_pytree is not None:
implemtations.append(("cxx", cxx_pytree))
for name, module in implemtations:
with self.subTest(f"pytree implement: {name}"):

View File

@ -4,12 +4,13 @@ Python polyfills for torch.utils.pytree
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Literal, TYPE_CHECKING
from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree
from torch.utils._pytree import BUILTIN_TYPES
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
from ..decorators import substitute_in_graph
@ -200,6 +201,95 @@ if python_pytree._cxx_pytree_dynamo_traceable:
def entry(self, index: int) -> Any:
return self._entries[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(
treespec: PyTreeSpec,
node: PyTree,
subtrees: list[PyTree],
) -> None:
if treespec.is_leaf():
subtrees.append(node)
return
node_type = type(node)
if treespec.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != treespec.type:
raise ValueError(
f"Type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
)
if len(children) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(children)}.",
)
if metadata != treespec._metadata:
raise ValueError(
f"Node context mismatch for custom node type {treespec.type!r}.",
)
else:
# For builtin dictionary types, we allow some flexibility
# Otherwise, we require exact matches
both_standard_dict = (
treespec.type in STANDARD_DICT_TYPES
and node_type in STANDARD_DICT_TYPES
)
if not both_standard_dict and node_type != treespec.type:
raise ValueError(
f"Node type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
if len(node) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(node)}.",
)
if both_standard_dict:
# dictionary types are compatible with each other
expected_keys = treespec.entries()
got_key_set = set(node)
expected_key_set = set(expected_keys)
if got_key_set != expected_key_set:
missing_keys = expected_key_set.difference(got_key_set)
extra_keys = got_key_set.difference(expected_key_set)
message = ""
if missing_keys:
message += f"; missing key(s): {missing_keys}"
if extra_keys:
message += f"; extra key(s): {extra_keys}"
raise ValueError(f"Node keys mismatch{message}.")
children = [node[key] for key in expected_keys]
else:
# node_type is treespec.type
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
)
if (
node_type
is not deque # ignore mismatch of `maxlen` for deque
) and metadata != treespec._metadata:
raise ValueError(
f"Node metadata mismatch for node type {treespec.type!r}; "
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec._children):
helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = []
helper(self, tree, subtrees)
return subtrees
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
@ -295,3 +385,30 @@ if python_pytree._cxx_pytree_dynamo_traceable:
return treespec.unflatten(leaves)
__all__ += ["tree_unflatten"]
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
def tree_map(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args))
__all__ += ["tree_map"]
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree
__all__ += ["tree_map_"]

View File

@ -293,7 +293,7 @@ def tree_flatten(
The flattening order (i.e., the order of elements in the output list) is deterministic,
corresponding to a left-to-right depth-first tree traversal.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> tree_flatten(tree)
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
>>> tree_flatten(1)
@ -306,7 +306,7 @@ def tree_flatten(
if you want to keep the keys in the insertion order.
>>> from collections import OrderedDict
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
>>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)])
>>> tree_flatten(tree)
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
@ -335,7 +335,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The inverse of :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> leaves, treespec = tree_flatten(tree)
>>> tree == tree_unflatten(leaves, treespec)
True
@ -365,7 +365,7 @@ def tree_iter(
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> list(tree_iter(tree))
[1, 2, 3, 4, None, 5]
>>> list(tree_iter(1))
@ -400,7 +400,7 @@ def tree_leaves(
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> tree_leaves(tree)
[1, 2, 3, 4, None, 5]
>>> tree_leaves(1)
@ -435,7 +435,7 @@ def tree_structure(
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
>>> tree_structure(tree)
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
>>> tree_structure(1)
@ -472,9 +472,9 @@ def tree_map(
See also :func:`tree_map_`.
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
>>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)})
{'x': 8, 'y': (43, 65)}
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
>>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None})
{'x': False, 'y': (False, False), 'z': True}
If multiple inputs are given, the structure of the tree is taken from the first input;
@ -572,7 +572,9 @@ def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
@overload
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
def map_only(
__type_or_types_or_pred: Type3[T, S, U],
) -> MapOnlyFn[Fn3[T, S, U, Any]]:
...
@ -588,12 +590,14 @@ def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
@overload
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
def map_only(
__type_or_types_or_pred: Callable[[Any], bool],
) -> MapOnlyFn[FnAny[Any]]:
...
def map_only(
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
) -> MapOnlyFn[FnAny[Any]]:
"""
Suppose you are writing a tree_map over tensors, leaving everything
@ -858,7 +862,7 @@ def broadcast_prefix(
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
>>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
[1, 2, 3, 3]
>>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
>>> broadcast_prefix([1, 2, 3], [1, 2, {"a": 3, "b": 4, "c": (None, 5)}])
[1, 2, 3, 3, 3, 3]
Args:
@ -873,13 +877,19 @@ def broadcast_prefix(
Returns:
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
"""
return optree.broadcast_prefix(
result: List[Any] = []
def add_leaves(x: Any, subtree: PyTree) -> None:
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
result.extend([x] * subtreespec.num_leaves)
tree_map_(
add_leaves,
prefix_tree,
full_tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
return result
# Broadcasts a pytree to the provided TreeSpec and returns the flattened

View File

@ -744,85 +744,87 @@ class TreeSpec:
def is_leaf(self) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1
def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None:
if self.is_leaf():
subtrees.append(tree)
return
def flatten_up_to(self, tree: PyTree) -> List[PyTree]:
def helper(treespec: TreeSpec, tree: PyTree, subtrees: List[PyTree]) -> None:
if treespec.is_leaf():
subtrees.append(tree)
return
node_type = _get_node_type(tree)
if self.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != self.type:
raise ValueError(
f"Type mismatch; "
f"expected {self.type!r}, but got {node_type!r}.",
)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)
if len(child_pytrees) != self.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {self.num_children}, but got {len(child_pytrees)}.",
)
if context != self.context:
raise ValueError(
f"Node context mismatch for custom node type {self.type!r}.",
)
else:
# For builtin dictionary types, we allow some flexibility
# Otherwise, we require exact matches
both_standard_dict = (
self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES
)
if node_type != self.type and not both_standard_dict:
raise ValueError(
f"Node type mismatch; "
f"expected {self.type!r}, but got {node_type!r}.",
)
if len(tree) != self.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {self.num_children}, but got {len(tree)}.",
)
if both_standard_dict: # dictionary types are compatible with each other
dict_context = (
self.context
if self.type is not defaultdict
# ignore mismatch of `default_factory` for defaultdict
else self.context[1]
)
expected_keys = dict_context
got_key_set = set(tree)
expected_key_set = set(expected_keys)
if got_key_set != expected_key_set:
missing_keys = expected_key_set.difference(got_key_set)
extra_keys = got_key_set.difference(expected_key_set)
message = ""
if missing_keys:
message += f"; missing key(s): {missing_keys}"
if extra_keys:
message += f"; extra key(s): {extra_keys}"
raise ValueError(f"Node keys mismatch{message}.")
child_pytrees = [tree[key] for key in expected_keys]
else:
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)
if (
context != self.context
and self.type is not deque # ignore mismatch of `maxlen` for deque
):
node_type = _get_node_type(tree)
if treespec.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != treespec.type:
raise ValueError(
f"Node context mismatch for node type {self.type!r}; "
f"expected {self.context!r}, but got {context!r}.", # namedtuple type mismatch
f"Type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(tree)
if len(children) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(children)}.",
)
if context != treespec.context:
raise ValueError(
f"Node context mismatch for custom node type {treespec.type!r}.",
)
else:
# For builtin dictionary types, we allow some flexibility
# Otherwise, we require exact matches
both_standard_dict = (
treespec.type in STANDARD_DICT_TYPES
and node_type in STANDARD_DICT_TYPES
)
if not both_standard_dict and node_type != treespec.type:
raise ValueError(
f"Node type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
if len(tree) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(tree)}.",
)
for child_pytree, child_spec in zip(child_pytrees, self.children_specs):
child_spec._flatten_up_to_helper(child_pytree, subtrees)
if both_standard_dict:
# dictionary types are compatible with each other
dict_context = (
treespec.context
if treespec.type is not defaultdict
# ignore mismatch of `default_factory` for defaultdict
else treespec.context[1]
)
expected_keys = dict_context
got_key_set = set(tree)
expected_key_set = set(expected_keys)
if got_key_set != expected_key_set:
missing_keys = expected_key_set.difference(got_key_set)
extra_keys = got_key_set.difference(expected_key_set)
message = ""
if missing_keys:
message += f"; missing key(s): {missing_keys}"
if extra_keys:
message += f"; extra key(s): {extra_keys}"
raise ValueError(f"Node keys mismatch{message}.")
children = [tree[key] for key in expected_keys]
else:
# node_type is treespec.type
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
children, context = flatten_fn(tree)
if (
node_type is not deque # ignore mismatch of `maxlen` for deque
) and context != treespec.context:
raise ValueError(
f"Node context mismatch for node type {treespec.type!r}; "
f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec.children_specs):
helper(subspec, subtree, subtrees)
def flatten_up_to(self, tree: PyTree) -> List[PyTree]:
subtrees: List[PyTree] = []
self._flatten_up_to_helper(tree, subtrees)
helper(self, tree, subtrees)
return subtrees
def unflatten(self, leaves: Iterable[Any]) -> PyTree: