[pytree] extend pytree operations with is_leaf prediction function (#116419)

Add an extra `is_leaf` prediction function to pytree operations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116419
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2024-01-09 19:50:03 +00:00
committed by PyTorch MergeBot
parent 902807a86d
commit ab1ac43752
3 changed files with 260 additions and 59 deletions

View File

@ -231,7 +231,10 @@ def _private_register_pytree_node(
)
def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]:
def tree_flatten(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Tuple[List[Any], TreeSpec]:
"""Flatten a pytree.
See also :func:`tree_unflatten`.
@ -258,6 +261,11 @@ def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]:
Args:
tree (pytree): A pytree to flatten.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the
@ -265,6 +273,7 @@ def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]:
"""
return optree.tree_flatten( # type: ignore[return-value]
tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
@ -297,7 +306,10 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
def tree_leaves(tree: PyTree) -> List[Any]:
def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> List[Any]:
"""Get the leaves of a pytree.
See also :func:`tree_flatten`.
@ -312,14 +324,27 @@ def tree_leaves(tree: PyTree) -> List[Any]:
Args:
tree (pytree): A pytree to flatten.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
A list of leaf values.
"""
return optree.tree_leaves(tree, none_is_leaf=True, namespace="torch")
return optree.tree_leaves(
tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
def tree_structure(tree: PyTree) -> TreeSpec:
def tree_structure(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> TreeSpec:
"""Get the treespec for a pytree.
See also :func:`tree_flatten`.
@ -334,18 +359,29 @@ def tree_structure(tree: PyTree) -> TreeSpec:
Args:
tree (pytree): A pytree to flatten.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
A treespec object representing the structure of the pytree.
"""
return optree.tree_structure( # type: ignore[return-value]
tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
def tree_map(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
"""Map a multi-input function over pytree args to produce a new pytree.
See also :func:`tree_map_`.
@ -368,6 +404,11 @@ def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
argument to function ``func``.
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
``tree`` or has ``tree`` as a prefix.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
@ -378,12 +419,18 @@ def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
func,
tree,
*rests,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
See also :func:`tree_map`.
@ -395,6 +442,11 @@ def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
argument to function ``func``.
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
``tree`` or has ``tree`` as a prefix.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
The original ``tree`` with the value at each leaf is given by the side-effect of function
@ -405,6 +457,7 @@ def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
func,
tree,
*rests,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
@ -482,6 +535,7 @@ def tree_map_only(
__type_or_types: Type[T],
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
@ -491,6 +545,7 @@ def tree_map_only(
__type_or_types: Type2[T, S],
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
@ -500,6 +555,7 @@ def tree_map_only(
__type_or_types: Type3[T, S, U],
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
@ -508,8 +564,9 @@ def tree_map_only(
__type_or_types: TypeAny,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
return tree_map(map_only(__type_or_types)(func), tree)
return tree_map(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
@overload
@ -517,6 +574,7 @@ def tree_map_only_(
__type_or_types: Type[T],
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
@ -526,6 +584,7 @@ def tree_map_only_(
__type_or_types: Type2[T, S],
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
@ -535,6 +594,7 @@ def tree_map_only_(
__type_or_types: Type3[T, S, U],
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
...
@ -543,17 +603,26 @@ def tree_map_only_(
__type_or_types: TypeAny,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
return tree_map_(map_only(__type_or_types)(func), tree)
return tree_map_(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
def tree_all(pred: Callable[[Any], bool], tree: PyTree) -> bool:
flat_args = tree_leaves(tree)
def tree_all(
pred: Callable[[Any], bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))
def tree_any(pred: Callable[[Any], bool], tree: PyTree) -> bool:
flat_args = tree_leaves(tree)
def tree_any(
pred: Callable[[Any], bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))
@ -562,6 +631,7 @@ def tree_all_only(
__type_or_types: Type[T],
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
@ -571,6 +641,7 @@ def tree_all_only(
__type_or_types: Type2[T, S],
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
@ -580,6 +651,7 @@ def tree_all_only(
__type_or_types: Type3[T, S, U],
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
@ -588,8 +660,9 @@ def tree_all_only(
__type_or_types: TypeAny,
pred: FnAny[bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree)
flat_args = tree_leaves(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@ -598,6 +671,7 @@ def tree_any_only(
__type_or_types: Type[T],
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
@ -607,6 +681,7 @@ def tree_any_only(
__type_or_types: Type2[T, S],
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
@ -616,6 +691,7 @@ def tree_any_only(
__type_or_types: Type3[T, S, U],
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
...
@ -624,12 +700,17 @@ def tree_any_only(
__type_or_types: TypeAny,
pred: FnAny[bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree)
flat_args = tree_leaves(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
def broadcast_prefix(prefix_tree: PyTree, full_tree: PyTree) -> List[Any]:
def broadcast_prefix(
prefix_tree: PyTree,
full_tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> List[Any]:
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
@ -655,10 +736,11 @@ def broadcast_prefix(prefix_tree: PyTree, full_tree: PyTree) -> List[Any]:
Args:
prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
@ -666,6 +748,7 @@ def broadcast_prefix(prefix_tree: PyTree, full_tree: PyTree) -> List[Any]:
return optree.broadcast_prefix(
prefix_tree,
full_tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
@ -679,11 +762,15 @@ def broadcast_prefix(prefix_tree: PyTree, full_tree: PyTree) -> List[Any]:
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
# broadcastable to the tree structure of `inputs` and we use
# _broadcast_to_and_flatten to check this.
def _broadcast_to_and_flatten(tree: PyTree, treespec: TreeSpec) -> Optional[List[Any]]:
def _broadcast_to_and_flatten(
tree: PyTree,
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[List[Any]]:
assert isinstance(treespec, TreeSpec)
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
return broadcast_prefix(tree, full_tree)
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
except ValueError:
return None