[pytree] add function tree_iter (#120155)

Fixes #119768

- #119768

This PR adds a new function `tree_iter` that lazily iterates over the tree leaves. It is different than the `tree_leaves` function while the latter traversal the whole tree first to build a list of leaves.

```python
for leaf in tree_iter(tree):
    ...
```

is much more efficient than:

```python
for leaf in tree_leaves(tree):
    ...
```

where `tree_leaves(tree)` is `list(tree_iter(tree))`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120155
Approved by: https://github.com/vmoens
This commit is contained in:
Xuehai Pan
2024-02-18 09:16:50 +00:00
committed by PyTorch MergeBot
parent 61a3a7628c
commit 372d078f36
2 changed files with 55 additions and 18 deletions

View File

@ -54,6 +54,7 @@ __all__ = [
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_leaves_with_path",
"tree_structure",
@ -319,6 +320,43 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
def tree_iter(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree.
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> list(tree_iter(tree))
[1, 2, 3, 4, None, 5]
>>> list(tree_iter(1))
[1]
>>> list(tree_iter(None))
[None]
Args:
tree (pytree): A pytree to flatten.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
An iterator over the leaf values.
"""
return iter(
optree.tree_leaves(
tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
),
)
def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -626,7 +664,7 @@ def tree_all(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))
@ -635,7 +673,7 @@ def tree_any(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))
@ -675,7 +713,7 @@ def tree_all_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@ -715,7 +753,7 @@ def tree_any_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))

View File

@ -63,6 +63,7 @@ __all__ = [
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_leaves_with_path",
"tree_structure",
@ -818,13 +819,13 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
return treespec.unflatten(leaves)
def _tree_leaves_helper(
def tree_iter(
tree: PyTree,
leaves: List[Any],
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> None:
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree."""
if _is_leaf(tree, is_leaf=is_leaf):
leaves.append(tree)
yield tree
return
node_type = _get_node_type(tree)
@ -833,7 +834,7 @@ def _tree_leaves_helper(
# Recursively flatten the children
for child in child_pytrees:
_tree_leaves_helper(child, leaves, is_leaf=is_leaf)
yield from tree_iter(child, is_leaf=is_leaf)
def tree_leaves(
@ -841,9 +842,7 @@ def tree_leaves(
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> List[Any]:
"""Get a list of leaves of a pytree."""
leaves: List[Any] = []
_tree_leaves_helper(tree, leaves, is_leaf=is_leaf)
return leaves
return list(tree_iter(tree, is_leaf=is_leaf))
def tree_structure(
@ -1082,7 +1081,7 @@ def tree_all(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))
@ -1091,7 +1090,7 @@ def tree_any(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))
@ -1131,7 +1130,7 @@ def tree_all_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@ -1171,7 +1170,7 @@ def tree_any_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@ -1379,9 +1378,9 @@ def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]:
"""
leaves: List[Any] = []
for a in args:
_tree_leaves_helper(a, leaves)
leaves.extend(tree_iter(a))
for a in kwargs.values():
_tree_leaves_helper(a, leaves)
leaves.extend(tree_iter(a))
return leaves