mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[Inductor] Add decomposeK as an autotuning choice for mm (#150654)
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`. Followups: * decompose_k does not currently support epilogue fusion, which will take some work to enable * Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM * Add for addmm * Enable for Inference and AOTI Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously: <img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" /> TorchInductor Benchmark Dashboard: <img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" /> We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over. Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
5e9682719f
commit
84aa0985fb
@ -1194,6 +1194,95 @@ class TestMaxAutotune(TestCase):
|
||||
actual = (opt_f(x, y), x.grad, linear.weight.grad, linear.bias.grad)
|
||||
assert same(expect, actual, tol=1e-2), f"ref:\n{expect}\nact:\n{actual}"
|
||||
|
||||
@skipIfXpu
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
|
||||
)
|
||||
@parametrize("dynamic", (True, False))
|
||||
@parametrize("sizes", ((32, 32, 32768), (64, 128, 200000), (64, 64, 177147)))
|
||||
@config.patch(
|
||||
max_autotune=True,
|
||||
max_autotune_gemm_backends="TRITON",
|
||||
autotune_fallback_to_aten=False,
|
||||
)
|
||||
def test_max_autotune_decompose_k(self, sizes, dynamic):
|
||||
M, N, K = sizes
|
||||
|
||||
a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True)
|
||||
b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True)
|
||||
|
||||
possible_splits = range(2, min(K // M, K // N) + 1)
|
||||
|
||||
divisors = {split for split in possible_splits if K % split == 0}
|
||||
|
||||
def check_divisors(code):
|
||||
for kernel in code:
|
||||
if "decompose_k" in kernel:
|
||||
divisor_found = False
|
||||
for divisor in divisors:
|
||||
if f"{divisor}_split" in kernel:
|
||||
divisor_found = True
|
||||
break
|
||||
|
||||
self.assertTrue(
|
||||
divisor_found,
|
||||
f"Could not find a split in {divisors} in {kernel}",
|
||||
)
|
||||
|
||||
compiled_func = torch.compile(lambda a, b: a @ b, dynamic=dynamic)
|
||||
# We assume with the large k dim relative to m, n, decompose_k will be most performant
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
|
||||
if dynamic:
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
|
||||
"decompose_k"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(out, a @ b, atol=1e-2, rtol=1e-2)
|
||||
|
||||
# Test adding epilogue also equivalent to eager
|
||||
compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic)
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
if dynamic:
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
|
||||
"decompose_k"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
compiled_func(a, b), (a @ b).relu(), atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
# Test adding reinterpret view before subgraph
|
||||
a = a.transpose(0, 1)
|
||||
compiled_func = torch.compile(
|
||||
lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic
|
||||
)
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
if dynamic:
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
|
||||
"decompose_k"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
compiled_func(a, b),
|
||||
(a.transpose(0, 1) @ b).relu(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestMaxAutotuneRemoteCache(TestCase):
|
||||
|
||||
@ -521,8 +521,8 @@ class PadMMTest(TestCase):
|
||||
return x @ y
|
||||
|
||||
args = [
|
||||
torch.randn(2**4, 2**14 - 1, device="cuda", dtype=torch.float16),
|
||||
torch.randn(2**14 - 1, 2**4, device="cuda", dtype=torch.float16),
|
||||
torch.randn(2**4, 2**8 - 1, device="cuda", dtype=torch.float16),
|
||||
torch.randn(2**8 - 1, 2**4, device="cuda", dtype=torch.float16),
|
||||
]
|
||||
|
||||
counters.clear()
|
||||
@ -534,6 +534,7 @@ class PadMMTest(TestCase):
|
||||
ret, code = run_and_get_code(opt_fn, *args)
|
||||
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
|
||||
|
||||
code = [c for c in code if "decompose_k" not in c]
|
||||
# The mm kernel should use a template (because we set max_autotune_gemm_backends = TRITON).
|
||||
# Its name should contain `mm` because `mm` was the original aten op where the mm came from.
|
||||
FileCheck().check("def triton_tem_fused_mm").run(code[0])
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
@ -13,6 +14,7 @@ from torch._inductor.select_algorithm import (
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import skipIfXpu, TEST_WITH_ROCM
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||
|
||||
|
||||
@ -26,6 +28,8 @@ class TestSubgraphChoice(TestCase):
|
||||
layout=FixedLayout(torch.device(f"{GPU_TYPE}:0"), dtype=dtype, size=shape),
|
||||
)
|
||||
|
||||
@skipIfXpu
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
|
||||
def test_subgraph_decompose_k(self):
|
||||
from torch._inductor.kernel.mm import aten_mm
|
||||
from torch._inductor.kernel.mm_common import mm_args
|
||||
@ -46,7 +50,7 @@ class TestSubgraphChoice(TestCase):
|
||||
B = k // kPartitions
|
||||
a_reshaped = torch.permute(a.reshape(m, B, kPartitions), (1, 0, 2))
|
||||
b_reshaped = b.reshape(B, kPartitions, n)
|
||||
result = torch.bmm(a_reshaped, b_reshaped)
|
||||
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
|
||||
result_fp32 = result.to(torch.float32)
|
||||
reduced_buf = torch.sum(result_fp32, 0)
|
||||
return reduced_buf.to(a.dtype)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
@ -63,6 +64,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
|
||||
bm_graph_lowering.run(*self.example_inputs)
|
||||
mod = bm_graph_lowering.compile_to_module()
|
||||
bm_func = mod.call
|
||||
|
||||
bm_func([*args])
|
||||
|
||||
return benchmarker.benchmark_gpu(lambda: bm_func([*args]))
|
||||
@ -76,6 +78,11 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
|
||||
for arg in self.example_inputs
|
||||
if isinstance(arg, torch.Tensor)
|
||||
],
|
||||
*[
|
||||
str(arg.stride())
|
||||
for arg in self.example_inputs
|
||||
if isinstance(arg, torch.Tensor)
|
||||
],
|
||||
str(self.gm.graph),
|
||||
]
|
||||
)
|
||||
@ -111,6 +118,8 @@ class SubgraphTemplate(KernelTemplate):
|
||||
optimal implementations for complex operations.
|
||||
"""
|
||||
|
||||
index_counter = itertools.count()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
@ -123,7 +132,7 @@ class SubgraphTemplate(KernelTemplate):
|
||||
name: The name of this template
|
||||
graph: The FX graph
|
||||
"""
|
||||
self.name = name
|
||||
self.name = f"{name}_{next(SubgraphTemplate.index_counter)}"
|
||||
self.make_fx_graph = make_fx_graph
|
||||
|
||||
def generate( # type: ignore[override]
|
||||
|
||||
@ -2644,7 +2644,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
if (
|
||||
name in V.graph.removed_buffers
|
||||
or name in self.allocated
|
||||
or isinstance(buffer, ir.DonatedBuffer)
|
||||
or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer))
|
||||
):
|
||||
return
|
||||
self.allocated.add(name)
|
||||
|
||||
@ -265,6 +265,7 @@ def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor:
|
||||
def bmm(
|
||||
self: torch.Tensor,
|
||||
batch2: torch.Tensor,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO: Re-enable for mps once our reductions are performant enough
|
||||
# (https://github.com/pytorch/pytorch/issues/150121)
|
||||
@ -291,6 +292,7 @@ def addmm(
|
||||
self: torch.Tensor,
|
||||
mat1: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
beta: torch.types.Number = 1,
|
||||
alpha: torch.types.Number = 1,
|
||||
) -> torch.Tensor:
|
||||
@ -319,6 +321,7 @@ def addmm(
|
||||
def mm(
|
||||
self: torch.Tensor,
|
||||
input2: torch.Tensor,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
# Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
|
||||
# todo: Look into why and fix it (hopefully)
|
||||
|
||||
@ -5720,6 +5720,7 @@ class ExternKernel(InputsKernel):
|
||||
return
|
||||
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
|
||||
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
|
||||
|
||||
wrapper.writeline(
|
||||
f"assert_size_stride({self.get_name()}, {size}, {stride})"
|
||||
)
|
||||
@ -6015,7 +6016,6 @@ class SubgraphBuffer(ExternKernel):
|
||||
self.subgraph = V.graph.make_subgraph(
|
||||
self.gm, self.example_inputs, subgraph_name
|
||||
)
|
||||
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
with V.set_graph_handler(self.subgraph):
|
||||
@ -6033,9 +6033,11 @@ class SubgraphBuffer(ExternKernel):
|
||||
self.graph = graph
|
||||
self.name = graph.name
|
||||
|
||||
outer_inputs = [t.codegen_reference() for t in self.inputs]
|
||||
|
||||
wrapper.codegen_subgraph_with_flattened_outputs(
|
||||
CodegenGraph(self.subgraph),
|
||||
[*[buffer.get_name() for buffer in self.inputs]],
|
||||
outer_inputs,
|
||||
[self.name],
|
||||
)
|
||||
|
||||
|
||||
@ -121,13 +121,22 @@ bmm_template = TritonTemplate(
|
||||
)
|
||||
|
||||
aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
|
||||
aten_bmm_dtype = ExternKernelChoice(
|
||||
torch.bmm,
|
||||
"at::_bmm_out_dtype_cuda",
|
||||
name="bmm_dtype",
|
||||
op_overload=aten.bmm.dtype_out,
|
||||
)
|
||||
aten_baddbmm = ExternKernelChoice(
|
||||
torch.baddbmm, "at::baddbmm_out", op_overload=aten.baddbmm.out
|
||||
)
|
||||
|
||||
|
||||
@L.register_lowering(aten.bmm)
|
||||
def tuned_bmm(mat1, mat2, *, layout=None):
|
||||
def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
"""
|
||||
Lowering for autotuning aten.bmm with different backends (Aten, Triton, CUTLASS, etc.)
|
||||
"""
|
||||
if all(x.get_device().type == "cpu" for x in [mat1, mat2]):
|
||||
# decompose to small ops when memory bound
|
||||
if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1:
|
||||
@ -165,7 +174,9 @@ def tuned_bmm(mat1, mat2, *, layout=None):
|
||||
meta_mat2 = V.graph.current_node.args[1]
|
||||
mat2 = may_require_contiguous(mat2, meta_mat2)
|
||||
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
m, n, k, layout, mat1, mat2 = mm_args(
|
||||
mat1, mat2, layout=layout, out_dtype=out_dtype
|
||||
)
|
||||
|
||||
# below is for getting an overview logging info of inductor mms
|
||||
counters["aten_mm_info"][f"aten.bmm_{m}_{n}_{k}"] += 1
|
||||
@ -179,13 +190,21 @@ def tuned_bmm(mat1, mat2, *, layout=None):
|
||||
layout,
|
||||
)
|
||||
|
||||
if out_dtype:
|
||||
assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA"
|
||||
aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype)
|
||||
else:
|
||||
aten_func = aten_bmm.bind((mat1, mat2), layout)
|
||||
|
||||
# options to tune from
|
||||
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
||||
choices = [aten_func] if use_aten_gemm_kernels() else []
|
||||
|
||||
device_type = ir.get_device_type(mat1)
|
||||
bmm_configs = V.choices.get_base_mm_configs(device_type)
|
||||
|
||||
if use_triton_template(layout):
|
||||
# TODO: add out_dtype support for Triton Template
|
||||
assert out_dtype is None, "out_dtype is not supported for Triton"
|
||||
for config in bmm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
):
|
||||
@ -200,7 +219,7 @@ def tuned_bmm(mat1, mat2, *, layout=None):
|
||||
if batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k):
|
||||
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
|
||||
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type]
|
||||
|
||||
if use_cpp_bmm_template(layout, mat1, mat2):
|
||||
from ..codegen.cpp_bmm_template import CppBmmTemplate
|
||||
|
||||
@ -3,6 +3,8 @@ import functools
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
|
||||
@ -14,12 +16,14 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
)
|
||||
from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
from .. import config as inductor_config, ir
|
||||
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
|
||||
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
||||
from ..ir import FlexibleLayout, is_triton
|
||||
from ..codegen.subgraph import SubgraphTemplate
|
||||
from ..ir import FlexibleLayout, ir_node_to_tensor, is_triton
|
||||
from ..lowering import (
|
||||
add_layout_constraint,
|
||||
constrain_to_fx_strides,
|
||||
@ -33,11 +37,13 @@ from ..select_algorithm import (
|
||||
TritonTemplate,
|
||||
)
|
||||
from ..utils import (
|
||||
get_k_splits,
|
||||
get_tma_workspace_arg,
|
||||
use_aten_gemm_kernels,
|
||||
use_ck_gemm_template,
|
||||
use_cpp_gemm_template,
|
||||
use_cutlass_template,
|
||||
use_decompose_k_choice,
|
||||
use_max_autotune,
|
||||
use_triton_template,
|
||||
use_triton_tma_template,
|
||||
@ -68,6 +74,7 @@ except ImportError:
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
mm_template = TritonTemplate(
|
||||
name="mm",
|
||||
@ -584,8 +591,25 @@ def check_supported_striding(mat_a, mat_b) -> None:
|
||||
aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
|
||||
|
||||
|
||||
def decomposeK(a, b, k_splits):
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
k = a.shape[1]
|
||||
|
||||
k_parts = k // k_splits
|
||||
B = k_splits
|
||||
a_reshaped = torch.permute(a.reshape(m, B, k_parts), (1, 0, 2))
|
||||
b_reshaped = b.reshape(B, k_parts, n)
|
||||
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
|
||||
reduced_buf = torch.sum(result, 0)
|
||||
return reduced_buf.to(a.dtype)
|
||||
|
||||
|
||||
@register_lowering(aten.mm, type_promotion_kind=None)
|
||||
def tuned_mm(mat1, mat2, *, layout=None):
|
||||
"""
|
||||
Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.)
|
||||
"""
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
device_type = ir.get_device_type(mat1)
|
||||
name = "mm"
|
||||
@ -620,7 +644,10 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
for config in mm_configs(
|
||||
m, n, k, *mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -628,9 +655,13 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
layout=layout,
|
||||
**mm_options(config, m, n, k, layout),
|
||||
)
|
||||
|
||||
if use_triton_tma_template(mat1, mat2):
|
||||
for config in persistent_mm_configs(
|
||||
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
|
||||
):
|
||||
persistent_tma_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -643,6 +674,40 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
**mm_options(config, m, n, k, layout),
|
||||
**persistent_mm_options(mat1, mat2),
|
||||
)
|
||||
# Only do split-k optimization if K is much larger than m, n and m, n are small
|
||||
if use_decompose_k_choice(m, n, k):
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
from ..decomposition import select_decomp_table
|
||||
|
||||
k_splits = get_k_splits(m, n, k)
|
||||
for k_split in k_splits:
|
||||
if not V.graph.sizevars.is_expr_static_and_true(
|
||||
sympy.Eq(sympy.Mod(k, k_split), 0)
|
||||
):
|
||||
continue
|
||||
|
||||
with enable_python_dispatcher():
|
||||
decompositions = select_decomp_table()
|
||||
|
||||
decompose_k_subgraph_template = SubgraphTemplate(
|
||||
name=f"decompose_k_mm_{k_split}_split",
|
||||
make_fx_graph=make_fx(
|
||||
functools.partial(decomposeK, k_splits=k_split),
|
||||
decompositions,
|
||||
),
|
||||
)
|
||||
|
||||
with V.fake_mode:
|
||||
mat1_tensor = ir_node_to_tensor(mat1)
|
||||
mat2_tensor = ir_node_to_tensor(mat2)
|
||||
|
||||
decompose_k_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
example_inputs=[mat1_tensor, mat2_tensor],
|
||||
)
|
||||
|
||||
if is_nonzero and use_cutlass_template(layout, m, n, k):
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
|
||||
@ -706,7 +771,6 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
|
||||
for k in inductor_config.external_matmul:
|
||||
choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout))
|
||||
|
||||
if should_fallback_to_aten(choices):
|
||||
return aten_mm.bind((mat1, mat2), aten_layout).output_node()
|
||||
|
||||
|
||||
@ -50,6 +50,7 @@ from .codegen.common import (
|
||||
WorkspaceZeroMode,
|
||||
)
|
||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||
from .codegen.subgraph import SubgraphChoiceCaller
|
||||
from .codegen.triton import (
|
||||
gen_common_triton_imports,
|
||||
texpr,
|
||||
@ -2195,7 +2196,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
def benchmark_choice_in_current_process(
|
||||
choice: ChoiceCaller, autotune_args: AutotuneArgs
|
||||
) -> float:
|
||||
is_extern = isinstance(choice, ExternKernelCaller)
|
||||
is_extern = isinstance(choice, (ExternKernelCaller, SubgraphChoiceCaller))
|
||||
benchmark_tensors = autotune_args.get_benchmark_tensors(is_extern)
|
||||
inpts, output = benchmark_tensors.unpack()
|
||||
output.zero_()
|
||||
|
||||
@ -43,6 +43,7 @@ from typing_extensions import (
|
||||
dataclass_transform,
|
||||
ParamSpec,
|
||||
Self,
|
||||
TypeAlias,
|
||||
TypeGuard,
|
||||
)
|
||||
from unittest import mock
|
||||
@ -1560,6 +1561,85 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
return res
|
||||
|
||||
|
||||
decompose_k_threshold = 32
|
||||
|
||||
# To limit compile time
|
||||
k_splits_limit = 5
|
||||
|
||||
# Hand-tuned
|
||||
default_k_splits = [16, 32, 64, 128, 256]
|
||||
|
||||
_IntLike: TypeAlias = Union[int, sympy.Expr]
|
||||
|
||||
|
||||
def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
return (
|
||||
V.graph.sizevars.is_expr_static_and_true(
|
||||
sympy.And(
|
||||
sympy.Ge(k, decompose_k_threshold * m),
|
||||
sympy.Ge(k, decompose_k_threshold * n),
|
||||
)
|
||||
)
|
||||
and not V.graph.aot_mode # TODO: Support AOTI for decomposeK
|
||||
and not V.graph.cpp_wrapper
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
|
||||
# If k is a sympy expression, we can't do any splitting
|
||||
if isinstance(k, sympy.Expr) and not k.is_number:
|
||||
return default_k_splits
|
||||
|
||||
if (isinstance(m, sympy.Expr) and not m.is_number) or (
|
||||
isinstance(n, sympy.Expr) and not n.is_number
|
||||
):
|
||||
max_k_split = 256
|
||||
else:
|
||||
max_k_split = min(k // m, k // n)
|
||||
|
||||
min_k_split = 2
|
||||
# Get all divisors of k, k has to be divisible by kPart
|
||||
divisors = sympy.divisors(k)
|
||||
|
||||
divisors = [
|
||||
divisor
|
||||
for divisor in divisors
|
||||
if divisor <= max_k_split and divisor >= min_k_split
|
||||
]
|
||||
|
||||
pow_of_2_divisors, mul_of_32_divisors, rest_of_splits = [], [], []
|
||||
|
||||
for d in divisors:
|
||||
kPart = k // d
|
||||
|
||||
# Smaller than 128 might not even fit in a single tile, BLOCK_K can be 128
|
||||
if kPart < 128:
|
||||
continue
|
||||
|
||||
# Power of 2 divisors are best performing, conform to hardware
|
||||
if (kPart & kPart - 1) == 0 and kPart >= 128:
|
||||
pow_of_2_divisors.append(d)
|
||||
# Else check if creates a multiple of 32
|
||||
elif kPart % 32 == 0:
|
||||
mul_of_32_divisors.append(d)
|
||||
# otherwise, take the smallest values
|
||||
else:
|
||||
rest_of_splits.append(d)
|
||||
|
||||
# If the # of power of 2 divisors are greater than k_splits_limit, return all
|
||||
# This should be ok for compile time, all perfect squares between 128 and min(k / m, k / n)
|
||||
# should never be a massive amount
|
||||
if len(pow_of_2_divisors) >= k_splits_limit:
|
||||
return pow_of_2_divisors
|
||||
else:
|
||||
best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
|
||||
# Otherwise, conform results to k_splits_limit
|
||||
return best_splits[:k_splits_limit]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _rocm_native_device_arch_name(device: str) -> str:
|
||||
return torch.cuda.get_device_properties(device).gcnArchName
|
||||
|
||||
@ -4298,7 +4298,7 @@ def meta_alias(self):
|
||||
return self.view(self.shape)
|
||||
|
||||
|
||||
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
|
||||
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype=None):
|
||||
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
|
||||
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
|
||||
|
||||
@ -4316,10 +4316,18 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
|
||||
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
|
||||
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
|
||||
)
|
||||
|
||||
# TODO: handle out
|
||||
|
||||
output = batch2.new_empty(output_size)
|
||||
if out_dtype:
|
||||
supported_out_dtype = (
|
||||
batch1.dtype == torch.float16 or batch1.dtype == torch.bfloat16
|
||||
) and out_dtype == torch.float32
|
||||
torch._check(
|
||||
out_dtype == batch1.dtype or supported_out_dtype,
|
||||
lambda: "out_dtype only supported for torch.float32 output with float16/bfloat16 inputs or same as input dtypes",
|
||||
)
|
||||
output = batch2.new_empty(output_size).to(out_dtype)
|
||||
else:
|
||||
# TODO: handle out
|
||||
output = batch2.new_empty(output_size)
|
||||
|
||||
if not is_bmm and self_baddbmm is not None:
|
||||
torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
|
||||
@ -4336,6 +4344,11 @@ def meta_bmm(self, mat2):
|
||||
return common_meta_baddbmm_bmm(self, mat2, True)
|
||||
|
||||
|
||||
@register_meta(aten.bmm.dtype)
|
||||
def meta_bmm_dtype(self, mat2, out_dtype):
|
||||
return common_meta_baddbmm_bmm(self, mat2, True, out_dtype=out_dtype)
|
||||
|
||||
|
||||
def div_rtn(x, y):
|
||||
q = x // y
|
||||
r = x % y
|
||||
|
||||
Reference in New Issue
Block a user