[inductor][template heuristics] don't take layout to generate choices (#162238)

# why

- unnecessary as we only ever need to know the dtype and maybe the
  device
- we already take in the kernel inputs which have the device
- enable us to specify the layout after finding all the configs
  but before generating the ChoiceCallers

# what

- replace all calls in template_heuristics that used to take Layout
  with now just taking out_dtype

# testing

ci

Differential Revision: [D81820115](https://our.internmc.facebook.com/intern/diff/D81820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162238
Approved by: https://github.com/eellison
ghstack dependencies: #161347, #161348, #161349
This commit is contained in:
Ruben Rodriguez Buchillon
2025-09-08 17:11:12 -07:00
committed by PyTorch MergeBot
parent 24a4dae85b
commit d91eecc9a5
12 changed files with 104 additions and 93 deletions

View File

@ -163,10 +163,9 @@ class InductorChoices:
heuristic = get_template_heuristic(template_name, device_type, op_name) heuristic = get_template_heuristic(template_name, device_type, op_name)
cs = heuristic.get_template_configs( cs = heuristic.get_template_configs(
kernel_inputs, kernel_inputs,
layout,
op_name, op_name,
) )
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, layout, op_name) extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, op_name)
# adjust the kernel inputs to the template-specific heuristic, if needed # adjust the kernel inputs to the template-specific heuristic, if needed
# default here is to just return the kernel_inputs as is # default here is to just return the kernel_inputs as is
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name) inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
@ -184,9 +183,9 @@ class InductorChoices:
def get_mm_configs( def get_mm_configs(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any,
templates: list[Union[KernelTemplate, ExternKernelChoice]], templates: list[Union[KernelTemplate, ExternKernelChoice]],
op_name: str, op_name: str,
layout: Optional[Layout] = None,
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
) -> list[ChoiceCaller]: ) -> list[ChoiceCaller]:
""" """
@ -207,7 +206,11 @@ class InductorChoices:
input_tensors = kernel_inputs.nodes() input_tensors = kernel_inputs.nodes()
if len(input_tensors) < 2: if len(input_tensors) < 2:
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}") raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
if layout is None:
# TODO(coconutruben): remove this once we remove the layout argument entirely
# This is just here to the brief gap between commits where we still need this
# to accommodate fixed vs flexible layout decision externally
layout = kernel_inputs.output_layout(flexible=False)
# First pass: Create dict of template.uid to generator of KernelTemplateChoice objects # First pass: Create dict of template.uid to generator of KernelTemplateChoice objects
template_choices = {} template_choices = {}
for template in templates: for template in templates:

View File

@ -173,7 +173,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
name = "bmm" name = "bmm"
# Create MMKernelInputs for BMM at the top # Create MMKernelInputs for BMM at the top
kernel_inputs = MMKernelInputs([mat1, mat2]) kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=out_dtype)
# below is for getting an overview logging info of inductor mms # below is for getting an overview logging info of inductor mms
batch_size = mat1.get_size()[0] # Extract batch dimension batch_size = mat1.get_size()[0] # Extract batch dimension
@ -201,10 +201,9 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout,
[aten_handler], [aten_handler],
name, name,
{aten_handler.uid: aten_extra_kwargs}, kwarg_overrides={aten_handler.uid: aten_extra_kwargs},
) )
) )
@ -212,9 +211,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
# TODO: add out_dtype support for Triton Template # TODO: add out_dtype support for Triton Template
assert out_dtype is None, "out_dtype is not supported for Triton" assert out_dtype is None, "out_dtype is not supported for Triton"
choices.extend( choices.extend(V.choices.get_mm_configs(kernel_inputs, [bmm_template], name))
V.choices.get_mm_configs(kernel_inputs, layout, [bmm_template], name)
)
_, is_nonzero = _is_static_problem(layout) _, is_nonzero = _is_static_problem(layout)
batch_stride_largest_or_zero = is_batch_stride_largest_or_zero(mat1, mat2, layout) batch_stride_largest_or_zero = is_batch_stride_largest_or_zero(mat1, mat2, layout)
if ( if (
@ -275,15 +272,12 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
# options to tune from # options to tune from
choices: list[ChoiceCaller] = [] choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels(): if use_aten_gemm_kernels():
choices.extend( choices.extend(V.choices.get_mm_configs(kernel_inputs, [aten_baddbmm], name))
V.choices.get_mm_configs(kernel_inputs, layout, [aten_baddbmm], name)
)
if use_triton_template(layout, check_max_autotune=False): if use_triton_template(layout, check_max_autotune=False):
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout,
[bmm_template], [bmm_template],
name, name,
) )

View File

@ -753,32 +753,30 @@ def tuned_mm(mat1, mat2, *, layout=None):
choices: list[ChoiceCaller] = [] choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels(): if use_aten_gemm_kernels():
choices.extend( choices.extend(
V.choices.get_mm_configs(kernel_inputs, aten_layout, [aten_mm], "mm") V.choices.get_mm_configs(kernel_inputs, [aten_mm], "mm", aten_layout)
) )
static_shape, is_nonzero = _is_static_problem(layout) static_shape, is_nonzero = _is_static_problem(layout)
if is_nonzero and use_triton_template(layout, check_max_autotune=False): if is_nonzero and use_triton_template(layout, check_max_autotune=False):
# Get template choices using the new unified function # Get template choices using the new unified function
choices.extend( choices.extend(V.choices.get_mm_configs(kernel_inputs, [mm_template], "mm"))
V.choices.get_mm_configs(kernel_inputs, layout, [mm_template], "mm")
)
if use_triton_tma_template(mat1, mat2): if use_triton_tma_template(mat1, mat2):
# Get TMA template choices using the new unified function # Get TMA template choices using the new unified function
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, layout, [persistent_tma_mm_template], "mm" kernel_inputs, [persistent_tma_mm_template], "mm"
) )
) )
if use_decompose_k_choice(m, n, k): if use_decompose_k_choice(m, n, k):
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, layout, [decompose_k_subgraph_template], "mm" kernel_inputs, [decompose_k_subgraph_template], "mm"
) )
) )
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, layout, [mm_contiguous_subgraph_template], "mm" kernel_inputs, [mm_contiguous_subgraph_template], "mm"
) )
) )
@ -820,7 +818,6 @@ def tuned_mm(mat1, mat2, *, layout=None):
# mm-extra is a hack to keep the ah functionality alive # mm-extra is a hack to keep the ah functionality alive
# while we transition to the unified kwargs retrieval # while we transition to the unified kwargs retrieval
kernel_inputs, kernel_inputs,
layout,
[mm_template], [mm_template],
"mm-ah", "mm-ah",
) )
@ -896,12 +893,11 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
choices: list[ChoiceCaller] = [] choices: list[ChoiceCaller] = []
# Create MMKernelInputs for Int MM # Create MMKernelInputs for Int MM
kernel_inputs = MMKernelInputs([mat1, mat2]) kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=torch.int32)
if use_aten_gemm_kernels(): if use_aten_gemm_kernels():
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout,
[aten__int_mm], [aten__int_mm],
name, name,
) )
@ -915,9 +911,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
if is_nonzero and use_triton_template( if is_nonzero and use_triton_template(
layout, enable_int32=True, check_max_autotune=False layout, enable_int32=True, check_max_autotune=False
): ):
choices.extend( choices.extend(V.choices.get_mm_configs(kernel_inputs, [mm_template], name))
V.choices.get_mm_configs(kernel_inputs, layout, [mm_template], name)
)
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
@ -969,9 +963,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
aten_layout,
[aten_addmm], [aten_addmm],
name, name,
aten_layout,
) )
) )
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
@ -980,7 +974,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
aten_layout,
[aten_bias_addmm], [aten_bias_addmm],
name, name,
) )
@ -988,7 +981,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
aten_layout,
[aten_addmm], [aten_addmm],
name, name,
) )
@ -1000,7 +992,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout,
[mm_template], [mm_template],
name, name,
) )
@ -1011,7 +1002,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout,
[persistent_tma_mm_template], [persistent_tma_mm_template],
name, name,
) )
@ -1020,7 +1010,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout,
[addmm_contiguous_subgraph_template], [addmm_contiguous_subgraph_template],
"addmm", "addmm",
) )
@ -1174,14 +1163,15 @@ def tuned_scaled_mm(
input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real, bias_real] input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real, bias_real]
# Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1) # Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1)
kernel_inputs = MMKernelInputs(input_nodes, mat1_idx=0, mat2_idx=1) kernel_inputs = MMKernelInputs(
input_nodes, mat1_idx=0, mat2_idx=1, out_dtype=out_dtype
)
choices: list[ChoiceCaller] = [] choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels(): if use_aten_gemm_kernels():
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout,
[aten__fp8_mm], [aten__fp8_mm],
name, name,
kwarg_overrides={ kwarg_overrides={
@ -1209,7 +1199,6 @@ def tuned_scaled_mm(
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout,
[scaled_mm_device_tma_template], [scaled_mm_device_tma_template],
name, name,
kwarg_overrides={scaled_mm_device_tma_template.uid: overriders}, kwarg_overrides={scaled_mm_device_tma_template.uid: overriders},
@ -1220,7 +1209,6 @@ def tuned_scaled_mm(
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(
kernel_inputs, kernel_inputs,
layout,
[mm_template], [mm_template],
name, name,
kwarg_overrides={mm_template.uid: overriders}, kwarg_overrides={mm_template.uid: overriders},

View File

@ -157,17 +157,13 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
choices: list[ChoiceCaller] = [] choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels(): if use_aten_gemm_kernels():
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(kernel_inputs, [aten_mm_plus_mm], "mm_plus_mm")
kernel_inputs, layout1, [aten_mm_plus_mm], "mm_plus_mm"
)
) )
if use_triton_template(layout1, check_max_autotune=False): if use_triton_template(layout1, check_max_autotune=False):
# Get template choices using the new unified function # Get template choices using the new unified function
choices.extend( choices.extend(
V.choices.get_mm_configs( V.choices.get_mm_configs(kernel_inputs, [mm_plus_mm_template], "mm_plus_mm")
kernel_inputs, layout1, [mm_plus_mm_template], "mm_plus_mm"
)
) )
return autotune_select_algorithm( return autotune_select_algorithm(

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, TYPE_CHECKING, Union from typing import Any, Optional, TYPE_CHECKING, Union
import torch import torch
@ -7,6 +8,8 @@ import torch._inductor.config
from torch._inductor import ir from torch._inductor import ir
from torch._inductor.virtualized import V from torch._inductor.virtualized import V
from .ir import FixedLayout, FlexibleLayout, Layout
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
@ -14,7 +17,7 @@ if TYPE_CHECKING:
import sympy import sympy
class KernelInputs: class KernelInputs(ABC):
""" """
Class to store and provide access to input nodes for kernels. Class to store and provide access to input nodes for kernels.
This class takes in a tuple of input nodes and provides methods to access This class takes in a tuple of input nodes and provides methods to access
@ -25,16 +28,19 @@ class KernelInputs:
self, self,
input_nodes: list[Any], input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None, scalars: Optional[dict[str, Union[float, int]]] = None,
out_dtype: Optional[torch.dtype] = None,
): ):
""" """
Initialize with a tuple of input nodes. Initialize with a tuple of input nodes.
Args: Args:
input_nodes: A tuple of input nodes to store input_nodes: A tuple of input nodes to store
out_dtype: Optional output dtype to store
""" """
self._input_nodes = input_nodes self._input_nodes = input_nodes
self._device_name: Optional[str] = None self._device_name: Optional[str] = None
self._scalars = scalars if scalars is not None else {} self._scalars = scalars if scalars is not None else {}
self._out_dtype = out_dtype
assert len(input_nodes) > 0, "Expected at least one input node" assert len(input_nodes) > 0, "Expected at least one input node"
def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]: def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]:
@ -168,6 +174,15 @@ class KernelInputs:
""" """
return self._input_nodes[idx].get_dtype() return self._input_nodes[idx].get_dtype()
@abstractmethod
def out_dtype(self) -> torch.dtype:
"""
Get the output dtype, whether passed in or inferred from the nodes
Returns:
The output dtype
"""
def get_scalar(self, name: str) -> Union[float, int]: def get_scalar(self, name: str) -> Union[float, int]:
""" """
Get the scalar value for a given name. Get the scalar value for a given name.
@ -181,6 +196,16 @@ class KernelInputs:
assert name in self._scalars, f"Scalar {name} not found, but required" assert name in self._scalars, f"Scalar {name} not found, but required"
return self._scalars[name] return self._scalars[name]
@abstractmethod
def output_layout(self, flexible: bool = True) -> Layout:
"""
Abstract method to handle output layout generation.
Args:
out_dtype: Optional output dtype. If not provided, infer from inputs
flexible: If True, return FlexibleLayout, otherwise FixedLayout
"""
class MMKernelInputs(KernelInputs): class MMKernelInputs(KernelInputs):
""" """
@ -192,6 +217,7 @@ class MMKernelInputs(KernelInputs):
self, self,
input_nodes: list[Any], input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None, scalars: Optional[dict[str, Union[float, int]]] = None,
out_dtype: Optional[torch.dtype] = None,
mat1_idx: int = -2, mat1_idx: int = -2,
mat2_idx: int = -1, mat2_idx: int = -1,
): ):
@ -201,7 +227,7 @@ class MMKernelInputs(KernelInputs):
By default, we assume the last 2 input nodes are mat1 and mat2, but By default, we assume the last 2 input nodes are mat1 and mat2, but
the caller can adjust when necessary the caller can adjust when necessary
""" """
super().__init__(input_nodes, scalars) super().__init__(input_nodes, scalars, out_dtype)
# for mm, we need at least 2 nodes, and we need to know which nodes # for mm, we need at least 2 nodes, and we need to know which nodes
# are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others # are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others
# might be (mat1, mat2, scale), etc. # might be (mat1, mat2, scale), etc.
@ -246,6 +272,37 @@ class MMKernelInputs(KernelInputs):
V.graph.sizevars.check_equals(k, k0) V.graph.sizevars.check_equals(k, k0)
return (m, n, k) return (m, n, k)
def out_dtype(self) -> torch.dtype:
"""
Get the output dtype, whether passed in or inferred from the nodes
Returns:
The output dtype
"""
if self._out_dtype is not None:
return self._out_dtype
return self.mat1mat2()[0].get_dtype()
def output_layout(self, flexible: bool = True) -> Layout:
"""
Handle output layout generation for matrix multiplication.
Args:
out_dtype: Optional output dtype. If not provided, infer from inputs
flexible: If True, return FlexibleLayout, otherwise FixedLayout
"""
mat1, mat2 = self.mat1mat2()
out_dtype = self.out_dtype()
# NOTE: taken from mm_common.mm_args
*b1, m, k1 = mat1.get_size()
*b2, k2, n = mat2.get_size()
b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)]
size = [*b, m, n]
if flexible:
return FlexibleLayout(self.device(), out_dtype, size)
else:
return FixedLayout(self.device(), out_dtype, size)
def mat1mat2(self) -> tuple[Any, Any]: def mat1mat2(self) -> tuple[Any, Any]:
""" """
Get the mat1 and mat2 nodes. Get the mat1 and mat2 nodes.

View File

@ -9,16 +9,14 @@ from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm,
from ..kernel.mm_plus_mm import aten_mm_plus_mm from ..kernel.mm_plus_mm import aten_mm_plus_mm
from .base import TemplateConfigHeuristics from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
from .registry import register_template_heuristic
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator from collections.abc import Generator
from ..ir import Layout
from ..kernel_inputs import KernelInputs from ..kernel_inputs import KernelInputs
from .registry import register_template_heuristic
# These are all labeled as device type None to indicate that they # These are all labeled as device type None to indicate that they
# are valid for all device types # are valid for all device types
@ -41,7 +39,6 @@ class ATenConfigHeuristics(TemplateConfigHeuristics):
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
yield dict() yield dict()
@ -55,10 +52,9 @@ class ATenAddMMConfigHeuristics(ATenConfigHeuristics):
def get_extra_kwargs( def get_extra_kwargs(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name) kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
alpha = kernel_inputs.get_scalar("alpha") alpha = kernel_inputs.get_scalar("alpha")
beta = kernel_inputs.get_scalar("beta") beta = kernel_inputs.get_scalar("beta")
return { return {
@ -75,7 +71,6 @@ class ATenBiasAddMMConfigHeuristics(
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
nodes = kernel_inputs.nodes() nodes = kernel_inputs.nodes()

View File

@ -6,14 +6,13 @@ from typing import Any, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator from collections.abc import Generator
from ..ir import Layout
from ..kernel_inputs import KernelInputs from ..kernel_inputs import KernelInputs
class TemplateConfigHeuristics: class TemplateConfigHeuristics:
"""Base class for generating sets of configs for an associated template.""" """Base class for generating sets of configs for an associated template."""
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool: def should_run(self, inputs: KernelInputs) -> bool:
""" """
hookup to check whether the configs are right to run at all e.g. you can check hookup to check whether the configs are right to run at all e.g. you can check
max-autotune specific to your heuristic here or other things max-autotune specific to your heuristic here or other things
@ -21,14 +20,12 @@ class TemplateConfigHeuristics:
Args: Args:
inputs: KernelInputs inputs: KernelInputs
layout: Layout
""" """
return True return True
def get_template_configs( def get_template_configs(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
@ -37,19 +34,17 @@ class TemplateConfigHeuristics:
Prefer to override the _get_template_configs_impl method Prefer to override the _get_template_configs_impl method
to leverage things like should_run to leverage things like should_run
""" """
if not self.should_run(kernel_inputs, layout): if not self.should_run(kernel_inputs):
return return
yield from self._get_template_configs_impl( yield from self._get_template_configs_impl(
kernel_inputs, kernel_inputs,
layout,
op_name, op_name,
) )
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
@ -62,7 +57,6 @@ class TemplateConfigHeuristics:
def get_extra_kwargs( def get_extra_kwargs(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """

View File

@ -19,8 +19,6 @@ from .registry import register_template_heuristic
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator from collections.abc import Generator
from ..ir import Layout
@register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm") @register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm")
@register_template_heuristic( @register_template_heuristic(
@ -46,7 +44,6 @@ class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """

View File

@ -19,8 +19,6 @@ from .registry import register_template_heuristic
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator from collections.abc import Generator
from ..ir import Layout
@register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm") @register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm")
class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics): class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
@ -43,7 +41,6 @@ class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """

View File

@ -7,12 +7,11 @@ from .base import TemplateConfigHeuristics
if TYPE_CHECKING: if TYPE_CHECKING:
from ..ir import Layout
from ..kernel_inputs import KernelInputs from ..kernel_inputs import KernelInputs
class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics): class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics):
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool: def should_run(self, inputs: KernelInputs) -> bool:
""" """
simple base override for GEMM family templates that run only in max-autotune simple base override for GEMM family templates that run only in max-autotune
""" """

View File

@ -41,8 +41,6 @@ if TYPE_CHECKING:
from triton import Config as TritonConfig from triton import Config as TritonConfig
from ..ir import Layout
# Gemm Configs # Gemm Configs
@dataclasses.dataclass @dataclasses.dataclass
@ -1451,7 +1449,6 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
@ -1479,7 +1476,11 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
# Generate and process configs # Generate and process configs
for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name): for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name):
template_kwargs = self._convert_config_to_template_kwargs( template_kwargs = self._convert_config_to_template_kwargs(
c, m, n, k, layout c,
m,
n,
k,
kernel_inputs.out_dtype(),
) )
yield template_kwargs yield template_kwargs
@ -1489,7 +1490,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
m: sympy.Integer, m: sympy.Integer,
n: sympy.Integer, n: sympy.Integer,
k: sympy.Integer, k: sympy.Integer,
layout: Any, out_dtype: torch.dtype,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Convert triton config to template kwargs. Convert triton config to template kwargs.
@ -1513,7 +1514,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
EVEN_K=even_k_symbolic, EVEN_K=even_k_symbolic,
ALLOW_TF32=allow_tf32, ALLOW_TF32=allow_tf32,
USE_FAST_ACCUM=False, # Option for _scaled_mm USE_FAST_ACCUM=False, # Option for _scaled_mm
ACC_TYPE=self._get_acc_type(layout.dtype), ACC_TYPE=self._get_acc_type(out_dtype),
num_stages=triton_config.num_stages, num_stages=triton_config.num_stages,
num_warps=triton_config.num_warps, num_warps=triton_config.num_warps,
**triton_config.kwargs, **triton_config.kwargs,
@ -1562,14 +1563,11 @@ class MMPlusMMTemplateConfigMixin(MMTemplateConfigMixin):
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
assert isinstance(kernel_inputs, MMKernelInputs), "Expect MMKernelInputs" assert isinstance(kernel_inputs, MMKernelInputs), "Expect MMKernelInputs"
m, n, k = kernel_inputs.mnk_symbolic() m, n, k = kernel_inputs.mnk_symbolic()
for kwargs in super()._get_template_configs_impl( for kwargs in super()._get_template_configs_impl(kernel_inputs, op_name):
kernel_inputs, layout, op_name
):
# Apply BLOCK_K constraint specific to mm_plus_mm # Apply BLOCK_K constraint specific to mm_plus_mm
# see https://github.com/triton-lang/triton/issues/1298 # see https://github.com/triton-lang/triton/issues/1298
# BLOCK_K = K causes llvm error # BLOCK_K = K causes llvm error
@ -1586,10 +1584,9 @@ class TMAWorkspaceMixin(MMTemplateConfigMixin):
def get_extra_kwargs( def get_extra_kwargs(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name) kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
kwargs["workspace_arg"] = get_tma_workspace_arg( kwargs["workspace_arg"] = get_tma_workspace_arg(
num_tma_descriptors=2, num_tma_descriptors=2,
device=kernel_inputs.device(), device=kernel_inputs.device(),
@ -1614,7 +1611,6 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
@ -1634,7 +1630,6 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
# Get base template configs from superclass # Get base template configs from superclass
for template_kwargs in super()._get_template_configs_impl( for template_kwargs in super()._get_template_configs_impl(
kernel_inputs, kernel_inputs,
layout,
op_name, op_name,
): ):
yield {**template_kwargs, **tma_opts} yield {**template_kwargs, **tma_opts}
@ -1684,7 +1679,6 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
@ -1734,7 +1728,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
# Get base template configs from superclass # Get base template configs from superclass
for template_kwargs in super()._get_template_configs_impl( for template_kwargs in super()._get_template_configs_impl(
kernel_inputs, layout, op_name kernel_inputs, op_name
): ):
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options) # Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
# Override accumulator type for scaled MM # Override accumulator type for scaled MM
@ -1752,10 +1746,9 @@ class ScaledMMConfigMixin(BaseScaledMMConfigMixin):
def get_extra_kwargs( def get_extra_kwargs(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name) kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
from ..kernel.mm_common import scale_mm_epilogue from ..kernel.mm_common import scale_mm_epilogue
return { return {
@ -1792,7 +1785,6 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
def _get_template_configs_impl( def _get_template_configs_impl(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Any,
op_name: str, op_name: str,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
@ -1801,7 +1793,6 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
# Get base scaled MM template configs from superclass # Get base scaled MM template configs from superclass
for template_kwargs in super()._get_template_configs_impl( for template_kwargs in super()._get_template_configs_impl(
kernel_inputs, kernel_inputs,
layout,
op_name, op_name,
): ):
# Add TMA-specific options for device TMA scaled MM # Add TMA-specific options for device TMA scaled MM

View File

@ -7,7 +7,6 @@ from .base import TemplateConfigHeuristics
if TYPE_CHECKING: if TYPE_CHECKING:
from ..ir import Layout
from ..kernel_inputs import KernelInputs from ..kernel_inputs import KernelInputs
@ -19,10 +18,9 @@ class AddMMConfigMixin(TemplateConfigHeuristics):
def get_extra_kwargs( def get_extra_kwargs(
self, self,
kernel_inputs: KernelInputs, kernel_inputs: KernelInputs,
layout: Layout,
op_name: str, op_name: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name) kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
assert op_name in [ assert op_name in [
"addmm", "addmm",
"baddbmm", "baddbmm",
@ -31,7 +29,9 @@ class AddMMConfigMixin(TemplateConfigHeuristics):
beta = kernel_inputs.get_scalar("beta") beta = kernel_inputs.get_scalar("beta")
return { return {
**kwargs, **kwargs,
"epilogue_fn": addmm_epilogue(layout.dtype, alpha, beta), "epilogue_fn": addmm_epilogue(kernel_inputs.out_dtype(), alpha, beta),
"epilogue_fn_hash": str(["addmm_epilogue", layout.dtype, alpha, beta]), "epilogue_fn_hash": str(
["addmm_epilogue", kernel_inputs.out_dtype(), alpha, beta]
),
"prefix_args": 1, "prefix_args": 1,
} }