Added a new configuration option `cutlass_enabled_ops` that allows users to control which operations use CUTLASS lowerings. By default, CUTLASS is enabled for all operations (maintaining backward compatibility), but users can now selectively enable it only for specific operations to optimize compilation time.
**Fixes #155718**
## Usage Examples
```bash
# Enable CUTLASS for all operations (default behavior)
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="ALL"
# Enable CUTLASS only for matrix multiplication operations
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="mm,addmm"
# Enable CUTLASS only for batch operations
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS="bmm,baddbmm"
# Disable CUTLASS for all operations
export TORCHINDUCTOR_CUTLASS_ENABLED_OPS=""
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155770
Approved by: https://github.com/henrylhtsang
... and fix ck-tile instances not being generated due to incorrect caching
### Testing
Added test cases for CKTILE instances
```
pytest test/inductor/test_ck_backend.py -k gemm_backends_CKTILE
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155294
Approved by: https://github.com/coconutruben
For graph partition, `write_get_raw_stream_header_once` is done once so the autotune code may not have the header. This PR additionally calls `write_get_raw_stream_header` in `codegen_device_guard_enter` before `get_raw_stream` is used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154698
Approved by: https://github.com/oulgen
Fix for https://github.com/pytorch/pytorch/issues/152425
inductor specializes whether or not a tensor is 16-bit aligned on the first invocation. then, on subsequent invocations, if we inferred alignment but are passed a non-aligned tensor we clone the tensor.
If we infer alignment, then run with unaligned, and mutate the input, we need to reflect back the mutation to the input. This pr adds back that mutation.
We could have also been less aggressive about inferring alignment for mutated tensors, but that has a pretty perf hit.See the following benchmark:
```
import torch
t = torch.rand(4096 * 4096, device="cuda", dtype=torch.float16)
@torch.compile(dynamic=False)
def foo(x):
return x.add_(1)
import triton
print(triton.testing.do_bench(lambda: foo(t[:-1])))
torch._dynamo.reset()
print(triton.testing.do_bench(lambda: foo(t[1:])))
```
gives
```
0.04063070610165596
0.07613472988113162
```
So almost twice as slow for non-aligned tensors. Tensors changing alignment is a relatively rare case.
In the future, we could considering a multi-kernel approach, or codegening a triton kernel that does most of the loads with aligned instructions, and a prologue/epilogue of un-alignment. But, it's yet to be seen this is a huge issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154442
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
Summary: We enable the activation quantization in the forward pass, and users can customize the dtype they want to quantize.
Test Plan:
# unit test
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization -- test_activation_quantization_aten
```
Buck UI: https://www.internalfb.com/buck2/776d3911-bb86-4ac8-a527-540cf1510b9d
Test UI: https://www.internalfb.com/intern/testinfra/testrun/4785074873051017
Network: Up: 4.3MiB Down: 42MiB (reSessionID-fef7e727-68b1-4645-a519-5652854df38d)
Executing actions. Remaining 0/4 6.7s exec time total
Command: test. Finished 2 local
Time elapsed: 3:11.5s
Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0
# E2E
### how to enable (you can overrite the dtype, if nothing given, the default is fp8)
```
post_grad_fusion_options={
"activation_quantization_aten_pass": {"quant_type": "torch.float8_e5m2"}
},
```
Differential Revision: D70522237
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148380
Approved by: https://github.com/Mingming-Ding, https://github.com/Hahu803
# Sub-PRs
These PRs contain refactors from the main one. They should be reviewed and merged first.
- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587
# Feature
The goals of this PR are twofold.
## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.
In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.
This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.
## Goal 2: Convert Wrapper IR into FX IR.
One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.
It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.
The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.
Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.
# Current status
Things that seem to work:
- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
- Handled the following cases, in addition to what was already in the Memory Planning IR:
- Comments
- Triton kernels
- Extern/fallback kernels
- Freeing tensors (`del buf0`)
- MultiOutput
- Graph outputs
- ReinterpretView / StorageBox, for both call args and outputs.
- FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
- Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
- Calling wrapped Triton kernels.
- Calling extern kernels and certain types of fallback kernels.
- Support both `extern_kernels.*` and `aten.*`.
- Support multi-output kernels like `torch.topk`.
- Graphs with multiple inputs/outputs.
- Training i.e. calling `Tensor.backward()` in a compiled function.
- Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.
Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
- Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
- Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
- CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.
# Out-of-tree compilers
With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.
```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen
class MyCustomBackend(WrapperFxCodegen):
def compile_graph(self, gm):
# Add 1 to the graph's outputs
def compiled_fn(*args):
return [x + 1 for x in gm.graph.forward(*args)]
return compiled_fn
```
# Example FX graphs
This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.
Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
return (buf0,)
```
Here's a more complicated graph that calls a `torch.addmm` extern kernel.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
%buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
%addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
%delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
return (buf2,)
```
Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
%buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
%buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
%delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
%triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
return (buf1, buf2)
```
Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
%buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
return (buf0_view_buf0_0,)
```
Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
%s6 : [num_users=0] = placeholder[target=s6]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
return buf0
```
Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
%s10 : [num_users=0] = placeholder[target=s10]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
return buf0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
Summary:
Allow TMA workspace creation allow specification for `num_programs`, which defaults to `num_sms` when not specified.
We need a total `num_programs * num_tma_descriptors` no. of descriptors for a kernel.
Test Plan: CI.
Differential Revision: D74189599
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152844
Approved by: https://github.com/drisspg
# Sub-PRs
These PRs contain refactors from the main one. They should be reviewed and merged first.
- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587
# Feature
The goals of this PR are twofold.
## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.
In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.
This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.
## Goal 2: Convert Wrapper IR into FX IR.
One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.
It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.
The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.
Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.
# Current status
Things that seem to work:
- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
- Handled the following cases, in addition to what was already in the Memory Planning IR:
- Comments
- Triton kernels
- Extern/fallback kernels
- Freeing tensors (`del buf0`)
- MultiOutput
- Graph outputs
- ReinterpretView / StorageBox, for both call args and outputs.
- FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
- Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
- Calling wrapped Triton kernels.
- Calling extern kernels and certain types of fallback kernels.
- Support both `extern_kernels.*` and `aten.*`.
- Support multi-output kernels like `torch.topk`.
- Graphs with multiple inputs/outputs.
- Training i.e. calling `Tensor.backward()` in a compiled function.
- Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.
Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
- Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
- Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
- CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.
# Out-of-tree compilers
With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.
```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen
class MyCustomBackend(WrapperFxCodegen):
def compile_graph(self, gm):
# Add 1 to the graph's outputs
def compiled_fn(*args):
return [x + 1 for x in gm.graph.forward(*args)]
return compiled_fn
```
# Example FX graphs
This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.
Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
return (buf0,)
```
Here's a more complicated graph that calls a `torch.addmm` extern kernel.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
%buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
%addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
%delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
return (buf2,)
```
Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
%buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
%buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
%delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
%triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
return (buf1, buf2)
```
Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
%buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
return (buf0_view_buf0_0,)
```
Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
%s6 : [num_users=0] = placeholder[target=s6]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
return buf0
```
Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
%s10 : [num_users=0] = placeholder[target=s10]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
return buf0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.
Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI
Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:
<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />
TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />
We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.
Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.
Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI
Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:
<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />
TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />
We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.
Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
Add precedence to the infix printing done by sympy_str.
Without this change sympy_str will print the same string for both `a+b*(c+d)` and `(a+b)*(c+d)`.
While there I also cleaned up the printing for `-a` and `a - b`.
Added some tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151920
Approved by: https://github.com/jansel
This PR fixes two things.
The first problem is that in the vLLM style standalone_compile is
called from within a custom torch.compile backend. If there already is a
FakeTensorMode (which there is), we shouldn't create a new
FakeTensorMode with the same shape_env, instead we should just reuse the
same FakeTensorMode.
The second thing is that compile_fx can mutate the passed in gm, so we
deepcopy (since standalone_compile should be standalone)
Test Plan:
- new test
- updated old tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151502
Approved by: https://github.com/oulgen
ghstack dependencies: #151501, #151551
Summary: In most instances, this action would take ~33% of the total run time, which means that our benchmark would previously differ from the end results by a lot.
Test Plan:
We can compare the benchmark results for
```
CUDA_VISIBLE_DEVICES=4,5 buck run mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100a //caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark -- --model-snapshot-id=672308665_0 --lower-backend=AOT_INDUCTOR --node-replacement-dict="{'torch.nn.Linear':{'(autotune)': 'fp8_float_model_dynamic_quantization_rowwise'}}" --trace-aot-inductor-module=True --disable-acc-tracer=False --batch-size=1024
```
before and after the diff, and notice that on average, the benchmark results decrease by ~0.1ms per iteration, which is more closely aligned with the lowered modules.
Differential Revision: D72469845
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150696
Approved by: https://github.com/frank-wei
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.
```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.
```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.
```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519