mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86622 Approved by: https://github.com/albanD
442 lines
20 KiB
Python
442 lines
20 KiB
Python
import torch
|
|
import warnings
|
|
import weakref
|
|
from typing import Any, Iterable, List, Tuple
|
|
|
|
__all__ = [
|
|
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
|
|
"check_backward_validity", "detach_variable", "get_device_states",
|
|
"set_device_states",
|
|
]
|
|
|
|
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
|
if isinstance(inputs, tuple):
|
|
out = []
|
|
for inp in inputs:
|
|
if not isinstance(inp, torch.Tensor):
|
|
out.append(inp)
|
|
continue
|
|
|
|
x = inp.detach()
|
|
x.requires_grad = inp.requires_grad
|
|
out.append(x)
|
|
return tuple(out)
|
|
else:
|
|
raise RuntimeError(
|
|
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
|
|
|
|
|
def check_backward_validity(inputs: Iterable[Any]) -> None:
|
|
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
|
|
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
|
|
|
|
|
|
# We can't know if the run_fn will internally move some args to different devices,
|
|
# which would require logic to preserve rng states for those devices as well.
|
|
# We could paranoically stash and restore ALL the rng states for all visible devices,
|
|
# but that seems very wasteful for most cases. Compromise: Stash the RNG state for
|
|
# the device of all Tensor args.
|
|
#
|
|
# To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
|
|
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
|
|
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
|
|
# the conditionals short-circuit.
|
|
fwd_gpu_devices = list(set(arg.get_device() for arg in args
|
|
if isinstance(arg, torch.Tensor) and arg.is_cuda))
|
|
|
|
fwd_gpu_states = []
|
|
for device in fwd_gpu_devices:
|
|
with torch.cuda.device(device):
|
|
fwd_gpu_states.append(torch.cuda.get_rng_state())
|
|
|
|
return fwd_gpu_devices, fwd_gpu_states
|
|
|
|
|
|
def set_device_states(devices, states) -> None:
|
|
for device, state in zip(devices, states):
|
|
with torch.cuda.device(device):
|
|
torch.cuda.set_rng_state(state)
|
|
|
|
def _get_autocast_kwargs():
|
|
gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
|
"dtype": torch.get_autocast_gpu_dtype(),
|
|
"cache_enabled": torch.is_autocast_cache_enabled()}
|
|
|
|
cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
|
|
"dtype": torch.get_autocast_cpu_dtype(),
|
|
"cache_enabled": torch.is_autocast_cache_enabled()}
|
|
|
|
return gpu_autocast_kwargs, cpu_autocast_kwargs
|
|
|
|
class CheckpointFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
|
check_backward_validity(args)
|
|
ctx.run_function = run_function
|
|
ctx.preserve_rng_state = preserve_rng_state
|
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
|
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
|
|
if preserve_rng_state:
|
|
ctx.fwd_cpu_state = torch.get_rng_state()
|
|
# Don't eagerly initialize the cuda context by accident.
|
|
# (If the user intends that the context is initialized later, within their
|
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
|
# we have no way to anticipate this will happen before we run the function.)
|
|
ctx.had_cuda_in_fwd = False
|
|
if torch.cuda._initialized:
|
|
ctx.had_cuda_in_fwd = True
|
|
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
|
|
|
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
|
# to be filled out during the backward.
|
|
ctx.inputs = []
|
|
ctx.tensor_indices = []
|
|
tensor_inputs = []
|
|
for i, arg in enumerate(args):
|
|
if torch.is_tensor(arg):
|
|
tensor_inputs.append(arg)
|
|
ctx.tensor_indices.append(i)
|
|
ctx.inputs.append(None)
|
|
else:
|
|
ctx.inputs.append(arg)
|
|
|
|
ctx.save_for_backward(*tensor_inputs)
|
|
|
|
with torch.no_grad():
|
|
outputs = run_function(*args)
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
if not torch.autograd._is_checkpoint_valid():
|
|
raise RuntimeError(
|
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
|
" argument.")
|
|
# Copy the list to avoid modifying original list.
|
|
inputs = list(ctx.inputs)
|
|
tensor_indices = ctx.tensor_indices
|
|
tensors = ctx.saved_tensors
|
|
|
|
# Fill in inputs with appropriate saved tensors.
|
|
for i, idx in enumerate(tensor_indices):
|
|
inputs[idx] = tensors[i]
|
|
|
|
# Stash the surrounding rng state, and mimic the state that was
|
|
# present at this time during forward. Restore the surrounding state
|
|
# when we're done.
|
|
rng_devices = []
|
|
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
|
|
rng_devices = ctx.fwd_gpu_devices
|
|
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
|
|
if ctx.preserve_rng_state:
|
|
torch.set_rng_state(ctx.fwd_cpu_state)
|
|
if ctx.had_cuda_in_fwd:
|
|
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
|
|
detached_inputs = detach_variable(tuple(inputs))
|
|
with torch.enable_grad(), \
|
|
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
|
|
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
|
outputs = ctx.run_function(*detached_inputs)
|
|
|
|
if isinstance(outputs, torch.Tensor):
|
|
outputs = (outputs,)
|
|
|
|
# run backward() with only tensor that requires grad
|
|
outputs_with_grad = []
|
|
args_with_grad = []
|
|
for i in range(len(outputs)):
|
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
|
outputs_with_grad.append(outputs[i])
|
|
args_with_grad.append(args[i])
|
|
if len(outputs_with_grad) == 0:
|
|
raise RuntimeError(
|
|
"none of output has requires_grad=True,"
|
|
" this checkpoint() is not necessary")
|
|
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
|
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
|
for inp in detached_inputs)
|
|
|
|
return (None, None) + grads
|
|
|
|
|
|
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
|
|
r"""Checkpoint a model or part of the model
|
|
|
|
Checkpointing works by trading compute for memory. Rather than storing all
|
|
intermediate activations of the entire computation graph for computing
|
|
backward, the checkpointed part does **not** save intermediate activations,
|
|
and instead recomputes them in backward pass. It can be applied on any part
|
|
of a model.
|
|
|
|
Specifically, in the forward pass, :attr:`function` will run in
|
|
:func:`torch.no_grad` manner, i.e., not storing the intermediate
|
|
activations. Instead, the forward pass saves the inputs tuple and the
|
|
:attr:`function` parameter. In the backwards pass, the saved inputs and
|
|
:attr:`function` is retrieved, and the forward pass is computed on
|
|
:attr:`function` again, now tracking the intermediate activations, and then
|
|
the gradients are calculated using these activation values.
|
|
|
|
The output of :attr:`function` can contain non-Tensor values and gradient
|
|
recording is only performed for the Tensor values. Note that if the output
|
|
consists of nested structures (ex: custom objects, lists, dicts etc.)
|
|
consisting of Tensors, these Tensors nested in custom structures will not
|
|
be considered as part of autograd.
|
|
|
|
|
|
.. warning::
|
|
If :attr:`function` invocation during backward does anything different
|
|
than the one during forward, e.g., due to some global variable, the
|
|
checkpointed version won't be equivalent, and unfortunately it can't be
|
|
detected.
|
|
|
|
.. warning::
|
|
If ``use_reentrant=True`` is specified, then if the checkpointed segment
|
|
contains tensors detached from the computational graph by `detach()` or
|
|
`torch.no_grad()`, the backward pass will raise an error. This is
|
|
because `checkpoint` makes all the outputs require gradients which
|
|
causes issues when a tensor is defined to have no gradient in the model.
|
|
To circumvent this, detach the tensors outside of the `checkpoint`
|
|
function. Note that the checkpointed segment can contain tensors
|
|
detached from the computational graph if ``use_reentrant=False`` is
|
|
specified.
|
|
|
|
.. warning::
|
|
If ``use_reentrant=True`` is specified, at least one of the inputs needs
|
|
to have :code:`requires_grad=True` if grads are needed for model inputs,
|
|
otherwise the checkpointed part of the model won't have gradients. At
|
|
least one of the outputs needs to have :code:`requires_grad=True` as
|
|
well. Note that this does not apply if ``use_reentrant=False`` is
|
|
specified.
|
|
|
|
.. warning::
|
|
If ``use_reentrant=True`` is specified, checkpointing currently only
|
|
supports :func:`torch.autograd.backward` and only if its `inputs`
|
|
argument is not passed. :func:`torch.autograd.grad`
|
|
is not supported. If ``use_reentrant=False`` is specified, checkpointing
|
|
will work with :func:`torch.autograd.grad`.
|
|
|
|
Args:
|
|
function: describes what to run in the forward pass of the model or
|
|
part of the model. It should also know how to handle the inputs
|
|
passed as the tuple. For example, in LSTM, if user passes
|
|
``(activation, hidden)``, :attr:`function` should correctly use the
|
|
first input as ``activation`` and the second input as ``hidden``
|
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
|
the RNG state during each checkpoint.
|
|
Default: ``True``
|
|
use_reentrant(bool, optional): Use checkpointing
|
|
implementation that requires re-entrant autograd.
|
|
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
|
|
implementation that does not require re-entrant autograd. This
|
|
allows ``checkpoint`` to support additional functionality, such as
|
|
working as expected with ``torch.autograd.grad`` and support for
|
|
keyword arguments input into the checkpointed function. Note that future
|
|
versions of PyTorch will default to ``use_reentrant=False``.
|
|
Default: ``True``
|
|
args: tuple containing inputs to the :attr:`function`
|
|
|
|
Returns:
|
|
Output of running :attr:`function` on :attr:`*args`
|
|
"""
|
|
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
|
preserve = kwargs.pop('preserve_rng_state', True)
|
|
if kwargs and use_reentrant:
|
|
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
|
|
|
if use_reentrant:
|
|
return CheckpointFunction.apply(function, preserve, *args)
|
|
else:
|
|
return _checkpoint_without_reentrant(
|
|
function,
|
|
preserve,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs):
|
|
r"""A helper function for checkpointing sequential models.
|
|
|
|
Sequential models execute a list of modules/functions in order
|
|
(sequentially). Therefore, we can divide such a model in various segments
|
|
and checkpoint each segment. All segments except the last will run in
|
|
:func:`torch.no_grad` manner, i.e., not storing the intermediate
|
|
activations. The inputs of each checkpointed segment will be saved for
|
|
re-running the segment in the backward pass.
|
|
|
|
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
|
|
|
|
.. warning::
|
|
Checkpointing currently only supports :func:`torch.autograd.backward`
|
|
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
|
|
is not supported.
|
|
|
|
.. warning:
|
|
At least one of the inputs needs to have :code:`requires_grad=True` if
|
|
grads are needed for model inputs, otherwise the checkpointed part of the
|
|
model won't have gradients.
|
|
|
|
.. warning:
|
|
Since PyTorch 1.4, it allows only one Tensor as the input and
|
|
intermediate outputs, just like :class:`torch.nn.Sequential`.
|
|
|
|
Args:
|
|
functions: A :class:`torch.nn.Sequential` or the list of modules or
|
|
functions (comprising the model) to run sequentially.
|
|
segments: Number of chunks to create in the model
|
|
input: A Tensor that is input to :attr:`functions`
|
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
|
the RNG state during each checkpoint.
|
|
Default: ``True``
|
|
use_reentrant(bool, optional): Use checkpointing
|
|
implementation that requires re-entrant autograd.
|
|
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
|
|
implementation that does not require re-entrant autograd. This
|
|
allows ``checkpoint`` to support additional functionality, such as
|
|
working as expected with ``torch.autograd.grad`` and support for
|
|
keyword arguments input into the checkpointed function.
|
|
Default: ``True``
|
|
|
|
Returns:
|
|
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("stub")
|
|
>>> model = nn.Sequential(...)
|
|
>>> input_var = checkpoint_sequential(model, chunks, input_var)
|
|
"""
|
|
# Hack for keyword-only parameter in a python 2.7-compliant way
|
|
preserve = kwargs.pop('preserve_rng_state', True)
|
|
if kwargs:
|
|
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
|
|
|
def run_function(start, end, functions):
|
|
def forward(input):
|
|
for j in range(start, end + 1):
|
|
input = functions[j](input)
|
|
return input
|
|
return forward
|
|
|
|
if isinstance(functions, torch.nn.Sequential):
|
|
functions = list(functions.children())
|
|
|
|
segment_size = len(functions) // segments
|
|
# the last chunk has to be non-volatile
|
|
end = -1
|
|
for start in range(0, segment_size * (segments - 1), segment_size):
|
|
end = start + segment_size - 1
|
|
input = checkpoint(
|
|
run_function(start, end, functions),
|
|
input,
|
|
use_reentrant=use_reentrant,
|
|
preserve_rng_state=preserve
|
|
)
|
|
return run_function(end + 1, len(functions) - 1, functions)(input)
|
|
|
|
def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs):
|
|
"""Checkpointining without re-entrant autograd
|
|
Args:
|
|
function: describes what to run in the forward pass of the model or
|
|
part of the model. It should also know how to handle the inputs
|
|
passed as the tuple. For example, in LSTM, if user passes
|
|
``(activation, hidden)``, :attr:`function` should correctly use the
|
|
first input as ``activation`` and the second input as ``hidden``
|
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
|
the RNG state during each checkpoint.
|
|
Default: ``True``
|
|
*args: Arguments to pass in to the given ``function``.
|
|
**kwargs: Keyword arguments to pass into the given ``function``.
|
|
"""
|
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
|
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
|
|
|
|
if preserve_rng_state:
|
|
fwd_cpu_state = torch.get_rng_state()
|
|
# Don't eagerly initialize the cuda context by accident.
|
|
# (If the user intends that the context is initialized later, within their
|
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
|
# we have no way to anticipate this will happen before we run the function.
|
|
# If they do so, we raise an error.)
|
|
had_cuda_in_fwd = False
|
|
if torch.cuda._initialized:
|
|
had_cuda_in_fwd = True
|
|
fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
|
|
|
|
# Custom class to be able to take weak references
|
|
class Holder():
|
|
pass
|
|
# The Holder object for each of the saved object is saved directly on the
|
|
# SavedVariable and is cleared when reset_data() is called on it. We MUST make
|
|
# sure that this is the only object having an owning reference to ensure that
|
|
# the Tensor stored in storage is deleted as soon as the corresponding SavedVariable
|
|
# data is cleared.
|
|
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
|
weak_holder_list = []
|
|
|
|
def pack(x):
|
|
# TODO(varal7): Instead of returning abstract object, we can return things metadata (such as
|
|
# size, device, ...) to catch certain cases of undeterministic behavior of the forward
|
|
res = Holder()
|
|
weak_holder_list.append(weakref.ref(res))
|
|
return res
|
|
|
|
|
|
def unpack(x):
|
|
unpack_counter = 0
|
|
if len(storage) == 0:
|
|
def inner_pack(inner):
|
|
nonlocal unpack_counter
|
|
unpack_counter += 1
|
|
# If the holder went out of scope, the SavedVariable is dead and so
|
|
# the value will never be read from the storage. Skip filling it.
|
|
if weak_holder_list[unpack_counter - 1]() is None:
|
|
return
|
|
# Use detach here to ensure we don't keep the temporary autograd
|
|
# graph created during the second forward
|
|
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
|
|
return
|
|
|
|
def inner_unpack(packed):
|
|
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
|
|
|
|
# Stash the surrounding rng state, and mimic the state that was
|
|
# present at this time during forward. Restore the surrounding state
|
|
# when we're done.
|
|
rng_devices = []
|
|
if preserve_rng_state and had_cuda_in_fwd:
|
|
rng_devices = fwd_gpu_devices
|
|
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
|
|
if preserve_rng_state:
|
|
torch.set_rng_state(fwd_cpu_state)
|
|
if had_cuda_in_fwd:
|
|
set_device_states(fwd_gpu_devices, fwd_gpu_states)
|
|
|
|
with torch.enable_grad(), \
|
|
torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
|
|
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
|
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
|
_unused = function(*args, **kwargs)
|
|
|
|
if x not in storage:
|
|
raise RuntimeError(
|
|
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
|
" recomputation being triggered in between, this is not currently supported. Please"
|
|
" open an issue with details on your use case so that we can prioritize adding this."
|
|
)
|
|
|
|
return storage[x]
|
|
|
|
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
|
output = function(*args, **kwargs)
|
|
if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
|
|
# Cuda was not initialized before running the forward, so we didn't
|
|
# stash the CUDA state.
|
|
raise RuntimeError(
|
|
"PyTorch's CUDA state was initialized in the forward pass "
|
|
"of a Checkpoint, which is not allowed. Please open an issue "
|
|
"if you need this feature.")
|
|
|
|
return output
|