mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[pytree] add function tree_iter
(#120155)"
This reverts commit 372d078f361e726bb4ac0884ac334b04c58179ef. Reverted https://github.com/pytorch/pytorch/pull/120155 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/120155#issuecomment-1955479765))
This commit is contained in:
@ -54,7 +54,6 @@ __all__ = [
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
"tree_unflatten",
|
||||
"tree_iter",
|
||||
"tree_leaves",
|
||||
"tree_leaves_with_path",
|
||||
"tree_structure",
|
||||
@ -320,43 +319,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Iterable[Any]:
|
||||
"""Get an iterator over the leaves of a pytree.
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
|
||||
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
||||
>>> list(tree_iter(tree))
|
||||
[1, 2, 3, 4, None, 5]
|
||||
>>> list(tree_iter(1))
|
||||
[1]
|
||||
>>> list(tree_iter(None))
|
||||
[None]
|
||||
|
||||
Args:
|
||||
tree (pytree): A pytree to flatten.
|
||||
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
|
||||
flattening step. The function should have a single argument with signature
|
||||
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
|
||||
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
|
||||
leaf or not. If the function is not specified, the default pytree registry will be used.
|
||||
|
||||
Returns:
|
||||
An iterator over the leaf values.
|
||||
"""
|
||||
return iter(
|
||||
optree.tree_leaves(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
@ -664,7 +626,7 @@ def tree_all(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
return all(map(pred, flat_args))
|
||||
|
||||
|
||||
@ -673,7 +635,7 @@ def tree_any(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
return any(map(pred, flat_args))
|
||||
|
||||
|
||||
@ -713,7 +675,7 @@ def tree_all_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
@ -753,7 +715,7 @@ def tree_any_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
flat_args = tree_leaves(tree, is_leaf=is_leaf)
|
||||
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user