Don't run compile inside kernel invocation

ghstack-source-id: 455ad8decd40a3fd0e91b4da4fd1c7212806dff1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165687
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-16 22:36:18 -07:00
parent 12fa4192c5
commit 9f6ca9edeb
2 changed files with 59 additions and 0 deletions

View File

@ -242,6 +242,57 @@ class MiscTests(torch._inductor.test_case.TestCase):
self.assertTrue(same(val4, correct1))
self.assertEqual(counter.frame_count, 3)
def test_dynamo_inside_custom_op(self):
cnt = torch._dynamo.testing.InductorAndRecordGraphs()
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
m.define("foo(Tensor x) -> Tensor")
def inner(x):
return x.sin().cos()
def foo_impl(x):
return torch.compile(inner, fullgraph=True, dynamic=True, backend=cnt)(
x
)
m.impl("foo", foo_impl, "CompositeExplicitAutograd")
@torch.compile(fullgraph=True, dynamic=True, backend=cnt1)
def f(x):
return torch.ops.mylib.foo.default(x)
x = torch.randn(3)
res = f(x)
res1 = f(x)
res2 = f(x)
expected = x.sin().cos()
self.assertEqual(res, expected)
self.assertEqual(res1, expected)
self.assertEqual(res2, expected)
self.assertTrue(len(cnt.inductor_graphs), 1)
self.assertTrue(len(cnt1.inductor_graphs), 1)
self.assertExpectedInline(
str(cnt.inductor_graphs[0].graph).strip(),
"""\
graph():
%arg0_1 : [num_users=0] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%arg1_1,), kwargs = {})
%cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%sin,), kwargs = {})
return (cos,)""",
)
self.assertExpectedInline(
str(cnt1.inductor_graphs[0].graph).strip(),
"""\
graph():
%arg0_1 : [num_users=0] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%foo : [num_users=1] = call_function[target=torch.ops.mylib.foo.default](args = (%arg1_1,), kwargs = {})
return (foo,)""",
)
@torch._dynamo.config.patch(accumulated_recompile_limit=1)
def test_dynamo_disabled_in_custom_op_kernels(self):
counters.clear()

View File

@ -847,6 +847,14 @@ class _TorchDynamoContext:
def compile_wrapper(*args: Any, **kwargs: Any) -> Any:
prior = set_eval_frame(None)
try:
# We shouldn't compile inside kernel invocation.
if tracing_context := torch._guards.TracingContext.try_get():
if (
tracing_context.fake_mode is not None
and tracing_context.fake_mode.in_kernel_invocation
):
return fn(*args, **kwargs)
# Skip nested compile - just inline the function
if is_fx_symbolic_tracing():
if config.error_on_nested_fx_trace:
raise RuntimeError(