mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Beef up vmap docs and expose to master documentation (#44825)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44825 Test Plan: - build and view docs locally. Reviewed By: ezyang Differential Revision: D23742727 Pulled By: zou3519 fbshipit-source-id: f62b7a76b5505d3387b7816c514c086c01089de0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c2cf6efd96
commit
6d312132e1
@ -161,9 +161,19 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
||||
operations called by `func`, effectively vectorizing those operations.
|
||||
|
||||
vmap is useful for handling batch dimensions: one can write a function `func`
|
||||
that runs on examples and the lift it to a function that can take batches of
|
||||
examples with `vmap(func)`. Furthermore, it is possible to use vmap to obtain
|
||||
batched gradients when composed with autograd.
|
||||
that runs on examples and then lift it to a function that can take batches of
|
||||
examples with `vmap(func)`. vmap can also be used to compute batched
|
||||
gradients when composed with autograd.
|
||||
|
||||
.. warning::
|
||||
torch.vmap is an experimental prototype that is subject to
|
||||
change and/or deletion. Please use at your own risk.
|
||||
|
||||
.. note::
|
||||
If you're interested in using vmap for your use case, please
|
||||
`contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
|
||||
We're interested in gathering feedback from early adopters to inform
|
||||
the design.
|
||||
|
||||
Args:
|
||||
func (function): A Python function that takes one or more arguments.
|
||||
@ -188,9 +198,56 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
||||
Examples of side-effects include mutating Python data structures and
|
||||
assigning values to variables not captured in `func`.
|
||||
|
||||
.. warning::
|
||||
torch.vmap is an experimental prototype that is subject to
|
||||
change and/or deletion. Please use at your own risk.
|
||||
One example of using `vmap` is to compute batched dot products. PyTorch
|
||||
doesn't provide a batched `torch.dot` API; instead of unsuccessfully
|
||||
rummaging through docs, use `vmap` to construct a new function.
|
||||
|
||||
>>> torch.dot # [D], [D] -> []
|
||||
>>> batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
|
||||
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
|
||||
>>> batched_dot(x, y)
|
||||
|
||||
`vmap` can be helpful in hiding batch dimensions, leading to a simpler
|
||||
model authoring experience.
|
||||
|
||||
>>> batch_size, feature_size = 3, 5
|
||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||
>>>
|
||||
>>> def model(feature_vec):
|
||||
>>> # Very simple linear model with activation
|
||||
>>> return feature_vec.dot(weights).relu()
|
||||
>>>
|
||||
>>> examples = torch.randn(batch_size, feature_size)
|
||||
>>> result = torch.vmap(model)(examples)
|
||||
|
||||
`vmap` can also help vectorize computations that were previously difficult
|
||||
or impossible to batch. One example is higher-order gradient computation.
|
||||
The PyTorch autograd engine computes vjps (vector-Jacobian products).
|
||||
Computing a full Jacobian matrix for some function f: R^N -> R^N usually
|
||||
requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
|
||||
we can vectorize the whole computation, computing the Jacobian in a single
|
||||
call to `autograd.grad`.
|
||||
|
||||
>>> # Setup
|
||||
>>> N = 5
|
||||
>>> f = lambda x: x ** 2
|
||||
>>> x = torch.randn(N, requires_grad=True)
|
||||
>>> y = f(x)
|
||||
>>> I_N = torch.eye(N)
|
||||
>>>
|
||||
>>> # Sequential approach
|
||||
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
|
||||
>>> for v in I_N.unbind()]
|
||||
>>> jacobian = torch.stack(jacobian_rows)
|
||||
>>>
|
||||
>>> # vectorized gradient computation
|
||||
>>> def get_vjp(v):
|
||||
>>> return torch.autograd.grad(y, x, v)
|
||||
>>> jacobian = torch.vmap(get_vjp)(I_N)
|
||||
|
||||
.. note::
|
||||
vmap does not provide general autobatching or handle variable-length
|
||||
sequences out of the box.
|
||||
"""
|
||||
warnings.warn(
|
||||
'torch.vmap is an experimental prototype that is subject to '
|
||||
|
Reference in New Issue
Block a user