mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Don't run compile inside kernel invocation (#165687)
When we call torch.compile during fake tensor prop, we shouldn't actually compile because we can't guarantee that the compiled artifact can be fake tensor prop-d. (for example, inductor backend). Instead we should just skip compiling. However, the inner compile will be triggered when being executed in runtime. Fixes: https://github.com/pytorch/pytorch/issues/151328 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165687 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
fae74cd52f
commit
08c97b4a1f
@ -242,6 +242,57 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
self.assertTrue(same(val4, correct1))
|
||||
self.assertEqual(counter.frame_count, 3)
|
||||
|
||||
def test_dynamo_inside_custom_op(self):
|
||||
cnt = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
||||
m.define("foo(Tensor x) -> Tensor")
|
||||
|
||||
def inner(x):
|
||||
return x.sin().cos()
|
||||
|
||||
def foo_impl(x):
|
||||
return torch.compile(inner, fullgraph=True, dynamic=True, backend=cnt)(
|
||||
x
|
||||
)
|
||||
|
||||
m.impl("foo", foo_impl, "CompositeExplicitAutograd")
|
||||
|
||||
@torch.compile(fullgraph=True, dynamic=True, backend=cnt1)
|
||||
def f(x):
|
||||
return torch.ops.mylib.foo.default(x)
|
||||
|
||||
x = torch.randn(3)
|
||||
res = f(x)
|
||||
res1 = f(x)
|
||||
res2 = f(x)
|
||||
expected = x.sin().cos()
|
||||
self.assertEqual(res, expected)
|
||||
self.assertEqual(res1, expected)
|
||||
self.assertEqual(res2, expected)
|
||||
self.assertTrue(len(cnt.inductor_graphs), 1)
|
||||
self.assertTrue(len(cnt1.inductor_graphs), 1)
|
||||
self.assertExpectedInline(
|
||||
str(cnt.inductor_graphs[0].graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%arg0_1 : [num_users=0] = placeholder[target=arg0_1]
|
||||
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
|
||||
%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%arg1_1,), kwargs = {})
|
||||
%cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%sin,), kwargs = {})
|
||||
return (cos,)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(cnt1.inductor_graphs[0].graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%arg0_1 : [num_users=0] = placeholder[target=arg0_1]
|
||||
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
|
||||
%foo : [num_users=1] = call_function[target=torch.ops.mylib.foo.default](args = (%arg1_1,), kwargs = {})
|
||||
return (foo,)""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(accumulated_recompile_limit=1)
|
||||
def test_dynamo_disabled_in_custom_op_kernels(self):
|
||||
counters.clear()
|
||||
|
@ -847,6 +847,14 @@ class _TorchDynamoContext:
|
||||
def compile_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
prior = set_eval_frame(None)
|
||||
try:
|
||||
# We shouldn't compile inside kernel invocation.
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
if (
|
||||
tracing_context.fake_mode is not None
|
||||
and tracing_context.fake_mode.in_kernel_invocation
|
||||
):
|
||||
return fn(*args, **kwargs)
|
||||
# Skip nested compile - just inline the function
|
||||
if is_fx_symbolic_tracing():
|
||||
if config.error_on_nested_fx_trace:
|
||||
raise RuntimeError(
|
||||
|
Reference in New Issue
Block a user