From 4863e5c843722eb2a34fb0ca1d518a33431a38c0 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 14 May 2025 14:45:09 -0700 Subject: [PATCH] [ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300) I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`. This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail. This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA. FIXES https://github.com/pytorch/pytorch/issues/127115 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153300 Approved by: https://github.com/jansel --- ...d_tensor_twice_without_recomputation_works | 0 ...ad.test_checkpoint_detects_non_determinism | 0 ...t_checkpointing_non_reentrant_autocast_cpu | 0 ...stAutograd.test_save_on_cpu_and_checkpoint | 0 ...nt.test_nested_checkpoint_early_stop_False | 0 ...int.test_nested_checkpoint_early_stop_True | 0 ...t_nested_checkpoint_kwargs_early_stop_True | 0 ..._tensor_inputs_and_outputs_early_stop_True | 0 ...kpoint_reentrant_backwards_early_stop_True | 0 ...sted_checkpoint_same_graph_early_stop_True | 0 ...oint.test_nested_checkpoint_set_early_stop | 0 ...d_checkpoint_two_children_early_stop_False | 0 ...ed_checkpoint_two_children_early_stop_True | 0 test/inductor/test_compiled_autograd.py | 41 +++++++++++++------ test/test_autograd.py | 5 --- torch/utils/checkpoint.py | 17 ++++++++ 16 files changed, 45 insertions(+), 18 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_access_saved_tensor_twice_without_recomputation_works delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_checkpoint_detects_non_determinism delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_checkpointing_non_reentrant_autocast_cpu delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_save_on_cpu_and_checkpoint delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_early_stop_False delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_early_stop_True delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_kwargs_early_stop_True delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_True delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_reentrant_backwards_early_stop_True delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_same_graph_early_stop_True delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_set_early_stop delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_False delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_True diff --git a/test/dynamo_expected_failures/TestAutograd.test_access_saved_tensor_twice_without_recomputation_works b/test/dynamo_expected_failures/TestAutograd.test_access_saved_tensor_twice_without_recomputation_works deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_checkpoint_detects_non_determinism b/test/dynamo_expected_failures/TestAutograd.test_checkpoint_detects_non_determinism deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_checkpointing_non_reentrant_autocast_cpu b/test/dynamo_expected_failures/TestAutograd.test_checkpointing_non_reentrant_autocast_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_save_on_cpu_and_checkpoint b/test/dynamo_expected_failures/TestAutograd.test_save_on_cpu_and_checkpoint deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_early_stop_False b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_early_stop_False deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_early_stop_True b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_early_stop_True deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_kwargs_early_stop_True b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_kwargs_early_stop_True deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_True b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_True deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_reentrant_backwards_early_stop_True b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_reentrant_backwards_early_stop_True deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_same_graph_early_stop_True b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_same_graph_early_stop_True deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_set_early_stop b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_set_early_stop deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_False b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_False deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_True b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_True deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index d6ef53052b60..034ca4d41100 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -4209,10 +4209,13 @@ def wrap_test_class(orig_cls): ): dct[name] = unittest.expectedFailure elif name.startswith("test_"): + backend = lookup_backend(name) + if not HAS_CUDA and backend == "inductor": + continue ctxs = [ compiled_autograd._enable( make_compiler_fn( - backend=lookup_backend(name), + backend=backend, fullgraph=name not in known_graph_breaks_tests, ) ), @@ -4305,6 +4308,21 @@ known_graph_breaks_tests = { "test_full_backward_hook_double_backward", # _pack_with_none "test_grad_mode_restored_reentrant", # assertTrue "test_multi_grad_any_hooks", # register_multi_grad_hook + "test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks + "test_graph_save_on_cpu", # dynamo disabled + "test_nested_checkpoint_early_stop_False", # dynamo disable + "test_nested_checkpoint_early_stop_True", # dynamo disable + "test_nested_checkpoint_kwargs_early_stop_False", # dynamo disable + "test_nested_checkpoint_kwargs_early_stop_True", # dynamo disable + "test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_False", # dynamo disable + "test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_True", # dynamo disable + "test_nested_checkpoint_reentrant_backwards_early_stop_False", # dynamo disable + "test_nested_checkpoint_reentrant_backwards_early_stop_True", # dynamo disable + "test_nested_checkpoint_same_graph_early_stop_False", # dynamo disable + "test_nested_checkpoint_same_graph_early_stop_True", # dynamo disable + "test_nested_checkpoint_set_early_stop", # dynamo disable + "test_nested_checkpoint_two_children_early_stop_False", # dynamo disable + "test_nested_checkpoint_two_children_early_stop_True", # dynamo disable } test_contexts = { @@ -4329,6 +4347,7 @@ xfail_by_backend = { "test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook "test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd "test_current_node", # TorchDispatchMode not yet implemented for compiled autograd + "test_nested_checkpoint_set_early_stop_no_recompution_needed", # TorchDispatchMode not yet implemented "test_post_accumulate_grad_hook_ordering", # accuracy error "test_current_graph_task_id", # autograd state already cleared once dynamo is called "test_custom_function_forward_mode_forward_is_no_op", # forward AD @@ -4362,6 +4381,10 @@ xfail_by_backend = { "test_return_duplicate_inplace", # batched gradients "test_naughty_autograd_function_stashing_ctx", # error not raised "test_unrelated_inputs", # batched gradients + "test_nested_checkpoint_early_stop_False", # unpack hook grad_fn semantics + "test_nested_checkpoint_early_stop_True", # unpack hook grad_fn semantics + "test_nested_checkpoint_two_children_early_stop_False", # unpack hook grad_fn semantics + "test_nested_checkpoint_two_children_early_stop_True", # unpack hook grad_fn semantics }, "eager": { # will be run without torch.compiling the CA graph "test_setup_context_when_forward_has_default_args", # autograd.Function with class methods @@ -4370,25 +4393,14 @@ xfail_by_backend = { "test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None "test_custom_function_non_tensor_inputs_outputs", # gradient batching rule not implemented for aten::sym_size.int "test_setitem", # CopySlices accuracy error - "test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565 - "test_checkpoint_detects_non_determinism", # different error - "test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute - "test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute "test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193 - "test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times - "test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised - "test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute - "test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115 - "test_checkpointing", # takes very very long - "test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long - "test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long - "test_checkpointing_without_reentrant_memory_savings", # takes very very long "test_dtensor_different_gradient_placement", # Dynamo failed to run FX node with fake tensors "test_dtensor_noncontiguous_output", # Dynamo failed to run FX node with fake tensors "test_dtensor_partial_placement_graph_output", # Dynamo failed to run FX node with fake tensors "test_unwrap_async_collective_tensor_tangent", # AttributeError: 'PlainTensorMeta' object has no attribute 'attrs' "test_graph_save_on_cpu", # torch.save should no-op and be recorded in the graph "test_saving_variable_to_disk", # torch.save should no-op and be recorded in the graph + "test_nested_checkpoint_early_stop_False", # AOT backward higher order gradients }, "aot_eager": { # will be run with torch.compile(backend="eager") # Category: FakeTensor @@ -4430,6 +4442,9 @@ test_autograd = load_test_module("test_autograd") test_custom_ops = load_test_module("test_custom_ops") TestAutogradWithCompiledAutograd = wrap_test_class(test_autograd.TestAutograd) +TestNestedCheckpointWithCompiledAutograd = wrap_test_class( + test_autograd.TestNestedCheckpoint +) TestCustomOpWithCompiledAutograd = wrap_test_class(test_custom_ops.TestCustomOp) if torch.distributed.is_available() and HAS_CUDA: test_dtensor = load_test_module("distributed/tensor/test_dtensor_compile") diff --git a/test/test_autograd.py b/test/test_autograd.py index 72217c86b003..efe11302d8c8 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -78,7 +78,6 @@ from torch.testing._internal.common_utils import ( skipIfWindows, slowTest, TestCase, - xfailIfTorchDynamo, ) from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode @@ -7430,8 +7429,6 @@ for shape in [(1,), ()]: self.assertEqual(b_grad, c_grad) self.assertEqual(b_grad, d_grad) - # PYTORCH_TEST_WITH_DYNAMO=1 test fails on CI but can't repro locally - @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127115") def test_checkpointing_without_reentrant_dataparallel(self): """ Verifies gradient correctness when checkpoint without reentrant autograd @@ -7489,8 +7486,6 @@ for shape in [(1,), ()]: # should only call hook once self.assertEqual(count, 1) - # https://github.com/pytorch/pytorch/issues/127115 - @xfailIfTorchDynamo def test_checkpointing_without_reentrant_arbitrary_input_output(self): """ Ensures checkpointing without reentrant autograd works with functions diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index ccde9707a7eb..f53db14753c0 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -328,6 +328,7 @@ class CheckpointFunction(torch.autograd.Function): def noop_context_fn(): return contextlib.nullcontext(), contextlib.nullcontext() +# Note: [torch.compile and checkpoint] # TorchDynamo does not step inside utils.checkpoint function. The flow # looks likes this # 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by @@ -1491,6 +1492,8 @@ def _checkpoint_without_reentrant_generator( had_device_in_fwd = True fwd_devices, fwd_device_states = get_device_states(*args) + # See Note: [compiled autograd and checkpoint unpack hook] + @torch._disable_dynamo def recompute_fn(*inputs): kwargs, *args = inputs # This will be called later during recomputation. This wrapping enables @@ -1541,3 +1544,17 @@ def _checkpoint_without_reentrant_generator( ) return + +# Note: [compiled autograd and checkpoint unpack hook] +# When tracing via compiled autograd, this hook will be visible to the +# compiler if the forward of this checkpointed region ran in eager. +# If the forward had ran under compile, it would have been wrapped in a +# higher order op. See Note: [torch.compile and checkpoint]. +# +# Since we run the recomputation hook under a enable_grad context, +# AOTDispatch will trace a joint graph for this hook, and may +# save different activations than in eager. This conflicts with the +# strict activation count checks in `frame.check_recomputed_tensors_match`. +# So, we disable this hook to force it to recompute eager checkpointed regions +# in eager. This could be removed if we can disable the partitioner for this +# graph segment.