diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 5dfaa14067d3..9c168a8e04ae 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -1647,6 +1647,29 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no self.assertEqual(opt_fn(x), fn(x)) + @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) + def test_nonlocal_mutation(self): + counter = 0 + + def gn(x): + nonlocal counter + counter += 1 + return torch.sin(x) + + def fn(x): + return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True) + + x = torch.randn(4, 4, requires_grad=True) + fn(x).sum().backward() + # The mutation is reapplied in the backward as well + self.assertEqual(counter, 2) + counter = 0 + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + opt_fn(x).sum().backward() + # The mutation is not reapplied in the backward because the flag was on. + self.assertEqual(counter, 1) + devices = ["cuda", "hpu"] instantiate_device_type_tests( diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index d62dd086f055..d35ba10ef1af 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -633,6 +633,14 @@ compiled_autograd = False # See https://github.com/pytorch/pytorch/issues/157452 for more context graph_break_on_nn_param_ctor = True +# Eager AC/SAC reapplies the mutations (like global dict mutations) in the +# backward during the recomputation of forward. torch.compile has no easy way to +# reapply python mutations in the backward. But many users might be ok to skip +# reapplication of side effects in the backward. They can set this config flag +# to accept this eager and compile divergence. +skip_fwd_side_effects_in_bwd_under_checkpoint = False + + # Overrides torch.compile() kwargs for Compiled Autograd: compiled_autograd_kwargs_override: dict[str, Any] = {} """Overrides torch.compile() kwargs for Compiled Autograd. diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 4e45dc7446d2..47912dadb941 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -218,7 +218,10 @@ class SideEffects: return bool( output_graph and output_graph.current_tx.output.current_tracer.under_activation_checkpoint - and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint + and ( + output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint + or torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint + ) ) def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool: diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 8c08a68e3b27..956eb4676018 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -2145,6 +2145,9 @@ class ReparametrizeModuleCallVariable(FunctorchHigherOrderVariable): class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): supports_input_mutation = True supports_aliasing = True + # TODO - Go through all subclasses of WrapHigherOrderVariable to see if + # restore_side_effects can be ignored. For now, this is conservative. + restore_side_effects = True def install_subgraph_in_output_graph( self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body" @@ -2178,6 +2181,7 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): kwargs, description, source_target=self.value, + restore_side_effects=self.restore_side_effects, should_flatten_outputs=True, under_activation_checkpoint=under_activation_checkpoint, supports_input_mutation=self.supports_input_mutation, @@ -2565,6 +2569,14 @@ class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable): class CheckpointHigherOrderVariable(WrapHigherOrderVariable): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # If side effects are allowed under checkpoint, we should not restore + # the side effects after speculate subgraph. + self.restore_side_effects = ( + not torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint + ) + def _call_function( self, tx: "InstructionTranslator",