[inductor] move max-autotune logic inside V.choices.get_mm_configs (#161344)

# why

- heuristics providers know decide whether to (or which choices to add)
  in the max-autotune case
- enables an eventual override point to gracefully fallback to the
  standard behavior

# what

- max-autotune is determined inside V.choices.get_mm_configs
  because it's mm only right now, we can just do
  `config.max_autotune or config.max_autotune_gemm`
  a TODO indicates that this can change in the future when this
  expands to more templates

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```

Differential Revision: [D81520573](https://our.internmc.facebook.com/intern/diff/D81520573)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161344
Approved by: https://github.com/jansel
ghstack dependencies: #162075, #161340, #161341, #161342, #161343
This commit is contained in:
Ruben Rodriguez Buchillon
2025-09-05 07:38:42 -07:00
committed by PyTorch MergeBot
parent a301dc3b60
commit 031d79cb51
11 changed files with 174 additions and 34 deletions

View File

@ -31,8 +31,11 @@ from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import ( from torch._inductor.select_algorithm import (
add_feedback_saver, add_feedback_saver,
add_preprocessing_fn,
AlgorithmSelectorCache, AlgorithmSelectorCache,
clear_feedback_savers, clear_feedback_savers,
clear_preprocessing_fns,
ExternKernelCaller,
TritonTemplate, TritonTemplate,
TritonTemplateCaller, TritonTemplateCaller,
) )
@ -1900,6 +1903,76 @@ class TestMaxAutotune(TestCase):
counters["inductor"]["select_algorithm_num_precompilation_exceptions"], 0 counters["inductor"]["select_algorithm_num_precompilation_exceptions"], 0
) )
@parametrize("op", ("mm", "addmm", "bmm", "baddbmm", "mm_plus_mm"))
@parametrize("max_autotune", (False, True))
@config.patch(
{"test_configs.max_mm_configs": 4, "max_autotune_gemm_backends": "ATEN,TRITON"}
)
def test_autotune_gemm_choice_validation(self, op, max_autotune):
def generate_inputs_and_func(op_name):
# Base config with just x and w
base_inputs = [
torch.randn(128, 256, device=GPU_TYPE),
torch.randn(256, 128, device=GPU_TYPE),
]
func = torch.mm
if op_name == "mm":
# default
pass
elif op_name == "addmm":
# Add bias for addmm
base_inputs = [torch.randn(128, device=GPU_TYPE)] + base_inputs
func = torch.addmm
elif op_name in ["bmm", "baddbmm"]:
# Override for batch dimensions
base_inputs[0] = torch.randn(4, 128, 256, device=GPU_TYPE)
base_inputs[1] = torch.randn(4, 256, 128, device=GPU_TYPE)
func = torch.bmm
if op_name == "baddbmm":
# Add batch bias
base_inputs = [
torch.torch.randn(4, 128, 128, device=GPU_TYPE)
] + base_inputs
func = torch.baddbmm
elif op_name == "mm_plus_mm":
# Add second matrix pair
base_inputs += [
torch.randn(128, 256, device=GPU_TYPE),
torch.randn(256, 128, device=GPU_TYPE),
]
def mmpmm(x, w, x2, w2):
return torch.mm(x, w) + torch.mm(x2, w2)
func = mmpmm
else:
raise ValueError(f"Unsupported op: {op_name}")
return base_inputs, func
choice_types_seen = set()
def choice_validator(choices):
for choice in choices:
choice_types_seen.add(type(choice))
return choices
inputs, fn = generate_inputs_and_func(op)
add_preprocessing_fn(choice_validator)
try:
with config.patch({"max_autotune": max_autotune}):
compiled_fn = torch.compile(fn, dynamic=False)
compiled_fn(*inputs)
if max_autotune:
self.assertIn(ExternKernelCaller, choice_types_seen)
self.assertIn(TritonTemplateCaller, choice_types_seen)
else:
self.assertIn(ExternKernelCaller, choice_types_seen)
self.assertNotIn(TritonTemplateCaller, choice_types_seen)
finally:
clear_preprocessing_fns()
class TestMaxAutotunePrecompile(TestCase): class TestMaxAutotunePrecompile(TestCase):
def test_precompilation_threads(self): def test_precompilation_threads(self):

View File

@ -207,7 +207,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
**extra_kwargs, **extra_kwargs,
) )
if use_triton_template(layout): if use_triton_template(layout, check_max_autotune=False):
# TODO: add out_dtype support for Triton Template # TODO: add out_dtype support for Triton Template
assert out_dtype is None, "out_dtype is not supported for Triton" assert out_dtype is None, "out_dtype is not supported for Triton"
@ -288,7 +288,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
**extra_kwargs, **extra_kwargs,
) )
if use_triton_template(layout): if use_triton_template(layout, check_max_autotune=False):
for kwargs, extra_kwargs in V.choices.get_mm_configs( for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout, layout,

View File

@ -762,7 +762,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
) )
static_shape, is_nonzero = _is_static_problem(layout) static_shape, is_nonzero = _is_static_problem(layout)
if is_nonzero and use_triton_template(layout): if is_nonzero and use_triton_template(layout, check_max_autotune=False):
# Get template params using the new unified function # Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs( for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template, "mm" kernel_inputs, layout, mm_template, "mm"
@ -941,7 +941,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True
) )
if is_nonzero and use_triton_template(layout, enable_int32=True): if is_nonzero and use_triton_template(
layout, enable_int32=True, check_max_autotune=False
):
for kwargs, extra_kwargs in V.choices.get_mm_configs( for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template, name kernel_inputs, layout, mm_template, name
): ):
@ -1035,7 +1037,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
**extra_kwargs, **extra_kwargs,
) )
if is_nonzero and use_triton_template(layout): if is_nonzero and use_triton_template(layout, check_max_autotune=False):
# all the triton templates use the extra_kwargs # all the triton templates use the extra_kwargs
# Get template params using the new unified function # Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs( for kwargs, extra_kwargs in V.choices.get_mm_configs(
@ -1248,7 +1250,9 @@ def tuned_scaled_mm(
_, is_nonzero = _is_static_problem(layout) _, is_nonzero = _is_static_problem(layout)
if is_nonzero and use_triton_template(layout, enable_float8=True): if is_nonzero and use_triton_template(
layout, enable_float8=True, check_max_autotune=False
):
overriders = dict(USE_FAST_ACCUM=use_fast_accum) overriders = dict(USE_FAST_ACCUM=use_fast_accum)
# TODO (paulzhan): There is no template that exists for bias and TMA # TODO (paulzhan): There is no template that exists for bias and TMA
# Don't run tma template currently if bias exists # Don't run tma template currently if bias exists

View File

@ -165,7 +165,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
**extra_kwargs, **extra_kwargs,
) )
if use_triton_template(layout1): if use_triton_template(layout1, check_max_autotune=False):
# Get template params using the new unified function # Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs( for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout1, mm_plus_mm_template, "mm_plus_mm" kernel_inputs, layout1, mm_plus_mm_template, "mm_plus_mm"

View File

@ -8,6 +8,7 @@ from ..kernel.bmm import aten_baddbmm, aten_bmm, aten_bmm_dtype
from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm, aten_mm from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm, aten_mm
from ..kernel.mm_plus_mm import aten_mm_plus_mm from ..kernel.mm_plus_mm import aten_mm_plus_mm
from .base import TemplateConfigHeuristics from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
if TYPE_CHECKING: if TYPE_CHECKING:
@ -37,7 +38,7 @@ class ATenConfigHeuristics(TemplateConfigHeuristics):
If you want to use this with an ATen choice that has kwargs, just subclass If you want to use this with an ATen choice that has kwargs, just subclass
""" """
def get_template_configs( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout, layout: Layout,
@ -68,17 +69,15 @@ class ATenAddMMConfigHeuristics(ATenConfigHeuristics):
@register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm") @register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm")
class ATenBiasAddMMConfigHeuristics(ATenAddMMConfigHeuristics): class ATenBiasAddMMConfigHeuristics(
def get_template_configs( ATenAddMMConfigHeuristics, GemmMaxAutotuneTemplateConfigHeuristics
):
def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout, layout: Layout,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
if not (inductor_config.max_autotune or inductor_config.max_autotune_gemm):
# NOTE: this preserves the original logic that if there is not max-autotune
# then we skip bias_addmm
return
nodes = kernel_inputs.nodes() nodes = kernel_inputs.nodes()
# for addmm, bias is the first input # for addmm, bias is the first input
bias = nodes[0] bias = nodes[0]

View File

@ -11,6 +11,20 @@ if TYPE_CHECKING:
class TemplateConfigHeuristics: class TemplateConfigHeuristics:
"""Base class for generating sets of configs for an associated template."""
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
"""
hookup to check whether the configs are right to run at all e.g. you can check
max-autotune specific to your heuristic here or other things
If this returns False, get_template_configs will yield no configs
Args:
inputs: KernelInputs
layout: Layout
"""
return True
def get_template_configs( def get_template_configs(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
@ -19,10 +33,30 @@ class TemplateConfigHeuristics:
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
Get template configs for the given inputs. Get template configs for the given inputs.
Prefer to override the _get_template_configs_impl method
to leverage things like should_run
"""
if not self.should_run(kernel_inputs, layout):
return
yield from self._get_template_configs_impl(
kernel_inputs,
layout,
op_name,
)
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Get template configs for the given inputs.
This is the main entry point for template-specific logic. This is the main entry point for template-specific logic.
""" """
# NOTE: not an abstract class, because that clashed below for the mixin # base implementation yields no entries
# functionality. Can be adjusted, but not a high priority
yield from [] yield from []
def get_extra_kwargs( def get_extra_kwargs(

View File

@ -12,6 +12,7 @@ from ..kernel.mm import (
from ..kernel_inputs import KernelInputs, MMKernelInputs from ..kernel_inputs import KernelInputs, MMKernelInputs
from ..utils import use_contiguous from ..utils import use_contiguous
from .base import TemplateConfigHeuristics from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
from .registry import register_template_heuristic from .registry import register_template_heuristic
@ -41,8 +42,8 @@ class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics):
register=torch.version.hip is not None, register=torch.version.hip is not None,
op_name="addmm", op_name="addmm",
) )
class ContiguousMMHeuristics(TemplateConfigHeuristics): class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
def get_template_configs( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout, layout: Layout,
@ -54,7 +55,6 @@ class ContiguousMMHeuristics(TemplateConfigHeuristics):
assert isinstance(kernel_inputs, MMKernelInputs), ( assert isinstance(kernel_inputs, MMKernelInputs), (
f"{self.__class__.__name__} requires MMKernelInputs" f"{self.__class__.__name__} requires MMKernelInputs"
) )
# Check for unbacked symbols - if found, yield nothing # Check for unbacked symbols - if found, yield nothing
unbacked_symbols = any( unbacked_symbols = any(
len(get_free_symbols(itr, unbacked_only=True)) > 0 len(get_free_symbols(itr, unbacked_only=True)) > 0

View File

@ -12,6 +12,7 @@ from ..kernel_inputs import KernelInputs, MMKernelInputs
from ..utils import get_k_splits from ..utils import get_k_splits
from ..virtualized import V from ..virtualized import V
from .base import TemplateConfigHeuristics from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
from .registry import register_template_heuristic from .registry import register_template_heuristic
@ -38,8 +39,8 @@ class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
# TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia) # TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia)
# by either adding specific register_template_heuristic tags, or setting the # by either adding specific register_template_heuristic tags, or setting the
# device to None (enabled on all devices) # device to None (enabled on all devices)
class DecomposeKConfigHeuristics(TemplateConfigHeuristics): class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
def get_template_configs( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout, layout: Layout,

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from .. import config as inductor_config
from .base import TemplateConfigHeuristics
if TYPE_CHECKING:
from ..ir import Layout
from ..kernel_inputs import KernelInputs
class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics):
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
"""
simple base override for GEMM family templates that run only in max-autotune
"""
return inductor_config.max_autotune or inductor_config.max_autotune_gemm

View File

@ -32,7 +32,7 @@ from ..utils import (
using_b200, using_b200,
) )
from ..virtualized import V from ..virtualized import V
from .base import TemplateConfigHeuristics from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
from .registry import register_template_heuristic from .registry import register_template_heuristic
@ -1417,7 +1417,7 @@ class MTIAConfigHeuristic(BaseConfigHeuristic):
# Template-specific mixin classes # Template-specific mixin classes
class MMTemplateConfigMixin(TemplateConfigHeuristics): class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
""" """
Mixin class that converts config lists to template kwargs. Mixin class that converts config lists to template kwargs.
This handles the logic that was previously in choices.get_mm_configs. This handles the logic that was previously in choices.get_mm_configs.
@ -1448,7 +1448,7 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
else: else:
return self.get_mm_configs() return self.get_mm_configs()
def get_template_configs( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any, layout: Any,
@ -1466,6 +1466,7 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}") raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}")
if not self._valid(kernel_inputs): if not self._valid(kernel_inputs):
return return
# Extract M, N, K from kernel_inputs # Extract M, N, K from kernel_inputs
m, n, k = kernel_inputs.mnk_symbolic() m, n, k = kernel_inputs.mnk_symbolic()
@ -1593,7 +1594,7 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
This inherits from MMTemplateConfigMixin and overrides config generation. This inherits from MMTemplateConfigMixin and overrides config generation.
""" """
def get_template_configs( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any, layout: Any,
@ -1614,8 +1615,10 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(), "TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(),
} }
# Get base template configs from superclass # Get base template configs from superclass
for template_kwargs in super().get_template_configs( for template_kwargs in super()._get_template_configs_impl(
kernel_inputs, layout, op_name kernel_inputs,
layout,
op_name,
): ):
yield {**template_kwargs, **tma_opts} yield {**template_kwargs, **tma_opts}
@ -1661,7 +1664,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
nodes, mat1_idx=kernel_inputs._mat1_idx, mat2_idx=kernel_inputs._mat2_idx nodes, mat1_idx=kernel_inputs._mat1_idx, mat2_idx=kernel_inputs._mat2_idx
) )
def get_template_configs( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any, layout: Any,
@ -1713,7 +1716,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
return return
# Get base template configs from superclass # Get base template configs from superclass
for template_kwargs in super().get_template_configs( for template_kwargs in super()._get_template_configs_impl(
kernel_inputs, layout, op_name kernel_inputs, layout, op_name
): ):
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options) # Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
@ -1769,7 +1772,7 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
This inherits from BaseScaledMMConfigMixin and adds TMA-specific options. This inherits from BaseScaledMMConfigMixin and adds TMA-specific options.
""" """
def get_template_configs( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any, layout: Any,
@ -1779,8 +1782,10 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
Generate scaled TMA template configs with both scaled MM and TMA-specific options. Generate scaled TMA template configs with both scaled MM and TMA-specific options.
""" """
# Get base scaled MM template configs from superclass # Get base scaled MM template configs from superclass
for template_kwargs in super().get_template_configs( for template_kwargs in super()._get_template_configs_impl(
kernel_inputs, layout, op_name kernel_inputs,
layout,
op_name,
): ):
# Add TMA-specific options for device TMA scaled MM # Add TMA-specific options for device TMA scaled MM
template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE

View File

@ -1634,7 +1634,11 @@ def _use_conv_autotune_backend(backend: str) -> bool:
def use_triton_template( def use_triton_template(
layout: Layout, *, enable_int32: bool = False, enable_float8: bool = False layout: Layout,
*,
enable_int32: bool = False,
enable_float8: bool = False,
check_max_autotune: bool = True,
) -> bool: ) -> bool:
from .codegen.common import BackendFeature, has_backend_feature from .codegen.common import BackendFeature, has_backend_feature
@ -1651,7 +1655,8 @@ def use_triton_template(
) )
or (layout.device.type == "cpu" and layout.dtype in layout_dtypes) or (layout.device.type == "cpu" and layout.dtype in layout_dtypes)
) )
and (config.max_autotune or config.max_autotune_gemm) # some callers handle max-autotune checking externally
and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune)
and _use_autotune_backend("TRITON") and _use_autotune_backend("TRITON")
and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
) )