[pytree] add tree_iter function (#123913)

- Add a new `tree_iter` function.
- Bump `optree` version to `0.11.0` for C++ version of `tree_iter`.

This PR is split from #120300.

- #120300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123913
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2024-04-13 13:22:48 +00:00
committed by PyTorch MergeBot
parent 0eab740db3
commit 2e48f7b044
8 changed files with 66 additions and 32 deletions

View File

@ -134,9 +134,9 @@ opt-einsum==3.3
#Pinned versions: 3.3
#test that import: test_linalg.py
optree==0.9.1
optree==0.11.0
#Description: A library for tree manipulation
#Pinned versions: 0.9.1
#Pinned versions: 0.11.0
#test that import: test_vmap.py, test_aotdispatch.py, test_dynamic_shapes.py,
#test_pytree.py, test_ops.py, test_control_flow.py, test_modules.py,
#common_utils.py, test_eager_transforms.py, test_python_dispatch.py,

View File

@ -1,4 +1,4 @@
# iOS simulator requirements
coremltools==5.0b5
protobuf==3.20.2
optree==0.9.1
optree==0.11.0

View File

@ -26,7 +26,7 @@ pytest-cpp==2.3.0
rockset==1.0.3
z3-solver==4.12.2.0
tensorboard==2.13.0
optree==0.9.1
optree==0.11.0
# NB: test_hparams_* from test_tensorboard is failing with protobuf 5.26.0 in
# which the stringify metadata is wrong when escaping double quote
protobuf==3.20.2

View File

@ -148,7 +148,7 @@ init_command = [
'junitparser==2.1.1',
'rich==10.9.0',
'pyyaml==6.0.1',
'optree==0.10.0',
'optree==0.11.0',
]
[[linter]]

View File

@ -17,4 +17,4 @@ fsspec
# setuptools was removed from default python install
setuptools ; python_version >= "3.12"
packaging
optree>=0.9.1
optree>=0.11.0

View File

@ -1169,7 +1169,7 @@ def main():
install_requires += extra_install_requires
extras_require = {
"optree": ["optree>=0.9.1"],
"optree": ["optree>=0.11.0"],
"opt-einsum": ["opt-einsum>=3.3"],
}

View File

@ -56,6 +56,7 @@ __all__ = [
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_leaves_with_path",
"tree_structure",
@ -321,6 +322,41 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
def tree_iter(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree.
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> list(tree_iter(tree))
[1, 2, 3, 4, None, 5]
>>> list(tree_iter(1))
[1]
>>> list(tree_iter(None))
[None]
Args:
tree (pytree): A pytree to flatten.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
An iterator over the leaf values.
"""
return optree.tree_iter(
tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -670,7 +706,7 @@ def tree_all(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))
@ -679,7 +715,7 @@ def tree_any(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))
@ -719,7 +755,7 @@ def tree_all_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@ -759,7 +795,7 @@ def tree_any_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))

View File

@ -66,6 +66,7 @@ __all__ = [
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_leaves_with_path",
"tree_structure",
@ -865,22 +866,21 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
return treespec.unflatten(leaves)
def _tree_leaves_helper(
def tree_iter(
tree: PyTree,
leaves: List[Any],
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> None:
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree."""
if _is_leaf(tree, is_leaf=is_leaf):
leaves.append(tree)
return
yield tree
else:
node_type = _get_node_type(tree)
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, _ = flatten_fn(tree)
# Recursively flatten the children
for child in child_pytrees:
_tree_leaves_helper(child, leaves, is_leaf=is_leaf)
yield from tree_iter(child, is_leaf=is_leaf)
def tree_leaves(
@ -888,9 +888,7 @@ def tree_leaves(
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> List[Any]:
"""Get a list of leaves of a pytree."""
leaves: List[Any] = []
_tree_leaves_helper(tree, leaves, is_leaf=is_leaf)
return leaves
return list(tree_iter(tree, is_leaf=is_leaf))
def tree_structure(
@ -1171,7 +1169,7 @@ def tree_all(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))
@ -1180,7 +1178,7 @@ def tree_any(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))
@ -1220,7 +1218,7 @@ def tree_all_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@ -1260,7 +1258,7 @@ def tree_any_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_leaves(tree, is_leaf=is_leaf)
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@ -1468,9 +1466,9 @@ def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]:
"""
leaves: List[Any] = []
for a in args:
_tree_leaves_helper(a, leaves)
leaves.extend(tree_iter(a))
for a in kwargs.values():
_tree_leaves_helper(a, leaves)
leaves.extend(tree_iter(a))
return leaves