mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo][ac] Config flag to allow eager and compile AC divergence for side-effects (#165775)
Eager AC/SAC reapplies the mutations (like global dict mutations) in the backward during the recomputation of forward. torch.compile has no easy way to reapply python mutations in the backward. But many users might be ok to skip reapplication of side effects in the backward. They can set this config flag to accept this eager and compile divergence. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165775 Approved by: https://github.com/zou3519 ghstack dependencies: #165734
This commit is contained in:
committed by
PyTorch MergeBot
parent
c18ddfc572
commit
616c6bdf8f
@ -1647,6 +1647,29 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
self.assertEqual(opt_fn(x), fn(x))
|
||||
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_mutation(self):
|
||||
counter = 0
|
||||
|
||||
def gn(x):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
fn(x).sum().backward()
|
||||
# The mutation is reapplied in the backward as well
|
||||
self.assertEqual(counter, 2)
|
||||
counter = 0
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
opt_fn(x).sum().backward()
|
||||
# The mutation is not reapplied in the backward because the flag was on.
|
||||
self.assertEqual(counter, 1)
|
||||
|
||||
|
||||
devices = ["cuda", "hpu"]
|
||||
instantiate_device_type_tests(
|
||||
|
@ -633,6 +633,14 @@ compiled_autograd = False
|
||||
# See https://github.com/pytorch/pytorch/issues/157452 for more context
|
||||
graph_break_on_nn_param_ctor = True
|
||||
|
||||
# Eager AC/SAC reapplies the mutations (like global dict mutations) in the
|
||||
# backward during the recomputation of forward. torch.compile has no easy way to
|
||||
# reapply python mutations in the backward. But many users might be ok to skip
|
||||
# reapplication of side effects in the backward. They can set this config flag
|
||||
# to accept this eager and compile divergence.
|
||||
skip_fwd_side_effects_in_bwd_under_checkpoint = False
|
||||
|
||||
|
||||
# Overrides torch.compile() kwargs for Compiled Autograd:
|
||||
compiled_autograd_kwargs_override: dict[str, Any] = {}
|
||||
"""Overrides torch.compile() kwargs for Compiled Autograd.
|
||||
|
@ -218,7 +218,10 @@ class SideEffects:
|
||||
return bool(
|
||||
output_graph
|
||||
and output_graph.current_tx.output.current_tracer.under_activation_checkpoint
|
||||
and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint
|
||||
and (
|
||||
output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint
|
||||
or torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint
|
||||
)
|
||||
)
|
||||
|
||||
def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool:
|
||||
|
@ -2145,6 +2145,9 @@ class ReparametrizeModuleCallVariable(FunctorchHigherOrderVariable):
|
||||
class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
supports_input_mutation = True
|
||||
supports_aliasing = True
|
||||
# TODO - Go through all subclasses of WrapHigherOrderVariable to see if
|
||||
# restore_side_effects can be ignored. For now, this is conservative.
|
||||
restore_side_effects = True
|
||||
|
||||
def install_subgraph_in_output_graph(
|
||||
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body"
|
||||
@ -2178,6 +2181,7 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
kwargs,
|
||||
description,
|
||||
source_target=self.value,
|
||||
restore_side_effects=self.restore_side_effects,
|
||||
should_flatten_outputs=True,
|
||||
under_activation_checkpoint=under_activation_checkpoint,
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
@ -2565,6 +2569,14 @@ class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
||||
|
||||
class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
# If side effects are allowed under checkpoint, we should not restore
|
||||
# the side effects after speculate subgraph.
|
||||
self.restore_side_effects = (
|
||||
not torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint
|
||||
)
|
||||
|
||||
def _call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
|
Reference in New Issue
Block a user