mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs These PRs contain refactors from the main one. They should be reviewed and merged first. - https://github.com/pytorch/pytorch/pull/150458 - https://github.com/pytorch/pytorch/pull/152391 - https://github.com/pytorch/pytorch/pull/152587 # Feature The goals of this PR are twofold. ## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen. In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components. This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR. ## Goal 2: Convert Wrapper IR into FX IR. One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc. It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes. The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source. Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa. # Current status Things that seem to work: - Converted a lot of the most common Python codegen lines to Wrapper IR lines. - Handled the following cases, in addition to what was already in the Memory Planning IR: - Comments - Triton kernels - Extern/fallback kernels - Freeing tensors (`del buf0`) - MultiOutput - Graph outputs - ReinterpretView / StorageBox, for both call args and outputs. - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code. - Prototype FX converter which can handle some of the most common use cases. - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565). - Calling wrapped Triton kernels. - Calling extern kernels and certain types of fallback kernels. - Support both `extern_kernels.*` and `aten.*`. - Support multi-output kernels like `torch.topk`. - Graphs with multiple inputs/outputs. - Training i.e. calling `Tensor.backward()` in a compiled function. - Graph breaks (training). - Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen. Things that don't work: - Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections. - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX. - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR. - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this. # Out-of-tree compilers With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below. ``` from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen class MyCustomBackend(WrapperFxCodegen): def compile_graph(self, gm): # Add 1 to the graph's outputs def compiled_fn(*args): return [x + 1 for x in gm.graph.forward(*args)] return compiled_fn ``` # Example FX graphs This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`. Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) return (buf0,) ``` Here's a more complicated graph that calls a `torch.addmm` extern kernel. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}}) %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) return (buf2,) ``` Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {}) %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {}) %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}}) %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}}) return (buf1, buf2) ``` Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {}) return (buf0_view_buf0_0,) ``` Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`. ``` graph(): %s6 : [num_users=0] = placeholder[target=s6] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}}) return buf0 ``` Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions. ``` graph(): %s10 : [num_users=0] = placeholder[target=s10] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}}) return buf0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
fdadda21b6
commit
a7691140a0
@ -436,6 +436,23 @@ def convert_shape_to_inductor(
|
||||
return [sympy.sympify(i) for i in lst]
|
||||
|
||||
|
||||
def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]:
|
||||
"""
|
||||
Like convert_shape_to_symint, but operates on a single expression.
|
||||
"""
|
||||
from .virtualized import V
|
||||
|
||||
return (
|
||||
i
|
||||
if isinstance(i, int)
|
||||
else (
|
||||
int(i)
|
||||
if isinstance(i, sympy.Integer)
|
||||
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def convert_shape_to_symint(
|
||||
lst: Iterable[Union[int, sympy.Expr]],
|
||||
) -> list[Union[int, torch.SymInt]]:
|
||||
@ -443,20 +460,7 @@ def convert_shape_to_symint(
|
||||
Takes a list of shapes from Inductor and converts them into symints (or just
|
||||
ints if all shapes are static).
|
||||
"""
|
||||
from .virtualized import V
|
||||
|
||||
return [
|
||||
(
|
||||
i
|
||||
if isinstance(i, int)
|
||||
else (
|
||||
int(i)
|
||||
if isinstance(i, sympy.Integer)
|
||||
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
||||
)
|
||||
)
|
||||
for i in lst
|
||||
]
|
||||
return [convert_to_symint(i) for i in lst]
|
||||
|
||||
|
||||
def is_view(op: torch._ops.OpOverload) -> bool:
|
||||
|
Reference in New Issue
Block a user