From 9f6ca9edeb18b79eb9ebb09648be64476f4380cc Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Thu, 16 Oct 2025 22:36:18 -0700 Subject: [PATCH] Don't run compile inside kernel invocation ghstack-source-id: 455ad8decd40a3fd0e91b4da4fd1c7212806dff1 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165687 --- test/dynamo/test_misc.py | 51 +++++++++++++++++++++++++++++++++++++ torch/_dynamo/eval_frame.py | 8 ++++++ 2 files changed, 59 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 365f5f1b1693..43e79960fff8 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -242,6 +242,57 @@ class MiscTests(torch._inductor.test_case.TestCase): self.assertTrue(same(val4, correct1)) self.assertEqual(counter.frame_count, 3) + def test_dynamo_inside_custom_op(self): + cnt = torch._dynamo.testing.InductorAndRecordGraphs() + cnt1 = torch._dynamo.testing.InductorAndRecordGraphs() + + with torch.library._scoped_library("mylib", "FRAGMENT") as m: + m.define("foo(Tensor x) -> Tensor") + + def inner(x): + return x.sin().cos() + + def foo_impl(x): + return torch.compile(inner, fullgraph=True, dynamic=True, backend=cnt)( + x + ) + + m.impl("foo", foo_impl, "CompositeExplicitAutograd") + + @torch.compile(fullgraph=True, dynamic=True, backend=cnt1) + def f(x): + return torch.ops.mylib.foo.default(x) + + x = torch.randn(3) + res = f(x) + res1 = f(x) + res2 = f(x) + expected = x.sin().cos() + self.assertEqual(res, expected) + self.assertEqual(res1, expected) + self.assertEqual(res2, expected) + self.assertTrue(len(cnt.inductor_graphs), 1) + self.assertTrue(len(cnt1.inductor_graphs), 1) + self.assertExpectedInline( + str(cnt.inductor_graphs[0].graph).strip(), + """\ +graph(): + %arg0_1 : [num_users=0] = placeholder[target=arg0_1] + %arg1_1 : [num_users=1] = placeholder[target=arg1_1] + %sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%arg1_1,), kwargs = {}) + %cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%sin,), kwargs = {}) + return (cos,)""", + ) + self.assertExpectedInline( + str(cnt1.inductor_graphs[0].graph).strip(), + """\ +graph(): + %arg0_1 : [num_users=0] = placeholder[target=arg0_1] + %arg1_1 : [num_users=1] = placeholder[target=arg1_1] + %foo : [num_users=1] = call_function[target=torch.ops.mylib.foo.default](args = (%arg1_1,), kwargs = {}) + return (foo,)""", + ) + @torch._dynamo.config.patch(accumulated_recompile_limit=1) def test_dynamo_disabled_in_custom_op_kernels(self): counters.clear() diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 472905eca6c1..036f1ba7d01a 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -847,6 +847,14 @@ class _TorchDynamoContext: def compile_wrapper(*args: Any, **kwargs: Any) -> Any: prior = set_eval_frame(None) try: + # We shouldn't compile inside kernel invocation. + if tracing_context := torch._guards.TracingContext.try_get(): + if ( + tracing_context.fake_mode is not None + and tracing_context.fake_mode.in_kernel_invocation + ): + return fn(*args, **kwargs) + # Skip nested compile - just inline the function if is_fx_symbolic_tracing(): if config.error_on_nested_fx_trace: raise RuntimeError(