[pytree] add tree_iter function (#123913)

- Add a new `tree_iter` function.
- Bump `optree` version to `0.11.0` for C++ version of `tree_iter`.

This PR is split from #120300.

- #120300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123913
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2024-04-13 13:22:48 +00:00
committed by PyTorch MergeBot
parent 0eab740db3
commit 2e48f7b044
8 changed files with 66 additions and 32 deletions

View File

@ -56,6 +56,7 @@ __all__ = [
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_leaves_with_path",
"tree_structure",
@ -321,6 +322,41 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
def tree_iter(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree.
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> list(tree_iter(tree))
[1, 2, 3, 4, None, 5]
>>> list(tree_iter(1))
[1]
>>> list(tree_iter(None))
[None]
Args:
tree (pytree): A pytree to flatten.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
An iterator over the leaf values.
"""
return optree.tree_iter(
tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -670,7 +706,7 @@ def tree_all(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))
@ -679,7 +715,7 @@ def tree_any(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))
@ -719,7 +755,7 @@ def tree_all_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@ -759,7 +795,7 @@ def tree_any_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))