mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add early_stop kwarg to torch.utils.checkpoint (#160781)
We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting. It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper. Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781 Approved by: https://github.com/tianyu-l
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d078cfc4e
commit
1e4dfeeb06
@ -14143,13 +14143,27 @@ class TestNestedCheckpoint(TestCase):
|
||||
# early stop is enabled.
|
||||
return clone(x.sin().cos())
|
||||
|
||||
# Test default
|
||||
# Early stopping is enabled by default
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
out = checkpoint(fn, a, use_reentrant=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
# Try using the context manager to set early stopping to False.
|
||||
# Test local setting
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
# Test context manager
|
||||
# Expect early stopping to be disabled for all checkpoints ran under
|
||||
# the context manager, even though context manager is no longer active
|
||||
# when backward/recomputation is performed.
|
||||
@ -14157,10 +14171,40 @@ class TestNestedCheckpoint(TestCase):
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
||||
out = checkpoint(fn, a, use_reentrant=False)
|
||||
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
|
||||
out = checkpoint(fn, a, use_reentrant=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
# Test context manager nesting
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
# Test precedence
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=True)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
counter = [0]
|
||||
a = torch.tensor(1.0, requires_grad=True)
|
||||
with torch.utils.checkpoint.set_checkpoint_early_stop(True):
|
||||
out = checkpoint(fn, a, use_reentrant=False, early_stop=False)
|
||||
out.backward()
|
||||
self.assertEqual(counter[0], 1)
|
||||
|
||||
def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
|
||||
# Case 1: We have one tensor saved and its the input
|
||||
|
||||
|
@ -79,6 +79,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
|
||||
user_context_fns = kwargs.pop("context_fn", None)
|
||||
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
|
||||
debug = kwargs.pop("debug", False)
|
||||
early_stop = kwargs.pop("early_stop", True)
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(
|
||||
@ -103,6 +104,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
|
||||
context_fns,
|
||||
determinism_check,
|
||||
debug,
|
||||
early_stop,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -347,6 +347,7 @@ def checkpoint(
|
||||
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
||||
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
|
||||
debug: bool = False,
|
||||
early_stop: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
r"""Checkpoint a model or part of the model.
|
||||
@ -425,6 +426,9 @@ def checkpoint(
|
||||
passed as the tuple. For example, in LSTM, if user passes
|
||||
``(activation, hidden)``, :attr:`function` should correctly use the
|
||||
first input as ``activation`` and the second input as ``hidden``
|
||||
args: tuple containing inputs to the :attr:`function`
|
||||
|
||||
Keyword args:
|
||||
preserve_rng_state(bool, optional): Omit stashing and restoring
|
||||
the RNG state during each checkpoint. Note that under torch.compile,
|
||||
this flag doesn't take effect and we always preserve RNG state.
|
||||
@ -455,7 +459,11 @@ def checkpoint(
|
||||
a trace of the operators ran during the original forward computation
|
||||
as well as the recomputation. This argument is only supported if
|
||||
``use_reentrant=False``.
|
||||
args: tuple containing inputs to the :attr:`function`
|
||||
early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
|
||||
recomputation as soon as it has computed all needed Tensors. This
|
||||
argument is ignored if ``use_reentrant=True``. Can be overridden
|
||||
globally using :func:`set_checkpoint_early_stop` context manager.
|
||||
Default: ``True``.
|
||||
|
||||
Returns:
|
||||
Output of running :attr:`function` on :attr:`*args`
|
||||
@ -488,7 +496,7 @@ def checkpoint(
|
||||
return CheckpointFunction.apply(function, preserve, *args)
|
||||
else:
|
||||
gen = _checkpoint_without_reentrant_generator(
|
||||
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
|
||||
function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs
|
||||
)
|
||||
# Runs pre-forward logic
|
||||
next(gen)
|
||||
@ -731,7 +739,7 @@ def _internal_assert(cond):
|
||||
# by holder=None. We skip over them. We still save x at (4) (since its holder
|
||||
# is still alive.)
|
||||
|
||||
_enable_checkpoint_early_stop = True
|
||||
_enable_checkpoint_early_stop: Optional[bool] = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -1448,6 +1456,7 @@ def _checkpoint_without_reentrant_generator(
|
||||
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
|
||||
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
|
||||
debug: bool = False,
|
||||
early_stop: bool = True,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
@ -1475,6 +1484,10 @@ def _checkpoint_without_reentrant_generator(
|
||||
debug(bool, optional): If ``True``, error messages will also include
|
||||
a trace of the operators ran during the original forward computation
|
||||
as well as the recomputation.
|
||||
early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
|
||||
recomputation as soon as it has computed all needed Tensors. Can be
|
||||
overridden globally using :func:`set_checkpoint_early_stop` context
|
||||
manager. Default: ``True``.
|
||||
*args: Arguments to pass in to the given ``function``.
|
||||
**kwargs: Keyword arguments to pass into the given ``function``.
|
||||
"""
|
||||
@ -1543,7 +1556,7 @@ def _checkpoint_without_reentrant_generator(
|
||||
|
||||
new_frame = _CheckpointFrame(
|
||||
recompute_fn,
|
||||
_enable_checkpoint_early_stop,
|
||||
_enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop,
|
||||
unpack_error_cb,
|
||||
metadata_fn
|
||||
)
|
||||
|
Reference in New Issue
Block a user