mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo][ac] Config flag to allow eager and compile AC divergence for side-effects (#165775)
Eager AC/SAC reapplies the mutations (like global dict mutations) in the backward during the recomputation of forward. torch.compile has no easy way to reapply python mutations in the backward. But many users might be ok to skip reapplication of side effects in the backward. They can set this config flag to accept this eager and compile divergence. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165775 Approved by: https://github.com/zou3519 ghstack dependencies: #165734
This commit is contained in:
committed by
PyTorch MergeBot
parent
c18ddfc572
commit
616c6bdf8f
@ -1647,6 +1647,29 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
self.assertEqual(opt_fn(x), fn(x))
|
||||
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_mutation(self):
|
||||
counter = 0
|
||||
|
||||
def gn(x):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
fn(x).sum().backward()
|
||||
# The mutation is reapplied in the backward as well
|
||||
self.assertEqual(counter, 2)
|
||||
counter = 0
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
opt_fn(x).sum().backward()
|
||||
# The mutation is not reapplied in the backward because the flag was on.
|
||||
self.assertEqual(counter, 1)
|
||||
|
||||
|
||||
devices = ["cuda", "hpu"]
|
||||
instantiate_device_type_tests(
|
||||
|
Reference in New Issue
Block a user