mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add early_stop kwarg to torch.utils.checkpoint (#160781)
We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting. It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper. Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781 Approved by: https://github.com/tianyu-l
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d078cfc4e
commit
1e4dfeeb06
@ -14143,13 +14143,27 @@ class TestNestedCheckpoint(TestCase):
|
||||
# early stop is enabled.
|
||||
return clone(x.sin().cos())
|
||||
|
||||
# Test default
|
||||
# Early stopping is enabled by default
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
out = checkpoint(fn, a, use_reentrant=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
# Try using the context manager to set early stopping to False.
|
||||
# Test local setting
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
# Test context manager
|
||||
# Expect early stopping to be disabled for all checkpoints ran under
|
||||
# the context manager, even though context manager is no longer active
|
||||
# when backward/recomputation is performed.
|
||||
@ -14157,10 +14171,40 @@ class TestNestedCheckpoint(TestCase):
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
||||
out = checkpoint(fn, a, use_reentrant=False)
|
||||
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
|
||||
out = checkpoint(fn, a, use_reentrant=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
# Test context manager nesting
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
# Test precedence
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
|
||||
# Case 1: We have one tensor saved and its the input
|
||||
|
||||
|
Reference in New Issue
Block a user