mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)
I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`. This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail. This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA. FIXES https://github.com/pytorch/pytorch/issues/127115 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153300 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
71027b13b2
commit
4863e5c843
@ -78,7 +78,6 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfWindows,
|
||||
slowTest,
|
||||
TestCase,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
@ -7430,8 +7429,6 @@ for shape in [(1,), ()]:
|
||||
self.assertEqual(b_grad, c_grad)
|
||||
self.assertEqual(b_grad, d_grad)
|
||||
|
||||
# PYTORCH_TEST_WITH_DYNAMO=1 test fails on CI but can't repro locally
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127115")
|
||||
def test_checkpointing_without_reentrant_dataparallel(self):
|
||||
"""
|
||||
Verifies gradient correctness when checkpoint without reentrant autograd
|
||||
@ -7489,8 +7486,6 @@ for shape in [(1,), ()]:
|
||||
# should only call hook once
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/127115
|
||||
@xfailIfTorchDynamo
|
||||
def test_checkpointing_without_reentrant_arbitrary_input_output(self):
|
||||
"""
|
||||
Ensures checkpointing without reentrant autograd works with functions
|
||||
|
Reference in New Issue
Block a user