mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Per title Differential Revision: [D45933467](https://our.internmc.facebook.com/intern/diff/D45933467/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/101649 Approved by: https://github.com/Skylion007, https://github.com/soulitzer
910 lines
36 KiB
Python
910 lines
36 KiB
Python
import contextlib
|
|
import uuid
|
|
import warnings
|
|
import weakref
|
|
from collections import defaultdict
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
ContextManager,
|
|
DefaultDict,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
)
|
|
from weakref import ReferenceType
|
|
|
|
import torch
|
|
|
|
__all__ = [
|
|
"checkpoint",
|
|
"checkpoint_sequential",
|
|
"CheckpointFunction",
|
|
"check_backward_validity",
|
|
"detach_variable",
|
|
"get_device_states",
|
|
"set_device_states",
|
|
"noop_context_fn",
|
|
"set_checkpoint_early_stop",
|
|
"DefaultDeviceType",
|
|
]
|
|
|
|
|
|
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
|
if isinstance(inputs, tuple):
|
|
out = []
|
|
for inp in inputs:
|
|
if not isinstance(inp, torch.Tensor):
|
|
out.append(inp)
|
|
continue
|
|
|
|
x = inp.detach()
|
|
x.requires_grad = inp.requires_grad
|
|
out.append(x)
|
|
return tuple(out)
|
|
else:
|
|
raise RuntimeError(
|
|
"Only tuple of tensors is supported. Got Unsupported input type: ",
|
|
type(inputs).__name__,
|
|
)
|
|
|
|
|
|
def check_backward_validity(inputs: Iterable[Any]) -> None:
|
|
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
|
|
warnings.warn(
|
|
"None of the inputs have requires_grad=True. Gradients will be None"
|
|
)
|
|
|
|
|
|
def _get_device_module(device="cuda"):
|
|
device_module = getattr(torch, device)
|
|
return device_module
|
|
|
|
|
|
class DefaultDeviceType(object):
|
|
r"""
|
|
A class that manages the default device type for checkpointing.
|
|
If no non-CPU tensors are present, the default device type will
|
|
be used. The default value is 'cuda'. The device type is used in
|
|
the checkpointing process when determining which device states
|
|
to save and restore for recomputation.
|
|
"""
|
|
_default_device_type = "cuda"
|
|
|
|
@staticmethod
|
|
def set_device_type(device: str = "cuda"):
|
|
"""
|
|
Set the default device type for checkpointing.
|
|
|
|
Args:
|
|
device (str): The device type to be set as default. Default is 'cuda'.
|
|
"""
|
|
DefaultDeviceType._default_device_type = device
|
|
|
|
@staticmethod
|
|
def get_device_type() -> str:
|
|
"""
|
|
Get the current default device type for checkpointing.
|
|
|
|
Returns:
|
|
str: The current default device type.
|
|
"""
|
|
return DefaultDeviceType._default_device_type
|
|
|
|
|
|
def _infer_device_type(*args):
|
|
device_types = list(
|
|
{
|
|
arg.device.type
|
|
for arg in args
|
|
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
|
|
}
|
|
)
|
|
if len(device_types) > 1:
|
|
warnings.warn(
|
|
"Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
|
|
"Device state will only be saved for devices of a single device type, and the remaining "
|
|
"devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
|
|
"this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
|
|
"detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
|
|
)
|
|
if len(device_types) == 0:
|
|
return DefaultDeviceType.get_device_type()
|
|
elif "cuda" in device_types:
|
|
return "cuda"
|
|
else:
|
|
return device_types[0]
|
|
|
|
|
|
# We can't know if the run_fn will internally move some args to different devices,
|
|
# which would require logic to preserve rng states for those devices as well.
|
|
# We could paranoically stash and restore ALL the rng states for all visible devices,
|
|
# but that seems very wasteful for most cases. Compromise: Stash the RNG state for
|
|
# the device of all Tensor args.
|
|
#
|
|
# To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
|
|
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
|
|
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
|
|
# the conditionals short-circuit.
|
|
fwd_device_ids = list(
|
|
{
|
|
arg.get_device()
|
|
for arg in args
|
|
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
|
|
}
|
|
)
|
|
|
|
fwd_device_states = []
|
|
device_module = _get_device_module(_infer_device_type(*args))
|
|
|
|
for device_id in fwd_device_ids:
|
|
with device_module.device(device_id):
|
|
fwd_device_states.append(device_module.get_rng_state())
|
|
|
|
return fwd_device_ids, fwd_device_states
|
|
|
|
|
|
def set_device_states(devices, states) -> None:
|
|
device_module = _get_device_module(_infer_device_type(*states))
|
|
for device, state in zip(devices, states):
|
|
with device_module.device(device):
|
|
device_module.set_rng_state(state)
|
|
|
|
|
|
def _get_autocast_kwargs(device="cuda"):
|
|
|
|
if device == "cuda":
|
|
device_autocast_kwargs = {
|
|
"enabled": torch.is_autocast_enabled(),
|
|
"dtype": torch.get_autocast_gpu_dtype(),
|
|
"cache_enabled": torch.is_autocast_cache_enabled(),
|
|
}
|
|
else:
|
|
device_module = _get_device_module(device)
|
|
device_autocast_kwargs = {
|
|
"enabled": device_module.is_autocast_enabled(),
|
|
"dtype": device_module.get_autocast_dtype(),
|
|
"cache_enabled": torch.is_autocast_cache_enabled(),
|
|
}
|
|
|
|
cpu_autocast_kwargs = {
|
|
"enabled": torch.is_autocast_cpu_enabled(),
|
|
"dtype": torch.get_autocast_cpu_dtype(),
|
|
"cache_enabled": torch.is_autocast_cache_enabled(),
|
|
}
|
|
|
|
return device_autocast_kwargs, cpu_autocast_kwargs
|
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
|
check_backward_validity(args)
|
|
ctx.run_function = run_function
|
|
ctx.preserve_rng_state = preserve_rng_state
|
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
|
ctx.device = _infer_device_type(*args)
|
|
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
|
|
ctx.device
|
|
)
|
|
if preserve_rng_state:
|
|
ctx.fwd_cpu_state = torch.get_rng_state()
|
|
# Don't eagerly initialize the cuda context by accident.
|
|
# (If the user intends that the context is initialized later, within their
|
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
|
# we have no way to anticipate this will happen before we run the function.)
|
|
ctx.had_device_in_fwd = False
|
|
device_module = _get_device_module(ctx.device)
|
|
if device_module._initialized:
|
|
ctx.had_device_in_fwd = True
|
|
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
|
|
|
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
|
# to be filled out during the backward.
|
|
ctx.inputs = []
|
|
ctx.tensor_indices = []
|
|
tensor_inputs = []
|
|
for i, arg in enumerate(args):
|
|
if torch.is_tensor(arg):
|
|
tensor_inputs.append(arg)
|
|
ctx.tensor_indices.append(i)
|
|
ctx.inputs.append(None)
|
|
else:
|
|
ctx.inputs.append(arg)
|
|
|
|
ctx.save_for_backward(*tensor_inputs)
|
|
|
|
with torch.no_grad():
|
|
outputs = run_function(*args)
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
if not torch.autograd._is_checkpoint_valid():
|
|
raise RuntimeError(
|
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
|
" argument."
|
|
)
|
|
# Copy the list to avoid modifying original list.
|
|
inputs = list(ctx.inputs)
|
|
tensor_indices = ctx.tensor_indices
|
|
tensors = ctx.saved_tensors
|
|
device_module = _get_device_module(ctx.device)
|
|
|
|
# Fill in inputs with appropriate saved tensors.
|
|
for i, idx in enumerate(tensor_indices):
|
|
inputs[idx] = tensors[i]
|
|
|
|
# Stash the surrounding rng state, and mimic the state that was
|
|
# present at this time during forward. Restore the surrounding state
|
|
# when we're done.
|
|
rng_devices = []
|
|
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
|
|
rng_devices = ctx.fwd_devices
|
|
with torch.random.fork_rng(
|
|
devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device
|
|
):
|
|
if ctx.preserve_rng_state:
|
|
torch.set_rng_state(ctx.fwd_cpu_state)
|
|
if ctx.had_device_in_fwd:
|
|
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
|
|
detached_inputs = detach_variable(tuple(inputs))
|
|
with torch.enable_grad(), device_module.amp.autocast(
|
|
**ctx.device_autocast_kwargs
|
|
), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
|
outputs = ctx.run_function(*detached_inputs)
|
|
|
|
if isinstance(outputs, torch.Tensor):
|
|
outputs = (outputs,)
|
|
|
|
# run backward() with only tensor that requires grad
|
|
outputs_with_grad = []
|
|
args_with_grad = []
|
|
for i in range(len(outputs)):
|
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
|
outputs_with_grad.append(outputs[i])
|
|
args_with_grad.append(args[i])
|
|
if len(outputs_with_grad) == 0:
|
|
raise RuntimeError(
|
|
"none of output has requires_grad=True,"
|
|
" this checkpoint() is not necessary"
|
|
)
|
|
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
|
grads = tuple(
|
|
inp.grad if isinstance(inp, torch.Tensor) else None
|
|
for inp in detached_inputs
|
|
)
|
|
|
|
return (None, None) + grads
|
|
|
|
|
|
def noop_context_fn():
|
|
return contextlib.nullcontext(), contextlib.nullcontext()
|
|
|
|
|
|
def checkpoint(
|
|
function,
|
|
*args,
|
|
use_reentrant: Optional[bool] = None,
|
|
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
|
**kwargs
|
|
):
|
|
r"""Checkpoint a model or part of the model
|
|
|
|
Checkpointing is a technique that trades compute for memory. Instead of
|
|
storing all intermediate activations of the entire computation graph for
|
|
the backward pass, the checkpointed part omits saving intermediate
|
|
activations and recomputes them during the backward pass. This can be
|
|
applied to any part of a model.
|
|
|
|
There are currently two checkpointing implementations available, determined
|
|
by the :attr:`use_reentrant` parameter. It is recommended that you use
|
|
``use_reentrant=False``. Please refer the note below for a discussion of
|
|
their differences.
|
|
|
|
.. warning::
|
|
|
|
If the :attr:`function` invocation during the backward pass differs
|
|
from the forward pass, e.g., due to a global variable, the checkpointed
|
|
checkpointed version may not be equivalent, potentially causing an
|
|
error being raised or leading to silently incorrect gradients.
|
|
|
|
.. warning::
|
|
|
|
If you are using the ``use_reentrant=True`` variant (this is currently
|
|
the default), please refer to the note below for important
|
|
considerations and potential limitations.
|
|
|
|
.. note::
|
|
|
|
The reentrant variant of checkpoint (``use_reentrant=True``) and
|
|
the non-reentrant variant of checkpoint (``use_reentrant=False``)
|
|
differ in the following ways:
|
|
|
|
* Non-reentrant checkpoint stops recomputation as soon as all needed
|
|
intermediate activations have been recomputed. This feature is enabled
|
|
by default, but can be disabled with :func:`set_checkpoint_early_stop`.
|
|
Reentrant checkpoint always recomputes :attr:`function` in its
|
|
entirety during the backward pass.
|
|
|
|
* The reentrant variant does not record the autograd graph during the
|
|
forward pass, as it runs with the forward pass under
|
|
:func:`torch.no_grad`. The non-reentrant version does record the
|
|
autograd graph, allowing one to perform backward on the graph within
|
|
checkpointed regions.
|
|
|
|
* The reentrant checkpoint only supports the
|
|
:func:`torch.autograd.backward` API for the backward pass without its
|
|
`inputs` argument, while the non-reentrant version supports all ways
|
|
of performing the backward pass.
|
|
|
|
* At least one input and output must have ``requires_grad=True`` for the
|
|
reentrant variant. If this condition is unmet, the checkpointed part
|
|
of the model will not have gradients. The non-reentrant version does
|
|
not have this requirement.
|
|
|
|
* The reentrant version does not consider tensors in nested structures
|
|
(e.g., custom objects, lists, dicts, etc) as participating in
|
|
autograd, while the non-reentrant version does.
|
|
|
|
* The reentrant checkpoint does not support checkpointed regions with
|
|
detached tensors from the computational graph, whereas the
|
|
non-reentrant version does. For the reentrant variant, if the
|
|
checkpointed segment contains tensors detached using ``detach()`` or
|
|
with :func:`torch.no_grad`, the backward pass will raise an error.
|
|
This is because ``checkpoint`` makes all the outputs require gradients
|
|
and this causes issues when a tensor is defined to have no gradient in
|
|
the model. To avoid this, detach the tensors outside of the
|
|
``checkpoint`` function.
|
|
|
|
Args:
|
|
function: describes what to run in the forward pass of the model or
|
|
part of the model. It should also know how to handle the inputs
|
|
passed as the tuple. For example, in LSTM, if user passes
|
|
``(activation, hidden)``, :attr:`function` should correctly use the
|
|
first input as ``activation`` and the second input as ``hidden``
|
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
|
the RNG state during each checkpoint.
|
|
Default: ``True``
|
|
use_reentrant(bool, optional): Use checkpointing
|
|
implementation that requires re-entrant autograd.
|
|
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
|
|
implementation that does not require re-entrant autograd. This
|
|
allows ``checkpoint`` to support additional functionality, such as
|
|
working as expected with ``torch.autograd.grad`` and support for
|
|
keyword arguments input into the checkpointed function. Note that future
|
|
versions of PyTorch will default to ``use_reentrant=False``.
|
|
Default: ``True``
|
|
context_fn(Callable, optional): A callable returning a tuple of two
|
|
context managers. The function and its recomputation will be run
|
|
under the first and second context managers respectively.
|
|
This argument is only supported if ``use_reentrant=False``.
|
|
args: tuple containing inputs to the :attr:`function`
|
|
|
|
Returns:
|
|
Output of running :attr:`function` on :attr:`*args`
|
|
"""
|
|
if use_reentrant is None:
|
|
warnings.warn(
|
|
"torch.utils.checkpoint: please pass in use_reentrant=True or "
|
|
"use_reentrant=False explicitly. The default value of use_reentrant "
|
|
"will be updated to be False in the future. To maintain current "
|
|
"behavior, pass use_reentrant=True. It is recommended that you use "
|
|
"use_reentrant=False. Refer to docs for more details on the "
|
|
"differences between the two variants."
|
|
)
|
|
use_reentrant = True
|
|
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
|
preserve = kwargs.pop("preserve_rng_state", True)
|
|
if kwargs and use_reentrant:
|
|
raise ValueError(
|
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
|
)
|
|
|
|
if use_reentrant:
|
|
if context_fn is not noop_context_fn:
|
|
raise ValueError(
|
|
"Passing context_fn is only supported when use_reentrant=False."
|
|
)
|
|
return CheckpointFunction.apply(function, preserve, *args)
|
|
else:
|
|
return _checkpoint_without_reentrant(
|
|
function,
|
|
preserve,
|
|
context_fn,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs):
|
|
r"""A helper function for checkpointing sequential models.
|
|
|
|
Sequential models execute a list of modules/functions in order
|
|
(sequentially). Therefore, we can divide such a model in various segments
|
|
and checkpoint each segment. All segments except the last will not store
|
|
the intermediate activations. The inputs of each checkpointed segment will
|
|
be saved for re-running the segment in the backward pass.
|
|
|
|
.. warning::
|
|
If you are using the ``use_reentrant=True` variant (this is the
|
|
default), please see :func:`~torch.utils.checkpoint.checkpoint` for
|
|
the important considerations and limitations of this variant. It is
|
|
recommended that you use ``use_reentrant=False``.
|
|
|
|
.. warning:
|
|
Since PyTorch 1.4, it allows only one Tensor as the input and
|
|
intermediate outputs, just like :class:`torch.nn.Sequential`.
|
|
|
|
Args:
|
|
functions: A :class:`torch.nn.Sequential` or the list of modules or
|
|
functions (comprising the model) to run sequentially.
|
|
segments: Number of chunks to create in the model
|
|
input: A Tensor that is input to :attr:`functions`
|
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
|
the RNG state during each checkpoint.
|
|
Default: ``True``
|
|
use_reentrant(bool, optional): Use checkpointing
|
|
implementation that requires re-entrant autograd.
|
|
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
|
|
implementation that does not require re-entrant autograd. This
|
|
allows ``checkpoint`` to support additional functionality, such as
|
|
working as expected with ``torch.autograd.grad`` and support for
|
|
keyword arguments input into the checkpointed function.
|
|
Default: ``True``
|
|
|
|
Returns:
|
|
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("stub")
|
|
>>> model = nn.Sequential(...)
|
|
>>> input_var = checkpoint_sequential(model, chunks, input_var)
|
|
"""
|
|
# Hack for keyword-only parameter in a python 2.7-compliant way
|
|
preserve = kwargs.pop("preserve_rng_state", True)
|
|
if kwargs:
|
|
raise ValueError(
|
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
|
)
|
|
|
|
def run_function(start, end, functions):
|
|
def forward(input):
|
|
for j in range(start, end + 1):
|
|
input = functions[j](input)
|
|
return input
|
|
|
|
return forward
|
|
|
|
if isinstance(functions, torch.nn.Sequential):
|
|
functions = list(functions.children())
|
|
|
|
segment_size = len(functions) // segments
|
|
# the last chunk has to be non-volatile
|
|
end = -1
|
|
for start in range(0, segment_size * (segments - 1), segment_size):
|
|
end = start + segment_size - 1
|
|
input = checkpoint(
|
|
run_function(start, end, functions),
|
|
input,
|
|
use_reentrant=use_reentrant,
|
|
preserve_rng_state=preserve,
|
|
)
|
|
return run_function(end + 1, len(functions) - 1, functions)(input)
|
|
|
|
|
|
def _internal_assert(cond):
|
|
if not cond:
|
|
raise AssertionError(
|
|
"Something went unexpectedly wrong in activation checkpoint. "
|
|
"Please report this bug by filing an issue to PyTorch."
|
|
)
|
|
|
|
|
|
# NOTE [ Nestable Checkpoint ]
|
|
#
|
|
# The semantics of nested checkpoint can be defined by two basic rules.
|
|
# Following the two rules leads to an important implication that is central
|
|
# to motivating the design.
|
|
#
|
|
# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden
|
|
# from any outer layers of checkpoint.
|
|
#
|
|
# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its
|
|
# parent checkpoint.
|
|
#
|
|
# Implication: To recompute any given saved tensor, we need to recompute all of
|
|
# the checkpoints wrapping it.
|
|
#
|
|
# Why is this implied? To unpack a saved tensor X during backward we need to
|
|
# recompute the inner-most checkpoint (#1), and in order to recompute that
|
|
# checkpoint I need to have its inputs, which are managed by that checkpoint's
|
|
# parent (#2), which thus also needs to be recomputed first. Continue this line
|
|
# of reasoning and we realize that in order to unpack X, all checkpoints that
|
|
# were active at the time X was saved need to be recomputed. (unless we have
|
|
# already done so in that backward for some other saved tensor).
|
|
#
|
|
# In practice, we use a noop autograd Function to save inputs as saved tensors.
|
|
# During unpack calling ctx.saved_tensor triggers the parent checkpoint to
|
|
# recompute.
|
|
#
|
|
# Rule 3. We should start recomputation as if there are no checkpoints currently
|
|
# active. Checkpoints encountered during recomputation are still
|
|
# respected.
|
|
#
|
|
# When we start recomputation, we push the saved variable hook meant for
|
|
# recomputation on the stack. See examples in Rule 6 for more context.
|
|
#
|
|
# * * * *
|
|
#
|
|
# Beyond the basic semantics specific to nested checkpoint, we impose several
|
|
# more constraints that may apply to checkpointing in general.
|
|
#
|
|
# Rule 4. Lifetime of recomputed tensors
|
|
#
|
|
# Recomputed tensors are considered specific to particular invocations
|
|
# of backward and are always cleared immediately as they are unpacked
|
|
# Particularly, we require this to happen even if retain_graph=True.
|
|
#
|
|
# [ Implementation details of Rule 4 ]
|
|
#
|
|
# If we were okay with recomputed tensors staying alive after backward is run
|
|
# with retain_graph=True, we would store recomputed variables as the values of a
|
|
# WeakKeyDictionary and pack strong references to the keys, so that as we
|
|
# backward, those packed keys would be cleared as long as retain_graph=False.
|
|
# Clearing the packed key clears the corresponding entry in the WKD.
|
|
#
|
|
# If we wish recomputed variables to be immediately cleared as we unpack them in
|
|
# the retain_graph=True case, we cannot rely on the packed keys to be cleared by
|
|
# backward automatically. Instead of packing the strong reference to the key
|
|
# directly, we pack a container object, which we manually clear as we unpack.
|
|
#
|
|
# An important detail is that if a second backward happens, the second
|
|
# recomputation needs to reset the container with a newly created key.
|
|
#
|
|
# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we
|
|
# know we need.
|
|
#
|
|
# [ Implementation details of Rule 5 ]
|
|
#
|
|
# During recomputation, raise an exception if the number of recomputed tensors
|
|
# matches the number of tensors that we expected to recompute. We wrap the
|
|
# recomputation call with a try-catch to catch this specific exception. See
|
|
# Rule #6 below for some examples.
|
|
#
|
|
# Rule 6. We support doing backward inside checkpoint context
|
|
#
|
|
# [ retain_graph is True]
|
|
#
|
|
# def fn(x):
|
|
# y = x.sin()
|
|
# z = y.cos()
|
|
# gx, = torch.autograd.grad(z, x, retains_grad=True)
|
|
# return gx, z
|
|
#
|
|
# out = checkpoint(fn)(inp)
|
|
# out.backward()
|
|
#
|
|
# Because z is saved by cos while checkpoint is enabled, it would not be
|
|
# actually saved, and so the .grad() call inside must trigger a recomputation.
|
|
#
|
|
# During recomputation the "inner pack hook" has two responsibilities:
|
|
#
|
|
# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors
|
|
# 2) Pack the actual tensor (detached) so that one may perform backward on the
|
|
# recomputed graph. The tensors saved to this graph will live until the end
|
|
# of recomputation, or die earlier if someone performs backward with
|
|
# retain_graph=False.
|
|
#
|
|
# More generally performing backward on the recomputed graph occurs in the
|
|
# following cases:
|
|
# - If backward is performed inside forward,
|
|
# - During the original forward IF early-stop is disabled
|
|
# - During the original backward
|
|
# - If there are multiple .grad()/.backward() calls, we would perform backward
|
|
# on the recomputed graph even if early-stop is enabled (see the example below)
|
|
#
|
|
# [ retain_graph is False ]
|
|
#
|
|
# The example below shows what happens if during recomputation we find that some
|
|
# of the tensors we are trying to recompute have already been cleared.
|
|
#
|
|
# Spoiler: we don't do anything special, we just skip over them!
|
|
#
|
|
# def fn(x):
|
|
# y = x.sin() # (1)
|
|
# z = y.cos() # (2)
|
|
# gx, = torch.autograd.grad(z, x) # (3)
|
|
# return x.cos() * gx # (4)
|
|
#
|
|
# out = checkpoint(fn)(inp)
|
|
# out.backward() # (5)
|
|
#
|
|
# 1, 2. Don't save x and y since we are inside a checkpoint.
|
|
# 3. Trigger a recompute of fn since x and y weren't saved.
|
|
# And depending on whether early stop is enabled, either stop at (2) or
|
|
# continue running the function.
|
|
# Because we are running backward with retain_graph=False, we clear x and y's
|
|
# holders.
|
|
# 4. Don't save x since we are inside a checkpoint.
|
|
# 5. Calling backward triggers another recompute of fn. During recompute, we see
|
|
# that x and y have already been cleared in the original graph as indicated
|
|
# by holder=None. We skip over them. We still save x at (4) (since its holder
|
|
# is still alive.)
|
|
|
|
_enable_checkpoint_early_stop = True
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def set_checkpoint_early_stop(enable: bool):
|
|
"""Context manager that sets whether checkpoint should stop recomputation
|
|
early.
|
|
|
|
By default, non-reentrant checkpoint stops recomputation as soon as it
|
|
has computed all needed Tensors. This context manager can be used to disable
|
|
that feature if it is problematic for your specific application.
|
|
|
|
This context manager only needs to be active when forward is run. It does
|
|
not need to be active during backward.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> message = "saved tensors default hooks are disabled"
|
|
>>> with set_checkpoint_early_stop(False):
|
|
... # Any checkpoint under this context manager will respect this
|
|
... # context manager, even if its backward is performed outside.
|
|
... out = checkpoint(fn, inputs)
|
|
...
|
|
>>> out.backward()
|
|
"""
|
|
global _enable_checkpoint_early_stop
|
|
try:
|
|
prev = _enable_checkpoint_early_stop
|
|
_enable_checkpoint_early_stop = enable
|
|
yield
|
|
finally:
|
|
_enable_checkpoint_early_stop = prev
|
|
|
|
|
|
class _Handle:
|
|
pass
|
|
|
|
|
|
class _Holder:
|
|
def __init__(self):
|
|
self.handles: Dict[int, Optional[_Handle]] = dict()
|
|
|
|
|
|
class _NoopSaveInputs(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(*args):
|
|
return torch.empty((0,))
|
|
|
|
@staticmethod
|
|
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
|
|
# Only tensors can be saved with ctx.save_for_backward, everything else
|
|
# is captured by get_args, which is saved directly on ctx
|
|
tensor_indices, tensors = zip(
|
|
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
|
|
)
|
|
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
|
|
# args but with tensors replaced with None as placeholders
|
|
args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
|
|
|
|
def get_args(saved_tensors):
|
|
# restore the placeholders with the original tensors grabbed from
|
|
# ctx.saved_tensors (which may be saved on a parent checkpoint if
|
|
# this checkpoint is nested, and that would trigger a recursive
|
|
# unpack!)
|
|
ret = [
|
|
saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
|
|
for i, o in enumerate(args)
|
|
]
|
|
# grab the tail since we also saved the dummy to avoid having to explicitly
|
|
# handle the case where there are no tensor inputs
|
|
return ret[1:]
|
|
|
|
ctx.get_args = get_args
|
|
ctx.save_for_backward(*tensors)
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad_outputs):
|
|
raise AssertionError("Did not expect to backward on this graph")
|
|
|
|
|
|
class _CheckpointFrame:
|
|
def __init__(self, recompute_fn, early_stop):
|
|
self.recompute_fn = recompute_fn
|
|
self.input_saver = None
|
|
self.weak_holders: List[ReferenceType] = []
|
|
# We store this as a weakkeydictionary so that in the case of a partial
|
|
# backward, the entries in the dict are cleared alongside the Holder
|
|
# which will be removed when the SavedVariable is cleared.
|
|
self.recomputed: DefaultDict[
|
|
int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
|
|
] = defaultdict(weakref.WeakKeyDictionary)
|
|
# We need both recomp_counter and recomputed since they can diverge
|
|
# https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
|
|
self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
|
|
self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)
|
|
|
|
# See Rule 5
|
|
self.early_stop = early_stop
|
|
|
|
|
|
# See Rule 5
|
|
class _StopRecomputationError(Exception):
|
|
pass
|
|
|
|
|
|
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
|
def __init__(self, target_frame_ref: ReferenceType, gid: int):
|
|
def pack_hook(x):
|
|
target_frame = target_frame_ref()
|
|
assert target_frame is not None # appease mypy
|
|
recomp_idx = target_frame.recomp_counter[gid]
|
|
target_frame.recomp_counter[gid] += 1
|
|
|
|
if recomp_idx >= len(target_frame.weak_holders):
|
|
# We run into this case when early stop is not enabled and do
|
|
# grad within checkpoint.
|
|
return x.detach()
|
|
holder = target_frame.weak_holders[recomp_idx]()
|
|
|
|
# This holder may have been cleared because someone may have called
|
|
# backward within forward. If so, we don't need to save.
|
|
if holder is not None:
|
|
_internal_assert(holder.handles.get(gid, None) is None)
|
|
holder.handles[gid] = _Handle()
|
|
target_frame.recomputed[gid][holder.handles[gid]] = x.detach()
|
|
|
|
if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
|
|
target_frame.weak_holders
|
|
):
|
|
raise _StopRecomputationError()
|
|
# See Rule 6: [ retain_graph is True ] above
|
|
return x.detach()
|
|
|
|
def unpack_hook(x):
|
|
# See Rule 6: [ retain_graph is True ] above for an example of when
|
|
# the graph created during recomputation could be backwarded.
|
|
return x
|
|
|
|
super().__init__(pack_hook, unpack_hook)
|
|
|
|
|
|
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
|
def __init__(self, frame):
|
|
def pack_hook(_unused_x):
|
|
# See Rule 4 above
|
|
holder = _Holder()
|
|
frame.weak_holders.append(weakref.ref(holder))
|
|
return holder
|
|
|
|
def unpack_hook(holder):
|
|
gid = torch._C._current_graph_task_id()
|
|
if gid == -1:
|
|
# generate a temporary id if we trigger unpack outside of a backward call
|
|
gid = int(uuid.uuid4())
|
|
|
|
if not frame.is_recomputed[gid]:
|
|
ctx = frame.input_saver.grad_fn
|
|
args = ctx.get_args(ctx.saved_tensors)
|
|
|
|
try:
|
|
with _recomputation_hook(
|
|
weakref.ref(frame), gid
|
|
), torch.autograd.enable_grad():
|
|
frame.recompute_fn(*args)
|
|
if frame.early_stop:
|
|
raise AssertionError(
|
|
"if early stop is enabled, we don't expect to reach here"
|
|
)
|
|
except _StopRecomputationError:
|
|
pass
|
|
frame.is_recomputed[gid] = True
|
|
|
|
if holder.handles[gid] is None:
|
|
raise RuntimeError(
|
|
"torch.utils.checkpoint: unpack is being triggered for a tensor that was either "
|
|
"never recomputed, or already unpacked once. If you are calling ctx.saved_tensors "
|
|
"in backward, make sure to do so only once. Otherwise please open an issue with "
|
|
"details on your use case."
|
|
)
|
|
_internal_assert(holder.handles[gid] in frame.recomputed[gid])
|
|
ret = frame.recomputed[gid][holder.handles[gid]]
|
|
holder.handles[gid] = None
|
|
return ret
|
|
|
|
super().__init__(pack_hook, unpack_hook)
|
|
|
|
|
|
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
|
|
# saving/restoring of global state is handled here.
|
|
def _checkpoint_without_reentrant(
|
|
fn,
|
|
preserve_rng_state=True,
|
|
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
|
*args,
|
|
**kwargs
|
|
):
|
|
"""Checkpointining without re-entrant autograd
|
|
Args:
|
|
function: describes what to run in the forward pass of the model or
|
|
part of the model. It should also know how to handle the inputs
|
|
passed as the tuple. For example, in LSTM, if user passes
|
|
``(activation, hidden)``, :attr:`function` should correctly use the
|
|
first input as ``activation`` and the second input as ``hidden``
|
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
|
the RNG state during each checkpoint.
|
|
Default: ``True``
|
|
context_fn(Callable, optional): A callable returning a tuple of two
|
|
context managers. The function and its recomputation will be run
|
|
under the first and second context managers respectively.
|
|
*args: Arguments to pass in to the given ``function``.
|
|
**kwargs: Keyword arguments to pass into the given ``function``.
|
|
"""
|
|
device = _infer_device_type(*args)
|
|
device_module = _get_device_module(device)
|
|
forward_context, recompute_context = context_fn()
|
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
|
device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device=device)
|
|
|
|
if preserve_rng_state:
|
|
fwd_cpu_state = torch.get_rng_state()
|
|
# Don't eagerly initialize the cuda context by accident.
|
|
# (If the user intends that the context is initialized later, within their
|
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
|
# we have no way to anticipate this will happen before we run the function.
|
|
# If they do so, we raise an error.)
|
|
had_device_in_fwd = False
|
|
if device_module._initialized:
|
|
had_device_in_fwd = True
|
|
fwd_devices, fwd_device_states = get_device_states(*args)
|
|
|
|
def recompute_fn(*inputs):
|
|
kwargs, *args = inputs
|
|
# This will be called later during recomputation. This wrapping enables
|
|
# the necessary global state to be captured.
|
|
rng_devices = []
|
|
if preserve_rng_state and had_device_in_fwd:
|
|
rng_devices = fwd_devices
|
|
with torch.random.fork_rng(
|
|
devices=rng_devices, enabled=preserve_rng_state, device_type=device
|
|
):
|
|
if preserve_rng_state:
|
|
torch.set_rng_state(fwd_cpu_state)
|
|
if had_device_in_fwd:
|
|
set_device_states(fwd_devices, fwd_device_states)
|
|
|
|
with device_module.amp.autocast(
|
|
**device_autocast_kwargs
|
|
), torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:
|
|
fn(*args, **kwargs)
|
|
|
|
new_frame = _CheckpointFrame(recompute_fn, _enable_checkpoint_early_stop)
|
|
dummy = torch.empty((0,), requires_grad=True)
|
|
new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
|
|
|
|
# When ambient grad_mode is False
|
|
if new_frame.input_saver.grad_fn is None:
|
|
return fn(*args, **kwargs)
|
|
|
|
with _checkpoint_hook(new_frame), forward_context:
|
|
ret = fn(*args, **kwargs)
|
|
|
|
if device_module._initialized and preserve_rng_state and not had_device_in_fwd:
|
|
# Device was not initialized before running the forward, so we didn't
|
|
# stash the device state.
|
|
raise RuntimeError(
|
|
"PyTorch's device state was initialized in the forward pass "
|
|
"of a Checkpoint, which is not allowed. Please open an issue "
|
|
"if you need this feature."
|
|
)
|
|
|
|
return ret
|