mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Summary: This PR intends to address apaszke's concerns in https://github.com/pytorch/pytorch/pull/14253#issuecomment-441740016. Preserving the rng state is now controlled by a kwarg rather than a global state, hopefully in a python 2.7-compatible way. Additionally, the checkpointing function stashes and restores the RNG states of 1. devices associated with all input tensor args to run_fn as well as 2. the current device. I could easily change this to only save and restore the RNG states associated 1. alone. This would simplify the logic to create a [deduplicated, ordered](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R37) list of devices considered active. I'm wondering if the [get_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R32) and [set_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) functions are general enough to reside elsewhere (presumably torch/random.py). I'm also wondering if the check on [torch.cuda._initialized](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) would be better placed within `get_device_states`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14518 Differential Revision: D13356210 Pulled By: ezyang fbshipit-source-id: afa4cc21ce7862142d5cb1dec3750018df222039
		
			
				
	
	
		
			29 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			ReStructuredText
		
	
	
	
	
	
			
		
		
	
	
			29 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			ReStructuredText
		
	
	
	
	
	
torch.utils.checkpoint
 | 
						|
======================
 | 
						|
 | 
						|
.. note::
 | 
						|
    Checkpointing is implemented by rerunning a forward-pass segment for
 | 
						|
    each checkpointed segment during backward.  This can cause persistent
 | 
						|
    states like the RNG state to be advanced than they would without
 | 
						|
    checkpointing.  By default, checkpointing includes logic to juggle
 | 
						|
    the RNG state such that checkpointed passes making use of RNG
 | 
						|
    (through dropout for example) have deterministic output as
 | 
						|
    compared to non-checkpointed passes.  The logic to stash and restore
 | 
						|
    RNG states can incur a moderate performance hit depending on the runtime
 | 
						|
    of checkpointed operations.  If deterministic output compared to
 | 
						|
    non-checkpointed passes is not required, supply ``preserve_rng_state=False``
 | 
						|
    to ``checkpoint`` or ``checkpoint_sequential`` to omit stashing and
 | 
						|
    restoring the RNG state during each checkpoint.
 | 
						|
 | 
						|
    The stashing logic saves and restores the RNG state for the current device
 | 
						|
    and the device of all cuda Tensor arguments to the ``run_fn``.
 | 
						|
    However, the logic has no way to anticipate if the user will move
 | 
						|
    Tensors to a new device within the ``run_fn`` itself.  Therefore, if you move
 | 
						|
    Tensors to a new device ("new" meaning not belonging to the set of
 | 
						|
    [current device + devices of Tensor arguments]) within ``run_fn``, deterministic
 | 
						|
    output compared to non-checkpointed passes is never guaranteed.
 | 
						|
 | 
						|
.. currentmodule:: torch.utils.checkpoint
 | 
						|
.. autofunction:: checkpoint
 | 
						|
.. autofunction:: checkpoint_sequential
 |