mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Contiguous subgraph decomposition (#161241)
## Summary Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently. ## Background On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower. For example: ``` args[0]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1])) )) args[1]: TensorBox(StorageBox( InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176])) )) ``` is a lot slower than: ``` args[0]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1])) )) args[1]: TensorBox(StorageBox( InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1])) )) ``` This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels. ## Data I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups: ``` Parsed 420 unique shapes from benchmark output addmm improvements when best: addmm_16448x512x2048: +0.14% addmm_128x2048x2048: +0.01% addmm_128x768x1000: +0.75% addmm_12672x3072x768: +1.08% addmm_512x768x32000: +0.62% addmm_12608x384x384: +0.00% addmm_4160x1024x4096: +0.90% addmm_16x768x2: +0.56% addmm_12608x3072x768: +0.09% addmm_64x4096x1000: +2.77% addmm_256x1024x512: +1.99% addmm_30x256x256: +1.12% addmm_100480x128x384: +0.91% addmm_6400x2048x512: +0.25% addmm_61568x1024x256: +0.08% addmm_1x768x768: +0.93% addmm_12544x384x384: +0.19% addmm_128x512x1000: +0.77% addmm_2048x128x128: +1.32% addmm_128x3072x1000: +0.24% addmm_7936x512x2048: +0.07% addmm_8192x512x2048: +0.33% addmm_64x1024x1000: +1.43% addmm_128x2304x1000: +0.01% addmm_32768x256x2: +0.75% addmm_64x384x1152: +0.79% addmm_64x640x1000: +0.01% addmm_100480x128x128: +0.87% addmm_1152x3072x768: +1.13% addmm_8192x256x2048: +1.40% addmm_4096x128x768: +0.01% addmm_128x2560x1000: +0.01% addmm_12544x2048x512: +0.43% addmm_200704x24x96: +0.14% addmm_8448x512x2048: +0.96% addmm_50176x256x1024: +0.62% addmm_4160x4096x1024: +0.22% addmm_4096x768x768: +0.32% addmm_220x2048x512: +0.56% addmm_8x2048x1000: +1.12% addmm_256x197951x512: +26.99% addmm_401536x64x192: +0.60% addmm_2040x2048x512: +0.47% addmm_512x1024x256: +1.32% addmm_128x4096x1000: +1.67% addmm_12672x768x768: +0.34% addmm_128x368x1000: +0.77% addmm_96x1280x1000: +0.01% addmm_12544x512x2048: +0.41% addmm_6272x320x1280: +0.76% addmm_12544x3072x768: +0.09% addmm_64x384x1000: +0.39% mm improvements when best: mm_200704x128x512: +1.29% mm_663552x16x16: +0.80% mm_4096x768x768: +0.51% mm_131072x64x31: +0.24% mm_12544x1152x384: +0.11% mm_128x2048x2: +0.46% mm_262144x16x23: +0.62% mm_50176x576x192: +0.37% mm_131072x16x31: +0.26% ================================================================================ BENCHMARK ANALYSIS RESULTS ================================================================================ Operation: addmm ---------------------------------------- Total shapes analyzed: 247 Average Subgraph placement: 3.38 Median Subgraph placement: 2.0 Subgraph is best choice: 52/247 shapes (21.1%) Average improvement when best: 1.15% Median improvement when best: 0.58% Largest improvement when best: +26.99% Operation: bmm ---------------------------------------- Total shapes analyzed: 85 Average Subgraph placement: 24.00 Median Subgraph placement: 21.0 Subgraph is best choice: 0/85 shapes (0.0%) Average improvement when best: N/A (never best) Median improvement when best: N/A (never best) Largest improvement when best: N/A (never best) Operation: mm ---------------------------------------- Total shapes analyzed: 88 Average Subgraph placement: 15.08 Median Subgraph placement: 4.0 Subgraph is best choice: 9/88 shapes (10.2%) Average improvement when best: 0.52% Median improvement when best: 0.46% Largest improvement when best: +1.29% ``` ## Results The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune: ``` addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766 addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205 addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541 addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436 addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702 addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834 addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105 ... ``` Compared to the non-transposed autotune: ``` addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421 addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145 addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803 addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684 addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246 addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547 addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895 addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916 ``` It seems to perform really well for high values of `K` vs `N` and `M`. Testing this hypothesis with some custom shapes: ``` Parsed 64 unique shapes from benchmark output addmm improvements when best: addmm_128x16384x128: +0.18% addmm_128x262144x256: +38.24% addmm_128x200000x512: +14.76% addmm_256x800000x128: +0.06% addmm_131072x128x256: +0.27% addmm_128x256x131072: +0.25% addmm_2048x200000x64: +12.45% mm improvements when best: mm_128x16384x128: +0.18% mm_128x262144x256: +38.05% mm_128x200000x512: +9.47% mm_256x800000x128: +0.99% mm_512x6400000x256: +3.17% mm_524288x64x64: +0.29% mm_2048x200000x64: +11.19% mm_8192x1000000x256: +34.14% mm_128x4096x100000: +0.40% mm_128x3072x150000: +0.27% ================================================================================ BENCHMARK ANALYSIS RESULTS ================================================================================ Operation: addmm ---------------------------------------- Total shapes analyzed: 33 Average Subgraph placement: 4.39 Median Subgraph placement: 2.0 Subgraph is best choice: 7/33 shapes (21.2%) Average improvement when best: 9.46% Median improvement when best: 0.27% Largest improvement when best: +38.24% Operation: mm ---------------------------------------- Total shapes analyzed: 30 Average Subgraph placement: 7.63 Median Subgraph placement: 2.0 Subgraph is best choice: 10/30 shapes (33.3%) Average improvement when best: 9.81% Median improvement when best: 2.08% Largest improvement when best: +38.05% ``` ## Conclusion Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes. Data gathering scripts: https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866 ## Test Plan: New unit tests. Differential Revision: D80771648 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161241 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
eb18d32bda
commit
d647185037
@ -1279,6 +1279,189 @@ class TestMaxAutotune(TestCase):
|
||||
code[0]
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.version.hip, "ROCM only")
|
||||
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32))
|
||||
@parametrize("sizes", ((64, 128, 256), (128, 256, 512), (256, 512, 1024)))
|
||||
@config.patch(
|
||||
max_autotune=True,
|
||||
)
|
||||
def test_max_autotune_contiguous_transform_mm(self, sizes, dtype):
|
||||
"""
|
||||
Test the contiguous subgraph transform with A * transpose(B) pattern.
|
||||
This transform makes the second matrix contiguous before the matmul.
|
||||
"""
|
||||
M, N, K = sizes
|
||||
|
||||
def mm_transpose(a, b):
|
||||
return a @ b.transpose(0, 1)
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
b = torch.randn(N, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
|
||||
# Compute fp64 baseline
|
||||
a_fp64 = a.to(torch.float64)
|
||||
b_fp64 = b.to(torch.float64)
|
||||
expected_fp64 = mm_transpose(a_fp64, b_fp64)
|
||||
|
||||
# Force only contiguous choice to test the transform
|
||||
with (
|
||||
mock.patch("torch._inductor.kernel.mm.use_contiguous") as contiguous_mock,
|
||||
):
|
||||
contiguous_mock.return_value = True
|
||||
|
||||
compiled_func = torch.compile(mm_transpose)
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
|
||||
# Verify correctness against fp64 baseline
|
||||
torch.testing.assert_close(
|
||||
out, expected_fp64.to(dtype), atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
# Check that contiguous transform was used
|
||||
FileCheck().check("contiguous_mm").run(code[0])
|
||||
|
||||
@unittest.skipIf(not torch.version.hip, "ROCM only")
|
||||
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32))
|
||||
@parametrize("sizes", ((64, 128, 256), (128, 256, 512), (256, 512, 1024)))
|
||||
@config.patch(
|
||||
max_autotune=True,
|
||||
)
|
||||
def test_max_autotune_contiguous_transform_addmm(self, sizes, dtype):
|
||||
"""
|
||||
Test the contiguous subgraph transform for addmm with non-contiguous second matrix.
|
||||
"""
|
||||
M, N, K = sizes
|
||||
|
||||
def addmm_transpose(inp, a, b):
|
||||
return torch.addmm(inp, a, b.transpose(0, 1))
|
||||
|
||||
inp = torch.randn(M, N, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
b = torch.randn(N, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
|
||||
# Compute fp64 baseline
|
||||
inp_fp64 = inp.to(torch.float64)
|
||||
a_fp64 = a.to(torch.float64)
|
||||
b_fp64 = b.to(torch.float64)
|
||||
expected_fp64 = addmm_transpose(inp_fp64, a_fp64, b_fp64)
|
||||
|
||||
# Force contiguous choice to test the transform
|
||||
with (
|
||||
mock.patch("torch._inductor.kernel.mm.use_contiguous") as contiguous_mock,
|
||||
):
|
||||
contiguous_mock.return_value = True
|
||||
|
||||
compiled_func = torch.compile(addmm_transpose)
|
||||
out, code = run_and_get_code(compiled_func, inp, a, b)
|
||||
|
||||
# Verify correctness against fp64 baseline
|
||||
torch.testing.assert_close(
|
||||
out, expected_fp64.to(dtype), atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
# Check that contiguous transform was used
|
||||
FileCheck().check("contiguous_addmm").run(code[0])
|
||||
|
||||
@unittest.skipIf(not torch.version.hip, "ROCM only")
|
||||
@parametrize("dynamic", (False, True))
|
||||
def test_max_autotune_contiguous_transform_non_contiguous_second_matrix(
|
||||
self, dynamic
|
||||
):
|
||||
"""
|
||||
Test that contiguous transform is only applied when the second matrix is non-contiguous.
|
||||
"""
|
||||
M, N, K = 64, 128, 64
|
||||
|
||||
def mm(a, b):
|
||||
return a @ b
|
||||
|
||||
a = torch.randn(M, K, dtype=torch.float32, device=GPU_TYPE)
|
||||
b_contiguous = torch.randn(K, N, dtype=torch.float32, device=GPU_TYPE)
|
||||
b_non_contiguous = torch.randn(
|
||||
N, K, dtype=torch.float32, device=GPU_TYPE
|
||||
).transpose(0, 1)
|
||||
|
||||
# Compute fp64 baselines without max_autotune (since fp64 doesn't work with max_autotune=True)
|
||||
a_fp64 = a.to(torch.float64)
|
||||
b_contiguous_fp64 = b_contiguous.to(torch.float64)
|
||||
b_non_contiguous_fp64 = b_non_contiguous.to(torch.float64)
|
||||
|
||||
expected1_fp64 = mm(a_fp64, b_contiguous_fp64)
|
||||
expected2_fp64 = mm(a_fp64, b_non_contiguous_fp64)
|
||||
|
||||
with config.patch(
|
||||
max_autotune=True,
|
||||
):
|
||||
# Test with contiguous second matrix - should not use contiguous transform
|
||||
compiled_func_contiguous = torch.compile(mm, dynamic=dynamic)
|
||||
out1, code1 = run_and_get_code(compiled_func_contiguous, a, b_contiguous)
|
||||
|
||||
# Should not contain contiguous transform
|
||||
try:
|
||||
FileCheck().check("contiguous_mm").run(code1[0])
|
||||
self.fail(
|
||||
"Contiguous transform should not be used for contiguous matrices"
|
||||
)
|
||||
except RuntimeError:
|
||||
pass # Expected - contiguous transform should not be used
|
||||
|
||||
# Test with non-contiguous second matrix - should use contiguous transform
|
||||
with (
|
||||
mock.patch(
|
||||
"torch._inductor.kernel.mm.use_contiguous"
|
||||
) as contiguous_mock,
|
||||
):
|
||||
contiguous_mock.return_value = True
|
||||
|
||||
compiled_func_non_contiguous = torch.compile(mm, dynamic=dynamic)
|
||||
out2, code2 = run_and_get_code(
|
||||
compiled_func_non_contiguous, a, b_non_contiguous
|
||||
)
|
||||
|
||||
# Should contain contiguous transform
|
||||
FileCheck().check("contiguous_mm").run(code2[0])
|
||||
|
||||
# Verify correctness against fp64 baselines
|
||||
torch.testing.assert_close(
|
||||
out1, expected1_fp64.to(torch.float32), atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
out2, expected2_fp64.to(torch.float32), atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.version.hip, "ROCM only")
|
||||
@config.patch(
|
||||
max_autotune=True,
|
||||
max_autotune_gemm_backends="TRITON",
|
||||
)
|
||||
def test_max_autotune_contiguous_transform_with_epilogue(self):
|
||||
"""
|
||||
Test contiguous transform with epilogue operations like relu.
|
||||
"""
|
||||
M, N, K = 128, 256, 512
|
||||
|
||||
def mm_transpose_relu(a, b):
|
||||
return (a @ b.transpose(0, 1)).relu()
|
||||
|
||||
a = torch.randn(M, K, dtype=torch.float32, device=GPU_TYPE)
|
||||
b = torch.randn(N, K, dtype=torch.float32, device=GPU_TYPE)
|
||||
|
||||
# Force contiguous transform
|
||||
with (
|
||||
mock.patch("torch._inductor.kernel.mm.use_contiguous") as contiguous_mock,
|
||||
):
|
||||
contiguous_mock.return_value = True
|
||||
|
||||
compiled_func = torch.compile(mm_transpose_relu)
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
|
||||
# Verify correctness
|
||||
expected = mm_transpose_relu(a, b)
|
||||
torch.testing.assert_close(out, expected, atol=1e-2, rtol=1e-2)
|
||||
|
||||
# Check that contiguous transform was used
|
||||
FileCheck().check("contiguous_mm").run(code[0])
|
||||
|
||||
def test_triton_template_generated_code_cache_key(self):
|
||||
generate_and_load_args = len(
|
||||
inspect.signature(
|
||||
|
@ -243,6 +243,9 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
|
||||
@L.register_lowering(aten.baddbmm)
|
||||
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
"""
|
||||
Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
|
||||
"""
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
|
||||
|
||||
|
@ -42,6 +42,7 @@ from ..utils import (
|
||||
use_aten_gemm_kernels,
|
||||
use_ck_gemm_template,
|
||||
use_ck_tile_gemm_template,
|
||||
use_contiguous,
|
||||
use_cpp_gemm_template,
|
||||
use_cutlass_template,
|
||||
use_decompose_k_choice,
|
||||
@ -675,6 +676,56 @@ class DecomposeKSugraphTemplate(SubgraphTemplate):
|
||||
decompose_k_subgraph_template = DecomposeKSugraphTemplate()
|
||||
|
||||
|
||||
class ContiguousTemplate(SubgraphTemplate):
|
||||
def __init__(self, name: str, description: str, fn: Any):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.fn = fn
|
||||
super().__init__(
|
||||
name=name,
|
||||
)
|
||||
|
||||
def generate( # type: ignore[override]
|
||||
self,
|
||||
input_nodes: list[Buffer],
|
||||
layout: Layout,
|
||||
) -> SubgraphChoiceCaller:
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
from ..decomposition import select_decomp_table
|
||||
|
||||
with enable_python_dispatcher():
|
||||
decompositions = select_decomp_table()
|
||||
fn = make_fx(
|
||||
self.fn,
|
||||
decompositions,
|
||||
)
|
||||
|
||||
return super().generate(
|
||||
name=self.name,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
make_fx_graph=fn,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
def contiguous_mm(a, b):
|
||||
return torch.mm(a, b.contiguous())
|
||||
|
||||
|
||||
def contiguous_addmm(inp, a, b):
|
||||
return torch.addmm(inp, a, b.contiguous())
|
||||
|
||||
|
||||
mm_contiguous_subgraph_template = ContiguousTemplate(
|
||||
"contiguous_mm", "contiguous mm", contiguous_mm
|
||||
)
|
||||
addmm_contiguous_subgraph_template = ContiguousTemplate(
|
||||
"contiguous_addmm", "contiguous addmm", contiguous_addmm
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten.mm, type_promotion_kind=None)
|
||||
def tuned_mm(mat1, mat2, *, layout=None):
|
||||
"""
|
||||
@ -746,6 +797,12 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
**kwargs,
|
||||
**extra_kwargs,
|
||||
)
|
||||
if not mat2.get_layout().is_contiguous() and use_contiguous(m, n, k):
|
||||
mm_contiguous_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
)
|
||||
|
||||
if (
|
||||
is_nonzero
|
||||
@ -891,6 +948,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||
|
||||
@register_lowering(aten.addmm, type_promotion_kind=None)
|
||||
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
"""
|
||||
Lowering for autotuning aten.addmm with different backends (Aten, Triton, CUTLASS, etc.)
|
||||
"""
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
@ -1005,6 +1065,13 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if not mat2.get_layout().is_contiguous() and use_contiguous(m, n, k):
|
||||
addmm_contiguous_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(inp_expanded, mat1, mat2),
|
||||
layout=layout,
|
||||
)
|
||||
|
||||
if (
|
||||
is_nonzero
|
||||
and use_cutlass_template(layout, m, n, k)
|
||||
|
56
torch/_inductor/template_heuristics/contiguous_mm.py
Normal file
56
torch/_inductor/template_heuristics/contiguous_mm.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ..ir import get_free_symbols
|
||||
from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from .base import TemplateConfigHeuristics
|
||||
from .registry import register_template_heuristic
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from ..ir import Layout
|
||||
|
||||
|
||||
@register_template_heuristic("contiguous_mm", None, op_name="mm")
|
||||
@register_template_heuristic("contiguous_addmm", None, op_name="addmm")
|
||||
class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics):
|
||||
"""empty heuristics to skip contiguous mm on not hip"""
|
||||
|
||||
|
||||
@register_template_heuristic(
|
||||
"contiguous_mm", "hip", register=torch.version.hip is not None, op_name="mm"
|
||||
)
|
||||
@register_template_heuristic(
|
||||
"contiguous_addmm", "hip", register=torch.version.hip is not None, op_name="addmm"
|
||||
)
|
||||
class ContiguousMMHeuristics(TemplateConfigHeuristics):
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Layout,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Get all the valid k_splits for the given m, n, k.
|
||||
"""
|
||||
assert isinstance(kernel_inputs, MMKernelInputs), (
|
||||
f"{self.__class__.__name__} requires MMKernelInputs"
|
||||
)
|
||||
|
||||
# Check for unbacked symbols - if found, yield nothing
|
||||
unbacked_symbols = any(
|
||||
len(get_free_symbols(itr, unbacked_only=True)) > 0
|
||||
for itr in (
|
||||
*kernel_inputs.shapes_symbolic(),
|
||||
*kernel_inputs.strides_symbolic(),
|
||||
)
|
||||
)
|
||||
if unbacked_symbols:
|
||||
return
|
||||
|
||||
yield {}
|
@ -1813,6 +1813,30 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
|
||||
"""
|
||||
Check if we should use the contiguous subgraph transform.
|
||||
This transform makes the second matrix contiguous before the matmul.
|
||||
"""
|
||||
decompose_k_threshold = config.triton.decompose_k_threshold
|
||||
|
||||
# Similar conditions to decompose_k but for contiguous transform
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
return (
|
||||
bool(torch.version.hip) # Only relevant on AMD
|
||||
and V.graph.sizevars.statically_known_true(
|
||||
sympy.And(
|
||||
sympy.Ge(k, decompose_k_threshold * m),
|
||||
sympy.Ge(k, decompose_k_threshold * n),
|
||||
)
|
||||
)
|
||||
and not V.graph.aot_mode
|
||||
and not V.graph.cpp_wrapper
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
|
||||
# To limit compile time
|
||||
|
Reference in New Issue
Block a user