[inductor] move max-autotune logic inside V.choices.get_mm_configs (#161344)

# why

- heuristics providers know decide whether to (or which choices to add)
  in the max-autotune case
- enables an eventual override point to gracefully fallback to the
  standard behavior

# what

- max-autotune is determined inside V.choices.get_mm_configs
  because it's mm only right now, we can just do
  `config.max_autotune or config.max_autotune_gemm`
  a TODO indicates that this can change in the future when this
  expands to more templates

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```

Differential Revision: [D81520573](https://our.internmc.facebook.com/intern/diff/D81520573)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161344
Approved by: https://github.com/jansel
ghstack dependencies: #162075, #161340, #161341, #161342, #161343
This commit is contained in:
Ruben Rodriguez Buchillon
2025-09-05 07:38:42 -07:00
committed by PyTorch MergeBot
parent a301dc3b60
commit 031d79cb51
11 changed files with 174 additions and 34 deletions

View File

@ -31,8 +31,11 @@ from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import (
add_feedback_saver,
add_preprocessing_fn,
AlgorithmSelectorCache,
clear_feedback_savers,
clear_preprocessing_fns,
ExternKernelCaller,
TritonTemplate,
TritonTemplateCaller,
)
@ -1900,6 +1903,76 @@ class TestMaxAutotune(TestCase):
counters["inductor"]["select_algorithm_num_precompilation_exceptions"], 0
)
@parametrize("op", ("mm", "addmm", "bmm", "baddbmm", "mm_plus_mm"))
@parametrize("max_autotune", (False, True))
@config.patch(
{"test_configs.max_mm_configs": 4, "max_autotune_gemm_backends": "ATEN,TRITON"}
)
def test_autotune_gemm_choice_validation(self, op, max_autotune):
def generate_inputs_and_func(op_name):
# Base config with just x and w
base_inputs = [
torch.randn(128, 256, device=GPU_TYPE),
torch.randn(256, 128, device=GPU_TYPE),
]
func = torch.mm
if op_name == "mm":
# default
pass
elif op_name == "addmm":
# Add bias for addmm
base_inputs = [torch.randn(128, device=GPU_TYPE)] + base_inputs
func = torch.addmm
elif op_name in ["bmm", "baddbmm"]:
# Override for batch dimensions
base_inputs[0] = torch.randn(4, 128, 256, device=GPU_TYPE)
base_inputs[1] = torch.randn(4, 256, 128, device=GPU_TYPE)
func = torch.bmm
if op_name == "baddbmm":
# Add batch bias
base_inputs = [
torch.torch.randn(4, 128, 128, device=GPU_TYPE)
] + base_inputs
func = torch.baddbmm
elif op_name == "mm_plus_mm":
# Add second matrix pair
base_inputs += [
torch.randn(128, 256, device=GPU_TYPE),
torch.randn(256, 128, device=GPU_TYPE),
]
def mmpmm(x, w, x2, w2):
return torch.mm(x, w) + torch.mm(x2, w2)
func = mmpmm
else:
raise ValueError(f"Unsupported op: {op_name}")
return base_inputs, func
choice_types_seen = set()
def choice_validator(choices):
for choice in choices:
choice_types_seen.add(type(choice))
return choices
inputs, fn = generate_inputs_and_func(op)
add_preprocessing_fn(choice_validator)
try:
with config.patch({"max_autotune": max_autotune}):
compiled_fn = torch.compile(fn, dynamic=False)
compiled_fn(*inputs)
if max_autotune:
self.assertIn(ExternKernelCaller, choice_types_seen)
self.assertIn(TritonTemplateCaller, choice_types_seen)
else:
self.assertIn(ExternKernelCaller, choice_types_seen)
self.assertNotIn(TritonTemplateCaller, choice_types_seen)
finally:
clear_preprocessing_fns()
class TestMaxAutotunePrecompile(TestCase):
def test_precompilation_threads(self):

View File

@ -207,7 +207,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
**extra_kwargs,
)
if use_triton_template(layout):
if use_triton_template(layout, check_max_autotune=False):
# TODO: add out_dtype support for Triton Template
assert out_dtype is None, "out_dtype is not supported for Triton"
@ -288,7 +288,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
**extra_kwargs,
)
if use_triton_template(layout):
if use_triton_template(layout, check_max_autotune=False):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,

View File

@ -762,7 +762,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
)
static_shape, is_nonzero = _is_static_problem(layout)
if is_nonzero and use_triton_template(layout):
if is_nonzero and use_triton_template(layout, check_max_autotune=False):
# Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template, "mm"
@ -941,7 +941,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True
)
if is_nonzero and use_triton_template(layout, enable_int32=True):
if is_nonzero and use_triton_template(
layout, enable_int32=True, check_max_autotune=False
):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template, name
):
@ -1035,7 +1037,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
**extra_kwargs,
)
if is_nonzero and use_triton_template(layout):
if is_nonzero and use_triton_template(layout, check_max_autotune=False):
# all the triton templates use the extra_kwargs
# Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs(
@ -1248,7 +1250,9 @@ def tuned_scaled_mm(
_, is_nonzero = _is_static_problem(layout)
if is_nonzero and use_triton_template(layout, enable_float8=True):
if is_nonzero and use_triton_template(
layout, enable_float8=True, check_max_autotune=False
):
overriders = dict(USE_FAST_ACCUM=use_fast_accum)
# TODO (paulzhan): There is no template that exists for bias and TMA
# Don't run tma template currently if bias exists

View File

@ -165,7 +165,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
**extra_kwargs,
)
if use_triton_template(layout1):
if use_triton_template(layout1, check_max_autotune=False):
# Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout1, mm_plus_mm_template, "mm_plus_mm"

View File

@ -8,6 +8,7 @@ from ..kernel.bmm import aten_baddbmm, aten_bmm, aten_bmm_dtype
from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm, aten_mm
from ..kernel.mm_plus_mm import aten_mm_plus_mm
from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
if TYPE_CHECKING:
@ -37,7 +38,7 @@ class ATenConfigHeuristics(TemplateConfigHeuristics):
If you want to use this with an ATen choice that has kwargs, just subclass
"""
def get_template_configs(
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
@ -68,17 +69,15 @@ class ATenAddMMConfigHeuristics(ATenConfigHeuristics):
@register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm")
class ATenBiasAddMMConfigHeuristics(ATenAddMMConfigHeuristics):
def get_template_configs(
class ATenBiasAddMMConfigHeuristics(
ATenAddMMConfigHeuristics, GemmMaxAutotuneTemplateConfigHeuristics
):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
if not (inductor_config.max_autotune or inductor_config.max_autotune_gemm):
# NOTE: this preserves the original logic that if there is not max-autotune
# then we skip bias_addmm
return
nodes = kernel_inputs.nodes()
# for addmm, bias is the first input
bias = nodes[0]

View File

@ -11,6 +11,20 @@ if TYPE_CHECKING:
class TemplateConfigHeuristics:
"""Base class for generating sets of configs for an associated template."""
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
"""
hookup to check whether the configs are right to run at all e.g. you can check
max-autotune specific to your heuristic here or other things
If this returns False, get_template_configs will yield no configs
Args:
inputs: KernelInputs
layout: Layout
"""
return True
def get_template_configs(
self,
kernel_inputs: KernelInputs,
@ -19,10 +33,30 @@ class TemplateConfigHeuristics:
) -> Generator[dict[str, Any], None, None]:
"""
Get template configs for the given inputs.
Prefer to override the _get_template_configs_impl method
to leverage things like should_run
"""
if not self.should_run(kernel_inputs, layout):
return
yield from self._get_template_configs_impl(
kernel_inputs,
layout,
op_name,
)
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Get template configs for the given inputs.
This is the main entry point for template-specific logic.
"""
# NOTE: not an abstract class, because that clashed below for the mixin
# functionality. Can be adjusted, but not a high priority
# base implementation yields no entries
yield from []
def get_extra_kwargs(

View File

@ -12,6 +12,7 @@ from ..kernel.mm import (
from ..kernel_inputs import KernelInputs, MMKernelInputs
from ..utils import use_contiguous
from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
from .registry import register_template_heuristic
@ -41,8 +42,8 @@ class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics):
register=torch.version.hip is not None,
op_name="addmm",
)
class ContiguousMMHeuristics(TemplateConfigHeuristics):
def get_template_configs(
class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
@ -54,7 +55,6 @@ class ContiguousMMHeuristics(TemplateConfigHeuristics):
assert isinstance(kernel_inputs, MMKernelInputs), (
f"{self.__class__.__name__} requires MMKernelInputs"
)
# Check for unbacked symbols - if found, yield nothing
unbacked_symbols = any(
len(get_free_symbols(itr, unbacked_only=True)) > 0

View File

@ -12,6 +12,7 @@ from ..kernel_inputs import KernelInputs, MMKernelInputs
from ..utils import get_k_splits
from ..virtualized import V
from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
from .registry import register_template_heuristic
@ -38,8 +39,8 @@ class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
# TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia)
# by either adding specific register_template_heuristic tags, or setting the
# device to None (enabled on all devices)
class DecomposeKConfigHeuristics(TemplateConfigHeuristics):
def get_template_configs(
class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from .. import config as inductor_config
from .base import TemplateConfigHeuristics
if TYPE_CHECKING:
from ..ir import Layout
from ..kernel_inputs import KernelInputs
class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics):
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
"""
simple base override for GEMM family templates that run only in max-autotune
"""
return inductor_config.max_autotune or inductor_config.max_autotune_gemm

View File

@ -32,7 +32,7 @@ from ..utils import (
using_b200,
)
from ..virtualized import V
from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
from .registry import register_template_heuristic
@ -1417,7 +1417,7 @@ class MTIAConfigHeuristic(BaseConfigHeuristic):
# Template-specific mixin classes
class MMTemplateConfigMixin(TemplateConfigHeuristics):
class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
"""
Mixin class that converts config lists to template kwargs.
This handles the logic that was previously in choices.get_mm_configs.
@ -1448,7 +1448,7 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
else:
return self.get_mm_configs()
def get_template_configs(
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Any,
@ -1466,6 +1466,7 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}")
if not self._valid(kernel_inputs):
return
# Extract M, N, K from kernel_inputs
m, n, k = kernel_inputs.mnk_symbolic()
@ -1593,7 +1594,7 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
This inherits from MMTemplateConfigMixin and overrides config generation.
"""
def get_template_configs(
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Any,
@ -1614,8 +1615,10 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(),
}
# Get base template configs from superclass
for template_kwargs in super().get_template_configs(
kernel_inputs, layout, op_name
for template_kwargs in super()._get_template_configs_impl(
kernel_inputs,
layout,
op_name,
):
yield {**template_kwargs, **tma_opts}
@ -1661,7 +1664,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
nodes, mat1_idx=kernel_inputs._mat1_idx, mat2_idx=kernel_inputs._mat2_idx
)
def get_template_configs(
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Any,
@ -1713,7 +1716,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
return
# Get base template configs from superclass
for template_kwargs in super().get_template_configs(
for template_kwargs in super()._get_template_configs_impl(
kernel_inputs, layout, op_name
):
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
@ -1769,7 +1772,7 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
This inherits from BaseScaledMMConfigMixin and adds TMA-specific options.
"""
def get_template_configs(
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Any,
@ -1779,8 +1782,10 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
Generate scaled TMA template configs with both scaled MM and TMA-specific options.
"""
# Get base scaled MM template configs from superclass
for template_kwargs in super().get_template_configs(
kernel_inputs, layout, op_name
for template_kwargs in super()._get_template_configs_impl(
kernel_inputs,
layout,
op_name,
):
# Add TMA-specific options for device TMA scaled MM
template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE

View File

@ -1634,7 +1634,11 @@ def _use_conv_autotune_backend(backend: str) -> bool:
def use_triton_template(
layout: Layout, *, enable_int32: bool = False, enable_float8: bool = False
layout: Layout,
*,
enable_int32: bool = False,
enable_float8: bool = False,
check_max_autotune: bool = True,
) -> bool:
from .codegen.common import BackendFeature, has_backend_feature
@ -1651,7 +1655,8 @@ def use_triton_template(
)
or (layout.device.type == "cpu" and layout.dtype in layout_dtypes)
)
and (config.max_autotune or config.max_autotune_gemm)
# some callers handle max-autotune checking externally
and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune)
and _use_autotune_backend("TRITON")
and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
)