[torch.func] alias torch.func.vmap as torch.vmap (#91026)

This PR also redirects torch.vmap to torch.func.vmap instead of the old
vmap prototype.

Test Plan:
- tests
- view docs preview
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91026
Approved by: https://github.com/albanD, https://github.com/samdow
This commit is contained in:
Richard Zou
2022-12-21 12:37:37 -05:00
committed by PyTorch MergeBot
parent e803d336eb
commit fb2e1878cb
8 changed files with 27 additions and 123 deletions

View File

@ -72,10 +72,11 @@ static void warnFallback(const c10::FunctionSchema& schema) {
}
TORCH_WARN("There is a performance drop because we have not yet implemented ",
"the batching rule for ", schema.operator_name(), ". ",
"We've moved development of vmap to to functorch "
"(https://github.com/pytorch/functorch), please try functorch.vmap "
"instead and/or file ",
" an issue on GitHub so that we can prioritize its implementation.");
"You are using the legacy vmap prototype (torch._vmap_internals.vmap). ",
"If you are using torch.autograd.functional.{jacobian, hessian} ",
"or torch._vmap_internals.vmap: please switch to using ",
"torch.func.{jacrev, jacfwd, hessian} and/or torch.vmap instead ",
"for better operator coverage and performance improvements .");
}
// The general flow of the algorithm is as follows.

View File

@ -3,7 +3,8 @@
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch.nn.functional as F
from torch import Tensor, vmap
from torch import Tensor
from torch._vmap_internals import vmap
import functools
import itertools
import warnings

View File

@ -48,7 +48,7 @@ __all__ = [
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
'set_float32_matmul_precision', 'get_float32_matmul_precision',
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'compile',
'compile', 'vmap',
]
################################################################################
@ -1116,8 +1116,6 @@ del register_after_fork
# torch.jit.script as a decorator, for instance):
from ._lobpcg import lobpcg as lobpcg
from ._vmap_internals import vmap as vmap
# These were previously defined in native_functions.yaml and appeared on the
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
# class usage. We add these lines here to preserve backward compatibility.
@ -1245,3 +1243,4 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
import torch.fx.experimental.symbolic_shapes
from torch import func as func
from torch.func import vmap

View File

@ -1259,8 +1259,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
>>> from torch.func import grad
>>> from torch.func import vmap
>>> from torch.func import grad, vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):

View File

@ -255,6 +255,10 @@ def vmap(
take batches of examples with ``vmap(func)``. vmap can also be used to
compute batched gradients when composed with autograd.
.. note::
:func:`torch.vmap` is aliased to :func:`torch.func.vmap` for
convenience. Use whichever one you'd like.
Args:
func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
@ -308,7 +312,7 @@ def vmap(
>>> return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.func.vmap(model)(examples)
>>> result = torch.vmap(model)(examples)
:func:`vmap` can also help vectorize computations that were previously difficult
or impossible to batch. One example is higher-order gradient computation.
@ -333,12 +337,12 @@ def vmap(
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>> return torch.autograd.grad(y, x, v)
>>> jacobian = torch.func.vmap(get_vjp)(I_N)
>>> jacobian = torch.vmap(get_vjp)(I_N)
:func:`vmap` can also be nested, producing an output with multiple batched dimensions
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.func.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y) # tensor of size [2, 3]
@ -346,7 +350,7 @@ def vmap(
the dimension that each inputs are batched along as
>>> torch.dot # [N], [N] -> []
>>> batched_dot = torch.func.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
>>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
@ -354,7 +358,7 @@ def vmap(
``in_dims`` must be a tuple with the batch dimension for each input as
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
@ -364,7 +368,7 @@ def vmap(
>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {'x': x, 'y': y}
>>> batched_dot = torch.func.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> batched_dot(input)
By default, the output is batched along the first dimension. However, it can be batched
@ -372,17 +376,17 @@ def vmap(
>>> f = lambda x: x ** 2
>>> x = torch.randn(2, 5)
>>> batched_pow = torch.func.vmap(f, out_dims=1)
>>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x) # [5, 2]
For any function that uses kwargs, the returned function will not batch the kwargs but will
accept kwargs
>>> x = torch.randn([2, 5])
>>> def f(x, scale=4.):
>>> def fn(x, scale=4.):
>>> return x * scale
>>>
>>> batched_pow = torch.func.vmap(f)
>>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]

View File

@ -191,111 +191,10 @@ def _get_name(func: Callable):
# on BatchedTensors perform the batched operations that the user is asking for.
def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
"""
vmap is the vectorizing map. Returns a new function that maps `func` over some
dimension of the inputs. Semantically, vmap pushes the map into PyTorch
operations called by `func`, effectively vectorizing those operations.
vmap is useful for handling batch dimensions: one can write a function `func`
that runs on examples and then lift it to a function that can take batches of
examples with `vmap(func)`. vmap can also be used to compute batched
gradients when composed with autograd.
.. note::
We have moved development of vmap to
`functorch. <https://github.com/pytorch/functorch>`_ functorch's
vmap is able to arbitrarily compose with gradient computation
and contains significant performance improvements.
Please give that a try if that is what you're looking for.
Furthermore, if you're interested in using vmap for your use case,
please `contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
We're interested in gathering feedback from early adopters to inform
the design.
.. warning::
torch.vmap is an experimental prototype that is subject to
change and/or deletion. Please use at your own risk.
Args:
func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
in_dims (int or nested structure): Specifies which dimension of the
inputs should be mapped over. `in_dims` should have a structure
like the inputs. If the `in_dim` for a particular input is None,
then that indicates there is no map dimension. Default: 0.
out_dims (int or Tuple[int]): Specifies where the mapped dimension
should appear in the outputs. If `out_dims` is a Tuple, then it should
have one element per output. Default: 0.
Returns:
Returns a new "batched" function. It takes the same inputs as `func`,
except each input has an extra dimension at the index specified by `in_dims`.
It takes returns the same outputs as `func`, except each output has
an extra dimension at the index specified by `out_dims`.
.. warning:
vmap works best with functional-style code. Please do not perform any
side-effects in `func`, with the exception of in-place PyTorch operations.
Examples of side-effects include mutating Python data structures and
assigning values to variables not captured in `func`.
One example of using `vmap` is to compute batched dot products. PyTorch
doesn't provide a batched `torch.dot` API; instead of unsuccessfully
rummaging through docs, use `vmap` to construct a new function.
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)
`vmap` can be helpful in hiding batch dimensions, leading to a simpler
model authoring experience.
>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>> # Very simple linear model with activation
>>> return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.vmap(model)(examples)
`vmap` can also help vectorize computations that were previously difficult
or impossible to batch. One example is higher-order gradient computation.
The PyTorch autograd engine computes vjps (vector-Jacobian products).
Computing a full Jacobian matrix for some function f: R^N -> R^N usually
requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
we can vectorize the whole computation, computing the Jacobian in a single
call to `autograd.grad`.
>>> # Setup
>>> N = 5
>>> f = lambda x: x ** 2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # Sequential approach
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>> for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>> return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)
.. note::
vmap does not provide general autobatching or handle variable-length
sequences out of the box.
Please use torch.vmap instead of this API.
"""
warnings.warn(
"Please use functorch.vmap instead of torch.vmap "
"(https://github.com/pytorch/functorch). "
"We've moved development on torch.vmap over to functorch; "
"functorch's vmap has a multitude of significant performance and "
"functionality improvements.",
"Please use torch.vmap instead of torch._vmap_internals.vmap. ",
stacklevel=2,
)
return _vmap(func, in_dims, out_dims)

View File

@ -870,7 +870,7 @@ def _test_batched_grad(input, output, output_idx) -> bool:
# NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="There is a performance drop")
warnings.filterwarnings("ignore", message="Please use functorch.vmap")
warnings.filterwarnings("ignore", message="Please use torch.vmap")
try:
result = vmap(vjp)(torch.stack(grad_outputs))
except RuntimeError as ex:

View File

@ -238,6 +238,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.vitals_enabled,
torch.set_vital,
torch.read_vitals,
torch.vmap,
torch.frombuffer,
torch.asarray,
Tensor.__delitem__,