mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: This is a grab-bag of documentation formatting fixes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/9359 Differential Revision: D8831400 Pulled By: soumith fbshipit-source-id: 8dac02303168b2ea365e23938ee528d8e8c9f9b7
150 lines
5.8 KiB
Python
150 lines
5.8 KiB
Python
import torch
|
|
import warnings
|
|
|
|
|
|
def detach_variable(inputs):
|
|
if isinstance(inputs, tuple):
|
|
out = []
|
|
for inp in inputs:
|
|
x = inp.detach()
|
|
x.requires_grad = inp.requires_grad
|
|
out.append(x)
|
|
return tuple(out)
|
|
else:
|
|
raise RuntimeError(
|
|
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
|
|
|
|
|
def check_backward_validity(inputs):
|
|
if not any(inp.requires_grad for inp in inputs):
|
|
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
|
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, run_function, *args):
|
|
check_backward_validity(args)
|
|
ctx.run_function = run_function
|
|
ctx.save_for_backward(*args)
|
|
with torch.no_grad():
|
|
outputs = run_function(*args)
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
if not torch.autograd._is_checkpoint_valid():
|
|
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
|
|
inputs = ctx.saved_tensors
|
|
detached_inputs = detach_variable(inputs)
|
|
with torch.enable_grad():
|
|
outputs = ctx.run_function(*detached_inputs)
|
|
|
|
if isinstance(outputs, torch.Tensor):
|
|
outputs = (outputs,)
|
|
torch.autograd.backward(outputs, args)
|
|
return (None,) + tuple(inp.grad for inp in detached_inputs)
|
|
|
|
|
|
def checkpoint(function, *args):
|
|
r"""Checkpoint a model or part of the model
|
|
|
|
Checkpointing works by trading compute for memory. Rather than storing all
|
|
intermediate activations of the entire computation graph for computing
|
|
backward, the checkpointed part does **not** save intermediate activations,
|
|
and instead recomputes them in backward pass. It can be applied on any part
|
|
of a model.
|
|
|
|
Specifically, in the forward pass, :attr:`function` will run in
|
|
:func:`torch.no_grad` manner, i.e., not storing the intermediate
|
|
activations. Instead, the forward pass saves the inputs tuple and the
|
|
:attr:`function` parameter. In the backwards pass, the saved inputs and
|
|
:attr:`function` is retreived, and the forward pass is computed on
|
|
:attr:`function` again, now tracking the intermediate activations, and then
|
|
the gradients are calculated using these activation values.
|
|
|
|
.. warning::
|
|
Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
|
|
with :func:`torch.autograd.backward`.
|
|
|
|
.. warning::
|
|
If :attr:`function` invocation during backward does anything different
|
|
than the one during forward, e.g., due to some global variable, the
|
|
checkpointed version won't be equivalent, and unfortunately it can't be
|
|
detected.
|
|
|
|
.. warning:
|
|
At least one of the inputs needs to have :code:`requires_grad=True` if
|
|
grads are needed for model inputs, otherwise the checkpointed part of the
|
|
model won't have gradients.
|
|
|
|
Args:
|
|
function: describes what to run in the forward pass of the model or
|
|
part of the model. It should also know how to handle the inputs
|
|
passed as the tuple. For example, in LSTM, if user passes
|
|
``(activation, hidden)``, :attr:`function` should correctly use the
|
|
first input as ``activation`` and the second input as ``hidden``
|
|
args: tuple containing inputs to the :attr:`function`
|
|
|
|
Returns:
|
|
Output of running :attr:`function` on :attr:`*args`
|
|
"""
|
|
return CheckpointFunction.apply(function, *args)
|
|
|
|
|
|
def checkpoint_sequential(functions, segments, *inputs):
|
|
r"""A helper function for checkpointing sequential models.
|
|
|
|
Sequential models execute a list of modules/functions in order
|
|
(sequentially). Therefore, we can divide such a model in various segments
|
|
and checkpoint each segment. All segments except the last will run in
|
|
:func:`torch.no_grad` manner, i.e., not storing the intermediate
|
|
activations. The inputs of each checkpointed segment will be saved for
|
|
re-running the segment in the backward pass.
|
|
|
|
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
|
|
|
|
.. warning::
|
|
Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
|
|
with :func:`torch.autograd.backward`.
|
|
|
|
.. warning:
|
|
At least one of the inputs needs to have :code:`requires_grad=True` if
|
|
grads are needed for model inputs, otherwise the checkpointed part of the
|
|
model won't have gradients.
|
|
|
|
Args:
|
|
functions: A :class:`torch.nn.Sequential` or the list of modules or
|
|
functions (comprising the model) to run sequentially.
|
|
segments: Number of chunks to create in the model
|
|
inputs: tuple of Tensors that are inputs to :attr:`functions`
|
|
|
|
Returns:
|
|
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
|
|
|
Example:
|
|
>>> model = nn.Sequential(...)
|
|
>>> input_var = checkpoint_sequential(model, chunks, input_var)
|
|
"""
|
|
|
|
def run_function(start, end, functions):
|
|
def forward(*inputs):
|
|
input = inputs[0]
|
|
for j in range(start, end + 1):
|
|
input = functions[j](input)
|
|
return input
|
|
return forward
|
|
|
|
if isinstance(functions, torch.nn.Sequential):
|
|
functions = list(functions.children())
|
|
|
|
segment_size = len(functions) // segments
|
|
# the last chunk has to be non-volatile
|
|
end = -1
|
|
for start in range(0, segment_size * (segments - 1), segment_size):
|
|
end = start + segment_size - 1
|
|
inputs = checkpoint(run_function(start, end, functions), *inputs)
|
|
if not isinstance(inputs, tuple):
|
|
inputs = (inputs,)
|
|
return run_function(end + 1, len(functions) - 1, functions)(*inputs)
|