Autoformat torch/utils/checkpoint (#101649)

Per title

Differential Revision: [D45933467](https://our.internmc.facebook.com/intern/diff/D45933467/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101649
Approved by: https://github.com/Skylion007, https://github.com/soulitzer
This commit is contained in:
Rohan Varma
2023-05-16 18:00:10 -07:00
committed by PyTorch MergeBot
parent d7f6bfe651
commit 60547fcbee

View File

@ -1,19 +1,37 @@
import torch
import contextlib
import uuid
import warnings
import weakref
from weakref import ReferenceType
from typing import Any, Callable, ContextManager, Iterable, List, Tuple, Dict, Optional, DefaultDict
from collections import defaultdict
import uuid
import contextlib
from typing import (
Any,
Callable,
ContextManager,
DefaultDict,
Dict,
Iterable,
List,
Optional,
Tuple,
)
from weakref import ReferenceType
import torch
__all__ = [
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
"check_backward_validity", "detach_variable", "get_device_states",
"set_device_states", "noop_context_fn", "set_checkpoint_early_stop",
"DefaultDeviceType"
"checkpoint",
"checkpoint_sequential",
"CheckpointFunction",
"check_backward_validity",
"detach_variable",
"get_device_states",
"set_device_states",
"noop_context_fn",
"set_checkpoint_early_stop",
"DefaultDeviceType",
]
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
if isinstance(inputs, tuple):
out = []
@ -28,18 +46,23 @@ def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
"Only tuple of tensors is supported. Got Unsupported input type: ",
type(inputs).__name__,
)
def check_backward_validity(inputs: Iterable[Any]) -> None:
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
warnings.warn(
"None of the inputs have requires_grad=True. Gradients will be None"
)
def _get_device_module(device='cuda'):
def _get_device_module(device="cuda"):
device_module = getattr(torch, device)
return device_module
class DefaultDeviceType(object):
r"""
A class that manages the default device type for checkpointing.
@ -70,15 +93,23 @@ class DefaultDeviceType(object):
"""
return DefaultDeviceType._default_device_type
def _infer_device_type(*args):
device_types = list({arg.device.type for arg in args
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"})
device_types = list(
{
arg.device.type
for arg in args
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
}
)
if len(device_types) > 1:
warnings.warn("Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
"Device state will only be saved for devices of a single device type, and the remaining "
"devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
"this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
"detected, it will be prioritized; otherwise, the first device encountered will be selected.)")
warnings.warn(
"Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
"Device state will only be saved for devices of a single device type, and the remaining "
"devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
"this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
"detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
)
if len(device_types) == 0:
return DefaultDeviceType.get_device_type()
elif "cuda" in device_types:
@ -86,6 +117,7 @@ def _infer_device_type(*args):
else:
return device_types[0]
# We can't know if the run_fn will internally move some args to different devices,
# which would require logic to preserve rng states for those devices as well.
# We could paranoically stash and restore ALL the rng states for all visible devices,
@ -96,8 +128,13 @@ def _infer_device_type(*args):
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
# the conditionals short-circuit.
fwd_device_ids = list({arg.get_device() for arg in args
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"})
fwd_device_ids = list(
{
arg.get_device()
for arg in args
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
}
)
fwd_device_states = []
device_module = _get_device_module(_infer_device_type(*args))
@ -119,23 +156,29 @@ def set_device_states(devices, states) -> None:
def _get_autocast_kwargs(device="cuda"):
if device == "cuda":
device_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
device_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
else:
device_module = _get_device_module(device)
device_autocast_kwargs = {"enabled": device_module.is_autocast_enabled(),
"dtype": device_module.get_autocast_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
device_autocast_kwargs = {
"enabled": device_module.is_autocast_enabled(),
"dtype": device_module.get_autocast_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
"dtype": torch.get_autocast_cpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
cpu_autocast_kwargs = {
"enabled": torch.is_autocast_cpu_enabled(),
"dtype": torch.get_autocast_cpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
return device_autocast_kwargs, cpu_autocast_kwargs
class CheckpointFunction(torch.autograd.Function):
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
@ -143,7 +186,9 @@ class CheckpointFunction(torch.autograd.Function):
ctx.preserve_rng_state = preserve_rng_state
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.device = _infer_device_type(*args)
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(ctx.device)
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
ctx.device
)
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
@ -181,7 +226,8 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument.")
" argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
@ -198,15 +244,17 @@ class CheckpointFunction(torch.autograd.Function):
rng_devices = []
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
rng_devices = ctx.fwd_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device):
with torch.random.fork_rng(
devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device
):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_device_in_fwd:
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
device_module.amp.autocast(**ctx.device_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
with torch.enable_grad(), device_module.amp.autocast(
**ctx.device_autocast_kwargs
), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
@ -222,10 +270,13 @@ class CheckpointFunction(torch.autograd.Function):
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary")
" this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
return (None, None) + grads
@ -343,16 +394,21 @@ def checkpoint(
"will be updated to be False in the future. To maintain current "
"behavior, pass use_reentrant=True. It is recommended that you use "
"use_reentrant=False. Refer to docs for more details on the "
"differences between the two variants.")
"differences between the two variants."
)
use_reentrant = True
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
preserve = kwargs.pop("preserve_rng_state", True)
if kwargs and use_reentrant:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
raise ValueError(
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
)
if use_reentrant:
if context_fn is not noop_context_fn:
raise ValueError("Passing context_fn is only supported when use_reentrant=False.")
raise ValueError(
"Passing context_fn is only supported when use_reentrant=False."
)
return CheckpointFunction.apply(function, preserve, *args)
else:
return _checkpoint_without_reentrant(
@ -409,15 +465,18 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwar
>>> input_var = checkpoint_sequential(model, chunks, input_var)
"""
# Hack for keyword-only parameter in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
preserve = kwargs.pop("preserve_rng_state", True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
raise ValueError(
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
)
def run_function(start, end, functions):
def forward(input):
for j in range(start, end + 1):
input = functions[j](input)
return input
return forward
if isinstance(functions, torch.nn.Sequential):
@ -432,10 +491,11 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwar
run_function(start, end, functions),
input,
use_reentrant=use_reentrant,
preserve_rng_state=preserve
preserve_rng_state=preserve,
)
return run_function(end + 1, len(functions) - 1, functions)(input)
def _internal_assert(cond):
if not cond:
raise AssertionError(
@ -443,6 +503,7 @@ def _internal_assert(cond):
"Please report this bug by filing an issue to PyTorch."
)
# NOTE [ Nestable Checkpoint ]
#
# The semantics of nested checkpoint can be defined by two basic rules.
@ -576,6 +637,7 @@ def _internal_assert(cond):
_enable_checkpoint_early_stop = True
@contextlib.contextmanager
def set_checkpoint_early_stop(enable: bool):
"""Context manager that sets whether checkpoint should stop recomputation
@ -607,13 +669,16 @@ def set_checkpoint_early_stop(enable: bool):
finally:
_enable_checkpoint_early_stop = prev
class _Handle():
class _Handle:
pass
class _Holder():
class _Holder:
def __init__(self):
self.handles: Dict[int, Optional[_Handle]] = dict()
class _NoopSaveInputs(torch.autograd.Function):
@staticmethod
def forward(*args):
@ -623,7 +688,9 @@ class _NoopSaveInputs(torch.autograd.Function):
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
# Only tensors can be saved with ctx.save_for_backward, everything else
# is captured by get_args, which is saved directly on ctx
tensor_indices, tensors = zip(*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)])
tensor_indices, tensors = zip(
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
)
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
# args but with tensors replaced with None as placeholders
args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
@ -633,7 +700,10 @@ class _NoopSaveInputs(torch.autograd.Function):
# ctx.saved_tensors (which may be saved on a parent checkpoint if
# this checkpoint is nested, and that would trigger a recursive
# unpack!)
ret = [saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o for i, o in enumerate(args)]
ret = [
saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
for i, o in enumerate(args)
]
# grab the tail since we also saved the dummy to avoid having to explicitly
# handle the case where there are no tensor inputs
return ret[1:]
@ -645,7 +715,8 @@ class _NoopSaveInputs(torch.autograd.Function):
def backward(ctx, *grad_outputs):
raise AssertionError("Did not expect to backward on this graph")
class _CheckpointFrame():
class _CheckpointFrame:
def __init__(self, recompute_fn, early_stop):
self.recompute_fn = recompute_fn
self.input_saver = None
@ -653,8 +724,9 @@ class _CheckpointFrame():
# We store this as a weakkeydictionary so that in the case of a partial
# backward, the entries in the dict are cleared alongside the Holder
# which will be removed when the SavedVariable is cleared.
self.recomputed: DefaultDict[int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]] = \
defaultdict(weakref.WeakKeyDictionary)
self.recomputed: DefaultDict[
int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
] = defaultdict(weakref.WeakKeyDictionary)
# We need both recomp_counter and recomputed since they can diverge
# https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
@ -663,10 +735,12 @@ class _CheckpointFrame():
# See Rule 5
self.early_stop = early_stop
# See Rule 5
class _StopRecomputationError(Exception):
pass
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, target_frame_ref: ReferenceType, gid: int):
def pack_hook(x):
@ -688,8 +762,9 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
holder.handles[gid] = _Handle()
target_frame.recomputed[gid][holder.handles[gid]] = x.detach()
if target_frame.early_stop and \
target_frame.recomp_counter[gid] == len(target_frame.weak_holders):
if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
target_frame.weak_holders
):
raise _StopRecomputationError()
# See Rule 6: [ retain_graph is True ] above
return x.detach()
@ -701,6 +776,7 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
super().__init__(pack_hook, unpack_hook)
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, frame):
def pack_hook(_unused_x):
@ -720,10 +796,14 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
args = ctx.get_args(ctx.saved_tensors)
try:
with _recomputation_hook(weakref.ref(frame), gid), torch.autograd.enable_grad():
with _recomputation_hook(
weakref.ref(frame), gid
), torch.autograd.enable_grad():
frame.recompute_fn(*args)
if frame.early_stop:
raise AssertionError("if early stop is enabled, we don't expect to reach here")
raise AssertionError(
"if early stop is enabled, we don't expect to reach here"
)
except _StopRecomputationError:
pass
frame.is_recomputed[gid] = True
@ -742,6 +822,7 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
super().__init__(pack_hook, unpack_hook)
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
# saving/restoring of global state is handled here.
def _checkpoint_without_reentrant(
@ -792,15 +873,17 @@ def _checkpoint_without_reentrant(
rng_devices = []
if preserve_rng_state and had_device_in_fwd:
rng_devices = fwd_devices
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state, device_type=device):
with torch.random.fork_rng(
devices=rng_devices, enabled=preserve_rng_state, device_type=device
):
if preserve_rng_state:
torch.set_rng_state(fwd_cpu_state)
if had_device_in_fwd:
set_device_states(fwd_devices, fwd_device_states)
with device_module.amp.autocast(**device_autocast_kwargs), \
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
with device_module.amp.autocast(
**device_autocast_kwargs
), torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:
fn(*args, **kwargs)
new_frame = _CheckpointFrame(recompute_fn, _enable_checkpoint_early_stop)
@ -811,8 +894,7 @@ def _checkpoint_without_reentrant(
if new_frame.input_saver.grad_fn is None:
return fn(*args, **kwargs)
with _checkpoint_hook(new_frame), \
forward_context:
with _checkpoint_hook(new_frame), forward_context:
ret = fn(*args, **kwargs)
if device_module._initialized and preserve_rng_state and not had_device_in_fwd:
@ -821,6 +903,7 @@ def _checkpoint_without_reentrant(
raise RuntimeError(
"PyTorch's device state was initialized in the forward pass "
"of a Checkpoint, which is not allowed. Please open an issue "
"if you need this feature.")
"if you need this feature."
)
return ret