Add early_stop kwarg to torch.utils.checkpoint (#160781)

We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting.

It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper.

Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781
Approved by: https://github.com/tianyu-l
This commit is contained in:
soulitzer
2025-08-20 11:57:15 -07:00
committed by PyTorch MergeBot
parent 4d078cfc4e
commit 1e4dfeeb06
3 changed files with 65 additions and 6 deletions

View File

@ -14143,13 +14143,27 @@ class TestNestedCheckpoint(TestCase):
# early stop is enabled.
return clone(x.sin().cos())
# Test default
# Early stopping is enabled by default
a = torch.tensor(1.0, requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 1)
# Try using the context manager to set early stopping to False.
# Test local setting
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
out.backward()
self.assertEqual(counter[0], 2)
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
out.backward()
self.assertEqual(counter[0], 1)
# Test context manager
# Expect early stopping to be disabled for all checkpoints ran under
# the context manager, even though context manager is no longer active
# when backward/recomputation is performed.
@ -14157,10 +14171,40 @@ class TestNestedCheckpoint(TestCase):
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 2)
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 1)
# Test context manager nesting
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
out.backward()
self.assertEqual(counter[0], 1)
# Test precedence
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
out.backward()
self.assertEqual(counter[0], 2)
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
out.backward()
self.assertEqual(counter[0], 1)
def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
# Case 1: We have one tensor saved and its the input

View File

@ -79,6 +79,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
user_context_fns = kwargs.pop("context_fn", None)
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
debug = kwargs.pop("debug", False)
early_stop = kwargs.pop("early_stop", True)
if kwargs:
raise ValueError(
@ -103,6 +104,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
context_fns,
determinism_check,
debug,
early_stop,
*args,
**kwargs,
)

View File

@ -347,6 +347,7 @@ def checkpoint(
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
debug: bool = False,
early_stop: bool = True,
**kwargs
):
r"""Checkpoint a model or part of the model.
@ -425,6 +426,9 @@ def checkpoint(
passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden``
args: tuple containing inputs to the :attr:`function`
Keyword args:
preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint. Note that under torch.compile,
this flag doesn't take effect and we always preserve RNG state.
@ -455,7 +459,11 @@ def checkpoint(
a trace of the operators ran during the original forward computation
as well as the recomputation. This argument is only supported if
``use_reentrant=False``.
args: tuple containing inputs to the :attr:`function`
early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
recomputation as soon as it has computed all needed Tensors. This
argument is ignored if ``use_reentrant=True``. Can be overridden
globally using :func:`set_checkpoint_early_stop` context manager.
Default: ``True``.
Returns:
Output of running :attr:`function` on :attr:`*args`
@ -488,7 +496,7 @@ def checkpoint(
return CheckpointFunction.apply(function, preserve, *args)
else:
gen = _checkpoint_without_reentrant_generator(
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs
)
# Runs pre-forward logic
next(gen)
@ -731,7 +739,7 @@ def _internal_assert(cond):
# by holder=None. We skip over them. We still save x at (4) (since its holder
# is still alive.)
_enable_checkpoint_early_stop = True
_enable_checkpoint_early_stop: Optional[bool] = None
@contextlib.contextmanager
@ -1448,6 +1456,7 @@ def _checkpoint_without_reentrant_generator(
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
debug: bool = False,
early_stop: bool = True,
*args,
**kwargs
):
@ -1475,6 +1484,10 @@ def _checkpoint_without_reentrant_generator(
debug(bool, optional): If ``True``, error messages will also include
a trace of the operators ran during the original forward computation
as well as the recomputation.
early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
recomputation as soon as it has computed all needed Tensors. Can be
overridden globally using :func:`set_checkpoint_early_stop` context
manager. Default: ``True``.
*args: Arguments to pass in to the given ``function``.
**kwargs: Keyword arguments to pass into the given ``function``.
"""
@ -1543,7 +1556,7 @@ def _checkpoint_without_reentrant_generator(
new_frame = _CheckpointFrame(
recompute_fn,
_enable_checkpoint_early_stop,
_enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop,
unpack_error_cb,
metadata_fn
)