Files
pytorch/docs/source/checkpoint.md
windsonsea fbd88ae2b5 Convert to markdown: checkpoint.rst (#156009)
Related to #155014

Use two commits to have a try.
```bash
 1800  git mv docs/source/checkpoint.rst docs/source/checkpoint.md
 1802  git commit -m "[Docs] Rename checkpoint.rst"
 1803  git push origin ckpoint

# update the markdown file
 1805  git add .
 1806  git commit -m "modify checkpoint.md"
 1807  git push origin ckpoint
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156009
Approved by: https://github.com/svekars
2025-06-16 17:48:23 +00:00

2.2 KiB

torch.utils.checkpoint

Checkpointing is implemented by rerunning a forward-pass segment for
each checkpointed segment during backward propagation.  This can cause persistent
states like the RNG state to be more advanced than they would without
checkpointing.  By default, checkpointing includes logic to juggle
the RNG state such that checkpointed passes making use of RNG
(through dropout for example) have deterministic output as
compared to non-checkpointed passes.  The logic to stash and restore
RNG states can incur a moderate performance hit depending on the runtime
of checkpointed operations.  If deterministic output compared to
non-checkpointed passes is not required, supply `preserve_rng_state=False`
to `checkpoint` or `checkpoint_sequential` to omit stashing and
restoring the RNG state during each checkpoint.

The stashing logic saves and restores the RNG state for CPU and another
device type (infer the device type from Tensor arguments excluding CPU
tensors by `_infer_device_type`) to the `run_fn`. If there are multiple
device, device state will only be saved for devices of a single device type,
and the remaining devices will be ignored. Consequently, if any checkpointed
functions involve randomness, this may result in incorrect gradients. (Note
that if CUDA devices are among the devices detected, it will be prioritized;
otherwise, the first device encountered will be selected.) If there are no
CPU-tensors, the default device type state (default value is `cuda`, and it
could be set to other device by `DefaultDeviceType`) will be saved and restored.
However, the logic has no way to anticipate if the user will move
Tensors to a new device within the `run_fn` itself.  Therefore, if you move
Tensors to a new device ("new" meaning not belonging to the set of
[current device + devices of Tensor arguments]) within `run_fn`, deterministic
output compared to non-checkpointed passes is never guaranteed.
.. currentmodule:: torch.utils.checkpoint
.. autofunction:: checkpoint
.. autofunction:: checkpoint_sequential
.. autofunction:: set_checkpoint_debug_enabled
.. autoclass:: CheckpointPolicy
.. autoclass:: SelectiveCheckpointContext
.. autofunction:: create_selective_checkpoint_contexts