Contiguous subgraph decomposition (#161241)

## Summary

Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.

## Background

On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
  ))
```
is a lot slower than:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
  ))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.

## Data

I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
  addmm_16448x512x2048: +0.14%
  addmm_128x2048x2048: +0.01%
  addmm_128x768x1000: +0.75%
  addmm_12672x3072x768: +1.08%
  addmm_512x768x32000: +0.62%
  addmm_12608x384x384: +0.00%
  addmm_4160x1024x4096: +0.90%
  addmm_16x768x2: +0.56%
  addmm_12608x3072x768: +0.09%
  addmm_64x4096x1000: +2.77%
  addmm_256x1024x512: +1.99%
  addmm_30x256x256: +1.12%
  addmm_100480x128x384: +0.91%
  addmm_6400x2048x512: +0.25%
  addmm_61568x1024x256: +0.08%
  addmm_1x768x768: +0.93%
  addmm_12544x384x384: +0.19%
  addmm_128x512x1000: +0.77%
  addmm_2048x128x128: +1.32%
  addmm_128x3072x1000: +0.24%
  addmm_7936x512x2048: +0.07%
  addmm_8192x512x2048: +0.33%
  addmm_64x1024x1000: +1.43%
  addmm_128x2304x1000: +0.01%
  addmm_32768x256x2: +0.75%
  addmm_64x384x1152: +0.79%
  addmm_64x640x1000: +0.01%
  addmm_100480x128x128: +0.87%
  addmm_1152x3072x768: +1.13%
  addmm_8192x256x2048: +1.40%
  addmm_4096x128x768: +0.01%
  addmm_128x2560x1000: +0.01%
  addmm_12544x2048x512: +0.43%
  addmm_200704x24x96: +0.14%
  addmm_8448x512x2048: +0.96%
  addmm_50176x256x1024: +0.62%
  addmm_4160x4096x1024: +0.22%
  addmm_4096x768x768: +0.32%
  addmm_220x2048x512: +0.56%
  addmm_8x2048x1000: +1.12%
  addmm_256x197951x512: +26.99%
  addmm_401536x64x192: +0.60%
  addmm_2040x2048x512: +0.47%
  addmm_512x1024x256: +1.32%
  addmm_128x4096x1000: +1.67%
  addmm_12672x768x768: +0.34%
  addmm_128x368x1000: +0.77%
  addmm_96x1280x1000: +0.01%
  addmm_12544x512x2048: +0.41%
  addmm_6272x320x1280: +0.76%
  addmm_12544x3072x768: +0.09%
  addmm_64x384x1000: +0.39%
mm improvements when best:
  mm_200704x128x512: +1.29%
  mm_663552x16x16: +0.80%
  mm_4096x768x768: +0.51%
  mm_131072x64x31: +0.24%
  mm_12544x1152x384: +0.11%
  mm_128x2048x2: +0.46%
  mm_262144x16x23: +0.62%
  mm_50176x576x192: +0.37%
  mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%

Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)

Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%

```

## Results

The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```

It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
  addmm_128x16384x128: +0.18%
  addmm_128x262144x256: +38.24%
  addmm_128x200000x512: +14.76%
  addmm_256x800000x128: +0.06%
  addmm_131072x128x256: +0.27%
  addmm_128x256x131072: +0.25%
  addmm_2048x200000x64: +12.45%
mm improvements when best:
  mm_128x16384x128: +0.18%
  mm_128x262144x256: +38.05%
  mm_128x200000x512: +9.47%
  mm_256x800000x128: +0.99%
  mm_512x6400000x256: +3.17%
  mm_524288x64x64: +0.29%
  mm_2048x200000x64: +11.19%
  mm_8192x1000000x256: +34.14%
  mm_128x4096x100000: +0.40%
  mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%

Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%

```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.

Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866

## Test Plan:
New unit tests.

Differential Revision: D80771648

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161241
Approved by: https://github.com/eellison
This commit is contained in:
Gabriel Ferns
2025-09-03 17:02:59 +00:00
committed by PyTorch MergeBot
parent eb18d32bda
commit d647185037
5 changed files with 333 additions and 0 deletions

View File

@ -1279,6 +1279,189 @@ class TestMaxAutotune(TestCase):
code[0]
)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32))
@parametrize("sizes", ((64, 128, 256), (128, 256, 512), (256, 512, 1024)))
@config.patch(
max_autotune=True,
)
def test_max_autotune_contiguous_transform_mm(self, sizes, dtype):
"""
Test the contiguous subgraph transform with A * transpose(B) pattern.
This transform makes the second matrix contiguous before the matmul.
"""
M, N, K = sizes
def mm_transpose(a, b):
return a @ b.transpose(0, 1)
a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
b = torch.randn(N, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
# Compute fp64 baseline
a_fp64 = a.to(torch.float64)
b_fp64 = b.to(torch.float64)
expected_fp64 = mm_transpose(a_fp64, b_fp64)
# Force only contiguous choice to test the transform
with (
mock.patch("torch._inductor.kernel.mm.use_contiguous") as contiguous_mock,
):
contiguous_mock.return_value = True
compiled_func = torch.compile(mm_transpose)
out, code = run_and_get_code(compiled_func, a, b)
# Verify correctness against fp64 baseline
torch.testing.assert_close(
out, expected_fp64.to(dtype), atol=1e-2, rtol=1e-2
)
# Check that contiguous transform was used
FileCheck().check("contiguous_mm").run(code[0])
@unittest.skipIf(not torch.version.hip, "ROCM only")
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32))
@parametrize("sizes", ((64, 128, 256), (128, 256, 512), (256, 512, 1024)))
@config.patch(
max_autotune=True,
)
def test_max_autotune_contiguous_transform_addmm(self, sizes, dtype):
"""
Test the contiguous subgraph transform for addmm with non-contiguous second matrix.
"""
M, N, K = sizes
def addmm_transpose(inp, a, b):
return torch.addmm(inp, a, b.transpose(0, 1))
inp = torch.randn(M, N, dtype=dtype, device=GPU_TYPE, requires_grad=True)
a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
b = torch.randn(N, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
# Compute fp64 baseline
inp_fp64 = inp.to(torch.float64)
a_fp64 = a.to(torch.float64)
b_fp64 = b.to(torch.float64)
expected_fp64 = addmm_transpose(inp_fp64, a_fp64, b_fp64)
# Force contiguous choice to test the transform
with (
mock.patch("torch._inductor.kernel.mm.use_contiguous") as contiguous_mock,
):
contiguous_mock.return_value = True
compiled_func = torch.compile(addmm_transpose)
out, code = run_and_get_code(compiled_func, inp, a, b)
# Verify correctness against fp64 baseline
torch.testing.assert_close(
out, expected_fp64.to(dtype), atol=1e-2, rtol=1e-2
)
# Check that contiguous transform was used
FileCheck().check("contiguous_addmm").run(code[0])
@unittest.skipIf(not torch.version.hip, "ROCM only")
@parametrize("dynamic", (False, True))
def test_max_autotune_contiguous_transform_non_contiguous_second_matrix(
self, dynamic
):
"""
Test that contiguous transform is only applied when the second matrix is non-contiguous.
"""
M, N, K = 64, 128, 64
def mm(a, b):
return a @ b
a = torch.randn(M, K, dtype=torch.float32, device=GPU_TYPE)
b_contiguous = torch.randn(K, N, dtype=torch.float32, device=GPU_TYPE)
b_non_contiguous = torch.randn(
N, K, dtype=torch.float32, device=GPU_TYPE
).transpose(0, 1)
# Compute fp64 baselines without max_autotune (since fp64 doesn't work with max_autotune=True)
a_fp64 = a.to(torch.float64)
b_contiguous_fp64 = b_contiguous.to(torch.float64)
b_non_contiguous_fp64 = b_non_contiguous.to(torch.float64)
expected1_fp64 = mm(a_fp64, b_contiguous_fp64)
expected2_fp64 = mm(a_fp64, b_non_contiguous_fp64)
with config.patch(
max_autotune=True,
):
# Test with contiguous second matrix - should not use contiguous transform
compiled_func_contiguous = torch.compile(mm, dynamic=dynamic)
out1, code1 = run_and_get_code(compiled_func_contiguous, a, b_contiguous)
# Should not contain contiguous transform
try:
FileCheck().check("contiguous_mm").run(code1[0])
self.fail(
"Contiguous transform should not be used for contiguous matrices"
)
except RuntimeError:
pass # Expected - contiguous transform should not be used
# Test with non-contiguous second matrix - should use contiguous transform
with (
mock.patch(
"torch._inductor.kernel.mm.use_contiguous"
) as contiguous_mock,
):
contiguous_mock.return_value = True
compiled_func_non_contiguous = torch.compile(mm, dynamic=dynamic)
out2, code2 = run_and_get_code(
compiled_func_non_contiguous, a, b_non_contiguous
)
# Should contain contiguous transform
FileCheck().check("contiguous_mm").run(code2[0])
# Verify correctness against fp64 baselines
torch.testing.assert_close(
out1, expected1_fp64.to(torch.float32), atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(
out2, expected2_fp64.to(torch.float32), atol=1e-2, rtol=1e-2
)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
)
def test_max_autotune_contiguous_transform_with_epilogue(self):
"""
Test contiguous transform with epilogue operations like relu.
"""
M, N, K = 128, 256, 512
def mm_transpose_relu(a, b):
return (a @ b.transpose(0, 1)).relu()
a = torch.randn(M, K, dtype=torch.float32, device=GPU_TYPE)
b = torch.randn(N, K, dtype=torch.float32, device=GPU_TYPE)
# Force contiguous transform
with (
mock.patch("torch._inductor.kernel.mm.use_contiguous") as contiguous_mock,
):
contiguous_mock.return_value = True
compiled_func = torch.compile(mm_transpose_relu)
out, code = run_and_get_code(compiled_func, a, b)
# Verify correctness
expected = mm_transpose_relu(a, b)
torch.testing.assert_close(out, expected, atol=1e-2, rtol=1e-2)
# Check that contiguous transform was used
FileCheck().check("contiguous_mm").run(code[0])
def test_triton_template_generated_code_cache_key(self):
generate_and_load_args = len(
inspect.signature(

View File

@ -243,6 +243,9 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
@L.register_lowering(aten.baddbmm)
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
"""
Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
"""
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)

View File

@ -42,6 +42,7 @@ from ..utils import (
use_aten_gemm_kernels,
use_ck_gemm_template,
use_ck_tile_gemm_template,
use_contiguous,
use_cpp_gemm_template,
use_cutlass_template,
use_decompose_k_choice,
@ -675,6 +676,56 @@ class DecomposeKSugraphTemplate(SubgraphTemplate):
decompose_k_subgraph_template = DecomposeKSugraphTemplate()
class ContiguousTemplate(SubgraphTemplate):
def __init__(self, name: str, description: str, fn: Any):
self.name = name
self.description = description
self.fn = fn
super().__init__(
name=name,
)
def generate( # type: ignore[override]
self,
input_nodes: list[Buffer],
layout: Layout,
) -> SubgraphChoiceCaller:
from torch._dispatch.python import enable_python_dispatcher
from ..decomposition import select_decomp_table
with enable_python_dispatcher():
decompositions = select_decomp_table()
fn = make_fx(
self.fn,
decompositions,
)
return super().generate(
name=self.name,
input_nodes=input_nodes,
layout=layout,
make_fx_graph=fn,
description=self.description,
)
def contiguous_mm(a, b):
return torch.mm(a, b.contiguous())
def contiguous_addmm(inp, a, b):
return torch.addmm(inp, a, b.contiguous())
mm_contiguous_subgraph_template = ContiguousTemplate(
"contiguous_mm", "contiguous mm", contiguous_mm
)
addmm_contiguous_subgraph_template = ContiguousTemplate(
"contiguous_addmm", "contiguous addmm", contiguous_addmm
)
@register_lowering(aten.mm, type_promotion_kind=None)
def tuned_mm(mat1, mat2, *, layout=None):
"""
@ -746,6 +797,12 @@ def tuned_mm(mat1, mat2, *, layout=None):
**kwargs,
**extra_kwargs,
)
if not mat2.get_layout().is_contiguous() and use_contiguous(m, n, k):
mm_contiguous_subgraph_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
layout=layout,
)
if (
is_nonzero
@ -891,6 +948,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
@register_lowering(aten.addmm, type_promotion_kind=None)
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
"""
Lowering for autotuning aten.addmm with different backends (Aten, Triton, CUTLASS, etc.)
"""
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
static_shape, is_nonzero = _is_static_problem(layout)
@ -1005,6 +1065,13 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
**extra_kwargs,
)
if not mat2.get_layout().is_contiguous() and use_contiguous(m, n, k):
addmm_contiguous_subgraph_template.maybe_append_choice(
choices,
input_nodes=(inp_expanded, mat1, mat2),
layout=layout,
)
if (
is_nonzero
and use_cutlass_template(layout, m, n, k)

View File

@ -0,0 +1,56 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING
import torch
from ..ir import get_free_symbols
from ..kernel_inputs import KernelInputs, MMKernelInputs
from .base import TemplateConfigHeuristics
from .registry import register_template_heuristic
if TYPE_CHECKING:
from collections.abc import Generator
from ..ir import Layout
@register_template_heuristic("contiguous_mm", None, op_name="mm")
@register_template_heuristic("contiguous_addmm", None, op_name="addmm")
class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics):
"""empty heuristics to skip contiguous mm on not hip"""
@register_template_heuristic(
"contiguous_mm", "hip", register=torch.version.hip is not None, op_name="mm"
)
@register_template_heuristic(
"contiguous_addmm", "hip", register=torch.version.hip is not None, op_name="addmm"
)
class ContiguousMMHeuristics(TemplateConfigHeuristics):
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Get all the valid k_splits for the given m, n, k.
"""
assert isinstance(kernel_inputs, MMKernelInputs), (
f"{self.__class__.__name__} requires MMKernelInputs"
)
# Check for unbacked symbols - if found, yield nothing
unbacked_symbols = any(
len(get_free_symbols(itr, unbacked_only=True)) > 0
for itr in (
*kernel_inputs.shapes_symbolic(),
*kernel_inputs.strides_symbolic(),
)
)
if unbacked_symbols:
return
yield {}

View File

@ -1813,6 +1813,30 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
)
@functools.cache
def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
"""
Check if we should use the contiguous subgraph transform.
This transform makes the second matrix contiguous before the matmul.
"""
decompose_k_threshold = config.triton.decompose_k_threshold
# Similar conditions to decompose_k but for contiguous transform
from torch._inductor.virtualized import V
return (
bool(torch.version.hip) # Only relevant on AMD
and V.graph.sizevars.statically_known_true(
sympy.And(
sympy.Ge(k, decompose_k_threshold * m),
sympy.Ge(k, decompose_k_threshold * n),
)
)
and not V.graph.aot_mode
and not V.graph.cpp_wrapper
)
@functools.cache
def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
# To limit compile time