diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 0b5a7d4ff1eb..a636d70e326e 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -134,9 +134,9 @@ opt-einsum==3.3 #Pinned versions: 3.3 #test that import: test_linalg.py -optree==0.9.1 +optree==0.11.0 #Description: A library for tree manipulation -#Pinned versions: 0.9.1 +#Pinned versions: 0.11.0 #test that import: test_vmap.py, test_aotdispatch.py, test_dynamic_shapes.py, #test_pytree.py, test_ops.py, test_control_flow.py, test_modules.py, #common_utils.py, test_eager_transforms.py, test_python_dispatch.py, diff --git a/.github/requirements/pip-requirements-iOS.txt b/.github/requirements/pip-requirements-iOS.txt index 30e67abc5c86..01290e4c7102 100644 --- a/.github/requirements/pip-requirements-iOS.txt +++ b/.github/requirements/pip-requirements-iOS.txt @@ -1,4 +1,4 @@ # iOS simulator requirements coremltools==5.0b5 protobuf==3.20.2 -optree==0.9.1 +optree==0.11.0 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 35d24ae34f82..f0e4890328b3 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -26,7 +26,7 @@ pytest-cpp==2.3.0 rockset==1.0.3 z3-solver==4.12.2.0 tensorboard==2.13.0 -optree==0.9.1 +optree==0.11.0 # NB: test_hparams_* from test_tensorboard is failing with protobuf 5.26.0 in # which the stringify metadata is wrong when escaping double quote protobuf==3.20.2 diff --git a/.lintrunner.toml b/.lintrunner.toml index f30812313cfc..9e83a8b96e19 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -148,7 +148,7 @@ init_command = [ 'junitparser==2.1.1', 'rich==10.9.0', 'pyyaml==6.0.1', - 'optree==0.10.0', + 'optree==0.11.0', ] [[linter]] diff --git a/requirements.txt b/requirements.txt index 51fd003805fa..a32fe66cb402 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,4 @@ fsspec # setuptools was removed from default python install setuptools ; python_version >= "3.12" packaging -optree>=0.9.1 +optree>=0.11.0 diff --git a/setup.py b/setup.py index ce8f16df7704..d774446780b4 100644 --- a/setup.py +++ b/setup.py @@ -1169,7 +1169,7 @@ def main(): install_requires += extra_install_requires extras_require = { - "optree": ["optree>=0.9.1"], + "optree": ["optree>=0.11.0"], "opt-einsum": ["opt-einsum>=3.3"], } diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 93605d3b0ba8..aba15f1482f2 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -56,6 +56,7 @@ __all__ = [ "tree_flatten", "tree_flatten_with_path", "tree_unflatten", + "tree_iter", "tree_leaves", "tree_leaves_with_path", "tree_structure", @@ -321,6 +322,41 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type] +def tree_iter( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Any]: + """Get an iterator over the leaves of a pytree. + + See also :func:`tree_flatten`. + + >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} + >>> list(tree_iter(tree)) + [1, 2, 3, 4, None, 5] + >>> list(tree_iter(1)) + [1] + >>> list(tree_iter(None)) + [None] + + Args: + tree (pytree): A pytree to flatten. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + An iterator over the leaf values. + """ + return optree.tree_iter( + tree, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + + def tree_leaves( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, @@ -670,7 +706,7 @@ def tree_all( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> bool: - flat_args = tree_leaves(tree, is_leaf=is_leaf) + flat_args = tree_iter(tree, is_leaf=is_leaf) return all(map(pred, flat_args)) @@ -679,7 +715,7 @@ def tree_any( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> bool: - flat_args = tree_leaves(tree, is_leaf=is_leaf) + flat_args = tree_iter(tree, is_leaf=is_leaf) return any(map(pred, flat_args)) @@ -719,7 +755,7 @@ def tree_all_only( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> bool: - flat_args = tree_leaves(tree, is_leaf=is_leaf) + flat_args = tree_iter(tree, is_leaf=is_leaf) return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) @@ -759,7 +795,7 @@ def tree_any_only( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> bool: - flat_args = tree_leaves(tree, is_leaf=is_leaf) + flat_args = tree_iter(tree, is_leaf=is_leaf) return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 77f93819b557..52f0d65ded0b 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -66,6 +66,7 @@ __all__ = [ "tree_flatten", "tree_flatten_with_path", "tree_unflatten", + "tree_iter", "tree_leaves", "tree_leaves_with_path", "tree_structure", @@ -865,22 +866,21 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: return treespec.unflatten(leaves) -def _tree_leaves_helper( +def tree_iter( tree: PyTree, - leaves: List[Any], is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> None: +) -> Iterable[Any]: + """Get an iterator over the leaves of a pytree.""" if _is_leaf(tree, is_leaf=is_leaf): - leaves.append(tree) - return + yield tree + else: + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, _ = flatten_fn(tree) - node_type = _get_node_type(tree) - flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - child_pytrees, _ = flatten_fn(tree) - - # Recursively flatten the children - for child in child_pytrees: - _tree_leaves_helper(child, leaves, is_leaf=is_leaf) + # Recursively flatten the children + for child in child_pytrees: + yield from tree_iter(child, is_leaf=is_leaf) def tree_leaves( @@ -888,9 +888,7 @@ def tree_leaves( is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> List[Any]: """Get a list of leaves of a pytree.""" - leaves: List[Any] = [] - _tree_leaves_helper(tree, leaves, is_leaf=is_leaf) - return leaves + return list(tree_iter(tree, is_leaf=is_leaf)) def tree_structure( @@ -1171,7 +1169,7 @@ def tree_all( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> bool: - flat_args = tree_leaves(tree, is_leaf=is_leaf) + flat_args = tree_iter(tree, is_leaf=is_leaf) return all(map(pred, flat_args)) @@ -1180,7 +1178,7 @@ def tree_any( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> bool: - flat_args = tree_leaves(tree, is_leaf=is_leaf) + flat_args = tree_iter(tree, is_leaf=is_leaf) return any(map(pred, flat_args)) @@ -1220,7 +1218,7 @@ def tree_all_only( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> bool: - flat_args = tree_leaves(tree, is_leaf=is_leaf) + flat_args = tree_iter(tree, is_leaf=is_leaf) return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) @@ -1260,7 +1258,7 @@ def tree_any_only( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> bool: - flat_args = tree_leaves(tree, is_leaf=is_leaf) + flat_args = tree_iter(tree, is_leaf=is_leaf) return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) @@ -1468,9 +1466,9 @@ def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]: """ leaves: List[Any] = [] for a in args: - _tree_leaves_helper(a, leaves) + leaves.extend(tree_iter(a)) for a in kwargs.values(): - _tree_leaves_helper(a, leaves) + leaves.extend(tree_iter(a)) return leaves