mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] extend pytree operations with is_leaf
prediction function (#116419)
Add an extra `is_leaf` prediction function to pytree operations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/116419 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
902807a86d
commit
ab1ac43752
@ -231,7 +231,10 @@ def _private_register_pytree_node(
|
||||
)
|
||||
|
||||
|
||||
def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]:
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Tuple[List[Any], TreeSpec]:
|
||||
"""Flatten a pytree.
|
||||
|
||||
See also :func:`tree_unflatten`.
|
||||
@ -258,6 +261,11 @@ def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]:
|
||||
|
||||
Args:
|
||||
tree (pytree): A pytree to flatten.
|
||||
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
||||
flattening step. The function should have a single argument with signature
|
||||
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
||||
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
||||
leaf or not. If the function is not specified, the default pytree registry will be used.
|
||||
|
||||
Returns:
|
||||
A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the
|
||||
@ -265,6 +273,7 @@ def tree_flatten(tree: PyTree) -> Tuple[List[Any], TreeSpec]:
|
||||
"""
|
||||
return optree.tree_flatten( # type: ignore[return-value]
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
@ -297,7 +306,10 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def tree_leaves(tree: PyTree) -> List[Any]:
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> List[Any]:
|
||||
"""Get the leaves of a pytree.
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
@ -312,14 +324,27 @@ def tree_leaves(tree: PyTree) -> List[Any]:
|
||||
|
||||
Args:
|
||||
tree (pytree): A pytree to flatten.
|
||||
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
||||
flattening step. The function should have a single argument with signature
|
||||
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
||||
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
||||
leaf or not. If the function is not specified, the default pytree registry will be used.
|
||||
|
||||
Returns:
|
||||
A list of leaf values.
|
||||
"""
|
||||
return optree.tree_leaves(tree, none_is_leaf=True, namespace="torch")
|
||||
return optree.tree_leaves(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
|
||||
|
||||
def tree_structure(tree: PyTree) -> TreeSpec:
|
||||
def tree_structure(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> TreeSpec:
|
||||
"""Get the treespec for a pytree.
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
@ -334,18 +359,29 @@ def tree_structure(tree: PyTree) -> TreeSpec:
|
||||
|
||||
Args:
|
||||
tree (pytree): A pytree to flatten.
|
||||
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
||||
flattening step. The function should have a single argument with signature
|
||||
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
||||
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
||||
leaf or not. If the function is not specified, the default pytree registry will be used.
|
||||
|
||||
Returns:
|
||||
A treespec object representing the structure of the pytree.
|
||||
"""
|
||||
return optree.tree_structure( # type: ignore[return-value]
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
|
||||
|
||||
def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
def tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
"""Map a multi-input function over pytree args to produce a new pytree.
|
||||
|
||||
See also :func:`tree_map_`.
|
||||
@ -368,6 +404,11 @@ def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
argument to function ``func``.
|
||||
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
||||
``tree`` or has ``tree`` as a prefix.
|
||||
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
||||
flattening step. The function should have a single argument with signature
|
||||
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
||||
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
||||
leaf or not. If the function is not specified, the default pytree registry will be used.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
|
||||
@ -378,12 +419,18 @@ def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
func,
|
||||
tree,
|
||||
*rests,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
|
||||
|
||||
def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
def tree_map_(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
|
||||
|
||||
See also :func:`tree_map`.
|
||||
@ -395,6 +442,11 @@ def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
argument to function ``func``.
|
||||
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
||||
``tree`` or has ``tree`` as a prefix.
|
||||
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
||||
flattening step. The function should have a single argument with signature
|
||||
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
||||
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
||||
leaf or not. If the function is not specified, the default pytree registry will be used.
|
||||
|
||||
Returns:
|
||||
The original ``tree`` with the value at each leaf is given by the side-effect of function
|
||||
@ -405,6 +457,7 @@ def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
func,
|
||||
tree,
|
||||
*rests,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
@ -482,6 +535,7 @@ def tree_map_only(
|
||||
__type_or_types: Type[T],
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
@ -491,6 +545,7 @@ def tree_map_only(
|
||||
__type_or_types: Type2[T, S],
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
@ -500,6 +555,7 @@ def tree_map_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
@ -508,8 +564,9 @@ def tree_map_only(
|
||||
__type_or_types: TypeAny,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map(map_only(__type_or_types)(func), tree)
|
||||
return tree_map(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@overload
|
||||
@ -517,6 +574,7 @@ def tree_map_only_(
|
||||
__type_or_types: Type[T],
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
@ -526,6 +584,7 @@ def tree_map_only_(
|
||||
__type_or_types: Type2[T, S],
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
@ -535,6 +594,7 @@ def tree_map_only_(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
...
|
||||
|
||||
@ -543,17 +603,26 @@ def tree_map_only_(
|
||||
__type_or_types: TypeAny,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map_(map_only(__type_or_types)(func), tree)
|
||||
return tree_map_(map_only(__type_or_types)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def tree_all(pred: Callable[[Any], bool], tree: PyTree) -> bool:
|
||||
flat_args = tree_leaves(tree)
|
||||
def tree_all(
|
||||
pred: Callable[[Any], bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
return all(map(pred, flat_args))
|
||||
|
||||
|
||||
def tree_any(pred: Callable[[Any], bool], tree: PyTree) -> bool:
|
||||
flat_args = tree_leaves(tree)
|
||||
def tree_any(
|
||||
pred: Callable[[Any], bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
return any(map(pred, flat_args))
|
||||
|
||||
|
||||
@ -562,6 +631,7 @@ def tree_all_only(
|
||||
__type_or_types: Type[T],
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
@ -571,6 +641,7 @@ def tree_all_only(
|
||||
__type_or_types: Type2[T, S],
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
@ -580,6 +651,7 @@ def tree_all_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
@ -588,8 +660,9 @@ def tree_all_only(
|
||||
__type_or_types: TypeAny,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree)
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
@ -598,6 +671,7 @@ def tree_any_only(
|
||||
__type_or_types: Type[T],
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
@ -607,6 +681,7 @@ def tree_any_only(
|
||||
__type_or_types: Type2[T, S],
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
@ -616,6 +691,7 @@ def tree_any_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
@ -624,12 +700,17 @@ def tree_any_only(
|
||||
__type_or_types: TypeAny,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree)
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
def broadcast_prefix(prefix_tree: PyTree, full_tree: PyTree) -> List[Any]:
|
||||
def broadcast_prefix(
|
||||
prefix_tree: PyTree,
|
||||
full_tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> List[Any]:
|
||||
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
|
||||
|
||||
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
|
||||
@ -655,10 +736,11 @@ def broadcast_prefix(prefix_tree: PyTree, full_tree: PyTree) -> List[Any]:
|
||||
Args:
|
||||
prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
|
||||
full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
|
||||
is_leaf (callable, optional): An optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, with :data:`True` stopping the traversal
|
||||
and the whole subtree being treated as a leaf, and :data:`False` indicating the
|
||||
flattening should traverse the current object.
|
||||
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
||||
flattening step. The function should have a single argument with signature
|
||||
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
||||
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
||||
leaf or not. If the function is not specified, the default pytree registry will be used.
|
||||
|
||||
Returns:
|
||||
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
|
||||
@ -666,6 +748,7 @@ def broadcast_prefix(prefix_tree: PyTree, full_tree: PyTree) -> List[Any]:
|
||||
return optree.broadcast_prefix(
|
||||
prefix_tree,
|
||||
full_tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
@ -679,11 +762,15 @@ def broadcast_prefix(prefix_tree: PyTree, full_tree: PyTree) -> List[Any]:
|
||||
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
|
||||
# broadcastable to the tree structure of `inputs` and we use
|
||||
# _broadcast_to_and_flatten to check this.
|
||||
def _broadcast_to_and_flatten(tree: PyTree, treespec: TreeSpec) -> Optional[List[Any]]:
|
||||
def _broadcast_to_and_flatten(
|
||||
tree: PyTree,
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[List[Any]]:
|
||||
assert isinstance(treespec, TreeSpec)
|
||||
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
try:
|
||||
return broadcast_prefix(tree, full_tree)
|
||||
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
Reference in New Issue
Block a user