mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)
I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`. This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail. This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA. FIXES https://github.com/pytorch/pytorch/issues/127115 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153300 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
71027b13b2
commit
4863e5c843
@ -4209,10 +4209,13 @@ def wrap_test_class(orig_cls):
|
|||||||
):
|
):
|
||||||
dct[name] = unittest.expectedFailure
|
dct[name] = unittest.expectedFailure
|
||||||
elif name.startswith("test_"):
|
elif name.startswith("test_"):
|
||||||
|
backend = lookup_backend(name)
|
||||||
|
if not HAS_CUDA and backend == "inductor":
|
||||||
|
continue
|
||||||
ctxs = [
|
ctxs = [
|
||||||
compiled_autograd._enable(
|
compiled_autograd._enable(
|
||||||
make_compiler_fn(
|
make_compiler_fn(
|
||||||
backend=lookup_backend(name),
|
backend=backend,
|
||||||
fullgraph=name not in known_graph_breaks_tests,
|
fullgraph=name not in known_graph_breaks_tests,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@ -4305,6 +4308,21 @@ known_graph_breaks_tests = {
|
|||||||
"test_full_backward_hook_double_backward", # _pack_with_none
|
"test_full_backward_hook_double_backward", # _pack_with_none
|
||||||
"test_grad_mode_restored_reentrant", # assertTrue
|
"test_grad_mode_restored_reentrant", # assertTrue
|
||||||
"test_multi_grad_any_hooks", # register_multi_grad_hook
|
"test_multi_grad_any_hooks", # register_multi_grad_hook
|
||||||
|
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks
|
||||||
|
"test_graph_save_on_cpu", # dynamo disabled
|
||||||
|
"test_nested_checkpoint_early_stop_False", # dynamo disable
|
||||||
|
"test_nested_checkpoint_early_stop_True", # dynamo disable
|
||||||
|
"test_nested_checkpoint_kwargs_early_stop_False", # dynamo disable
|
||||||
|
"test_nested_checkpoint_kwargs_early_stop_True", # dynamo disable
|
||||||
|
"test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_False", # dynamo disable
|
||||||
|
"test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_True", # dynamo disable
|
||||||
|
"test_nested_checkpoint_reentrant_backwards_early_stop_False", # dynamo disable
|
||||||
|
"test_nested_checkpoint_reentrant_backwards_early_stop_True", # dynamo disable
|
||||||
|
"test_nested_checkpoint_same_graph_early_stop_False", # dynamo disable
|
||||||
|
"test_nested_checkpoint_same_graph_early_stop_True", # dynamo disable
|
||||||
|
"test_nested_checkpoint_set_early_stop", # dynamo disable
|
||||||
|
"test_nested_checkpoint_two_children_early_stop_False", # dynamo disable
|
||||||
|
"test_nested_checkpoint_two_children_early_stop_True", # dynamo disable
|
||||||
}
|
}
|
||||||
|
|
||||||
test_contexts = {
|
test_contexts = {
|
||||||
@ -4329,6 +4347,7 @@ xfail_by_backend = {
|
|||||||
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
|
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
|
||||||
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
|
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
|
||||||
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
|
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
|
||||||
|
"test_nested_checkpoint_set_early_stop_no_recompution_needed", # TorchDispatchMode not yet implemented
|
||||||
"test_post_accumulate_grad_hook_ordering", # accuracy error
|
"test_post_accumulate_grad_hook_ordering", # accuracy error
|
||||||
"test_current_graph_task_id", # autograd state already cleared once dynamo is called
|
"test_current_graph_task_id", # autograd state already cleared once dynamo is called
|
||||||
"test_custom_function_forward_mode_forward_is_no_op", # forward AD
|
"test_custom_function_forward_mode_forward_is_no_op", # forward AD
|
||||||
@ -4362,6 +4381,10 @@ xfail_by_backend = {
|
|||||||
"test_return_duplicate_inplace", # batched gradients
|
"test_return_duplicate_inplace", # batched gradients
|
||||||
"test_naughty_autograd_function_stashing_ctx", # error not raised
|
"test_naughty_autograd_function_stashing_ctx", # error not raised
|
||||||
"test_unrelated_inputs", # batched gradients
|
"test_unrelated_inputs", # batched gradients
|
||||||
|
"test_nested_checkpoint_early_stop_False", # unpack hook grad_fn semantics
|
||||||
|
"test_nested_checkpoint_early_stop_True", # unpack hook grad_fn semantics
|
||||||
|
"test_nested_checkpoint_two_children_early_stop_False", # unpack hook grad_fn semantics
|
||||||
|
"test_nested_checkpoint_two_children_early_stop_True", # unpack hook grad_fn semantics
|
||||||
},
|
},
|
||||||
"eager": { # will be run without torch.compiling the CA graph
|
"eager": { # will be run without torch.compiling the CA graph
|
||||||
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods
|
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods
|
||||||
@ -4370,25 +4393,14 @@ xfail_by_backend = {
|
|||||||
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
|
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
|
||||||
"test_custom_function_non_tensor_inputs_outputs", # gradient batching rule not implemented for aten::sym_size.int
|
"test_custom_function_non_tensor_inputs_outputs", # gradient batching rule not implemented for aten::sym_size.int
|
||||||
"test_setitem", # CopySlices accuracy error
|
"test_setitem", # CopySlices accuracy error
|
||||||
"test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565
|
|
||||||
"test_checkpoint_detects_non_determinism", # different error
|
|
||||||
"test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute
|
|
||||||
"test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute
|
|
||||||
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
|
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
|
||||||
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times
|
|
||||||
"test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised
|
|
||||||
"test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute
|
|
||||||
"test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115
|
|
||||||
"test_checkpointing", # takes very very long
|
|
||||||
"test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long
|
|
||||||
"test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long
|
|
||||||
"test_checkpointing_without_reentrant_memory_savings", # takes very very long
|
|
||||||
"test_dtensor_different_gradient_placement", # Dynamo failed to run FX node with fake tensors
|
"test_dtensor_different_gradient_placement", # Dynamo failed to run FX node with fake tensors
|
||||||
"test_dtensor_noncontiguous_output", # Dynamo failed to run FX node with fake tensors
|
"test_dtensor_noncontiguous_output", # Dynamo failed to run FX node with fake tensors
|
||||||
"test_dtensor_partial_placement_graph_output", # Dynamo failed to run FX node with fake tensors
|
"test_dtensor_partial_placement_graph_output", # Dynamo failed to run FX node with fake tensors
|
||||||
"test_unwrap_async_collective_tensor_tangent", # AttributeError: 'PlainTensorMeta' object has no attribute 'attrs'
|
"test_unwrap_async_collective_tensor_tangent", # AttributeError: 'PlainTensorMeta' object has no attribute 'attrs'
|
||||||
"test_graph_save_on_cpu", # torch.save should no-op and be recorded in the graph
|
"test_graph_save_on_cpu", # torch.save should no-op and be recorded in the graph
|
||||||
"test_saving_variable_to_disk", # torch.save should no-op and be recorded in the graph
|
"test_saving_variable_to_disk", # torch.save should no-op and be recorded in the graph
|
||||||
|
"test_nested_checkpoint_early_stop_False", # AOT backward higher order gradients
|
||||||
},
|
},
|
||||||
"aot_eager": { # will be run with torch.compile(backend="eager")
|
"aot_eager": { # will be run with torch.compile(backend="eager")
|
||||||
# Category: FakeTensor
|
# Category: FakeTensor
|
||||||
@ -4430,6 +4442,9 @@ test_autograd = load_test_module("test_autograd")
|
|||||||
test_custom_ops = load_test_module("test_custom_ops")
|
test_custom_ops = load_test_module("test_custom_ops")
|
||||||
|
|
||||||
TestAutogradWithCompiledAutograd = wrap_test_class(test_autograd.TestAutograd)
|
TestAutogradWithCompiledAutograd = wrap_test_class(test_autograd.TestAutograd)
|
||||||
|
TestNestedCheckpointWithCompiledAutograd = wrap_test_class(
|
||||||
|
test_autograd.TestNestedCheckpoint
|
||||||
|
)
|
||||||
TestCustomOpWithCompiledAutograd = wrap_test_class(test_custom_ops.TestCustomOp)
|
TestCustomOpWithCompiledAutograd = wrap_test_class(test_custom_ops.TestCustomOp)
|
||||||
if torch.distributed.is_available() and HAS_CUDA:
|
if torch.distributed.is_available() and HAS_CUDA:
|
||||||
test_dtensor = load_test_module("distributed/tensor/test_dtensor_compile")
|
test_dtensor = load_test_module("distributed/tensor/test_dtensor_compile")
|
||||||
|
@ -78,7 +78,6 @@ from torch.testing._internal.common_utils import (
|
|||||||
skipIfWindows,
|
skipIfWindows,
|
||||||
slowTest,
|
slowTest,
|
||||||
TestCase,
|
TestCase,
|
||||||
xfailIfTorchDynamo,
|
|
||||||
)
|
)
|
||||||
from torch.utils._mode_utils import no_dispatch
|
from torch.utils._mode_utils import no_dispatch
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
@ -7430,8 +7429,6 @@ for shape in [(1,), ()]:
|
|||||||
self.assertEqual(b_grad, c_grad)
|
self.assertEqual(b_grad, c_grad)
|
||||||
self.assertEqual(b_grad, d_grad)
|
self.assertEqual(b_grad, d_grad)
|
||||||
|
|
||||||
# PYTORCH_TEST_WITH_DYNAMO=1 test fails on CI but can't repro locally
|
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127115")
|
|
||||||
def test_checkpointing_without_reentrant_dataparallel(self):
|
def test_checkpointing_without_reentrant_dataparallel(self):
|
||||||
"""
|
"""
|
||||||
Verifies gradient correctness when checkpoint without reentrant autograd
|
Verifies gradient correctness when checkpoint without reentrant autograd
|
||||||
@ -7489,8 +7486,6 @@ for shape in [(1,), ()]:
|
|||||||
# should only call hook once
|
# should only call hook once
|
||||||
self.assertEqual(count, 1)
|
self.assertEqual(count, 1)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/127115
|
|
||||||
@xfailIfTorchDynamo
|
|
||||||
def test_checkpointing_without_reentrant_arbitrary_input_output(self):
|
def test_checkpointing_without_reentrant_arbitrary_input_output(self):
|
||||||
"""
|
"""
|
||||||
Ensures checkpointing without reentrant autograd works with functions
|
Ensures checkpointing without reentrant autograd works with functions
|
||||||
|
@ -328,6 +328,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||||||
def noop_context_fn():
|
def noop_context_fn():
|
||||||
return contextlib.nullcontext(), contextlib.nullcontext()
|
return contextlib.nullcontext(), contextlib.nullcontext()
|
||||||
|
|
||||||
|
# Note: [torch.compile and checkpoint]
|
||||||
# TorchDynamo does not step inside utils.checkpoint function. The flow
|
# TorchDynamo does not step inside utils.checkpoint function. The flow
|
||||||
# looks likes this
|
# looks likes this
|
||||||
# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
|
# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
|
||||||
@ -1491,6 +1492,8 @@ def _checkpoint_without_reentrant_generator(
|
|||||||
had_device_in_fwd = True
|
had_device_in_fwd = True
|
||||||
fwd_devices, fwd_device_states = get_device_states(*args)
|
fwd_devices, fwd_device_states = get_device_states(*args)
|
||||||
|
|
||||||
|
# See Note: [compiled autograd and checkpoint unpack hook]
|
||||||
|
@torch._disable_dynamo
|
||||||
def recompute_fn(*inputs):
|
def recompute_fn(*inputs):
|
||||||
kwargs, *args = inputs
|
kwargs, *args = inputs
|
||||||
# This will be called later during recomputation. This wrapping enables
|
# This will be called later during recomputation. This wrapping enables
|
||||||
@ -1541,3 +1544,17 @@ def _checkpoint_without_reentrant_generator(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Note: [compiled autograd and checkpoint unpack hook]
|
||||||
|
# When tracing via compiled autograd, this hook will be visible to the
|
||||||
|
# compiler if the forward of this checkpointed region ran in eager.
|
||||||
|
# If the forward had ran under compile, it would have been wrapped in a
|
||||||
|
# higher order op. See Note: [torch.compile and checkpoint].
|
||||||
|
#
|
||||||
|
# Since we run the recomputation hook under a enable_grad context,
|
||||||
|
# AOTDispatch will trace a joint graph for this hook, and may
|
||||||
|
# save different activations than in eager. This conflicts with the
|
||||||
|
# strict activation count checks in `frame.check_recomputed_tensors_match`.
|
||||||
|
# So, we disable this hook to force it to recompute eager checkpointed regions
|
||||||
|
# in eager. This could be removed if we can disable the partitioner for this
|
||||||
|
# graph segment.
|
||||||
|
Reference in New Issue
Block a user