[export] Fix custom ops in subgraphs (#160004)

Fixes https://github.com/pytorch/pytorch/issues/159995

Currently there are two problems with extern kernels in subgraphs:
1. They don't get serialized to the extern kernel json file because we only look at the toplevel graph.
2. Since the scope of each extern_kernel list is within its own subgraph, the indices referencing the operator is messed up because each subgraph will start counting from 0.

So, this PR moves the extern_kernels list to a global view (under virtualized) so that we can count the extern kernels across subgraphs and the toplevel graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160004
Approved by: https://github.com/ydwu4
This commit is contained in:
angelayi
2025-08-18 15:42:17 +00:00
committed by PyTorch MergeBot
parent 1091165826
commit 3c8c509a9c
7 changed files with 69 additions and 10 deletions

View File

@ -6772,6 +6772,49 @@ class AOTInductorTestsTemplate:
# compare against eager
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
def test_custom_op_in_subgraph(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo_add1",
"(Tensor a) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo_add1", "CompositeExplicitAutograd", lib=lib)
@torch.library.register_fake("mylib::foo_add1", lib=lib)
def foo_add1_impl(a: torch.Tensor) -> torch.Tensor:
return a + 1
torch.library.define(
"mylib::foo_add2",
"(Tensor a) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo_add2", "CompositeExplicitAutograd", lib=lib)
@torch.library.register_fake("mylib::foo_add2", lib=lib)
def foo_add2_impl(a: torch.Tensor) -> torch.Tensor:
return a + 2
class M(torch.nn.Module):
def forward(self, x):
return torch.cond(
x.shape[0] < 5,
torch.ops.mylib.foo_add1,
torch.ops.mylib.foo_add2,
(x,),
)
list_example_inputs = [
(torch.ones(6, device=self.device),),
(torch.ones(3, device=self.device),),
]
self.check_model_with_multiple_inputs(
M(), list_example_inputs, dynamic_shapes=({0: Dim.DYNAMIC},)
)
def test_clamp_decomposition(self):
class Model1(torch.nn.Module):
def forward(self, x):

View File

@ -70,6 +70,7 @@ CPU_TEST_FAILURES = {
"test_cond_with_multiple_outputs": fail_minimal_arrayref_interface(),
"test_cond_with_parameters": fail_minimal_arrayref_interface(),
"test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(),
"test_custom_op_in_subgraph": fail_minimal_arrayref_interface(),
"test_cond_share_predicte": fail_stack_allocation(is_skip=True),
"test_cond_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(),
"test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(),

View File

@ -2604,7 +2604,7 @@ if (!custom_op_wrapper) {
"AtenTensorHandle", tensor_call_args, force_mutable=True
)
extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1
extern_kernel_node_index = len(V.extern_kernel_nodes) - 1
self.writeline(
f"aoti_torch_proxy_executor_call_function(proxy_executor, "
f"{extern_kernel_node_index}, "

View File

@ -1390,7 +1390,10 @@ class _InProcessFxCompile(FxCompile):
is_backward=is_backward,
is_const_graph=True,
)
with V.set_graph_handler(const_graph):
with (
V.set_graph_handler(const_graph),
V.set_extern_kernel_nodes([]),
):
assert cpp_wrapper, "AOT mode only supports C++ wrapper"
const_graph.run()
const_wrapper_code, const_kernel_code = (
@ -1425,7 +1428,7 @@ class _InProcessFxCompile(FxCompile):
# We are going to start code generating runtime asserts, so make sure
# you don't start adding new ones in the lowering process
graph.freeze_runtime_asserts()
with V.set_graph_handler(graph):
with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]):
graph.run(*example_inputs)
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
if graph.graph_outputs is not None:
@ -1472,11 +1475,9 @@ class _InProcessFxCompile(FxCompile):
)
serialized_extern_kernel_nodes = None
if graph.extern_kernel_nodes:
if V.extern_kernel_nodes:
serialized_extern_kernel_nodes = (
graph.extern_node_serializer(
graph.extern_kernel_nodes
)
graph.extern_node_serializer(V.extern_kernel_nodes)
)
output_code_log.debug(
"Serialized Extern Kernel Nodes: \n%s",

View File

@ -392,8 +392,6 @@ class GraphLowering(torch.fx.Interpreter):
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
# See `ProxyExecutor Design Note` in ir.py for more details
self.extern_kernel_nodes: list[ir.ExternKernelNode] = []
from torch._inductor.extern_node_serializer import extern_node_json_serializer

View File

@ -7656,7 +7656,7 @@ class FallbackKernel(ExternKernelAlloc):
),
)
V.graph.extern_kernel_nodes.append(node)
V.extern_kernel_nodes.append(node)
return [*args, *ordered_kwargs]

View File

@ -80,6 +80,7 @@ if TYPE_CHECKING:
from torch._inductor.codegen.cpp_utils import LocalBufferContext
from torch._inductor.debug import DebugContext
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import ExternKernelNode
from torch._inductor.loop_body import InterpreterShim
from torch._subclasses import FakeTensorMode
@ -183,6 +184,9 @@ _ops: Virtualized[OpsHandler[Any]] = Virtualized(
"ops", cast(type[OpsHandler[Any]], MockHandler)
)
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
_extern_kernel_nodes: Virtualized[list[ExternKernelNode]] = Virtualized(
"extern_kernel_nodes", NullHandler
)
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
_kernel: Virtualized[NullKernelHandler] = Virtualized(
@ -343,6 +347,9 @@ class _V:
)
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
set_extern_kernel_nodes: Callable[[list[ExternKernelNode]], Any] = (
_extern_kernel_nodes._set_handler
)
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
@ -368,6 +375,15 @@ class _V:
"""The graph currently being generated"""
return _graph._get_handler()
@property
def extern_kernel_nodes(self) -> list[ExternKernelNode]:
"""
The extern_kernel_nodes needed for the entire graph, including the
subgraphs.
See `ProxyExecutor Design Note` in ir.py for more details
"""
return _extern_kernel_nodes._get_handler()
@property
def real_inputs(self):
"""non-fake example inputs"""