mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Autoformat torch/utils/checkpoint (#101649)
Per title Differential Revision: [D45933467](https://our.internmc.facebook.com/intern/diff/D45933467/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/101649 Approved by: https://github.com/Skylion007, https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
d7f6bfe651
commit
60547fcbee
@ -1,19 +1,37 @@
|
||||
import torch
|
||||
import contextlib
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from weakref import ReferenceType
|
||||
from typing import Any, Callable, ContextManager, Iterable, List, Tuple, Dict, Optional, DefaultDict
|
||||
from collections import defaultdict
|
||||
import uuid
|
||||
import contextlib
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
|
||||
"check_backward_validity", "detach_variable", "get_device_states",
|
||||
"set_device_states", "noop_context_fn", "set_checkpoint_early_stop",
|
||||
"DefaultDeviceType"
|
||||
"checkpoint",
|
||||
"checkpoint_sequential",
|
||||
"CheckpointFunction",
|
||||
"check_backward_validity",
|
||||
"detach_variable",
|
||||
"get_device_states",
|
||||
"set_device_states",
|
||||
"noop_context_fn",
|
||||
"set_checkpoint_early_stop",
|
||||
"DefaultDeviceType",
|
||||
]
|
||||
|
||||
|
||||
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
||||
if isinstance(inputs, tuple):
|
||||
out = []
|
||||
@ -28,18 +46,23 @@ def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
||||
return tuple(out)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
||||
"Only tuple of tensors is supported. Got Unsupported input type: ",
|
||||
type(inputs).__name__,
|
||||
)
|
||||
|
||||
|
||||
def check_backward_validity(inputs: Iterable[Any]) -> None:
|
||||
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
|
||||
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
|
||||
warnings.warn(
|
||||
"None of the inputs have requires_grad=True. Gradients will be None"
|
||||
)
|
||||
|
||||
|
||||
def _get_device_module(device='cuda'):
|
||||
def _get_device_module(device="cuda"):
|
||||
device_module = getattr(torch, device)
|
||||
return device_module
|
||||
|
||||
|
||||
class DefaultDeviceType(object):
|
||||
r"""
|
||||
A class that manages the default device type for checkpointing.
|
||||
@ -70,15 +93,23 @@ class DefaultDeviceType(object):
|
||||
"""
|
||||
return DefaultDeviceType._default_device_type
|
||||
|
||||
|
||||
def _infer_device_type(*args):
|
||||
device_types = list({arg.device.type for arg in args
|
||||
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"})
|
||||
device_types = list(
|
||||
{
|
||||
arg.device.type
|
||||
for arg in args
|
||||
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
|
||||
}
|
||||
)
|
||||
if len(device_types) > 1:
|
||||
warnings.warn("Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
|
||||
"Device state will only be saved for devices of a single device type, and the remaining "
|
||||
"devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
|
||||
"this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
|
||||
"detected, it will be prioritized; otherwise, the first device encountered will be selected.)")
|
||||
warnings.warn(
|
||||
"Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
|
||||
"Device state will only be saved for devices of a single device type, and the remaining "
|
||||
"devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
|
||||
"this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
|
||||
"detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
|
||||
)
|
||||
if len(device_types) == 0:
|
||||
return DefaultDeviceType.get_device_type()
|
||||
elif "cuda" in device_types:
|
||||
@ -86,6 +117,7 @@ def _infer_device_type(*args):
|
||||
else:
|
||||
return device_types[0]
|
||||
|
||||
|
||||
# We can't know if the run_fn will internally move some args to different devices,
|
||||
# which would require logic to preserve rng states for those devices as well.
|
||||
# We could paranoically stash and restore ALL the rng states for all visible devices,
|
||||
@ -96,8 +128,13 @@ def _infer_device_type(*args):
|
||||
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
|
||||
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
|
||||
# the conditionals short-circuit.
|
||||
fwd_device_ids = list({arg.get_device() for arg in args
|
||||
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"})
|
||||
fwd_device_ids = list(
|
||||
{
|
||||
arg.get_device()
|
||||
for arg in args
|
||||
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
|
||||
}
|
||||
)
|
||||
|
||||
fwd_device_states = []
|
||||
device_module = _get_device_module(_infer_device_type(*args))
|
||||
@ -119,23 +156,29 @@ def set_device_states(devices, states) -> None:
|
||||
def _get_autocast_kwargs(device="cuda"):
|
||||
|
||||
if device == "cuda":
|
||||
device_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||
device_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
else:
|
||||
device_module = _get_device_module(device)
|
||||
device_autocast_kwargs = {"enabled": device_module.is_autocast_enabled(),
|
||||
"dtype": device_module.get_autocast_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||
device_autocast_kwargs = {
|
||||
"enabled": device_module.is_autocast_enabled(),
|
||||
"dtype": device_module.get_autocast_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
|
||||
cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
|
||||
"dtype": torch.get_autocast_cpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||
cpu_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_cpu_enabled(),
|
||||
"dtype": torch.get_autocast_cpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
|
||||
return device_autocast_kwargs, cpu_autocast_kwargs
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||
check_backward_validity(args)
|
||||
@ -143,7 +186,9 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
ctx.preserve_rng_state = preserve_rng_state
|
||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||
ctx.device = _infer_device_type(*args)
|
||||
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(ctx.device)
|
||||
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
|
||||
ctx.device
|
||||
)
|
||||
if preserve_rng_state:
|
||||
ctx.fwd_cpu_state = torch.get_rng_state()
|
||||
# Don't eagerly initialize the cuda context by accident.
|
||||
@ -181,7 +226,8 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
||||
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
||||
" argument.")
|
||||
" argument."
|
||||
)
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
@ -198,15 +244,17 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
rng_devices = []
|
||||
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
|
||||
rng_devices = ctx.fwd_devices
|
||||
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device):
|
||||
with torch.random.fork_rng(
|
||||
devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device
|
||||
):
|
||||
if ctx.preserve_rng_state:
|
||||
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||
if ctx.had_device_in_fwd:
|
||||
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
with torch.enable_grad(), \
|
||||
device_module.amp.autocast(**ctx.device_autocast_kwargs), \
|
||||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||
with torch.enable_grad(), device_module.amp.autocast(
|
||||
**ctx.device_autocast_kwargs
|
||||
), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
@ -222,10 +270,13 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError(
|
||||
"none of output has requires_grad=True,"
|
||||
" this checkpoint() is not necessary")
|
||||
" this checkpoint() is not necessary"
|
||||
)
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs)
|
||||
grads = tuple(
|
||||
inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs
|
||||
)
|
||||
|
||||
return (None, None) + grads
|
||||
|
||||
@ -343,16 +394,21 @@ def checkpoint(
|
||||
"will be updated to be False in the future. To maintain current "
|
||||
"behavior, pass use_reentrant=True. It is recommended that you use "
|
||||
"use_reentrant=False. Refer to docs for more details on the "
|
||||
"differences between the two variants.")
|
||||
"differences between the two variants."
|
||||
)
|
||||
use_reentrant = True
|
||||
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
||||
preserve = kwargs.pop('preserve_rng_state', True)
|
||||
preserve = kwargs.pop("preserve_rng_state", True)
|
||||
if kwargs and use_reentrant:
|
||||
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
||||
raise ValueError(
|
||||
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
if context_fn is not noop_context_fn:
|
||||
raise ValueError("Passing context_fn is only supported when use_reentrant=False.")
|
||||
raise ValueError(
|
||||
"Passing context_fn is only supported when use_reentrant=False."
|
||||
)
|
||||
return CheckpointFunction.apply(function, preserve, *args)
|
||||
else:
|
||||
return _checkpoint_without_reentrant(
|
||||
@ -409,15 +465,18 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwar
|
||||
>>> input_var = checkpoint_sequential(model, chunks, input_var)
|
||||
"""
|
||||
# Hack for keyword-only parameter in a python 2.7-compliant way
|
||||
preserve = kwargs.pop('preserve_rng_state', True)
|
||||
preserve = kwargs.pop("preserve_rng_state", True)
|
||||
if kwargs:
|
||||
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
||||
raise ValueError(
|
||||
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
||||
)
|
||||
|
||||
def run_function(start, end, functions):
|
||||
def forward(input):
|
||||
for j in range(start, end + 1):
|
||||
input = functions[j](input)
|
||||
return input
|
||||
|
||||
return forward
|
||||
|
||||
if isinstance(functions, torch.nn.Sequential):
|
||||
@ -432,10 +491,11 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwar
|
||||
run_function(start, end, functions),
|
||||
input,
|
||||
use_reentrant=use_reentrant,
|
||||
preserve_rng_state=preserve
|
||||
preserve_rng_state=preserve,
|
||||
)
|
||||
return run_function(end + 1, len(functions) - 1, functions)(input)
|
||||
|
||||
|
||||
def _internal_assert(cond):
|
||||
if not cond:
|
||||
raise AssertionError(
|
||||
@ -443,6 +503,7 @@ def _internal_assert(cond):
|
||||
"Please report this bug by filing an issue to PyTorch."
|
||||
)
|
||||
|
||||
|
||||
# NOTE [ Nestable Checkpoint ]
|
||||
#
|
||||
# The semantics of nested checkpoint can be defined by two basic rules.
|
||||
@ -576,6 +637,7 @@ def _internal_assert(cond):
|
||||
|
||||
_enable_checkpoint_early_stop = True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_checkpoint_early_stop(enable: bool):
|
||||
"""Context manager that sets whether checkpoint should stop recomputation
|
||||
@ -607,13 +669,16 @@ def set_checkpoint_early_stop(enable: bool):
|
||||
finally:
|
||||
_enable_checkpoint_early_stop = prev
|
||||
|
||||
class _Handle():
|
||||
|
||||
class _Handle:
|
||||
pass
|
||||
|
||||
class _Holder():
|
||||
|
||||
class _Holder:
|
||||
def __init__(self):
|
||||
self.handles: Dict[int, Optional[_Handle]] = dict()
|
||||
|
||||
|
||||
class _NoopSaveInputs(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(*args):
|
||||
@ -623,7 +688,9 @@ class _NoopSaveInputs(torch.autograd.Function):
|
||||
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
|
||||
# Only tensors can be saved with ctx.save_for_backward, everything else
|
||||
# is captured by get_args, which is saved directly on ctx
|
||||
tensor_indices, tensors = zip(*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)])
|
||||
tensor_indices, tensors = zip(
|
||||
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
|
||||
)
|
||||
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
|
||||
# args but with tensors replaced with None as placeholders
|
||||
args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
|
||||
@ -633,7 +700,10 @@ class _NoopSaveInputs(torch.autograd.Function):
|
||||
# ctx.saved_tensors (which may be saved on a parent checkpoint if
|
||||
# this checkpoint is nested, and that would trigger a recursive
|
||||
# unpack!)
|
||||
ret = [saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o for i, o in enumerate(args)]
|
||||
ret = [
|
||||
saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
|
||||
for i, o in enumerate(args)
|
||||
]
|
||||
# grab the tail since we also saved the dummy to avoid having to explicitly
|
||||
# handle the case where there are no tensor inputs
|
||||
return ret[1:]
|
||||
@ -645,7 +715,8 @@ class _NoopSaveInputs(torch.autograd.Function):
|
||||
def backward(ctx, *grad_outputs):
|
||||
raise AssertionError("Did not expect to backward on this graph")
|
||||
|
||||
class _CheckpointFrame():
|
||||
|
||||
class _CheckpointFrame:
|
||||
def __init__(self, recompute_fn, early_stop):
|
||||
self.recompute_fn = recompute_fn
|
||||
self.input_saver = None
|
||||
@ -653,8 +724,9 @@ class _CheckpointFrame():
|
||||
# We store this as a weakkeydictionary so that in the case of a partial
|
||||
# backward, the entries in the dict are cleared alongside the Holder
|
||||
# which will be removed when the SavedVariable is cleared.
|
||||
self.recomputed: DefaultDict[int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]] = \
|
||||
defaultdict(weakref.WeakKeyDictionary)
|
||||
self.recomputed: DefaultDict[
|
||||
int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
|
||||
] = defaultdict(weakref.WeakKeyDictionary)
|
||||
# We need both recomp_counter and recomputed since they can diverge
|
||||
# https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
|
||||
self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
|
||||
@ -663,10 +735,12 @@ class _CheckpointFrame():
|
||||
# See Rule 5
|
||||
self.early_stop = early_stop
|
||||
|
||||
|
||||
# See Rule 5
|
||||
class _StopRecomputationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
def __init__(self, target_frame_ref: ReferenceType, gid: int):
|
||||
def pack_hook(x):
|
||||
@ -688,8 +762,9 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
holder.handles[gid] = _Handle()
|
||||
target_frame.recomputed[gid][holder.handles[gid]] = x.detach()
|
||||
|
||||
if target_frame.early_stop and \
|
||||
target_frame.recomp_counter[gid] == len(target_frame.weak_holders):
|
||||
if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
|
||||
target_frame.weak_holders
|
||||
):
|
||||
raise _StopRecomputationError()
|
||||
# See Rule 6: [ retain_graph is True ] above
|
||||
return x.detach()
|
||||
@ -701,6 +776,7 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
|
||||
super().__init__(pack_hook, unpack_hook)
|
||||
|
||||
|
||||
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
def __init__(self, frame):
|
||||
def pack_hook(_unused_x):
|
||||
@ -720,10 +796,14 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
args = ctx.get_args(ctx.saved_tensors)
|
||||
|
||||
try:
|
||||
with _recomputation_hook(weakref.ref(frame), gid), torch.autograd.enable_grad():
|
||||
with _recomputation_hook(
|
||||
weakref.ref(frame), gid
|
||||
), torch.autograd.enable_grad():
|
||||
frame.recompute_fn(*args)
|
||||
if frame.early_stop:
|
||||
raise AssertionError("if early stop is enabled, we don't expect to reach here")
|
||||
raise AssertionError(
|
||||
"if early stop is enabled, we don't expect to reach here"
|
||||
)
|
||||
except _StopRecomputationError:
|
||||
pass
|
||||
frame.is_recomputed[gid] = True
|
||||
@ -742,6 +822,7 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
|
||||
super().__init__(pack_hook, unpack_hook)
|
||||
|
||||
|
||||
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
|
||||
# saving/restoring of global state is handled here.
|
||||
def _checkpoint_without_reentrant(
|
||||
@ -792,15 +873,17 @@ def _checkpoint_without_reentrant(
|
||||
rng_devices = []
|
||||
if preserve_rng_state and had_device_in_fwd:
|
||||
rng_devices = fwd_devices
|
||||
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state, device_type=device):
|
||||
with torch.random.fork_rng(
|
||||
devices=rng_devices, enabled=preserve_rng_state, device_type=device
|
||||
):
|
||||
if preserve_rng_state:
|
||||
torch.set_rng_state(fwd_cpu_state)
|
||||
if had_device_in_fwd:
|
||||
set_device_states(fwd_devices, fwd_device_states)
|
||||
|
||||
with device_module.amp.autocast(**device_autocast_kwargs), \
|
||||
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
||||
recompute_context:
|
||||
with device_module.amp.autocast(
|
||||
**device_autocast_kwargs
|
||||
), torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:
|
||||
fn(*args, **kwargs)
|
||||
|
||||
new_frame = _CheckpointFrame(recompute_fn, _enable_checkpoint_early_stop)
|
||||
@ -811,8 +894,7 @@ def _checkpoint_without_reentrant(
|
||||
if new_frame.input_saver.grad_fn is None:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
with _checkpoint_hook(new_frame), \
|
||||
forward_context:
|
||||
with _checkpoint_hook(new_frame), forward_context:
|
||||
ret = fn(*args, **kwargs)
|
||||
|
||||
if device_module._initialized and preserve_rng_state and not had_device_in_fwd:
|
||||
@ -821,6 +903,7 @@ def _checkpoint_without_reentrant(
|
||||
raise RuntimeError(
|
||||
"PyTorch's device state was initialized in the forward pass "
|
||||
"of a Checkpoint, which is not allowed. Please open an issue "
|
||||
"if you need this feature.")
|
||||
"if you need this feature."
|
||||
)
|
||||
|
||||
return ret
|
||||
|
Reference in New Issue
Block a user