From 1e4dfeeb069b560b5cf920cfea9b815f01f4248d Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 20 Aug 2025 11:57:15 -0700 Subject: [PATCH] Add early_stop kwarg to torch.utils.checkpoint (#160781) We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting. It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper. Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781 Approved by: https://github.com/tianyu-l --- test/test_autograd.py | 48 ++++++++++++++++++- .../_composable/checkpoint_activation.py | 2 + torch/utils/checkpoint.py | 21 ++++++-- 3 files changed, 65 insertions(+), 6 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 9e8560c6f191..53a98276090c 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -14143,13 +14143,27 @@ class TestNestedCheckpoint(TestCase): # early stop is enabled. return clone(x.sin().cos()) + # Test default # Early stopping is enabled by default a = torch.tensor(1.0, requires_grad=True) out = checkpoint(fn, a, use_reentrant=False) out.backward() self.assertEqual(counter[0], 1) - # Try using the context manager to set early stopping to False. + # Test local setting + counter = [0] + a = torch.tensor(1.0, requires_grad=True) + out = checkpoint(fn, a, use_reentrant=False, early_stop=False) + out.backward() + self.assertEqual(counter[0], 2) + + counter = [0] + a = torch.tensor(1.0, requires_grad=True) + out = checkpoint(fn, a, use_reentrant=False, early_stop=True) + out.backward() + self.assertEqual(counter[0], 1) + + # Test context manager # Expect early stopping to be disabled for all checkpoints ran under # the context manager, even though context manager is no longer active # when backward/recomputation is performed. @@ -14157,10 +14171,40 @@ class TestNestedCheckpoint(TestCase): a = torch.tensor(1.0, requires_grad=True) with torch.utils.checkpoint.set_checkpoint_early_stop(False): out = checkpoint(fn, a, use_reentrant=False) - out.backward() self.assertEqual(counter[0], 2) + counter = [0] + a = torch.tensor(1.0, requires_grad=True) + with torch.utils.checkpoint.set_checkpoint_early_stop(True): + out = checkpoint(fn, a, use_reentrant=False) + out.backward() + self.assertEqual(counter[0], 1) + + # Test context manager nesting + counter = [0] + a = torch.tensor(1.0, requires_grad=True) + with torch.utils.checkpoint.set_checkpoint_early_stop(False): + with torch.utils.checkpoint.set_checkpoint_early_stop(True): + out = checkpoint(fn, a, use_reentrant=False, early_stop=False) + out.backward() + self.assertEqual(counter[0], 1) + + # Test precedence + counter = [0] + a = torch.tensor(1.0, requires_grad=True) + with torch.utils.checkpoint.set_checkpoint_early_stop(False): + out = checkpoint(fn, a, use_reentrant=False, early_stop=True) + out.backward() + self.assertEqual(counter[0], 2) + + counter = [0] + a = torch.tensor(1.0, requires_grad=True) + with torch.utils.checkpoint.set_checkpoint_early_stop(True): + out = checkpoint(fn, a, use_reentrant=False, early_stop=False) + out.backward() + self.assertEqual(counter[0], 1) + def test_nested_checkpoint_set_early_stop_no_recompution_needed(self): # Case 1: We have one tensor saved and its the input diff --git a/torch/distributed/_composable/checkpoint_activation.py b/torch/distributed/_composable/checkpoint_activation.py index 0fe23cab72c4..2d109ad56835 100644 --- a/torch/distributed/_composable/checkpoint_activation.py +++ b/torch/distributed/_composable/checkpoint_activation.py @@ -79,6 +79,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module: user_context_fns = kwargs.pop("context_fn", None) determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE) debug = kwargs.pop("debug", False) + early_stop = kwargs.pop("early_stop", True) if kwargs: raise ValueError( @@ -103,6 +104,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module: context_fns, determinism_check, debug, + early_stop, *args, **kwargs, ) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index e2515d9d9268..30d2fc106f5f 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -347,6 +347,7 @@ def checkpoint( context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, determinism_check: str = _DEFAULT_DETERMINISM_MODE, debug: bool = False, + early_stop: bool = True, **kwargs ): r"""Checkpoint a model or part of the model. @@ -425,6 +426,9 @@ def checkpoint( passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` + args: tuple containing inputs to the :attr:`function` + + Keyword args: preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Note that under torch.compile, this flag doesn't take effect and we always preserve RNG state. @@ -455,7 +459,11 @@ def checkpoint( a trace of the operators ran during the original forward computation as well as the recomputation. This argument is only supported if ``use_reentrant=False``. - args: tuple containing inputs to the :attr:`function` + early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops + recomputation as soon as it has computed all needed Tensors. This + argument is ignored if ``use_reentrant=True``. Can be overridden + globally using :func:`set_checkpoint_early_stop` context manager. + Default: ``True``. Returns: Output of running :attr:`function` on :attr:`*args` @@ -488,7 +496,7 @@ def checkpoint( return CheckpointFunction.apply(function, preserve, *args) else: gen = _checkpoint_without_reentrant_generator( - function, preserve, context_fn, determinism_check, debug, *args, **kwargs + function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs ) # Runs pre-forward logic next(gen) @@ -731,7 +739,7 @@ def _internal_assert(cond): # by holder=None. We skip over them. We still save x at (4) (since its holder # is still alive.) -_enable_checkpoint_early_stop = True +_enable_checkpoint_early_stop: Optional[bool] = None @contextlib.contextmanager @@ -1448,6 +1456,7 @@ def _checkpoint_without_reentrant_generator( context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, determinism_check: str = _DEFAULT_DETERMINISM_MODE, debug: bool = False, + early_stop: bool = True, *args, **kwargs ): @@ -1475,6 +1484,10 @@ def _checkpoint_without_reentrant_generator( debug(bool, optional): If ``True``, error messages will also include a trace of the operators ran during the original forward computation as well as the recomputation. + early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops + recomputation as soon as it has computed all needed Tensors. Can be + overridden globally using :func:`set_checkpoint_early_stop` context + manager. Default: ``True``. *args: Arguments to pass in to the given ``function``. **kwargs: Keyword arguments to pass into the given ``function``. """ @@ -1543,7 +1556,7 @@ def _checkpoint_without_reentrant_generator( new_frame = _CheckpointFrame( recompute_fn, - _enable_checkpoint_early_stop, + _enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop, unpack_error_cb, metadata_fn )