mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 16:04:58 +08:00
We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting. It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper. Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781 Approved by: https://github.com/tianyu-l
135 lines
4.7 KiB
Python
135 lines
4.7 KiB
Python
# mypy: allow-untyped-defs
|
|
from collections.abc import Generator
|
|
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.checkpoint import (
|
|
_checkpoint_without_reentrant_generator,
|
|
_DEFAULT_DETERMINISM_MODE,
|
|
)
|
|
|
|
from .contract import _State, contract
|
|
|
|
|
|
@contextmanager
|
|
def _no_hook(module: nn.Module, user_ctx: Optional[AbstractContextManager] = None):
|
|
r"""
|
|
Disable hooks installed by checkpoint to avoid unintentional recursion
|
|
during backward recomputation.
|
|
"""
|
|
|
|
with user_ctx if user_ctx else nullcontext():
|
|
orig_enable_hook = checkpoint.state(module).enable_hook
|
|
checkpoint.state(module).enable_hook = False
|
|
try:
|
|
yield
|
|
finally:
|
|
checkpoint.state(module).enable_hook = orig_enable_hook
|
|
|
|
|
|
class _CheckpointState(_State):
|
|
enable_hook: bool = False
|
|
_ac_generator: Optional[Generator[None, None, None]]
|
|
|
|
|
|
@contract(_CheckpointState)
|
|
def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
|
|
r"""
|
|
This is a composable activation checkpointing API. Unlike functional
|
|
activation checkpointing APIs, this one does not require changing model
|
|
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
|
|
this one does not modify model structure or fully-qualified names either.
|
|
Under the hood, it registers activation checkpointing logic as pre- and
|
|
post-forward hooks. Hence, this API can be easily applied to any model or
|
|
sub-modules in the model.
|
|
|
|
Args:
|
|
module (nn.Module): the target model or sub-module to apply activation
|
|
checkpointing.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP
|
|
>>> import torch.nn as nn
|
|
>>>
|
|
>>> class MyModel(nn.Module):
|
|
>>> def __init__(self) -> None:
|
|
>>> super().__init__()
|
|
>>> self.l1 = nn.Linear(10, 10)
|
|
>>> self.l2 = nn.Linear(10, 10)
|
|
>>>
|
|
>>> def forward(self, x):
|
|
>>> return self.l2(self.l1(x))
|
|
>>>
|
|
>>> model = MyModel()
|
|
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
|
|
>>> model(torch.zeros(2, 10)).sum().backward()
|
|
|
|
"""
|
|
torch._C._log_api_usage_once("torch.distributed.checkpoint")
|
|
|
|
use_reentrant = kwargs.pop("use_reentrant", False)
|
|
if use_reentrant:
|
|
raise NotImplementedError(
|
|
"use_reentrant=True is not supported in composable checkpoint. "
|
|
"Please use torch.utils.checkpoint.checkpoint instead."
|
|
)
|
|
preserve_rng_state = kwargs.pop("preserve_rng_state", True)
|
|
user_context_fns = kwargs.pop("context_fn", None)
|
|
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
|
|
debug = kwargs.pop("debug", False)
|
|
early_stop = kwargs.pop("early_stop", True)
|
|
|
|
if kwargs:
|
|
raise ValueError(
|
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
|
)
|
|
|
|
def forward_pre_hook(
|
|
module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
) -> None:
|
|
if checkpoint.state(module).enable_hook:
|
|
|
|
def context_fns():
|
|
if user_context_fns is not None:
|
|
ctx1, ctx2 = user_context_fns()
|
|
return ctx1, _no_hook(module, ctx2)
|
|
else:
|
|
return nullcontext(), _no_hook(module)
|
|
|
|
gen = _checkpoint_without_reentrant_generator(
|
|
module,
|
|
preserve_rng_state,
|
|
context_fns,
|
|
determinism_check,
|
|
debug,
|
|
early_stop,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
checkpoint.state(module)._ac_generator = gen
|
|
next(gen)
|
|
|
|
def forward_hook(module: nn.Module, inputs: tuple[Any, ...], output: Any) -> Any:
|
|
if checkpoint.state(module).enable_hook:
|
|
try:
|
|
gen = checkpoint.state(module)._ac_generator
|
|
assert gen is not None
|
|
next(gen)
|
|
except StopIteration:
|
|
pass
|
|
else:
|
|
raise RuntimeError(
|
|
"Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
|
|
)
|
|
|
|
# Ensure that we no longer hold on to the generator. always_call=True helps ensure we
|
|
# clear this even in the case of exception in fwd pass.
|
|
checkpoint.state(module)._ac_generator = None
|
|
|
|
checkpoint.state(module).enable_hook = True
|
|
module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
module.register_forward_hook(forward_hook, prepend=True, always_call=True)
|
|
return module
|