diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 9587d94e188b..d6df07740e08 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -99,6 +99,7 @@ def reset(): torch._logging.set_logs(compiled_autograd_verbose=False) config.compiled_autograd = False compiled_autograd.reset() + torch._dynamo.utils.counters.clear() class TestCompiledAutograd(TestCase): @@ -706,6 +707,44 @@ main() self.assertEqual(expected, actual) self.assertEqual(counters["compiled_autograd"]["captures"], 2) + @parametrize("api", ("compile", "optimize")) + @parametrize("backend", ("eager", "aot_eager", "inductor")) + def test_compile_api_disable(self, api, backend): + def wrap(fn, backend): + if api == "compile": + return torch.compile(fn, backend=backend) + elif api == "optimize": + return torch._dynamo.optimize(backend)(fn) + + def fn(model, inputs): + res = [] + for inp in inputs: + result = model(inp).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + inputs = [ + torch.randn([1, 4]), + torch.randn([2, 4]), + torch.randn([3, 4]), + ] + + expected = fn(model, inputs) + with config.patch(compiled_autograd=True): + compiled_fn = wrap(fn, backend) + with torch._dynamo.compiled_autograd._disable(): + actual = compiled_fn(model, inputs) + self.assertEqual(expected, actual) + self.assertTrue("compiled_autograd" not in counters) + @parametrize("backend", ("eager", "aot_eager", "inductor")) def test_optimize_assert(self, backend): # can be merged into the test above once we support diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 2cb8a3b6ff8a..548eb45b3863 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1373,9 +1373,11 @@ compiled_autograd_enabled_force_eager = False # global flag to check if we are processing graphs produced from a compiled autograd graph in_compiled_autograd_region = False +active_disable_ctx = False + @contextlib.contextmanager -def _enable(compiler_fn, dynamic: bool = True): +def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): # The entrypoint to enable CA. # It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather # than using this context manager directly. If you are torch.compiling the corresponding @@ -1396,44 +1398,47 @@ def _enable(compiler_fn, dynamic: bool = True): # - dynamic: Whether compiled autograd will treat tensors in the autograd graph (params, activations) as dynamic. # This doesn't affect the dynamic configuration of the compilation wrapper. - if dynamic: - assert type(dynamic) is bool - - from torch._dynamo import eval_frame - - if eval_frame._stance.stance == "force_eager": - # If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd - # to fall back to eager as well. - global compiled_autograd_enabled_force_eager - compiled_autograd_enabled_force_eager = True - try: - yield - finally: - compiled_autograd_enabled_force_eager = False + if not ignore_active_disable_ctx and active_disable_ctx: + yield else: - # we need to import this, because user might not have imported it if they directly use this context manager - # we need to lazily import it, because of circular dependencies - import torch._inductor.cudagraph_trees + if dynamic: + assert type(dynamic) is bool - ( - prior_compiler, - prior_dynamic, - ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler( - functools.partial(AutogradCompilerInstance, compiler_fn), dynamic - ) - if snapshot_verbose_logging_enabled(): - torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type] - global compiled_autograd_enabled - compiled_autograd_enabled = True - try: - with torch.autograd.set_multithreading_enabled(False): + from torch._dynamo import eval_frame + + if eval_frame._stance.stance == "force_eager": + # If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd + # to fall back to eager as well. + global compiled_autograd_enabled_force_eager + compiled_autograd_enabled_force_eager = True + try: yield - finally: - if not prior_compiler: - compiled_autograd_enabled = False - torch._C._dynamo.compiled_autograd.set_autograd_compiler( - prior_compiler, prior_dynamic + finally: + compiled_autograd_enabled_force_eager = False + else: + # we need to import this, because user might not have imported it if they directly use this context manager + # we need to lazily import it, because of circular dependencies + import torch._inductor.cudagraph_trees + + ( + prior_compiler, + prior_dynamic, + ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler( + functools.partial(AutogradCompilerInstance, compiler_fn), dynamic ) + if snapshot_verbose_logging_enabled(): + torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type] + global compiled_autograd_enabled + compiled_autograd_enabled = True + try: + with torch.autograd.set_multithreading_enabled(False): + yield + finally: + if not prior_compiler: + compiled_autograd_enabled = False + torch._C._dynamo.compiled_autograd.set_autograd_compiler( + prior_compiler, prior_dynamic + ) @contextlib.contextmanager @@ -1444,11 +1449,15 @@ def _disable(): ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False) global compiled_autograd_enabled compiled_autograd_enabled = False + global active_disable_ctx + if not active_disable_ctx: + active_disable_ctx = True try: yield finally: if prior_compiler: compiled_autograd_enabled = True + active_disable_ctx = False torch._C._dynamo.compiled_autograd.set_autograd_compiler( prior_compiler, prior_dynamic ) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 291b175fb031..88d40c6dafa4 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -816,7 +816,7 @@ class OptimizeContext(_TorchDynamoContext): assert rebuild_ctx is not None compiler_fn = rebuild_ctx() ctx = torch._dynamo.compiled_autograd._enable( - compiler_fn, dynamic=_dynamic + compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False ) ctx.__enter__() return functools.partial(ctx.__exit__, None, None, None)