mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] move max-autotune logic inside V.choices.get_mm_configs (#161344)
# why - heuristics providers know decide whether to (or which choices to add) in the max-autotune case - enables an eventual override point to gracefully fallback to the standard behavior # what - max-autotune is determined inside V.choices.get_mm_configs because it's mm only right now, we can just do `config.max_autotune or config.max_autotune_gemm` a TODO indicates that this can change in the future when this expands to more templates # testing ``` python3 -bb -m pytest test/inductor/test_max_autotune.py -v ``` Differential Revision: [D81520573](https://our.internmc.facebook.com/intern/diff/D81520573) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161344 Approved by: https://github.com/jansel ghstack dependencies: #162075, #161340, #161341, #161342, #161343
This commit is contained in:
committed by
PyTorch MergeBot
parent
a301dc3b60
commit
031d79cb51
@ -31,8 +31,11 @@ from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout
|
||||
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
|
||||
from torch._inductor.select_algorithm import (
|
||||
add_feedback_saver,
|
||||
add_preprocessing_fn,
|
||||
AlgorithmSelectorCache,
|
||||
clear_feedback_savers,
|
||||
clear_preprocessing_fns,
|
||||
ExternKernelCaller,
|
||||
TritonTemplate,
|
||||
TritonTemplateCaller,
|
||||
)
|
||||
@ -1900,6 +1903,76 @@ class TestMaxAutotune(TestCase):
|
||||
counters["inductor"]["select_algorithm_num_precompilation_exceptions"], 0
|
||||
)
|
||||
|
||||
@parametrize("op", ("mm", "addmm", "bmm", "baddbmm", "mm_plus_mm"))
|
||||
@parametrize("max_autotune", (False, True))
|
||||
@config.patch(
|
||||
{"test_configs.max_mm_configs": 4, "max_autotune_gemm_backends": "ATEN,TRITON"}
|
||||
)
|
||||
def test_autotune_gemm_choice_validation(self, op, max_autotune):
|
||||
def generate_inputs_and_func(op_name):
|
||||
# Base config with just x and w
|
||||
base_inputs = [
|
||||
torch.randn(128, 256, device=GPU_TYPE),
|
||||
torch.randn(256, 128, device=GPU_TYPE),
|
||||
]
|
||||
func = torch.mm
|
||||
if op_name == "mm":
|
||||
# default
|
||||
pass
|
||||
elif op_name == "addmm":
|
||||
# Add bias for addmm
|
||||
base_inputs = [torch.randn(128, device=GPU_TYPE)] + base_inputs
|
||||
func = torch.addmm
|
||||
elif op_name in ["bmm", "baddbmm"]:
|
||||
# Override for batch dimensions
|
||||
base_inputs[0] = torch.randn(4, 128, 256, device=GPU_TYPE)
|
||||
base_inputs[1] = torch.randn(4, 256, 128, device=GPU_TYPE)
|
||||
func = torch.bmm
|
||||
if op_name == "baddbmm":
|
||||
# Add batch bias
|
||||
base_inputs = [
|
||||
torch.torch.randn(4, 128, 128, device=GPU_TYPE)
|
||||
] + base_inputs
|
||||
func = torch.baddbmm
|
||||
elif op_name == "mm_plus_mm":
|
||||
# Add second matrix pair
|
||||
base_inputs += [
|
||||
torch.randn(128, 256, device=GPU_TYPE),
|
||||
torch.randn(256, 128, device=GPU_TYPE),
|
||||
]
|
||||
|
||||
def mmpmm(x, w, x2, w2):
|
||||
return torch.mm(x, w) + torch.mm(x2, w2)
|
||||
|
||||
func = mmpmm
|
||||
else:
|
||||
raise ValueError(f"Unsupported op: {op_name}")
|
||||
return base_inputs, func
|
||||
|
||||
choice_types_seen = set()
|
||||
|
||||
def choice_validator(choices):
|
||||
for choice in choices:
|
||||
choice_types_seen.add(type(choice))
|
||||
return choices
|
||||
|
||||
inputs, fn = generate_inputs_and_func(op)
|
||||
|
||||
add_preprocessing_fn(choice_validator)
|
||||
try:
|
||||
with config.patch({"max_autotune": max_autotune}):
|
||||
compiled_fn = torch.compile(fn, dynamic=False)
|
||||
compiled_fn(*inputs)
|
||||
|
||||
if max_autotune:
|
||||
self.assertIn(ExternKernelCaller, choice_types_seen)
|
||||
self.assertIn(TritonTemplateCaller, choice_types_seen)
|
||||
else:
|
||||
self.assertIn(ExternKernelCaller, choice_types_seen)
|
||||
self.assertNotIn(TritonTemplateCaller, choice_types_seen)
|
||||
finally:
|
||||
clear_preprocessing_fns()
|
||||
|
||||
|
||||
class TestMaxAutotunePrecompile(TestCase):
|
||||
def test_precompilation_threads(self):
|
||||
|
@ -207,7 +207,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if use_triton_template(layout):
|
||||
if use_triton_template(layout, check_max_autotune=False):
|
||||
# TODO: add out_dtype support for Triton Template
|
||||
assert out_dtype is None, "out_dtype is not supported for Triton"
|
||||
|
||||
@ -288,7 +288,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if use_triton_template(layout):
|
||||
if use_triton_template(layout, check_max_autotune=False):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
|
@ -762,7 +762,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
)
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
if is_nonzero and use_triton_template(layout, check_max_autotune=False):
|
||||
# Get template params using the new unified function
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template, "mm"
|
||||
@ -941,7 +941,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||
choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True
|
||||
)
|
||||
|
||||
if is_nonzero and use_triton_template(layout, enable_int32=True):
|
||||
if is_nonzero and use_triton_template(
|
||||
layout, enable_int32=True, check_max_autotune=False
|
||||
):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template, name
|
||||
):
|
||||
@ -1035,7 +1037,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
if is_nonzero and use_triton_template(layout, check_max_autotune=False):
|
||||
# all the triton templates use the extra_kwargs
|
||||
# Get template params using the new unified function
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
@ -1248,7 +1250,9 @@ def tuned_scaled_mm(
|
||||
|
||||
_, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
if is_nonzero and use_triton_template(layout, enable_float8=True):
|
||||
if is_nonzero and use_triton_template(
|
||||
layout, enable_float8=True, check_max_autotune=False
|
||||
):
|
||||
overriders = dict(USE_FAST_ACCUM=use_fast_accum)
|
||||
# TODO (paulzhan): There is no template that exists for bias and TMA
|
||||
# Don't run tma template currently if bias exists
|
||||
|
@ -165,7 +165,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if use_triton_template(layout1):
|
||||
if use_triton_template(layout1, check_max_autotune=False):
|
||||
# Get template params using the new unified function
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout1, mm_plus_mm_template, "mm_plus_mm"
|
||||
|
@ -8,6 +8,7 @@ from ..kernel.bmm import aten_baddbmm, aten_bmm, aten_bmm_dtype
|
||||
from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm, aten_mm
|
||||
from ..kernel.mm_plus_mm import aten_mm_plus_mm
|
||||
from .base import TemplateConfigHeuristics
|
||||
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -37,7 +38,7 @@ class ATenConfigHeuristics(TemplateConfigHeuristics):
|
||||
If you want to use this with an ATen choice that has kwargs, just subclass
|
||||
"""
|
||||
|
||||
def get_template_configs(
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Layout,
|
||||
@ -68,17 +69,15 @@ class ATenAddMMConfigHeuristics(ATenConfigHeuristics):
|
||||
|
||||
|
||||
@register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm")
|
||||
class ATenBiasAddMMConfigHeuristics(ATenAddMMConfigHeuristics):
|
||||
def get_template_configs(
|
||||
class ATenBiasAddMMConfigHeuristics(
|
||||
ATenAddMMConfigHeuristics, GemmMaxAutotuneTemplateConfigHeuristics
|
||||
):
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Layout,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
if not (inductor_config.max_autotune or inductor_config.max_autotune_gemm):
|
||||
# NOTE: this preserves the original logic that if there is not max-autotune
|
||||
# then we skip bias_addmm
|
||||
return
|
||||
nodes = kernel_inputs.nodes()
|
||||
# for addmm, bias is the first input
|
||||
bias = nodes[0]
|
||||
|
@ -11,6 +11,20 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class TemplateConfigHeuristics:
|
||||
"""Base class for generating sets of configs for an associated template."""
|
||||
|
||||
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
|
||||
"""
|
||||
hookup to check whether the configs are right to run at all e.g. you can check
|
||||
max-autotune specific to your heuristic here or other things
|
||||
If this returns False, get_template_configs will yield no configs
|
||||
|
||||
Args:
|
||||
inputs: KernelInputs
|
||||
layout: Layout
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
@ -19,10 +33,30 @@ class TemplateConfigHeuristics:
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Get template configs for the given inputs.
|
||||
|
||||
Prefer to override the _get_template_configs_impl method
|
||||
to leverage things like should_run
|
||||
"""
|
||||
if not self.should_run(kernel_inputs, layout):
|
||||
return
|
||||
|
||||
yield from self._get_template_configs_impl(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
op_name,
|
||||
)
|
||||
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Layout,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Get template configs for the given inputs.
|
||||
This is the main entry point for template-specific logic.
|
||||
"""
|
||||
# NOTE: not an abstract class, because that clashed below for the mixin
|
||||
# functionality. Can be adjusted, but not a high priority
|
||||
# base implementation yields no entries
|
||||
yield from []
|
||||
|
||||
def get_extra_kwargs(
|
||||
|
@ -12,6 +12,7 @@ from ..kernel.mm import (
|
||||
from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from ..utils import use_contiguous
|
||||
from .base import TemplateConfigHeuristics
|
||||
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
|
||||
from .registry import register_template_heuristic
|
||||
|
||||
|
||||
@ -41,8 +42,8 @@ class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics):
|
||||
register=torch.version.hip is not None,
|
||||
op_name="addmm",
|
||||
)
|
||||
class ContiguousMMHeuristics(TemplateConfigHeuristics):
|
||||
def get_template_configs(
|
||||
class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Layout,
|
||||
@ -54,7 +55,6 @@ class ContiguousMMHeuristics(TemplateConfigHeuristics):
|
||||
assert isinstance(kernel_inputs, MMKernelInputs), (
|
||||
f"{self.__class__.__name__} requires MMKernelInputs"
|
||||
)
|
||||
|
||||
# Check for unbacked symbols - if found, yield nothing
|
||||
unbacked_symbols = any(
|
||||
len(get_free_symbols(itr, unbacked_only=True)) > 0
|
||||
|
@ -12,6 +12,7 @@ from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from ..utils import get_k_splits
|
||||
from ..virtualized import V
|
||||
from .base import TemplateConfigHeuristics
|
||||
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
|
||||
from .registry import register_template_heuristic
|
||||
|
||||
|
||||
@ -38,8 +39,8 @@ class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
|
||||
# TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia)
|
||||
# by either adding specific register_template_heuristic tags, or setting the
|
||||
# device to None (enabled on all devices)
|
||||
class DecomposeKConfigHeuristics(TemplateConfigHeuristics):
|
||||
def get_template_configs(
|
||||
class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Layout,
|
||||
|
19
torch/_inductor/template_heuristics/gemm.py
Normal file
19
torch/_inductor/template_heuristics/gemm.py
Normal file
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .. import config as inductor_config
|
||||
from .base import TemplateConfigHeuristics
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import Layout
|
||||
from ..kernel_inputs import KernelInputs
|
||||
|
||||
|
||||
class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics):
|
||||
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
|
||||
"""
|
||||
simple base override for GEMM family templates that run only in max-autotune
|
||||
"""
|
||||
return inductor_config.max_autotune or inductor_config.max_autotune_gemm
|
@ -32,7 +32,7 @@ from ..utils import (
|
||||
using_b200,
|
||||
)
|
||||
from ..virtualized import V
|
||||
from .base import TemplateConfigHeuristics
|
||||
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
|
||||
from .registry import register_template_heuristic
|
||||
|
||||
|
||||
@ -1417,7 +1417,7 @@ class MTIAConfigHeuristic(BaseConfigHeuristic):
|
||||
|
||||
|
||||
# Template-specific mixin classes
|
||||
class MMTemplateConfigMixin(TemplateConfigHeuristics):
|
||||
class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
|
||||
"""
|
||||
Mixin class that converts config lists to template kwargs.
|
||||
This handles the logic that was previously in choices.get_mm_configs.
|
||||
@ -1448,7 +1448,7 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
|
||||
else:
|
||||
return self.get_mm_configs()
|
||||
|
||||
def get_template_configs(
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
@ -1466,6 +1466,7 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics):
|
||||
raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}")
|
||||
if not self._valid(kernel_inputs):
|
||||
return
|
||||
|
||||
# Extract M, N, K from kernel_inputs
|
||||
m, n, k = kernel_inputs.mnk_symbolic()
|
||||
|
||||
@ -1593,7 +1594,7 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
|
||||
This inherits from MMTemplateConfigMixin and overrides config generation.
|
||||
"""
|
||||
|
||||
def get_template_configs(
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
@ -1614,8 +1615,10 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
|
||||
"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(),
|
||||
}
|
||||
# Get base template configs from superclass
|
||||
for template_kwargs in super().get_template_configs(
|
||||
kernel_inputs, layout, op_name
|
||||
for template_kwargs in super()._get_template_configs_impl(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
op_name,
|
||||
):
|
||||
yield {**template_kwargs, **tma_opts}
|
||||
|
||||
@ -1661,7 +1664,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
|
||||
nodes, mat1_idx=kernel_inputs._mat1_idx, mat2_idx=kernel_inputs._mat2_idx
|
||||
)
|
||||
|
||||
def get_template_configs(
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
@ -1713,7 +1716,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
|
||||
return
|
||||
|
||||
# Get base template configs from superclass
|
||||
for template_kwargs in super().get_template_configs(
|
||||
for template_kwargs in super()._get_template_configs_impl(
|
||||
kernel_inputs, layout, op_name
|
||||
):
|
||||
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
|
||||
@ -1769,7 +1772,7 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
|
||||
This inherits from BaseScaledMMConfigMixin and adds TMA-specific options.
|
||||
"""
|
||||
|
||||
def get_template_configs(
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
@ -1779,8 +1782,10 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
|
||||
Generate scaled TMA template configs with both scaled MM and TMA-specific options.
|
||||
"""
|
||||
# Get base scaled MM template configs from superclass
|
||||
for template_kwargs in super().get_template_configs(
|
||||
kernel_inputs, layout, op_name
|
||||
for template_kwargs in super()._get_template_configs_impl(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
op_name,
|
||||
):
|
||||
# Add TMA-specific options for device TMA scaled MM
|
||||
template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE
|
||||
|
@ -1634,7 +1634,11 @@ def _use_conv_autotune_backend(backend: str) -> bool:
|
||||
|
||||
|
||||
def use_triton_template(
|
||||
layout: Layout, *, enable_int32: bool = False, enable_float8: bool = False
|
||||
layout: Layout,
|
||||
*,
|
||||
enable_int32: bool = False,
|
||||
enable_float8: bool = False,
|
||||
check_max_autotune: bool = True,
|
||||
) -> bool:
|
||||
from .codegen.common import BackendFeature, has_backend_feature
|
||||
|
||||
@ -1651,7 +1655,8 @@ def use_triton_template(
|
||||
)
|
||||
or (layout.device.type == "cpu" and layout.dtype in layout_dtypes)
|
||||
)
|
||||
and (config.max_autotune or config.max_autotune_gemm)
|
||||
# some callers handle max-autotune checking externally
|
||||
and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune)
|
||||
and _use_autotune_backend("TRITON")
|
||||
and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
|
||||
)
|
||||
|
Reference in New Issue
Block a user