[ca] make torch.compile API respect ambient disable contexts (#155473)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155473
Approved by: https://github.com/jansel
This commit is contained in:
Simon Fan
2025-06-10 19:12:13 -07:00
committed by PyTorch MergeBot
parent be124a61a4
commit 87b002b6fb
3 changed files with 84 additions and 36 deletions

View File

@ -99,6 +99,7 @@ def reset():
torch._logging.set_logs(compiled_autograd_verbose=False)
config.compiled_autograd = False
compiled_autograd.reset()
torch._dynamo.utils.counters.clear()
class TestCompiledAutograd(TestCase):
@ -706,6 +707,44 @@ main()
self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], 2)
@parametrize("api", ("compile", "optimize"))
@parametrize("backend", ("eager", "aot_eager", "inductor"))
def test_compile_api_disable(self, api, backend):
def wrap(fn, backend):
if api == "compile":
return torch.compile(fn, backend=backend)
elif api == "optimize":
return torch._dynamo.optimize(backend)(fn)
def fn(model, inputs):
res = []
for inp in inputs:
result = model(inp).sum()
result.backward()
res.append(model[0].weight.grad)
res.append(model[0].bias.grad)
model.zero_grad()
return res
torch.manual_seed(123)
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
inputs = [
torch.randn([1, 4]),
torch.randn([2, 4]),
torch.randn([3, 4]),
]
expected = fn(model, inputs)
with config.patch(compiled_autograd=True):
compiled_fn = wrap(fn, backend)
with torch._dynamo.compiled_autograd._disable():
actual = compiled_fn(model, inputs)
self.assertEqual(expected, actual)
self.assertTrue("compiled_autograd" not in counters)
@parametrize("backend", ("eager", "aot_eager", "inductor"))
def test_optimize_assert(self, backend):
# can be merged into the test above once we support

View File

@ -1373,9 +1373,11 @@ compiled_autograd_enabled_force_eager = False
# global flag to check if we are processing graphs produced from a compiled autograd graph
in_compiled_autograd_region = False
active_disable_ctx = False
@contextlib.contextmanager
def _enable(compiler_fn, dynamic: bool = True):
def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
# The entrypoint to enable CA.
# It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather
# than using this context manager directly. If you are torch.compiling the corresponding
@ -1396,44 +1398,47 @@ def _enable(compiler_fn, dynamic: bool = True):
# - dynamic: Whether compiled autograd will treat tensors in the autograd graph (params, activations) as dynamic.
# This doesn't affect the dynamic configuration of the compilation wrapper.
if dynamic:
assert type(dynamic) is bool
from torch._dynamo import eval_frame
if eval_frame._stance.stance == "force_eager":
# If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd
# to fall back to eager as well.
global compiled_autograd_enabled_force_eager
compiled_autograd_enabled_force_eager = True
try:
yield
finally:
compiled_autograd_enabled_force_eager = False
if not ignore_active_disable_ctx and active_disable_ctx:
yield
else:
# we need to import this, because user might not have imported it if they directly use this context manager
# we need to lazily import it, because of circular dependencies
import torch._inductor.cudagraph_trees
if dynamic:
assert type(dynamic) is bool
(
prior_compiler,
prior_dynamic,
) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(AutogradCompilerInstance, compiler_fn), dynamic
)
if snapshot_verbose_logging_enabled():
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type]
global compiled_autograd_enabled
compiled_autograd_enabled = True
try:
with torch.autograd.set_multithreading_enabled(False):
from torch._dynamo import eval_frame
if eval_frame._stance.stance == "force_eager":
# If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd
# to fall back to eager as well.
global compiled_autograd_enabled_force_eager
compiled_autograd_enabled_force_eager = True
try:
yield
finally:
if not prior_compiler:
compiled_autograd_enabled = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
prior_compiler, prior_dynamic
finally:
compiled_autograd_enabled_force_eager = False
else:
# we need to import this, because user might not have imported it if they directly use this context manager
# we need to lazily import it, because of circular dependencies
import torch._inductor.cudagraph_trees
(
prior_compiler,
prior_dynamic,
) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(AutogradCompilerInstance, compiler_fn), dynamic
)
if snapshot_verbose_logging_enabled():
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type]
global compiled_autograd_enabled
compiled_autograd_enabled = True
try:
with torch.autograd.set_multithreading_enabled(False):
yield
finally:
if not prior_compiler:
compiled_autograd_enabled = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
prior_compiler, prior_dynamic
)
@contextlib.contextmanager
@ -1444,11 +1449,15 @@ def _disable():
) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False)
global compiled_autograd_enabled
compiled_autograd_enabled = False
global active_disable_ctx
if not active_disable_ctx:
active_disable_ctx = True
try:
yield
finally:
if prior_compiler:
compiled_autograd_enabled = True
active_disable_ctx = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
prior_compiler, prior_dynamic
)

View File

@ -816,7 +816,7 @@ class OptimizeContext(_TorchDynamoContext):
assert rebuild_ctx is not None
compiler_fn = rebuild_ctx()
ctx = torch._dynamo.compiled_autograd._enable(
compiler_fn, dynamic=_dynamic
compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False
)
ctx.__enter__()
return functools.partial(ctx.__exit__, None, None, None)