mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
As the title shown ,the `backward` function is missing the definition of `ind` and `ind_inv`, which will lead to error when calling backward Pull Request resolved: https://github.com/pytorch/pytorch/pull/109279 Approved by: https://github.com/zou3519
505 lines
21 KiB
ReStructuredText
505 lines
21 KiB
ReStructuredText
.. _func-autograd-function:
|
|
|
|
Extending torch.func with autograd.Function
|
|
===========================================
|
|
|
|
.. currentmodule:: torch.autograd
|
|
|
|
So you'd like to use :class:`torch.autograd.Function` with the :mod:`torch.func`
|
|
transforms like :func:`torch.vmap`, :func:`torch.func.grad`, etc.
|
|
|
|
There are two main use cases:
|
|
|
|
- you wish to call code that does not contain PyTorch operations and
|
|
have it work with function transforms. That is, the :class:`torch.autograd.Function`'s
|
|
forward/backward/etc calls into functions from other systems like C++, CUDA, numpy.
|
|
- you wish to specify custom gradient rules, like
|
|
JAX's `custom_vjp/custom_jvp <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_
|
|
|
|
PyTorch combines both of these concepts into :class:`torch.autograd.Function`.
|
|
|
|
Basic Usage
|
|
-----------
|
|
|
|
This guide assumes you are familiar with :ref:`extending-autograd`,
|
|
which explains how to use :class:`torch.autograd.Function`.
|
|
|
|
:class:`torch.autograd.Function` can either have a :meth:`~Function.forward` that accepts a ctx object,
|
|
or it can have separate :meth:`~Function.forward` (that does not accept ``ctx``) and a :meth:`~Function.setup_context`
|
|
staticmethod that modifies the ``ctx`` object.
|
|
|
|
Only the latter is supported with function transforms:
|
|
|
|
- :meth:`~Function.forward` is the code that performs the operation and it should not accept
|
|
a ``ctx`` object.
|
|
- ``setup_context(ctx, inputs, output)`` is the code where you can
|
|
call methods on ``ctx``. Here is where you should save Tensors for backward
|
|
(by calling ``ctx.save_for_backward(*tensors)``), or save non-Tensors
|
|
(by assigning them to the ``ctx`` object).
|
|
|
|
Because :meth:`~Function.setup_context` accepts only ``inputs`` and ``output``,
|
|
the only quantities that can be saved are either objects (such as Tensors) in
|
|
the inputs or outputs or quantities (like ``Tensor.shape``) derived from them.
|
|
If you wish to save a non-input intermediate activation from
|
|
:meth:`Function.forward` for backward, then you'll need to return it as an
|
|
output from :meth:`~Function.forward` so that it gets passed to
|
|
:meth:`~Function.setup_context`.
|
|
|
|
Depending on the transform,
|
|
|
|
- to support reverse-mode AD (:func:`torch.func.grad`, :func:`torch.func.vjp`),
|
|
the :class:`torch.autograd.Function` needs a :meth:`~Function.backward` staticmethod.
|
|
- to support :func:`torch.vmap`, the :class:`torch.autograd.Function` needs a :meth:`~Function.vmap` staticmethod.
|
|
- to support :func:`torch.func.jvp`, the :class:`torch.autograd.Function` needs a :meth:`~Function.jvp` staticmethod.
|
|
- to support compositions of transforms (like :func:`torch.func.jacrev`,
|
|
:func:`torch.func.jacfwd`, :func:`torch.func.hessian`) -- you may need multiple
|
|
of the above.
|
|
|
|
In order for the :class:`torch.autograd.Function` to be arbitrarily composable with function
|
|
transforms, we recommend that all other staticmethods other than :meth:`~Function.forward` and
|
|
:meth:`~Function.setup_context` must be transformable: that is, they must consist of only PyTorch
|
|
operators or call other :class:`torch.autograd.Function` (that may call into C++/CUDA/etc).
|
|
|
|
Let's go over some examples of common use cases.
|
|
|
|
Example 1: autograd.Function calls into another system
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
A common case is a :class:`torch.autograd.Function` with both forward() and backward() calling
|
|
into another system (like C++, CUDA, numpy, triton).
|
|
|
|
::
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
def to_numpy(tensor):
|
|
return tensor.cpu().numpy()
|
|
|
|
class NumpySort(torch.autograd.Function):
|
|
# Note that forward does not take ctx
|
|
@staticmethod
|
|
def forward(x, dim):
|
|
device = x.device
|
|
x = to_numpy(x)
|
|
ind = np.argsort(x, axis=dim)
|
|
ind_inv = np.argsort(ind, axis=dim)
|
|
result = np.take_along_axis(x, ind, axis=dim)
|
|
# Any intermediates to be saved in backward must be returned as
|
|
# outputs.
|
|
return (
|
|
# The desired output
|
|
torch.tensor(result, device=device),
|
|
# intermediate to save for backward
|
|
torch.tensor(ind, device=device),
|
|
# intermediate to save for backward
|
|
torch.tensor(ind_inv, device=device),
|
|
)
|
|
|
|
# setup_context is responsible for calling methods and/or assigning to
|
|
# the ctx object. Please do not do additional compute (e.g. add
|
|
# Tensors together) in setup_context.
|
|
@staticmethod
|
|
def setup_context(ctx, inputs, output):
|
|
x, dim = inputs
|
|
# Note that output is whatever you returned from forward.
|
|
# If you returned multiple values, then output is a Tuple of multiple values.
|
|
# If you returned a single Tensor, then output is a Tensor.
|
|
# If you returned a Tuple with a single Tensor, then output is a
|
|
# Tuple with a single Tensor.
|
|
_, ind, ind_inv = output
|
|
ctx.mark_non_differentiable(ind, ind_inv)
|
|
# Tensors must be saved via ctx.save_for_backward. Please do not
|
|
# assign them directly onto the ctx object.
|
|
ctx.save_for_backward(ind, ind_inv)
|
|
# Non-tensors may be saved by assigning them as attributes on the ctx object.
|
|
ctx.dim = dim
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output, _0, _1):
|
|
# For the autograd.Function to be arbitrarily composable with function
|
|
# transforms, all staticmethod other than forward and setup_context
|
|
# must be implemented in a "transformable" way; that is, they must
|
|
# only consist of PyTorch operations or autograd.Function.
|
|
#
|
|
# For example, this allows us to do double backwards and/or compute
|
|
# second order gradients.
|
|
#
|
|
# We've written the backward pass of NumpySort in terms of another
|
|
# autograd.Function, NumpyTake.
|
|
ind, ind_inv = ctx.saved_tensors
|
|
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
|
|
|
|
class NumpyTake(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(x, ind, ind_inv, dim):
|
|
device = x.device
|
|
x = to_numpy(x)
|
|
ind = to_numpy(ind)
|
|
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
|
|
|
|
@staticmethod
|
|
def setup_context(ctx, inputs, output):
|
|
x, ind, ind_inv, dim = inputs
|
|
ctx.save_for_backward(ind, ind_inv)
|
|
ctx.dim = dim
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
ind, ind_inv = ctx.saved_tensors
|
|
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
|
|
return result, None, None, None
|
|
|
|
|
|
Now, to make it easier to use ``NumpySort`` (to hide away the intermediates we
|
|
returned as outputs, as well as allow default args and kwargs), we create a new
|
|
function that invokes it::
|
|
|
|
def numpy_sort(x, dim=-1):
|
|
result, _, _ = NumpySort.apply(x, dim)
|
|
return result
|
|
|
|
And here's a sanity check::
|
|
|
|
x = torch.randn(2, 3)
|
|
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
|
|
assert torch.allclose(grad_x, torch.ones_like(x))
|
|
|
|
|
|
|
|
Example 2: autograd.Function specifies custom gradient rules
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Another common case is an :class:`torch.autograd.Function` that is implemented with PyTorch
|
|
operations. PyTorch is able to compute gradients for PyTorch operations automatically,
|
|
but perhaps we wish to customize how the gradients are computed. Some reasons why
|
|
we may want a custom backward different from the one PyTorch gives us are:
|
|
|
|
- improving numeric stability
|
|
- changing the performance characteristics of the backward
|
|
- changing how edge cases are handled (e.g. nans, inf)
|
|
- modifying the gradient (e.g. gradient clipping)
|
|
|
|
Here's an example of an :class:`torch.autograd.Function` for the function ``y = x ** 3`` where we
|
|
change the performance characteristics (some computation that would normally happen
|
|
during the backward pass, computing dx, happens in the forward pass).
|
|
|
|
::
|
|
|
|
class MyCube(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(x):
|
|
result = x ** 3
|
|
# In regular PyTorch, if we had just run y = x ** 3, then the backward
|
|
# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
|
|
# that computation here in the forward pass instead.
|
|
dx = 3 * x ** 2
|
|
return result, dx
|
|
|
|
@staticmethod
|
|
def setup_context(ctx, inputs, output):
|
|
x, = inputs
|
|
result, dx = output
|
|
ctx.save_for_backward(x, dx)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output, grad_dx):
|
|
x, dx = ctx.saved_tensors
|
|
# In order for the autograd.Function to work with higher-order
|
|
# gradients, we must add the gradient contribution of `dx`.
|
|
result = grad_output * dx + grad_dx * 6 * x
|
|
return result
|
|
|
|
Now, to make it easier to use ``NumpySort`` (and hide away the intermediates we
|
|
returned as outputs) we create a new function that invokes it::
|
|
|
|
def my_cube(x):
|
|
result, _ = MyCube.apply(x)
|
|
return result
|
|
|
|
Here's a sanity check computing the second-order gradients::
|
|
|
|
x = torch.randn([])
|
|
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
|
|
assert torch.allclose(ggx, 6 * x)
|
|
|
|
Limitations and gotchas
|
|
^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
.. warning::
|
|
|
|
Please read these limitations of :class:`torch.autograd.Function` with torch.func transforms
|
|
carefully. We are not able to catch many of these situations and error out
|
|
gracefully so they will lead to undefined behavior.
|
|
|
|
Please do not capture Tensors that are being transformed over, have
|
|
requires_grad=True, or are dual tensors, into the methods of the
|
|
:class:`torch.autograd.Function`. The way to be completely safe is to ensure that the only
|
|
Tensors being used inside any method of the :class:`torch.autograd.Function` must be directly
|
|
passed as inputs (or via the ctx object) rather than come from outside
|
|
the :class:`torch.autograd.Function`.
|
|
|
|
:class:`torch.autograd.Function` does not handle Tensors in pytrees (arbitrary nested
|
|
Python data structures that may or may not contain Tensors). For
|
|
those Tensors to be tracked by autograd, they must be passed directly as
|
|
an argument to :class:`torch.autograd.Function`. This is in contrast to
|
|
jax.{custom_vjp, custom_jvp}, which do accept pytrees.
|
|
|
|
Please only use :meth:`~torch.autograd.function.FunctionCtx.save_for_backward` or
|
|
:meth:`~torch.autograd.function.FunctionCtx.save_for_forward` to save Tensors.
|
|
Please do not assign Tensors or collections of Tensors directly onto the ctx object -
|
|
these Tensors will not get tracked
|
|
|
|
|
|
:func:`torch.vmap` Support
|
|
--------------------------
|
|
|
|
To use an :class:`torch.autograd.Function` with :func:`torch.vmap`, you must either:
|
|
|
|
- provide a :meth:`~Function.vmap` staticmethod that tells us the behavior of the :class:`torch.autograd.Function`
|
|
under :func:`torch.vmap`
|
|
- ask us to autogenerate it by setting ``generate_vmap_rule=True``.
|
|
|
|
Automatically generate a vmap rule
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
If your :class:`torch.autograd.Function` fulfills the following additional constraints, then we
|
|
are able to generate a vmap rule for it. If it doesn't fulfill the constraints or if you
|
|
want custom behavior under vmap, please manually define a vmap staticmethod (see next section).
|
|
|
|
.. warning::
|
|
|
|
We are not easily able to check for the following constraints and error
|
|
out gracefully. Violation of the constraints may lead to undefined
|
|
behavior.
|
|
|
|
- The :class:`torch.autograd.Function`'s :meth:`~Function.forward`, :meth:`~Function.backward` (if it exists) and :meth:`~Function.jvp`
|
|
(if it exists) staticmethods must be transformable via :func:`torch.vmap`. That
|
|
is, they must consist of only PyTorch operations (as opposed to e.g. NumPy or custom
|
|
CUDA kernels).
|
|
|
|
Example::
|
|
|
|
class MyCube(torch.autograd.Function):
|
|
# Set generate_vmap_rule to True to ask PyTorch to automatically generate
|
|
# a vmap rule.
|
|
generate_vmap_rule = True
|
|
|
|
@staticmethod
|
|
def forward(x):
|
|
result = x ** 3
|
|
dx = 3 * x ** 2
|
|
return result, dx
|
|
|
|
@staticmethod
|
|
def setup_context(ctx, inputs, output):
|
|
x, = inputs
|
|
result, dx = output
|
|
ctx.save_for_backward(x, dx)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output, grad_dx):
|
|
x, dx = ctx.saved_tensors
|
|
result = grad_output * dx + grad_dx * 6 * x
|
|
return result
|
|
|
|
def my_cube(x):
|
|
result, dx = MyCube.apply(x)
|
|
return result
|
|
|
|
x = torch.randn(3)
|
|
result = torch.vmap(my_cube)(x)
|
|
assert torch.allclose(result, x ** 3)
|
|
|
|
|
|
Defining the vmap staticmethod
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
If your :class:`torch.autograd.Function` calls into another system (like NumPy, C++, CUDA, triton),
|
|
then to get it to work with :func:`torch.vmap` or transforms that use it, you'll
|
|
need to manually define a :meth:`~Function.vmap` staticmethod.
|
|
|
|
Depending on what transforms you want to use and your use case, you may not need
|
|
to add a :meth:`~Function.vmap` staticmethod to all of your :class:`torch.autograd.Function`:
|
|
|
|
- For example, :func:`torch.func.jacrev` performs :func:`~torch.vmap` over the backward pass.
|
|
So if you're only interested in using :func:`torch.func.jacrev`, only
|
|
the :meth:`~Function.backward` staticmethod needs to be vmappable.
|
|
|
|
We do recommend ensuring all of your :class:`torch.autograd.Function` have support for
|
|
:func:`torch.vmap` though, especially if you are writing a third-party library and you want your
|
|
:class:`torch.autograd.Function` to work with all combinations of :func:`torch.func` transforms.
|
|
|
|
Conceptually, the vmap staticmethod is responsible for defining how the :meth:`~Function.forward`
|
|
should behave under :func:`torch.vmap`. That is, it defines how to transform
|
|
the :meth:`~Function.forward` to run over inputs with an additional dimension (the dimension
|
|
being vmapped over). This is similar to how :func:`torch.vmap` is implemented over
|
|
PyTorch operations: for each operation, we define a vmap rule (sometimes also
|
|
referred to as a "batching rule").
|
|
|
|
Here's how to define the :meth:`~Function.vmap` staticmethod:
|
|
|
|
- the signature is ``vmap(info, in_dims: Tuple[Optional[int]], *args)``, where
|
|
``*args`` is the same as the args to :meth:`~Function.forward`.
|
|
- The vmap staticmethod is responsible for defining how the :meth:`~Function.forward` should behave
|
|
under :func:`torch.vmap`. That is, given inputs with an additional dimension
|
|
(specified by ``in_dims``), how do we compute the batched version of :meth:`~Function.forward`?
|
|
- For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``.
|
|
It is ``None`` if the arg is not a Tensor or if the arg is not being vmapped over,
|
|
otherwise, it is an integer specifying what dimension of the Tensor is being vmapped
|
|
over.
|
|
- ``info`` is a collection of additional metadata that may be helpful:
|
|
``info.batch_size`` specifies the size of the dimension being vmapped over, while
|
|
``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
|
|
- The return of the vmap staticmethod is a tuple of ``(output, out_dims)``. Similar
|
|
to ``in_dims``, ``out_dims`` should be of the same structure as ``output`` and contain
|
|
one ``out_dim`` per output that specifies if the output has the vmapped
|
|
dimension and what index it is in.
|
|
|
|
|
|
Example::
|
|
|
|
def to_numpy(tensor):
|
|
return tensor.cpu().numpy()
|
|
|
|
class NumpySort(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(x, dim):
|
|
device = x.device
|
|
x = to_numpy(x)
|
|
ind = np.argsort(x, axis=dim)
|
|
ind_inv = np.argsort(ind, axis=dim)
|
|
result = np.take_along_axis(x, ind, axis=dim)
|
|
return (
|
|
torch.tensor(result, device=device),
|
|
torch.tensor(ind, device=device),
|
|
torch.tensor(ind_inv, device=device),
|
|
)
|
|
|
|
@staticmethod
|
|
def setup_context(ctx, inputs, output):
|
|
x, dim = inputs
|
|
_, ind, ind_inv = output
|
|
ctx.mark_non_differentiable(ind, ind_inv)
|
|
ctx.save_for_backward(ind, ind_inv)
|
|
ctx.dim = dim
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output, _0, _1):
|
|
ind, ind_inv = ctx.saved_tensors
|
|
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
|
|
|
|
# The signature of the vmap staticmethod is:
|
|
# vmap(info, in_dims: Tuple[Optional[int]], *args)
|
|
# where *args is the same as the arguments to `forward`.
|
|
@staticmethod
|
|
def vmap(info, in_dims, x, dim):
|
|
# For every input (x and dim), in_dims stores an Optional[int]
|
|
# that is:
|
|
# - None if the input is not being vmapped over or if the input
|
|
# is not a Tensor
|
|
# - an integer if the input is being vmapped over that represents
|
|
# the index of the dimension being vmapped over.
|
|
x_bdim, _ = in_dims
|
|
|
|
# A "vmap rule" is the logic of how to perform the operation given
|
|
# inputs with one additional dimension. In NumpySort, x has an
|
|
# additional dimension (x_bdim). The vmap rule is simply
|
|
# to call NumpySort again but pass it a different `dim`.
|
|
x = x.movedim(x_bdim, 0)
|
|
# Handle negative dims correctly
|
|
dim = dim if dim >= 0 else dim + x.dim() - 1
|
|
result = NumpySort.apply(x, dim + 1)
|
|
|
|
# The vmap rule must return a tuple of two things
|
|
# 1. the output. Should be the same amount of things
|
|
# as returned by the forward().
|
|
# 2. one Optional[int] for each output specifying if each output
|
|
# is being vmapped over, and if so, the index of the
|
|
# dimension being vmapped over.
|
|
#
|
|
# NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
|
|
# dimension being vmapped over to the front of `x`, that appears at
|
|
# dimension 0 of all outputs.
|
|
# The return is (output, out_dims) -- output is a tuple of 3 Tensors
|
|
# and out_dims is a Tuple of 3 Optional[int]
|
|
return NumpySort.apply(x, dim + 1), (0, 0, 0)
|
|
|
|
class NumpyTake(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(x, ind, ind_inv, dim):
|
|
device = x.device
|
|
x = to_numpy(x)
|
|
ind = to_numpy(ind)
|
|
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
|
|
|
|
@staticmethod
|
|
def setup_context(ctx, inputs, output):
|
|
x, ind, ind_inv, dim = inputs
|
|
ctx.save_for_backward(ind, ind_inv)
|
|
ctx.dim = dim
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
ind, ind_inv = ctx.saved_tensors
|
|
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
|
|
return result, None, None, None
|
|
|
|
@staticmethod
|
|
def vmap(info, in_dims, x, ind, ind_inv, dim):
|
|
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
|
|
|
|
# The strategy is: expand {x, ind, ind_inv} to all have the dimension
|
|
# being vmapped over.
|
|
# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
|
|
|
|
# Handle negative dims by wrapping them to be positive
|
|
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
|
|
dim = dim if dim >= 0 else dim + logical_dim
|
|
|
|
def maybe_expand_bdim_at_front(x, x_bdim):
|
|
if x_bdim is None:
|
|
return x.expand(info.batch_size, *x.shape)
|
|
return x.movedim(x_bdim, 0)
|
|
|
|
# If the Tensor doesn't have the dimension being vmapped over,
|
|
# expand it out. Otherwise, move it to the front of the Tensor
|
|
x = maybe_expand_bdim_at_front(x, x_bdim)
|
|
ind = maybe_expand_bdim_at_front(ind, ind_bdim)
|
|
ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
|
|
|
|
# The return is a tuple (output, out_dims). Since output is a Tensor,
|
|
# then out_dims is an Optional[int] (instead of being a Tuple).
|
|
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
|
|
|
|
def numpy_sort(x, dim=-1):
|
|
result, _, _ = NumpySort.apply(x, dim)
|
|
return result
|
|
|
|
x = torch.randn(2, 3)
|
|
result = torch.vmap(numpy_sort)(x)
|
|
assert torch.allclose(result, numpy_sort(result, 1))
|
|
|
|
|
|
.. note::
|
|
|
|
The vmap staticmethod should aim to preserve the semantics of the
|
|
entire :class:`~torch.autograd.Function`. That is, (pseudocode) ``grad(vmap(MyFunc))``
|
|
should be replaceable with a ``grad(map(MyFunc))``.
|
|
|
|
If your autograd.Function has any custom behavior in the backward pass, please
|
|
keep this in mind.
|
|
|
|
.. note::
|
|
|
|
It is a legitimate use case to write a custom vmap staticmethod for a
|
|
:class:`~torch.autograd.Function` that PyTorch is able to generate a vmap
|
|
rule for via ``generate_vmap_rule=True``. You may wish to do this if the
|
|
generated vmap rule doesn't have the semantics you're looking for.
|
|
|
|
:func:`torch.func.jvp` Support
|
|
------------------------------
|
|
|
|
To support forward-mode AD, a :class:`torch.autograd.Function` must have a :meth:`~Function.jvp` staticmethod.
|
|
Please see :ref:`forward-ad-autograd-function` for details.
|