# why
- make it easier to integrate into lookup table later
# what
- current version generates templates on the fly and uses them
to generate a single choice
- lookup table and performance model work best when there is a
stable set of templates (with predictable names) and those
are then parametrized
- this change makes it so that there is a single DecomposeK template
with a stable name, and the k split is the only parametrization we do
# testing
```
python3 -bb -m pytest test/inductor/test_max_autotune.py::TestMaxAutotune::test_max_autotune_decompose_k_dynamic_False_bfloat16_sizes1 -v
```
Differential Revision: [D80670913](https://our.internmc.facebook.com/intern/diff/D80670913)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161026
Approved by: https://github.com/PaulZhang12, https://github.com/jansel
Differential Revision: D76514984
Fix subgraph as a choice for when a symbolic shape is inputted as an expression, i.e. 256 * s0, which typically happens in the backwards pass. The current logic assumes that all symbolic shapes are single inputs, i.e. standalone s0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156185
Approved by: https://github.com/masnesral
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.
Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI
Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:
<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />
TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />
We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.
Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.
Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI
Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:
<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />
TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />
We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.
Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
Add the option for providing a Subgraph as an autotuning choice in Inductor. This is crucial for implementing the split-k optimization for GEMMs by decomposing a mm -> bmm. https://github.com/pytorch/pytorch/pull/150654 uses these changes to add decomposeK as a default autotuning choice for aten.mm in Inductor.
Using https://github.com/pytorch/pytorch/pull/150654 and a simple script:
```
import torch
def f(a, b):
return torch.matmul(a, b)
def decompose_func(a_in, b_in):
M, K = a_in.shape
K, N = b_in.shape
# TODO: Ideally we want to autotune over this parameter
kPartitions = 256
assert K % kPartitions == 0, "K must be divisible by Kmini"
B = K // kPartitions
a_reshaped = a_in.reshape(M, B, kPartitions).transpose(
0, 1
) # Shape: (B, M, kPartitions)
b_reshaped = b_in.reshape(B, kPartitions, N) # Shape: (B, kPartitions, N)
result = torch.bmm(a_reshaped, b_reshaped) # Shape: (B, M, N)
return result.sum(dim=0).to(torch.float16) # Sum over B dimension, Shape: (M, N)
for k in [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768]:
a = torch.randn(32, k, dtype=torch.float16, device="cuda", requires_grad=True)
b = torch.randn(k, 32, dtype=torch.float16, device="cuda", requires_grad=True)
compiled_res = torch.compile(f, dynamic=False)(a, b)
decompose_res = decompose_func(a, b)
print(f"Compiled mm result close to aten: {torch.allclose(f(a, b), compiled_res, atol=1e-5, rtol=0.5)}")
print(f"Compiled mm result close to decompose: {torch.allclose(decompose_res, compiled_res, atol=1e-5, rtol=0.5)}")
```
we are able to autotune the decomposeK optimization to aten and the traditional Triton templates in Inductor. DecomposeK is faster than aten by about ~10% on average and > 4x speedup over the best Triton templates on an H100 machine, e.g.:
```
AUTOTUNE mm(32x28672, 28672x32)
decompose_k_mm 0.0126 ms 100.0%
mm 0.0144 ms 87.5%
triton_mm_69 0.0579 ms 21.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_75 0.0677 ms 18.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
triton_mm_76 0.0850 ms 14.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_68 0.1444 ms 8.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_72 0.1546 ms 8.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_74 0.1819 ms 6.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
triton_mm_67 0.1917 ms 6.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_73 0.2766 ms 4.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
```
https://pastebin.com/g3FMaauT is the generated code from Inductor containing the subgraph decomposition for aten.mm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150653
Approved by: https://github.com/eellison