mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] add tree_iter
function (#123913)
- Add a new `tree_iter` function. - Bump `optree` version to `0.11.0` for C++ version of `tree_iter`. This PR is split from #120300. - #120300 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123913 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
0eab740db3
commit
2e48f7b044
@ -56,6 +56,7 @@ __all__ = [
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
"tree_unflatten",
|
||||
"tree_iter",
|
||||
"tree_leaves",
|
||||
"tree_leaves_with_path",
|
||||
"tree_structure",
|
||||
@ -321,6 +322,41 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Iterable[Any]:
|
||||
"""Get an iterator over the leaves of a pytree.
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
|
||||
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
||||
>>> list(tree_iter(tree))
|
||||
[1, 2, 3, 4, None, 5]
|
||||
>>> list(tree_iter(1))
|
||||
[1]
|
||||
>>> list(tree_iter(None))
|
||||
[None]
|
||||
|
||||
Args:
|
||||
tree (pytree): A pytree to flatten.
|
||||
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
||||
flattening step. The function should have a single argument with signature
|
||||
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
||||
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
||||
leaf or not. If the function is not specified, the default pytree registry will be used.
|
||||
|
||||
Returns:
|
||||
An iterator over the leaf values.
|
||||
"""
|
||||
return optree.tree_iter(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
|
||||
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
@ -670,7 +706,7 @@ def tree_all(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(map(pred, flat_args))
|
||||
|
||||
|
||||
@ -679,7 +715,7 @@ def tree_any(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(map(pred, flat_args))
|
||||
|
||||
|
||||
@ -719,7 +755,7 @@ def tree_all_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
@ -759,7 +795,7 @@ def tree_any_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user