mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/27365 . This PR: 1. Makes Context method docs available. 2. Links [Extending torch autograd](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd) notes to Context method docs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/28643 Differential Revision: D18170089 Pulled By: albanD fbshipit-source-id: a1119ea8e2f8a71f0d1aadf416f2f98343aa9b7b
This commit is contained in:
committed by
Facebook Github Bot
parent
0e86c99bfb
commit
4230132baf
@ -82,6 +82,13 @@ Tensor autograd functions
|
||||
.. autoclass:: Function
|
||||
:members:
|
||||
|
||||
Context method mixins
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
When creating a new :class:`Function`, the following methods are available to `ctx`.
|
||||
|
||||
.. autoclass:: torch.autograd.function._ContextMethodMixin
|
||||
:members:
|
||||
|
||||
.. _grad-check:
|
||||
|
||||
Numerical gradient checking
|
||||
|
@ -13,8 +13,7 @@ Extending :mod:`torch.autograd`
|
||||
Adding operations to :mod:`~torch.autograd` requires implementing a new
|
||||
:class:`Function` subclass for each operation. Recall that :class:`Function` s
|
||||
are what :mod:`~torch.autograd` uses to compute the results and gradients, and
|
||||
encode the operation history. Every new function requires you to implement 2
|
||||
methods:
|
||||
encode the operation history. Every new function requires you to implement 2 methods:
|
||||
|
||||
- :meth:`~Function.forward` - the code that performs the operation. It can take
|
||||
as many arguments as you want, with some of them being optional, if you
|
||||
@ -39,6 +38,20 @@ methods:
|
||||
arguments to :meth:`~Function.forward` you can return more gradients than there
|
||||
were inputs, as long as they're all :any:`python:None`.
|
||||
|
||||
.. note::
|
||||
|
||||
It's the user's responsibility to use the special functions in the forward's `ctx`
|
||||
properly in order to ensure that the new :class:`Function` works properly with
|
||||
the autograd engine.
|
||||
|
||||
- :meth:`~torch.autograd.function._ContextMethodMixin.save_for_backward` must be
|
||||
used when saving input or ouput of the forward to be used later in the backward.
|
||||
- :meth:`~torch.autograd.function._ContextMethodMixin.mark_dirty` must be used to
|
||||
marked any input that is modified inplace by the forward function.
|
||||
- :meth:`~torch.autograd.function._ContextMethodMixin.mark_non_differentiable` must
|
||||
be used to tell the engine if an output is not differentiable.
|
||||
|
||||
|
||||
Below you can find code for a ``Linear`` function from :mod:`torch.nn`, with
|
||||
additional comments::
|
||||
|
||||
|
Reference in New Issue
Block a user