mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[autograd][docs] Add more details on why save_for_backward is important in extending autograd note (#153005)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153005 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
50657120a0
commit
9d00f2b375
@ -107,11 +107,25 @@ Take the following steps:
|
||||
properly in order to ensure that the new :class:`Function` works properly with
|
||||
the autograd engine.
|
||||
|
||||
- :meth:`~torch.autograd.function.FunctionCtx.save_for_backward` must be
|
||||
used to save any tensors to be used in the backward pass. Non-tensors should
|
||||
be stored directly on `ctx`. If tensors that are neither input nor output
|
||||
are saved for backward your :class:`~Function` may not support double backward
|
||||
(see step 3).
|
||||
- :meth:`~torch.autograd.function.FunctionCtx.save_for_backward` should be
|
||||
used to save any tensors needed for the backward pass (as opposed to
|
||||
directly on ``ctx``). You cannot use ``save_for_backward`` for non-tensors;
|
||||
you should store those directly on ``ctx``.
|
||||
|
||||
Saving tensors via ``save_for_backward``:
|
||||
1. Allows the autograd engine to clear
|
||||
them as soon as the backward computation of the ``autograd.Function`` completes.
|
||||
(If a tensor is stored directly on ``ctx``
|
||||
it will unnecessarily remain alive for the lifetime of the autograd graph --
|
||||
typically until the end of the iteration.)
|
||||
2. Helps avoid certain reference cycles, (e.g., since the tensor
|
||||
output of the ``autograd.Function`` itself keeps a reference to the ctx).
|
||||
3. Is important for compatibility with
|
||||
features like activation checkpointing and offloading that rely on
|
||||
:class:`torch.autograd.graph.saved_tensors_hooks`.
|
||||
|
||||
If tensors that are neither inputs nor outputs are saved for backward your
|
||||
:class:`~Function` may not support double backward (see step 3).
|
||||
- :meth:`~torch.autograd.function.FunctionCtx.mark_dirty` must be used to
|
||||
mark any input that is modified inplace by the forward function.
|
||||
- :meth:`~torch.autograd.function.FunctionCtx.mark_non_differentiable` must
|
||||
|
@ -42,6 +42,7 @@ class FunctionCtx:
|
||||
with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
|
||||
incorrect gradients and memory leaks, and enable the application of saved
|
||||
tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
|
||||
See :ref:`extending-autograd` for more details.
|
||||
|
||||
Note that if intermediary tensors, tensors that are neither inputs
|
||||
nor outputs of :func:`forward`, are saved for backward, your custom Function
|
||||
|
Reference in New Issue
Block a user