mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] add function tree_iter
(#120155)
Fixes #119768 - #119768 This PR adds a new function `tree_iter` that lazily iterates over the tree leaves. It is different than the `tree_leaves` function while the latter traversal the whole tree first to build a list of leaves. ```python for leaf in tree_iter(tree): ... ``` is much more efficient than: ```python for leaf in tree_leaves(tree): ... ``` where `tree_leaves(tree)` is `list(tree_iter(tree))`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120155 Approved by: https://github.com/vmoens
This commit is contained in:
committed by
PyTorch MergeBot
parent
61a3a7628c
commit
372d078f36
@ -54,6 +54,7 @@ __all__ = [
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
"tree_unflatten",
|
||||
"tree_iter",
|
||||
"tree_leaves",
|
||||
"tree_leaves_with_path",
|
||||
"tree_structure",
|
||||
@ -319,6 +320,43 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Iterable[Any]:
|
||||
"""Get an iterator over the leaves of a pytree.
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
|
||||
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
||||
>>> list(tree_iter(tree))
|
||||
[1, 2, 3, 4, None, 5]
|
||||
>>> list(tree_iter(1))
|
||||
[1]
|
||||
>>> list(tree_iter(None))
|
||||
[None]
|
||||
|
||||
Args:
|
||||
tree (pytree): A pytree to flatten.
|
||||
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
||||
flattening step. The function should have a single argument with signature
|
||||
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
||||
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
||||
leaf or not. If the function is not specified, the default pytree registry will be used.
|
||||
|
||||
Returns:
|
||||
An iterator over the leaf values.
|
||||
"""
|
||||
return iter(
|
||||
optree.tree_leaves(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
@ -626,7 +664,7 @@ def tree_all(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(map(pred, flat_args))
|
||||
|
||||
|
||||
@ -635,7 +673,7 @@ def tree_any(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(map(pred, flat_args))
|
||||
|
||||
|
||||
@ -675,7 +713,7 @@ def tree_all_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
@ -715,7 +753,7 @@ def tree_any_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
|
@ -63,6 +63,7 @@ __all__ = [
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
"tree_unflatten",
|
||||
"tree_iter",
|
||||
"tree_leaves",
|
||||
"tree_leaves_with_path",
|
||||
"tree_structure",
|
||||
@ -818,13 +819,13 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
|
||||
def _tree_leaves_helper(
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
leaves: List[Any],
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> None:
|
||||
) -> Iterable[Any]:
|
||||
"""Get an iterator over the leaves of a pytree."""
|
||||
if _is_leaf(tree, is_leaf=is_leaf):
|
||||
leaves.append(tree)
|
||||
yield tree
|
||||
return
|
||||
|
||||
node_type = _get_node_type(tree)
|
||||
@ -833,7 +834,7 @@ def _tree_leaves_helper(
|
||||
|
||||
# Recursively flatten the children
|
||||
for child in child_pytrees:
|
||||
_tree_leaves_helper(child, leaves, is_leaf=is_leaf)
|
||||
yield from tree_iter(child, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def tree_leaves(
|
||||
@ -841,9 +842,7 @@ def tree_leaves(
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> List[Any]:
|
||||
"""Get a list of leaves of a pytree."""
|
||||
leaves: List[Any] = []
|
||||
_tree_leaves_helper(tree, leaves, is_leaf=is_leaf)
|
||||
return leaves
|
||||
return list(tree_iter(tree, is_leaf=is_leaf))
|
||||
|
||||
|
||||
def tree_structure(
|
||||
@ -1082,7 +1081,7 @@ def tree_all(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(map(pred, flat_args))
|
||||
|
||||
|
||||
@ -1091,7 +1090,7 @@ def tree_any(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(map(pred, flat_args))
|
||||
|
||||
|
||||
@ -1131,7 +1130,7 @@ def tree_all_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
@ -1171,7 +1170,7 @@ def tree_any_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
@ -1379,9 +1378,9 @@ def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]:
|
||||
"""
|
||||
leaves: List[Any] = []
|
||||
for a in args:
|
||||
_tree_leaves_helper(a, leaves)
|
||||
leaves.extend(tree_iter(a))
|
||||
for a in kwargs.values():
|
||||
_tree_leaves_helper(a, leaves)
|
||||
leaves.extend(tree_iter(a))
|
||||
return leaves
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user