Added docs for context method mixins. Fixes issue #27365 (#28643)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/27365 .

This PR:
1. Makes Context method docs available.
2. Links [Extending torch autograd](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd) notes to Context method docs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28643

Differential Revision: D18170089

Pulled By: albanD

fbshipit-source-id: a1119ea8e2f8a71f0d1aadf416f2f98343aa9b7b
This commit is contained in:
Prasun Anand
2019-10-28 08:28:40 -07:00
committed by Facebook Github Bot
parent 0e86c99bfb
commit 4230132baf
2 changed files with 22 additions and 2 deletions

View File

@ -82,6 +82,13 @@ Tensor autograd functions
.. autoclass:: Function
:members:
Context method mixins
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
When creating a new :class:`Function`, the following methods are available to `ctx`.
.. autoclass:: torch.autograd.function._ContextMethodMixin
:members:
.. _grad-check:
Numerical gradient checking

View File

@ -13,8 +13,7 @@ Extending :mod:`torch.autograd`
Adding operations to :mod:`~torch.autograd` requires implementing a new
:class:`Function` subclass for each operation. Recall that :class:`Function` s
are what :mod:`~torch.autograd` uses to compute the results and gradients, and
encode the operation history. Every new function requires you to implement 2
methods:
encode the operation history. Every new function requires you to implement 2 methods:
- :meth:`~Function.forward` - the code that performs the operation. It can take
as many arguments as you want, with some of them being optional, if you
@ -39,6 +38,20 @@ methods:
arguments to :meth:`~Function.forward` you can return more gradients than there
were inputs, as long as they're all :any:`python:None`.
.. note::
It's the user's responsibility to use the special functions in the forward's `ctx`
properly in order to ensure that the new :class:`Function` works properly with
the autograd engine.
- :meth:`~torch.autograd.function._ContextMethodMixin.save_for_backward` must be
used when saving input or ouput of the forward to be used later in the backward.
- :meth:`~torch.autograd.function._ContextMethodMixin.mark_dirty` must be used to
marked any input that is modified inplace by the forward function.
- :meth:`~torch.autograd.function._ContextMethodMixin.mark_non_differentiable` must
be used to tell the engine if an output is not differentiable.
Below you can find code for a ``Linear`` function from :mod:`torch.nn`, with
additional comments::