Add early_stop kwarg to torch.utils.checkpoint (#160781)

We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting.

It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper.

Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781
Approved by: https://github.com/tianyu-l
This commit is contained in:
soulitzer
2025-08-20 11:57:15 -07:00
committed by PyTorch MergeBot
parent 4d078cfc4e
commit 1e4dfeeb06
3 changed files with 65 additions and 6 deletions

View File

@ -14143,13 +14143,27 @@ class TestNestedCheckpoint(TestCase):
# early stop is enabled.
return clone(x.sin().cos())
# Test default
# Early stopping is enabled by default
a = torch.tensor(1.0, requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 1)
# Try using the context manager to set early stopping to False.
# Test local setting
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
out.backward()
self.assertEqual(counter[0], 2)
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
out.backward()
self.assertEqual(counter[0], 1)
# Test context manager
# Expect early stopping to be disabled for all checkpoints ran under
# the context manager, even though context manager is no longer active
# when backward/recomputation is performed.
@ -14157,10 +14171,40 @@ class TestNestedCheckpoint(TestCase):
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 2)
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 1)
# Test context manager nesting
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
out.backward()
self.assertEqual(counter[0], 1)
# Test precedence
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
out.backward()
self.assertEqual(counter[0], 2)
counter = [0]
a = torch.tensor(1.0, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
out.backward()
self.assertEqual(counter[0], 1)
def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
# Case 1: We have one tensor saved and its the input