diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 43e79960fff8..3c3a5ae7bb08 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -242,6 +242,67 @@ class MiscTests(torch._inductor.test_case.TestCase): self.assertTrue(same(val4, correct1)) self.assertEqual(counter.frame_count, 3) + def test_compile_non_infra_inside_compile(self): + from torch.utils._python_dispatch import TorchDispatchMode + + backend = torch._dynamo.testing.EagerAndRecordGraphs() + + class YoloMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + out = torch.compile(func, backend=backend, fullgraph=True)( + *args, **kwargs + ) + return out + + x = torch.randn(5) + with YoloMode(): + out = torch.add(x, x) + + self.assertEqual(len(backend.graphs), 1) + + def test_compile_non_infra_empty(self): + from torch.utils._python_dispatch import TorchDispatchMode + + backend = torch._dynamo.testing.EagerAndRecordGraphs() + + class YoloMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + return torch.ops.aten.mul.Tensor(args[0], args[1]) + + x = torch.ones(5) + with YoloMode(): + out = torch.compile(torch.add, backend=backend, fullgraph=True)(x, x) + + self.assertEqual(out.sum().item(), 5.0) + self.assertEqual(len(backend.graphs), 0) + + def test_compile_non_infra_multiple(self): + from torch.utils._python_dispatch import TorchDispatchMode + + backend3 = torch._dynamo.testing.EagerAndRecordGraphs() + backend2 = torch._dynamo.testing.EagerAndRecordGraphs() + backend = torch._dynamo.testing.EagerAndRecordGraphs() + + class YoloMode2(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + out = torch.compile(func, backend=backend3, fullgraph=True)( + *args, **kwargs + ) + return out + + class YoloMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + out = torch.compile(torch.add, backend=backend2, fullgraph=True)(args[0], args[1]) + return out + + x = torch.ones(5) + with YoloMode(), YoloMode2(): + torch.compile(lambda x, y: torch.add(x, y), fullgraph=True, backend=backend)(x, x) + + self.assertEqual(len(backend2.graphs), 1) + self.assertEqual(len(backend3.graphs), 0) + self.assertEqual(len(backend.graphs), 0) + def test_dynamo_inside_custom_op(self): cnt = torch._dynamo.testing.InductorAndRecordGraphs() cnt1 = torch._dynamo.testing.InductorAndRecordGraphs() diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index cf7392763e6c..173f38cac29d 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -74,8 +74,8 @@ from torch.monitor import _WaitCounter from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils._python_dispatch import ( _disable_current_modes, - is_in_any_mode_without_ignore_compile_internals, is_in_torch_dispatch_mode, + any_torch_dispatch_mode_on_stack ) from torch.utils._traceback import CapturedTraceback, format_traceback_short @@ -1940,10 +1940,6 @@ class ConvertFrameProtocol(typing.Protocol): ) -> ConvertFrameReturn: ... -def should_skip_due_to_torch_dispatch_mode() -> bool: - return is_in_any_mode_without_ignore_compile_internals() - - class CatchErrorsWrapper: def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None: functools.wraps(callback)(self) @@ -1964,13 +1960,15 @@ class CatchErrorsWrapper: has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code) else: has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code) + + should_skip_due_to_torch_dispatch = any_torch_dispatch_mode_on_stack(include_infra_modes=False, respect_ignore_compile_internals=True) if ( # TODO: the first condition is not covered by any test has_started_execution or is_skipfile or config.disable or ( - should_skip_due_to_torch_dispatch_mode() + should_skip_due_to_torch_dispatch and not getattr(self._torchdynamo_orig_backend, "_export", False) ) ): @@ -1979,7 +1977,7 @@ class CatchErrorsWrapper: skip_reason = "traced frame already" elif trace_rules.check(frame.f_code): skip_reason = "in skipfiles" - elif is_in_torch_dispatch_mode(include_infra_modes=False): + elif should_skip_due_to_torch_dispatch: skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile" else: skip_reason = "dynamo tracing is disabled" diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 7d844cd3f91b..845ce617cffe 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -48,6 +48,35 @@ def is_in_torch_dispatch_mode(include_infra_modes: bool = True) -> bool: def is_in_any_mode_without_ignore_compile_internals() -> bool: return _is_in_any_mode_without_ignore_compile_internals +def any_torch_dispatch_mode_on_stack( + *, + include_infra_modes=True, + respect_ignore_compile_internals=False, +) -> bool: + """ + Check if there are any ambient (non-entered) modes on the stack. + Returns True if there are modes that are on the stack but NOT entered. + """ + stack_len = torch._C._len_torch_dispatch_stack() + + for idx in range(stack_len): + mode = _get_dispatch_stack_at(idx) + print("MODE", mode) + + # Apply filters first + if mode.is_infra_mode(): + if not include_infra_modes: + continue + if mode.ignore_compile_internals(): + print("HUUSH") + if respect_ignore_compile_internals: + continue + print("GOT HERE") + breakpoint() + return True + breakpoint() + return False + class TorchDispatchMode: """