Don't run compile inside kernel invocation (#165687)

When we call torch.compile during fake tensor prop, we shouldn't actually compile because we can't guarantee that the compiled artifact can be fake tensor prop-d. (for example, inductor backend). Instead we should just skip compiling. However, the inner compile will be triggered when being executed in runtime.

Fixes: https://github.com/pytorch/pytorch/issues/151328

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165687
Approved by: https://github.com/zou3519
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-16 22:36:18 -07:00
committed by PyTorch MergeBot
parent fae74cd52f
commit 08c97b4a1f
2 changed files with 59 additions and 0 deletions

View File

@ -242,6 +242,57 @@ class MiscTests(torch._inductor.test_case.TestCase):
self.assertTrue(same(val4, correct1))
self.assertEqual(counter.frame_count, 3)
def test_dynamo_inside_custom_op(self):
cnt = torch._dynamo.testing.InductorAndRecordGraphs()
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
m.define("foo(Tensor x) -> Tensor")
def inner(x):
return x.sin().cos()
def foo_impl(x):
return torch.compile(inner, fullgraph=True, dynamic=True, backend=cnt)(
x
)
m.impl("foo", foo_impl, "CompositeExplicitAutograd")
@torch.compile(fullgraph=True, dynamic=True, backend=cnt1)
def f(x):
return torch.ops.mylib.foo.default(x)
x = torch.randn(3)
res = f(x)
res1 = f(x)
res2 = f(x)
expected = x.sin().cos()
self.assertEqual(res, expected)
self.assertEqual(res1, expected)
self.assertEqual(res2, expected)
self.assertTrue(len(cnt.inductor_graphs), 1)
self.assertTrue(len(cnt1.inductor_graphs), 1)
self.assertExpectedInline(
str(cnt.inductor_graphs[0].graph).strip(),
"""\
graph():
%arg0_1 : [num_users=0] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%arg1_1,), kwargs = {})
%cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%sin,), kwargs = {})
return (cos,)""",
)
self.assertExpectedInline(
str(cnt1.inductor_graphs[0].graph).strip(),
"""\
graph():
%arg0_1 : [num_users=0] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%foo : [num_users=1] = call_function[target=torch.ops.mylib.foo.default](args = (%arg1_1,), kwargs = {})
return (foo,)""",
)
@torch._dynamo.config.patch(accumulated_recompile_limit=1)
def test_dynamo_disabled_in_custom_op_kernels(self):
counters.clear()

View File

@ -847,6 +847,14 @@ class _TorchDynamoContext:
def compile_wrapper(*args: Any, **kwargs: Any) -> Any:
prior = set_eval_frame(None)
try:
# We shouldn't compile inside kernel invocation.
if tracing_context := torch._guards.TracingContext.try_get():
if (
tracing_context.fake_mode is not None
and tracing_context.fake_mode.in_kernel_invocation
):
return fn(*args, **kwargs)
# Skip nested compile - just inline the function
if is_fx_symbolic_tracing():
if config.error_on_nested_fx_trace:
raise RuntimeError(