[dynamo][ac] Config flag to allow eager and compile AC divergence for side-effects (#165775)

Eager AC/SAC reapplies the mutations (like global dict mutations) in the backward during the recomputation of forward. torch.compile has no easy way to reapply python mutations in the backward. But many users might be ok to skip reapplication of side effects in the backward. They can set this config flag to accept this eager and compile divergence.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165775
Approved by: https://github.com/zou3519
ghstack dependencies: #165734
This commit is contained in:
Animesh Jain
2025-10-17 11:11:57 -07:00
committed by PyTorch MergeBot
parent c18ddfc572
commit 616c6bdf8f
4 changed files with 47 additions and 1 deletions

View File

@ -1647,6 +1647,29 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
self.assertEqual(opt_fn(x), fn(x))
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_mutation(self):
counter = 0
def gn(x):
nonlocal counter
counter += 1
return torch.sin(x)
def fn(x):
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
x = torch.randn(4, 4, requires_grad=True)
fn(x).sum().backward()
# The mutation is reapplied in the backward as well
self.assertEqual(counter, 2)
counter = 0
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
opt_fn(x).sum().backward()
# The mutation is not reapplied in the backward because the flag was on.
self.assertEqual(counter, 1)
devices = ["cuda", "hpu"]
instantiate_device_type_tests(

View File

@ -633,6 +633,14 @@ compiled_autograd = False
# See https://github.com/pytorch/pytorch/issues/157452 for more context
graph_break_on_nn_param_ctor = True
# Eager AC/SAC reapplies the mutations (like global dict mutations) in the
# backward during the recomputation of forward. torch.compile has no easy way to
# reapply python mutations in the backward. But many users might be ok to skip
# reapplication of side effects in the backward. They can set this config flag
# to accept this eager and compile divergence.
skip_fwd_side_effects_in_bwd_under_checkpoint = False
# Overrides torch.compile() kwargs for Compiled Autograd:
compiled_autograd_kwargs_override: dict[str, Any] = {}
"""Overrides torch.compile() kwargs for Compiled Autograd.

View File

@ -218,7 +218,10 @@ class SideEffects:
return bool(
output_graph
and output_graph.current_tx.output.current_tracer.under_activation_checkpoint
and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint
and (
output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint
or torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint
)
)
def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool:

View File

@ -2145,6 +2145,9 @@ class ReparametrizeModuleCallVariable(FunctorchHigherOrderVariable):
class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
supports_input_mutation = True
supports_aliasing = True
# TODO - Go through all subclasses of WrapHigherOrderVariable to see if
# restore_side_effects can be ignored. For now, this is conservative.
restore_side_effects = True
def install_subgraph_in_output_graph(
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body"
@ -2178,6 +2181,7 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
kwargs,
description,
source_target=self.value,
restore_side_effects=self.restore_side_effects,
should_flatten_outputs=True,
under_activation_checkpoint=under_activation_checkpoint,
supports_input_mutation=self.supports_input_mutation,
@ -2565,6 +2569,14 @@ class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable):
class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# If side effects are allowed under checkpoint, we should not restore
# the side effects after speculate subgraph.
self.restore_side_effects = (
not torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint
)
def _call_function(
self,
tx: "InstructionTranslator",