Files
pytorch/docs/source/checkpoint.rst
Michael Carilli 5d3a347685 Stashing checkpointing RNG states based on devices of arg tensors (#14518)
Summary:
This PR intends to address apaszke's concerns in https://github.com/pytorch/pytorch/pull/14253#issuecomment-441740016.  Preserving the rng state is now controlled by a kwarg rather than a global state, hopefully in a python 2.7-compatible way.

Additionally, the checkpointing function stashes and restores the RNG states of
1. devices associated with all input tensor args to run_fn as well as
2. the current device.

I could easily change this to only save and restore the RNG states associated 1. alone.  This would simplify the logic to create a [deduplicated, ordered](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R37) list of devices considered active.

I'm wondering if the [get_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R32) and [set_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) functions are general enough to reside elsewhere (presumably torch/random.py).  I'm also wondering if the check on [torch.cuda._initialized](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) would be better placed within `get_device_states`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14518

Differential Revision: D13356210

Pulled By: ezyang

fbshipit-source-id: afa4cc21ce7862142d5cb1dec3750018df222039
2018-12-11 09:48:45 -08:00

29 lines
1.5 KiB
ReStructuredText

torch.utils.checkpoint
======================
.. note::
Checkpointing is implemented by rerunning a forward-pass segment for
each checkpointed segment during backward. This can cause persistent
states like the RNG state to be advanced than they would without
checkpointing. By default, checkpointing includes logic to juggle
the RNG state such that checkpointed passes making use of RNG
(through dropout for example) have deterministic output as
compared to non-checkpointed passes. The logic to stash and restore
RNG states can incur a moderate performance hit depending on the runtime
of checkpointed operations. If deterministic output compared to
non-checkpointed passes is not required, supply ``preserve_rng_state=False``
to ``checkpoint`` or ``checkpoint_sequential`` to omit stashing and
restoring the RNG state during each checkpoint.
The stashing logic saves and restores the RNG state for the current device
and the device of all cuda Tensor arguments to the ``run_fn``.
However, the logic has no way to anticipate if the user will move
Tensors to a new device within the ``run_fn`` itself. Therefore, if you move
Tensors to a new device ("new" meaning not belonging to the set of
[current device + devices of Tensor arguments]) within ``run_fn``, deterministic
output compared to non-checkpointed passes is never guaranteed.
.. currentmodule:: torch.utils.checkpoint
.. autofunction:: checkpoint
.. autofunction:: checkpoint_sequential