mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335 Approved by: https://github.com/albanD
127 lines
4.4 KiB
Python
127 lines
4.4 KiB
Python
# mypy: allow-untyped-decorators
|
|
# mypy: allow-untyped-defs
|
|
from contextlib import contextmanager, nullcontext
|
|
from typing import Any, ContextManager, Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.checkpoint import (
|
|
_checkpoint_without_reentrant_generator,
|
|
_DEFAULT_DETERMINISM_MODE,
|
|
)
|
|
|
|
from .contract import contract
|
|
|
|
|
|
@contextmanager
|
|
def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None):
|
|
r"""
|
|
Disable hooks installed by checkpoint to avoid unintentional recursion
|
|
during backward recomputation.
|
|
"""
|
|
|
|
with user_ctx if user_ctx else nullcontext():
|
|
orig_enable_hook = checkpoint.state(module).enable_hook
|
|
checkpoint.state(module).enable_hook = False
|
|
try:
|
|
yield
|
|
finally:
|
|
checkpoint.state(module).enable_hook = orig_enable_hook
|
|
|
|
|
|
@contract()
|
|
def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
|
|
r"""
|
|
This is a composable activation checkpointing API. Unlike functional
|
|
activation checkpointing APIs, this one does not require changing model
|
|
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
|
|
this one does not modify model structure or fully-qualified names either.
|
|
Under the hood, it registers activation checkpointing logic as pre- and
|
|
post-forward hooks. Hence, this API can be easily applied to any model or
|
|
sub-modules in the model.
|
|
|
|
Args:
|
|
module (nn.Module): the target model or sub-module to apply activation
|
|
checkpointing.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP
|
|
>>> import torch.nn as nn
|
|
>>>
|
|
>>> class MyModel(nn.Module):
|
|
>>> def __init__(self) -> None:
|
|
>>> super().__init__()
|
|
>>> self.l1 = nn.Linear(10, 10)
|
|
>>> self.l2 = nn.Linear(10, 10)
|
|
>>>
|
|
>>> def forward(self, x):
|
|
>>> return self.l2(self.l1(x))
|
|
>>>
|
|
>>> model = MyModel()
|
|
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
|
|
>>> model(torch.zeros(2, 10)).sum().backward()
|
|
|
|
"""
|
|
torch._C._log_api_usage_once("torch.distributed.checkpoint")
|
|
|
|
use_reentrant = kwargs.pop("use_reentrant", False)
|
|
if use_reentrant:
|
|
raise NotImplementedError(
|
|
"use_reentrant=True is not supported in composable checkpoint. "
|
|
"Please use torch.utils.checkpoint.checkpoint instead."
|
|
)
|
|
preserve_rng_state = kwargs.pop("preserve_rng_state", True)
|
|
user_context_fns = kwargs.pop("context_fn", None)
|
|
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
|
|
debug = kwargs.pop("debug", False)
|
|
|
|
if kwargs:
|
|
raise ValueError(
|
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
|
)
|
|
|
|
def forward_pre_hook(
|
|
module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
|
) -> None:
|
|
if checkpoint.state(module).enable_hook:
|
|
|
|
def context_fns():
|
|
if user_context_fns is not None:
|
|
ctx1, ctx2 = user_context_fns()
|
|
return ctx1, _no_hook(module, ctx2)
|
|
else:
|
|
return nullcontext(), _no_hook(module)
|
|
|
|
checkpoint.state(
|
|
module
|
|
)._ac_generator = _checkpoint_without_reentrant_generator(
|
|
module,
|
|
preserve_rng_state,
|
|
context_fns,
|
|
determinism_check,
|
|
debug,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
next(checkpoint.state(module)._ac_generator)
|
|
|
|
def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
|
|
if checkpoint.state(module).enable_hook:
|
|
try:
|
|
next(checkpoint.state(module)._ac_generator)
|
|
except StopIteration:
|
|
pass
|
|
else:
|
|
raise RuntimeError(
|
|
"Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
|
|
)
|
|
|
|
# Ensure that we no longer hold on to the generator. always_call=True helps ensure we
|
|
# clear this even in the case of exception in fwd pass.
|
|
checkpoint.state(module)._ac_generator = None
|
|
|
|
checkpoint.state(module).enable_hook = True
|
|
module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
module.register_forward_hook(forward_hook, prepend=True, always_call=True)
|
|
return module
|