[autograd][docs] Add more details on why save_for_backward is important in extending autograd note (#153005)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153005
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2025-05-07 14:10:25 -04:00
committed by PyTorch MergeBot
parent 50657120a0
commit 9d00f2b375
2 changed files with 20 additions and 5 deletions

View File

@ -107,11 +107,25 @@ Take the following steps:
properly in order to ensure that the new :class:`Function` works properly with
the autograd engine.
- :meth:`~torch.autograd.function.FunctionCtx.save_for_backward` must be
used to save any tensors to be used in the backward pass. Non-tensors should
be stored directly on `ctx`. If tensors that are neither input nor output
are saved for backward your :class:`~Function` may not support double backward
(see step 3).
- :meth:`~torch.autograd.function.FunctionCtx.save_for_backward` should be
used to save any tensors needed for the backward pass (as opposed to
directly on ``ctx``). You cannot use ``save_for_backward`` for non-tensors;
you should store those directly on ``ctx``.
Saving tensors via ``save_for_backward``:
1. Allows the autograd engine to clear
them as soon as the backward computation of the ``autograd.Function`` completes.
(If a tensor is stored directly on ``ctx``
it will unnecessarily remain alive for the lifetime of the autograd graph --
typically until the end of the iteration.)
2. Helps avoid certain reference cycles, (e.g., since the tensor
output of the ``autograd.Function`` itself keeps a reference to the ctx).
3. Is important for compatibility with
features like activation checkpointing and offloading that rely on
:class:`torch.autograd.graph.saved_tensors_hooks`.
If tensors that are neither inputs nor outputs are saved for backward your
:class:`~Function` may not support double backward (see step 3).
- :meth:`~torch.autograd.function.FunctionCtx.mark_dirty` must be used to
mark any input that is modified inplace by the forward function.
- :meth:`~torch.autograd.function.FunctionCtx.mark_non_differentiable` must

View File

@ -42,6 +42,7 @@ class FunctionCtx:
with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
incorrect gradients and memory leaks, and enable the application of saved
tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
See :ref:`extending-autograd` for more details.
Note that if intermediary tensors, tensors that are neither inputs
nor outputs of :func:`forward`, are saved for backward, your custom Function