[dynamo][ac] Config flag to allow eager and compile AC divergence for side-effects (#165775)

Eager AC/SAC reapplies the mutations (like global dict mutations) in the backward during the recomputation of forward. torch.compile has no easy way to reapply python mutations in the backward. But many users might be ok to skip reapplication of side effects in the backward. They can set this config flag to accept this eager and compile divergence.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165775
Approved by: https://github.com/zou3519
ghstack dependencies: #165734
This commit is contained in:
Animesh Jain
2025-10-17 11:11:57 -07:00
committed by PyTorch MergeBot
parent c18ddfc572
commit 616c6bdf8f
4 changed files with 47 additions and 1 deletions

View File

@ -1647,6 +1647,29 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
self.assertEqual(opt_fn(x), fn(x))
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_mutation(self):
counter = 0
def gn(x):
nonlocal counter
counter += 1
return torch.sin(x)
def fn(x):
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
x = torch.randn(4, 4, requires_grad=True)
fn(x).sum().backward()
# The mutation is reapplied in the backward as well
self.assertEqual(counter, 2)
counter = 0
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
opt_fn(x).sum().backward()
# The mutation is not reapplied in the backward because the flag was on.
self.assertEqual(counter, 1)
devices = ["cuda", "hpu"]
instantiate_device_type_tests(