mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] expand tree_map
to accept multi-inputs (#115642)
Fixes #115419 Fixes #91323 Closes #115549 - #115419 - #91323 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115642 Approved by: https://github.com/vmoens, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
7e1542b938
commit
36c6c0c7dc
@ -345,8 +345,8 @@ def tree_structure(tree: PyTree) -> TreeSpec:
|
||||
)
|
||||
|
||||
|
||||
def tree_map(func: Callable[..., Any], tree: PyTree) -> PyTree:
|
||||
"""Map a function over leaves in a pytree to produce a new pytree.
|
||||
def tree_map(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
"""Map a multi-input function over pytree args to produce a new pytree.
|
||||
|
||||
See also :func:`tree_map_`.
|
||||
|
||||
@ -355,43 +355,56 @@ def tree_map(func: Callable[..., Any], tree: PyTree) -> PyTree:
|
||||
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
|
||||
{'x': False, 'y': (False, False), 'z': True}
|
||||
|
||||
If multiple inputs are given, the structure of the tree is taken from the first input;
|
||||
subsequent inputs need only have ``tree`` as a prefix:
|
||||
|
||||
>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
|
||||
[[5, 7, 9], [6, 1, 2]]
|
||||
|
||||
Args:
|
||||
func (callable): A function that takes a single argument, to be applied at the corresponding
|
||||
leaves of the pytree.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the argument to function
|
||||
``func``.
|
||||
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
||||
corresponding leaves of the pytrees.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
||||
argument to function ``func``.
|
||||
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
||||
``tree`` or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
|
||||
``func(x)`` where ``x`` is the value at the corresponding leaf in ``tree``.
|
||||
``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
|
||||
is the tuple of values at corresponding nodes in ``rests``.
|
||||
"""
|
||||
return optree.tree_map(
|
||||
func,
|
||||
tree,
|
||||
*rests,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
|
||||
|
||||
def tree_map_(func: Callable[..., Any], tree: PyTree) -> PyTree:
|
||||
def tree_map_(func: Callable[..., Any], tree: PyTree, *rests: PyTree) -> PyTree:
|
||||
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
|
||||
|
||||
See also :func:`tree_map`.
|
||||
|
||||
Args:
|
||||
func (callable): A function that takes a single argument, to be applied at the corresponding
|
||||
leaves of the pytree.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the argument to function
|
||||
``func``.
|
||||
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
|
||||
corresponding leaves of the pytrees.
|
||||
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
|
||||
argument to function ``func``.
|
||||
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
|
||||
``tree`` or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
The original ``tree`` with the value at each leaf is given by the side-effect of function
|
||||
``func(x)`` (not the return value) where ``x`` is the value at the corresponding leaf in
|
||||
``tree``.
|
||||
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
|
||||
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
|
||||
"""
|
||||
return optree.tree_map_(
|
||||
func,
|
||||
tree,
|
||||
*rests,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
)
|
||||
|
Reference in New Issue
Block a user