Revert "[pytree] add function tree_iter (#120155)"

This reverts commit 372d078f361e726bb4ac0884ac334b04c58179ef.

Reverted https://github.com/pytorch/pytorch/pull/120155 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/120155#issuecomment-1955479765))
This commit is contained in:
PyTorch MergeBot
2024-02-21 00:21:28 +00:00
parent 701f651f9c
commit a1fc29cd78
2 changed files with 18 additions and 55 deletions

View File

@ -54,7 +54,6 @@ __all__ = [
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
"tree_iter",
"tree_leaves",
"tree_leaves_with_path",
"tree_structure",
@ -320,43 +319,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
def tree_iter(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree.
See also :func:`tree_flatten`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> list(tree_iter(tree))
[1, 2, 3, 4, None, 5]
>>> list(tree_iter(1))
[1]
>>> list(tree_iter(None))
[None]
Args:
tree (pytree): A pytree to flatten.
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
flattening step. The function should have a single argument with signature
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
leaf or not. If the function is not specified, the default pytree registry will be used.
Returns:
An iterator over the leaf values.
"""
return iter(
optree.tree_leaves(
tree,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
),
)
def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -664,7 +626,7 @@ def tree_all(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
flat_args = tree_leaves(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))
@ -673,7 +635,7 @@ def tree_any(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
flat_args = tree_leaves(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))
@ -713,7 +675,7 @@ def tree_all_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
flat_args = tree_leaves(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
@ -753,7 +715,7 @@ def tree_any_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
flat_args = tree_leaves(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))