[torch.func] alias torch.func.vmap as torch.vmap (#91026)

This PR also redirects torch.vmap to torch.func.vmap instead of the old
vmap prototype.

Test Plan:
- tests
- view docs preview
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91026
Approved by: https://github.com/albanD, https://github.com/samdow
This commit is contained in:
Richard Zou
2022-12-21 12:37:37 -05:00
committed by PyTorch MergeBot
parent e803d336eb
commit fb2e1878cb
8 changed files with 27 additions and 123 deletions

View File

@ -72,10 +72,11 @@ static void warnFallback(const c10::FunctionSchema& schema) {
} }
TORCH_WARN("There is a performance drop because we have not yet implemented ", TORCH_WARN("There is a performance drop because we have not yet implemented ",
"the batching rule for ", schema.operator_name(), ". ", "the batching rule for ", schema.operator_name(), ". ",
"We've moved development of vmap to to functorch " "You are using the legacy vmap prototype (torch._vmap_internals.vmap). ",
"(https://github.com/pytorch/functorch), please try functorch.vmap " "If you are using torch.autograd.functional.{jacobian, hessian} ",
"instead and/or file ", "or torch._vmap_internals.vmap: please switch to using ",
" an issue on GitHub so that we can prioritize its implementation."); "torch.func.{jacrev, jacfwd, hessian} and/or torch.vmap instead ",
"for better operator coverage and performance improvements .");
} }
// The general flow of the algorithm is as follows. // The general flow of the algorithm is as follows.

View File

@ -3,7 +3,8 @@
from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_utils import TestCase, run_tests
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, vmap from torch import Tensor
from torch._vmap_internals import vmap
import functools import functools
import itertools import itertools
import warnings import warnings

View File

@ -48,7 +48,7 @@ __all__ = [
'set_deterministic_debug_mode', 'get_deterministic_debug_mode', 'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
'set_float32_matmul_precision', 'get_float32_matmul_precision', 'set_float32_matmul_precision', 'get_float32_matmul_precision',
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'compile', 'compile', 'vmap',
] ]
################################################################################ ################################################################################
@ -1116,8 +1116,6 @@ del register_after_fork
# torch.jit.script as a decorator, for instance): # torch.jit.script as a decorator, for instance):
from ._lobpcg import lobpcg as lobpcg from ._lobpcg import lobpcg as lobpcg
from ._vmap_internals import vmap as vmap
# These were previously defined in native_functions.yaml and appeared on the # These were previously defined in native_functions.yaml and appeared on the
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom # `torch` namespace, but we moved them to c10 dispatch to facilitate custom
# class usage. We add these lines here to preserve backward compatibility. # class usage. We add these lines here to preserve backward compatibility.
@ -1245,3 +1243,4 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
import torch.fx.experimental.symbolic_shapes import torch.fx.experimental.symbolic_shapes
from torch import func as func from torch import func as func
from torch.func import vmap

View File

@ -1259,8 +1259,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients: When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
>>> from torch.func import grad >>> from torch.func import grad, vmap
>>> from torch.func import vmap
>>> batch_size, feature_size = 3, 5 >>> batch_size, feature_size = 3, 5
>>> >>>
>>> def model(weights, feature_vec): >>> def model(weights, feature_vec):

View File

@ -255,6 +255,10 @@ def vmap(
take batches of examples with ``vmap(func)``. vmap can also be used to take batches of examples with ``vmap(func)``. vmap can also be used to
compute batched gradients when composed with autograd. compute batched gradients when composed with autograd.
.. note::
:func:`torch.vmap` is aliased to :func:`torch.func.vmap` for
convenience. Use whichever one you'd like.
Args: Args:
func (function): A Python function that takes one or more arguments. func (function): A Python function that takes one or more arguments.
Must return one or more Tensors. Must return one or more Tensors.
@ -308,7 +312,7 @@ def vmap(
>>> return feature_vec.dot(weights).relu() >>> return feature_vec.dot(weights).relu()
>>> >>>
>>> examples = torch.randn(batch_size, feature_size) >>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.func.vmap(model)(examples) >>> result = torch.vmap(model)(examples)
:func:`vmap` can also help vectorize computations that were previously difficult :func:`vmap` can also help vectorize computations that were previously difficult
or impossible to batch. One example is higher-order gradient computation. or impossible to batch. One example is higher-order gradient computation.
@ -333,12 +337,12 @@ def vmap(
>>> # vectorized gradient computation >>> # vectorized gradient computation
>>> def get_vjp(v): >>> def get_vjp(v):
>>> return torch.autograd.grad(y, x, v) >>> return torch.autograd.grad(y, x, v)
>>> jacobian = torch.func.vmap(get_vjp)(I_N) >>> jacobian = torch.vmap(get_vjp)(I_N)
:func:`vmap` can also be nested, producing an output with multiple batched dimensions :func:`vmap` can also be nested, producing an output with multiple batched dimensions
>>> torch.dot # [D], [D] -> [] >>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.func.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y) # tensor of size [2, 3] >>> batched_dot(x, y) # tensor of size [2, 3]
@ -346,7 +350,7 @@ def vmap(
the dimension that each inputs are batched along as the dimension that each inputs are batched along as
>>> torch.dot # [N], [N] -> [] >>> torch.dot # [N], [N] -> []
>>> batched_dot = torch.func.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
@ -354,7 +358,7 @@ def vmap(
``in_dims`` must be a tuple with the batch dimension for each input as ``in_dims`` must be a tuple with the batch dimension for each input as
>>> torch.dot # [D], [D] -> [] >>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5) >>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
@ -364,7 +368,7 @@ def vmap(
>>> f = lambda dict: torch.dot(dict['x'], dict['y']) >>> f = lambda dict: torch.dot(dict['x'], dict['y'])
>>> x, y = torch.randn(2, 5), torch.randn(5) >>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {'x': x, 'y': y} >>> input = {'x': x, 'y': y}
>>> batched_dot = torch.func.vmap(f, in_dims=({'x': 0, 'y': None},)) >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> batched_dot(input) >>> batched_dot(input)
By default, the output is batched along the first dimension. However, it can be batched By default, the output is batched along the first dimension. However, it can be batched
@ -372,17 +376,17 @@ def vmap(
>>> f = lambda x: x ** 2 >>> f = lambda x: x ** 2
>>> x = torch.randn(2, 5) >>> x = torch.randn(2, 5)
>>> batched_pow = torch.func.vmap(f, out_dims=1) >>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x) # [5, 2] >>> batched_pow(x) # [5, 2]
For any function that uses kwargs, the returned function will not batch the kwargs but will For any function that uses kwargs, the returned function will not batch the kwargs but will
accept kwargs accept kwargs
>>> x = torch.randn([2, 5]) >>> x = torch.randn([2, 5])
>>> def f(x, scale=4.): >>> def fn(x, scale=4.):
>>> return x * scale >>> return x * scale
>>> >>>
>>> batched_pow = torch.func.vmap(f) >>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4) >>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]

View File

@ -191,111 +191,10 @@ def _get_name(func: Callable):
# on BatchedTensors perform the batched operations that the user is asking for. # on BatchedTensors perform the batched operations that the user is asking for.
def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable: def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
""" """
vmap is the vectorizing map. Returns a new function that maps `func` over some Please use torch.vmap instead of this API.
dimension of the inputs. Semantically, vmap pushes the map into PyTorch
operations called by `func`, effectively vectorizing those operations.
vmap is useful for handling batch dimensions: one can write a function `func`
that runs on examples and then lift it to a function that can take batches of
examples with `vmap(func)`. vmap can also be used to compute batched
gradients when composed with autograd.
.. note::
We have moved development of vmap to
`functorch. <https://github.com/pytorch/functorch>`_ functorch's
vmap is able to arbitrarily compose with gradient computation
and contains significant performance improvements.
Please give that a try if that is what you're looking for.
Furthermore, if you're interested in using vmap for your use case,
please `contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
We're interested in gathering feedback from early adopters to inform
the design.
.. warning::
torch.vmap is an experimental prototype that is subject to
change and/or deletion. Please use at your own risk.
Args:
func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
in_dims (int or nested structure): Specifies which dimension of the
inputs should be mapped over. `in_dims` should have a structure
like the inputs. If the `in_dim` for a particular input is None,
then that indicates there is no map dimension. Default: 0.
out_dims (int or Tuple[int]): Specifies where the mapped dimension
should appear in the outputs. If `out_dims` is a Tuple, then it should
have one element per output. Default: 0.
Returns:
Returns a new "batched" function. It takes the same inputs as `func`,
except each input has an extra dimension at the index specified by `in_dims`.
It takes returns the same outputs as `func`, except each output has
an extra dimension at the index specified by `out_dims`.
.. warning:
vmap works best with functional-style code. Please do not perform any
side-effects in `func`, with the exception of in-place PyTorch operations.
Examples of side-effects include mutating Python data structures and
assigning values to variables not captured in `func`.
One example of using `vmap` is to compute batched dot products. PyTorch
doesn't provide a batched `torch.dot` API; instead of unsuccessfully
rummaging through docs, use `vmap` to construct a new function.
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)
`vmap` can be helpful in hiding batch dimensions, leading to a simpler
model authoring experience.
>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>> # Very simple linear model with activation
>>> return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.vmap(model)(examples)
`vmap` can also help vectorize computations that were previously difficult
or impossible to batch. One example is higher-order gradient computation.
The PyTorch autograd engine computes vjps (vector-Jacobian products).
Computing a full Jacobian matrix for some function f: R^N -> R^N usually
requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
we can vectorize the whole computation, computing the Jacobian in a single
call to `autograd.grad`.
>>> # Setup
>>> N = 5
>>> f = lambda x: x ** 2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # Sequential approach
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>> for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>> return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)
.. note::
vmap does not provide general autobatching or handle variable-length
sequences out of the box.
""" """
warnings.warn( warnings.warn(
"Please use functorch.vmap instead of torch.vmap " "Please use torch.vmap instead of torch._vmap_internals.vmap. ",
"(https://github.com/pytorch/functorch). "
"We've moved development on torch.vmap over to functorch; "
"functorch's vmap has a multitude of significant performance and "
"functionality improvements.",
stacklevel=2, stacklevel=2,
) )
return _vmap(func, in_dims, out_dims) return _vmap(func, in_dims, out_dims)

View File

@ -870,7 +870,7 @@ def _test_batched_grad(input, output, output_idx) -> bool:
# NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209 # NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="There is a performance drop") warnings.filterwarnings("ignore", message="There is a performance drop")
warnings.filterwarnings("ignore", message="Please use functorch.vmap") warnings.filterwarnings("ignore", message="Please use torch.vmap")
try: try:
result = vmap(vjp)(torch.stack(grad_outputs)) result = vmap(vjp)(torch.stack(grad_outputs))
except RuntimeError as ex: except RuntimeError as ex:

View File

@ -238,6 +238,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.vitals_enabled, torch.vitals_enabled,
torch.set_vital, torch.set_vital,
torch.read_vitals, torch.read_vitals,
torch.vmap,
torch.frombuffer, torch.frombuffer,
torch.asarray, torch.asarray,
Tensor.__delitem__, Tensor.__delitem__,