mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch.func] alias torch.func.vmap as torch.vmap (#91026)
This PR also redirects torch.vmap to torch.func.vmap instead of the old vmap prototype. Test Plan: - tests - view docs preview Pull Request resolved: https://github.com/pytorch/pytorch/pull/91026 Approved by: https://github.com/albanD, https://github.com/samdow
This commit is contained in:
committed by
PyTorch MergeBot
parent
e803d336eb
commit
fb2e1878cb
@ -72,10 +72,11 @@ static void warnFallback(const c10::FunctionSchema& schema) {
|
|||||||
}
|
}
|
||||||
TORCH_WARN("There is a performance drop because we have not yet implemented ",
|
TORCH_WARN("There is a performance drop because we have not yet implemented ",
|
||||||
"the batching rule for ", schema.operator_name(), ". ",
|
"the batching rule for ", schema.operator_name(), ". ",
|
||||||
"We've moved development of vmap to to functorch "
|
"You are using the legacy vmap prototype (torch._vmap_internals.vmap). ",
|
||||||
"(https://github.com/pytorch/functorch), please try functorch.vmap "
|
"If you are using torch.autograd.functional.{jacobian, hessian} ",
|
||||||
"instead and/or file ",
|
"or torch._vmap_internals.vmap: please switch to using ",
|
||||||
" an issue on GitHub so that we can prioritize its implementation.");
|
"torch.func.{jacrev, jacfwd, hessian} and/or torch.vmap instead ",
|
||||||
|
"for better operator coverage and performance improvements .");
|
||||||
}
|
}
|
||||||
|
|
||||||
// The general flow of the algorithm is as follows.
|
// The general flow of the algorithm is as follows.
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, vmap
|
from torch import Tensor
|
||||||
|
from torch._vmap_internals import vmap
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -48,7 +48,7 @@ __all__ = [
|
|||||||
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
|
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
|
||||||
'set_float32_matmul_precision', 'get_float32_matmul_precision',
|
'set_float32_matmul_precision', 'get_float32_matmul_precision',
|
||||||
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||||
'compile',
|
'compile', 'vmap',
|
||||||
]
|
]
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
@ -1116,8 +1116,6 @@ del register_after_fork
|
|||||||
# torch.jit.script as a decorator, for instance):
|
# torch.jit.script as a decorator, for instance):
|
||||||
from ._lobpcg import lobpcg as lobpcg
|
from ._lobpcg import lobpcg as lobpcg
|
||||||
|
|
||||||
from ._vmap_internals import vmap as vmap
|
|
||||||
|
|
||||||
# These were previously defined in native_functions.yaml and appeared on the
|
# These were previously defined in native_functions.yaml and appeared on the
|
||||||
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
|
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
|
||||||
# class usage. We add these lines here to preserve backward compatibility.
|
# class usage. We add these lines here to preserve backward compatibility.
|
||||||
@ -1245,3 +1243,4 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
|
|||||||
import torch.fx.experimental.symbolic_shapes
|
import torch.fx.experimental.symbolic_shapes
|
||||||
|
|
||||||
from torch import func as func
|
from torch import func as func
|
||||||
|
from torch.func import vmap
|
||||||
|
@ -1259,8 +1259,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
|
|||||||
|
|
||||||
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
|
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
|
||||||
|
|
||||||
>>> from torch.func import grad
|
>>> from torch.func import grad, vmap
|
||||||
>>> from torch.func import vmap
|
|
||||||
>>> batch_size, feature_size = 3, 5
|
>>> batch_size, feature_size = 3, 5
|
||||||
>>>
|
>>>
|
||||||
>>> def model(weights, feature_vec):
|
>>> def model(weights, feature_vec):
|
||||||
|
@ -255,6 +255,10 @@ def vmap(
|
|||||||
take batches of examples with ``vmap(func)``. vmap can also be used to
|
take batches of examples with ``vmap(func)``. vmap can also be used to
|
||||||
compute batched gradients when composed with autograd.
|
compute batched gradients when composed with autograd.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
:func:`torch.vmap` is aliased to :func:`torch.func.vmap` for
|
||||||
|
convenience. Use whichever one you'd like.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (function): A Python function that takes one or more arguments.
|
func (function): A Python function that takes one or more arguments.
|
||||||
Must return one or more Tensors.
|
Must return one or more Tensors.
|
||||||
@ -308,7 +312,7 @@ def vmap(
|
|||||||
>>> return feature_vec.dot(weights).relu()
|
>>> return feature_vec.dot(weights).relu()
|
||||||
>>>
|
>>>
|
||||||
>>> examples = torch.randn(batch_size, feature_size)
|
>>> examples = torch.randn(batch_size, feature_size)
|
||||||
>>> result = torch.func.vmap(model)(examples)
|
>>> result = torch.vmap(model)(examples)
|
||||||
|
|
||||||
:func:`vmap` can also help vectorize computations that were previously difficult
|
:func:`vmap` can also help vectorize computations that were previously difficult
|
||||||
or impossible to batch. One example is higher-order gradient computation.
|
or impossible to batch. One example is higher-order gradient computation.
|
||||||
@ -333,12 +337,12 @@ def vmap(
|
|||||||
>>> # vectorized gradient computation
|
>>> # vectorized gradient computation
|
||||||
>>> def get_vjp(v):
|
>>> def get_vjp(v):
|
||||||
>>> return torch.autograd.grad(y, x, v)
|
>>> return torch.autograd.grad(y, x, v)
|
||||||
>>> jacobian = torch.func.vmap(get_vjp)(I_N)
|
>>> jacobian = torch.vmap(get_vjp)(I_N)
|
||||||
|
|
||||||
:func:`vmap` can also be nested, producing an output with multiple batched dimensions
|
:func:`vmap` can also be nested, producing an output with multiple batched dimensions
|
||||||
|
|
||||||
>>> torch.dot # [D], [D] -> []
|
>>> torch.dot # [D], [D] -> []
|
||||||
>>> batched_dot = torch.func.vmap(torch.func.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
|
>>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
|
||||||
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
|
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
|
||||||
>>> batched_dot(x, y) # tensor of size [2, 3]
|
>>> batched_dot(x, y) # tensor of size [2, 3]
|
||||||
|
|
||||||
@ -346,7 +350,7 @@ def vmap(
|
|||||||
the dimension that each inputs are batched along as
|
the dimension that each inputs are batched along as
|
||||||
|
|
||||||
>>> torch.dot # [N], [N] -> []
|
>>> torch.dot # [N], [N] -> []
|
||||||
>>> batched_dot = torch.func.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
|
>>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
|
||||||
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
|
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
|
||||||
>>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
|
>>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
|
||||||
|
|
||||||
@ -354,7 +358,7 @@ def vmap(
|
|||||||
``in_dims`` must be a tuple with the batch dimension for each input as
|
``in_dims`` must be a tuple with the batch dimension for each input as
|
||||||
|
|
||||||
>>> torch.dot # [D], [D] -> []
|
>>> torch.dot # [D], [D] -> []
|
||||||
>>> batched_dot = torch.func.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
|
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
|
||||||
>>> x, y = torch.randn(2, 5), torch.randn(5)
|
>>> x, y = torch.randn(2, 5), torch.randn(5)
|
||||||
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
|
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
|
||||||
|
|
||||||
@ -364,7 +368,7 @@ def vmap(
|
|||||||
>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
|
>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
|
||||||
>>> x, y = torch.randn(2, 5), torch.randn(5)
|
>>> x, y = torch.randn(2, 5), torch.randn(5)
|
||||||
>>> input = {'x': x, 'y': y}
|
>>> input = {'x': x, 'y': y}
|
||||||
>>> batched_dot = torch.func.vmap(f, in_dims=({'x': 0, 'y': None},))
|
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
|
||||||
>>> batched_dot(input)
|
>>> batched_dot(input)
|
||||||
|
|
||||||
By default, the output is batched along the first dimension. However, it can be batched
|
By default, the output is batched along the first dimension. However, it can be batched
|
||||||
@ -372,17 +376,17 @@ def vmap(
|
|||||||
|
|
||||||
>>> f = lambda x: x ** 2
|
>>> f = lambda x: x ** 2
|
||||||
>>> x = torch.randn(2, 5)
|
>>> x = torch.randn(2, 5)
|
||||||
>>> batched_pow = torch.func.vmap(f, out_dims=1)
|
>>> batched_pow = torch.vmap(f, out_dims=1)
|
||||||
>>> batched_pow(x) # [5, 2]
|
>>> batched_pow(x) # [5, 2]
|
||||||
|
|
||||||
For any function that uses kwargs, the returned function will not batch the kwargs but will
|
For any function that uses kwargs, the returned function will not batch the kwargs but will
|
||||||
accept kwargs
|
accept kwargs
|
||||||
|
|
||||||
>>> x = torch.randn([2, 5])
|
>>> x = torch.randn([2, 5])
|
||||||
>>> def f(x, scale=4.):
|
>>> def fn(x, scale=4.):
|
||||||
>>> return x * scale
|
>>> return x * scale
|
||||||
>>>
|
>>>
|
||||||
>>> batched_pow = torch.func.vmap(f)
|
>>> batched_pow = torch.vmap(fn)
|
||||||
>>> assert torch.allclose(batched_pow(x), x * 4)
|
>>> assert torch.allclose(batched_pow(x), x * 4)
|
||||||
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
|
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
|
||||||
|
|
||||||
|
@ -191,111 +191,10 @@ def _get_name(func: Callable):
|
|||||||
# on BatchedTensors perform the batched operations that the user is asking for.
|
# on BatchedTensors perform the batched operations that the user is asking for.
|
||||||
def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
|
def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
|
||||||
"""
|
"""
|
||||||
vmap is the vectorizing map. Returns a new function that maps `func` over some
|
Please use torch.vmap instead of this API.
|
||||||
dimension of the inputs. Semantically, vmap pushes the map into PyTorch
|
|
||||||
operations called by `func`, effectively vectorizing those operations.
|
|
||||||
|
|
||||||
vmap is useful for handling batch dimensions: one can write a function `func`
|
|
||||||
that runs on examples and then lift it to a function that can take batches of
|
|
||||||
examples with `vmap(func)`. vmap can also be used to compute batched
|
|
||||||
gradients when composed with autograd.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
We have moved development of vmap to
|
|
||||||
`functorch. <https://github.com/pytorch/functorch>`_ functorch's
|
|
||||||
vmap is able to arbitrarily compose with gradient computation
|
|
||||||
and contains significant performance improvements.
|
|
||||||
Please give that a try if that is what you're looking for.
|
|
||||||
|
|
||||||
Furthermore, if you're interested in using vmap for your use case,
|
|
||||||
please `contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
|
|
||||||
We're interested in gathering feedback from early adopters to inform
|
|
||||||
the design.
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
torch.vmap is an experimental prototype that is subject to
|
|
||||||
change and/or deletion. Please use at your own risk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func (function): A Python function that takes one or more arguments.
|
|
||||||
Must return one or more Tensors.
|
|
||||||
in_dims (int or nested structure): Specifies which dimension of the
|
|
||||||
inputs should be mapped over. `in_dims` should have a structure
|
|
||||||
like the inputs. If the `in_dim` for a particular input is None,
|
|
||||||
then that indicates there is no map dimension. Default: 0.
|
|
||||||
out_dims (int or Tuple[int]): Specifies where the mapped dimension
|
|
||||||
should appear in the outputs. If `out_dims` is a Tuple, then it should
|
|
||||||
have one element per output. Default: 0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Returns a new "batched" function. It takes the same inputs as `func`,
|
|
||||||
except each input has an extra dimension at the index specified by `in_dims`.
|
|
||||||
It takes returns the same outputs as `func`, except each output has
|
|
||||||
an extra dimension at the index specified by `out_dims`.
|
|
||||||
|
|
||||||
.. warning:
|
|
||||||
vmap works best with functional-style code. Please do not perform any
|
|
||||||
side-effects in `func`, with the exception of in-place PyTorch operations.
|
|
||||||
Examples of side-effects include mutating Python data structures and
|
|
||||||
assigning values to variables not captured in `func`.
|
|
||||||
|
|
||||||
One example of using `vmap` is to compute batched dot products. PyTorch
|
|
||||||
doesn't provide a batched `torch.dot` API; instead of unsuccessfully
|
|
||||||
rummaging through docs, use `vmap` to construct a new function.
|
|
||||||
|
|
||||||
>>> torch.dot # [D], [D] -> []
|
|
||||||
>>> batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
|
|
||||||
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
|
|
||||||
>>> batched_dot(x, y)
|
|
||||||
|
|
||||||
`vmap` can be helpful in hiding batch dimensions, leading to a simpler
|
|
||||||
model authoring experience.
|
|
||||||
|
|
||||||
>>> batch_size, feature_size = 3, 5
|
|
||||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
|
||||||
>>>
|
|
||||||
>>> def model(feature_vec):
|
|
||||||
>>> # Very simple linear model with activation
|
|
||||||
>>> return feature_vec.dot(weights).relu()
|
|
||||||
>>>
|
|
||||||
>>> examples = torch.randn(batch_size, feature_size)
|
|
||||||
>>> result = torch.vmap(model)(examples)
|
|
||||||
|
|
||||||
`vmap` can also help vectorize computations that were previously difficult
|
|
||||||
or impossible to batch. One example is higher-order gradient computation.
|
|
||||||
The PyTorch autograd engine computes vjps (vector-Jacobian products).
|
|
||||||
Computing a full Jacobian matrix for some function f: R^N -> R^N usually
|
|
||||||
requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
|
|
||||||
we can vectorize the whole computation, computing the Jacobian in a single
|
|
||||||
call to `autograd.grad`.
|
|
||||||
|
|
||||||
>>> # Setup
|
|
||||||
>>> N = 5
|
|
||||||
>>> f = lambda x: x ** 2
|
|
||||||
>>> x = torch.randn(N, requires_grad=True)
|
|
||||||
>>> y = f(x)
|
|
||||||
>>> I_N = torch.eye(N)
|
|
||||||
>>>
|
|
||||||
>>> # Sequential approach
|
|
||||||
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
|
|
||||||
>>> for v in I_N.unbind()]
|
|
||||||
>>> jacobian = torch.stack(jacobian_rows)
|
|
||||||
>>>
|
|
||||||
>>> # vectorized gradient computation
|
|
||||||
>>> def get_vjp(v):
|
|
||||||
>>> return torch.autograd.grad(y, x, v)
|
|
||||||
>>> jacobian = torch.vmap(get_vjp)(I_N)
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
vmap does not provide general autobatching or handle variable-length
|
|
||||||
sequences out of the box.
|
|
||||||
"""
|
"""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Please use functorch.vmap instead of torch.vmap "
|
"Please use torch.vmap instead of torch._vmap_internals.vmap. ",
|
||||||
"(https://github.com/pytorch/functorch). "
|
|
||||||
"We've moved development on torch.vmap over to functorch; "
|
|
||||||
"functorch's vmap has a multitude of significant performance and "
|
|
||||||
"functionality improvements.",
|
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
return _vmap(func, in_dims, out_dims)
|
return _vmap(func, in_dims, out_dims)
|
||||||
|
@ -870,7 +870,7 @@ def _test_batched_grad(input, output, output_idx) -> bool:
|
|||||||
# NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209
|
# NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", message="There is a performance drop")
|
warnings.filterwarnings("ignore", message="There is a performance drop")
|
||||||
warnings.filterwarnings("ignore", message="Please use functorch.vmap")
|
warnings.filterwarnings("ignore", message="Please use torch.vmap")
|
||||||
try:
|
try:
|
||||||
result = vmap(vjp)(torch.stack(grad_outputs))
|
result = vmap(vjp)(torch.stack(grad_outputs))
|
||||||
except RuntimeError as ex:
|
except RuntimeError as ex:
|
||||||
|
@ -238,6 +238,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||||||
torch.vitals_enabled,
|
torch.vitals_enabled,
|
||||||
torch.set_vital,
|
torch.set_vital,
|
||||||
torch.read_vitals,
|
torch.read_vitals,
|
||||||
|
torch.vmap,
|
||||||
torch.frombuffer,
|
torch.frombuffer,
|
||||||
torch.asarray,
|
torch.asarray,
|
||||||
Tensor.__delitem__,
|
Tensor.__delitem__,
|
||||||
|
Reference in New Issue
Block a user