[pytree] expand tree_map to accept multi-inputs (#115642)

Fixes #115419
Fixes #91323
Closes #115549

- #115419
- #91323

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115642
Approved by: https://github.com/vmoens, https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2023-12-14 06:16:39 +00:00
committed by PyTorch MergeBot
parent 7e1542b938
commit 36c6c0c7dc
3 changed files with 241 additions and 47 deletions

View File

@ -345,8 +345,8 @@ def tree_structure(tree: PyTree) -> TreeSpec:
)
def tree_map(func: Callable[..., Any], tree: PyTree) -> PyTree:
"""Map a function over leaves in a pytree to produce a new pytree.
def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
"""Map a multi-input function over pytree args to produce a new pytree.
See also :func:`tree_map_`.
@ -355,43 +355,56 @@ def tree_map(func: Callable[..., Any], tree: PyTree) -> PyTree:
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
{'x': False, 'y': (False, False), 'z': True}
If multiple inputs are given, the structure of the tree is taken from the first input;
subsequent inputs need only have ``tree`` as a prefix:
>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
[[5, 7, 9], [6, 1, 2]]
Args:
func (callable): A function that takes a single argument, to be applied at the corresponding
leaves of the pytree.
tree (pytree): A pytree to be mapped over, with each leaf providing the argument to function
``func``.
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
argument to function ``func``.
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
``tree`` or has ``tree`` as a prefix.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
``func(x)`` where ``x`` is the value at the corresponding leaf in ``tree``.
``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
is the tuple of values at corresponding nodes in ``rests``.
"""
return optree.tree_map(
func,
tree,
*rests,
none_is_leaf=True,
namespace="torch",
)
def tree_map_(func: Callable[..., Any], tree: PyTree) -> PyTree:
def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
See also :func:`tree_map`.
Args:
func (callable): A function that takes a single argument, to be applied at the corresponding
leaves of the pytree.
tree (pytree): A pytree to be mapped over, with each leaf providing the argument to function
``func``.
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
argument to function ``func``.
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
``tree`` or has ``tree`` as a prefix.
Returns:
The original ``tree`` with the value at each leaf is given by the side-effect of function
``func(x)`` (not the return value) where ``x`` is the value at the corresponding leaf in
``tree``.
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
"""
return optree.tree_map_(
func,
tree,
*rests,
none_is_leaf=True,
namespace="torch",
)