mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[export] Fix custom ops in subgraphs (#160004)
Fixes https://github.com/pytorch/pytorch/issues/159995 Currently there are two problems with extern kernels in subgraphs: 1. They don't get serialized to the extern kernel json file because we only look at the toplevel graph. 2. Since the scope of each extern_kernel list is within its own subgraph, the indices referencing the operator is messed up because each subgraph will start counting from 0. So, this PR moves the extern_kernels list to a global view (under virtualized) so that we can count the extern kernels across subgraphs and the toplevel graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160004 Approved by: https://github.com/ydwu4
This commit is contained in:
committed by
PyTorch MergeBot
parent
1091165826
commit
3c8c509a9c
@ -6772,6 +6772,49 @@ class AOTInductorTestsTemplate:
|
||||
# compare against eager
|
||||
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
|
||||
|
||||
def test_custom_op_in_subgraph(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo_add1",
|
||||
"(Tensor a) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo_add1", "CompositeExplicitAutograd", lib=lib)
|
||||
@torch.library.register_fake("mylib::foo_add1", lib=lib)
|
||||
def foo_add1_impl(a: torch.Tensor) -> torch.Tensor:
|
||||
return a + 1
|
||||
|
||||
torch.library.define(
|
||||
"mylib::foo_add2",
|
||||
"(Tensor a) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo_add2", "CompositeExplicitAutograd", lib=lib)
|
||||
@torch.library.register_fake("mylib::foo_add2", lib=lib)
|
||||
def foo_add2_impl(a: torch.Tensor) -> torch.Tensor:
|
||||
return a + 2
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.cond(
|
||||
x.shape[0] < 5,
|
||||
torch.ops.mylib.foo_add1,
|
||||
torch.ops.mylib.foo_add2,
|
||||
(x,),
|
||||
)
|
||||
|
||||
list_example_inputs = [
|
||||
(torch.ones(6, device=self.device),),
|
||||
(torch.ones(3, device=self.device),),
|
||||
]
|
||||
self.check_model_with_multiple_inputs(
|
||||
M(), list_example_inputs, dynamic_shapes=({0: Dim.DYNAMIC},)
|
||||
)
|
||||
|
||||
def test_clamp_decomposition(self):
|
||||
class Model1(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -70,6 +70,7 @@ CPU_TEST_FAILURES = {
|
||||
"test_cond_with_multiple_outputs": fail_minimal_arrayref_interface(),
|
||||
"test_cond_with_parameters": fail_minimal_arrayref_interface(),
|
||||
"test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(),
|
||||
"test_custom_op_in_subgraph": fail_minimal_arrayref_interface(),
|
||||
"test_cond_share_predicte": fail_stack_allocation(is_skip=True),
|
||||
"test_cond_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(),
|
||||
"test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(),
|
||||
|
@ -2604,7 +2604,7 @@ if (!custom_op_wrapper) {
|
||||
"AtenTensorHandle", tensor_call_args, force_mutable=True
|
||||
)
|
||||
|
||||
extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1
|
||||
extern_kernel_node_index = len(V.extern_kernel_nodes) - 1
|
||||
self.writeline(
|
||||
f"aoti_torch_proxy_executor_call_function(proxy_executor, "
|
||||
f"{extern_kernel_node_index}, "
|
||||
|
@ -1390,7 +1390,10 @@ class _InProcessFxCompile(FxCompile):
|
||||
is_backward=is_backward,
|
||||
is_const_graph=True,
|
||||
)
|
||||
with V.set_graph_handler(const_graph):
|
||||
with (
|
||||
V.set_graph_handler(const_graph),
|
||||
V.set_extern_kernel_nodes([]),
|
||||
):
|
||||
assert cpp_wrapper, "AOT mode only supports C++ wrapper"
|
||||
const_graph.run()
|
||||
const_wrapper_code, const_kernel_code = (
|
||||
@ -1425,7 +1428,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
# We are going to start code generating runtime asserts, so make sure
|
||||
# you don't start adding new ones in the lowering process
|
||||
graph.freeze_runtime_asserts()
|
||||
with V.set_graph_handler(graph):
|
||||
with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]):
|
||||
graph.run(*example_inputs)
|
||||
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
|
||||
if graph.graph_outputs is not None:
|
||||
@ -1472,11 +1475,9 @@ class _InProcessFxCompile(FxCompile):
|
||||
)
|
||||
|
||||
serialized_extern_kernel_nodes = None
|
||||
if graph.extern_kernel_nodes:
|
||||
if V.extern_kernel_nodes:
|
||||
serialized_extern_kernel_nodes = (
|
||||
graph.extern_node_serializer(
|
||||
graph.extern_kernel_nodes
|
||||
)
|
||||
graph.extern_node_serializer(V.extern_kernel_nodes)
|
||||
)
|
||||
output_code_log.debug(
|
||||
"Serialized Extern Kernel Nodes: \n%s",
|
||||
|
@ -392,8 +392,6 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
||||
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
||||
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
|
||||
# See `ProxyExecutor Design Note` in ir.py for more details
|
||||
self.extern_kernel_nodes: list[ir.ExternKernelNode] = []
|
||||
|
||||
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
||||
|
||||
|
@ -7656,7 +7656,7 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
),
|
||||
)
|
||||
|
||||
V.graph.extern_kernel_nodes.append(node)
|
||||
V.extern_kernel_nodes.append(node)
|
||||
|
||||
return [*args, *ordered_kwargs]
|
||||
|
||||
|
@ -80,6 +80,7 @@ if TYPE_CHECKING:
|
||||
from torch._inductor.codegen.cpp_utils import LocalBufferContext
|
||||
from torch._inductor.debug import DebugContext
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.ir import ExternKernelNode
|
||||
from torch._inductor.loop_body import InterpreterShim
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
||||
@ -183,6 +184,9 @@ _ops: Virtualized[OpsHandler[Any]] = Virtualized(
|
||||
"ops", cast(type[OpsHandler[Any]], MockHandler)
|
||||
)
|
||||
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
|
||||
_extern_kernel_nodes: Virtualized[list[ExternKernelNode]] = Virtualized(
|
||||
"extern_kernel_nodes", NullHandler
|
||||
)
|
||||
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
|
||||
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
|
||||
_kernel: Virtualized[NullKernelHandler] = Virtualized(
|
||||
@ -343,6 +347,9 @@ class _V:
|
||||
)
|
||||
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
|
||||
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
|
||||
set_extern_kernel_nodes: Callable[[list[ExternKernelNode]], Any] = (
|
||||
_extern_kernel_nodes._set_handler
|
||||
)
|
||||
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
|
||||
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
|
||||
set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
|
||||
@ -368,6 +375,15 @@ class _V:
|
||||
"""The graph currently being generated"""
|
||||
return _graph._get_handler()
|
||||
|
||||
@property
|
||||
def extern_kernel_nodes(self) -> list[ExternKernelNode]:
|
||||
"""
|
||||
The extern_kernel_nodes needed for the entire graph, including the
|
||||
subgraphs.
|
||||
See `ProxyExecutor Design Note` in ir.py for more details
|
||||
"""
|
||||
return _extern_kernel_nodes._get_handler()
|
||||
|
||||
@property
|
||||
def real_inputs(self):
|
||||
"""non-fake example inputs"""
|
||||
|
Reference in New Issue
Block a user