[inductor][template heuristics] don't take layout to generate choices (#162238)

# why

- unnecessary as we only ever need to know the dtype and maybe the
  device
- we already take in the kernel inputs which have the device
- enable us to specify the layout after finding all the configs
  but before generating the ChoiceCallers

# what

- replace all calls in template_heuristics that used to take Layout
  with now just taking out_dtype

# testing

ci

Differential Revision: [D81820115](https://our.internmc.facebook.com/intern/diff/D81820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162238
Approved by: https://github.com/eellison
ghstack dependencies: #161347, #161348, #161349
This commit is contained in:
Ruben Rodriguez Buchillon
2025-09-08 17:11:12 -07:00
committed by PyTorch MergeBot
parent 24a4dae85b
commit d91eecc9a5
12 changed files with 104 additions and 93 deletions

View File

@ -163,10 +163,9 @@ class InductorChoices:
heuristic = get_template_heuristic(template_name, device_type, op_name)
cs = heuristic.get_template_configs(
kernel_inputs,
layout,
op_name,
)
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, layout, op_name)
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, op_name)
# adjust the kernel inputs to the template-specific heuristic, if needed
# default here is to just return the kernel_inputs as is
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
@ -184,9 +183,9 @@ class InductorChoices:
def get_mm_configs(
self,
kernel_inputs: KernelInputs,
layout: Any,
templates: list[Union[KernelTemplate, ExternKernelChoice]],
op_name: str,
layout: Optional[Layout] = None,
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
) -> list[ChoiceCaller]:
"""
@ -207,7 +206,11 @@ class InductorChoices:
input_tensors = kernel_inputs.nodes()
if len(input_tensors) < 2:
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
if layout is None:
# TODO(coconutruben): remove this once we remove the layout argument entirely
# This is just here to the brief gap between commits where we still need this
# to accommodate fixed vs flexible layout decision externally
layout = kernel_inputs.output_layout(flexible=False)
# First pass: Create dict of template.uid to generator of KernelTemplateChoice objects
template_choices = {}
for template in templates:

View File

@ -173,7 +173,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
name = "bmm"
# Create MMKernelInputs for BMM at the top
kernel_inputs = MMKernelInputs([mat1, mat2])
kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=out_dtype)
# below is for getting an overview logging info of inductor mms
batch_size = mat1.get_size()[0] # Extract batch dimension
@ -201,10 +201,9 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
layout,
[aten_handler],
name,
{aten_handler.uid: aten_extra_kwargs},
kwarg_overrides={aten_handler.uid: aten_extra_kwargs},
)
)
@ -212,9 +211,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
# TODO: add out_dtype support for Triton Template
assert out_dtype is None, "out_dtype is not supported for Triton"
choices.extend(
V.choices.get_mm_configs(kernel_inputs, layout, [bmm_template], name)
)
choices.extend(V.choices.get_mm_configs(kernel_inputs, [bmm_template], name))
_, is_nonzero = _is_static_problem(layout)
batch_stride_largest_or_zero = is_batch_stride_largest_or_zero(mat1, mat2, layout)
if (
@ -275,15 +272,12 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
# options to tune from
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
choices.extend(
V.choices.get_mm_configs(kernel_inputs, layout, [aten_baddbmm], name)
)
choices.extend(V.choices.get_mm_configs(kernel_inputs, [aten_baddbmm], name))
if use_triton_template(layout, check_max_autotune=False):
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
layout,
[bmm_template],
name,
)

View File

@ -753,32 +753,30 @@ def tuned_mm(mat1, mat2, *, layout=None):
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
choices.extend(
V.choices.get_mm_configs(kernel_inputs, aten_layout, [aten_mm], "mm")
V.choices.get_mm_configs(kernel_inputs, [aten_mm], "mm", aten_layout)
)
static_shape, is_nonzero = _is_static_problem(layout)
if is_nonzero and use_triton_template(layout, check_max_autotune=False):
# Get template choices using the new unified function
choices.extend(
V.choices.get_mm_configs(kernel_inputs, layout, [mm_template], "mm")
)
choices.extend(V.choices.get_mm_configs(kernel_inputs, [mm_template], "mm"))
if use_triton_tma_template(mat1, mat2):
# Get TMA template choices using the new unified function
choices.extend(
V.choices.get_mm_configs(
kernel_inputs, layout, [persistent_tma_mm_template], "mm"
kernel_inputs, [persistent_tma_mm_template], "mm"
)
)
if use_decompose_k_choice(m, n, k):
choices.extend(
V.choices.get_mm_configs(
kernel_inputs, layout, [decompose_k_subgraph_template], "mm"
kernel_inputs, [decompose_k_subgraph_template], "mm"
)
)
choices.extend(
V.choices.get_mm_configs(
kernel_inputs, layout, [mm_contiguous_subgraph_template], "mm"
kernel_inputs, [mm_contiguous_subgraph_template], "mm"
)
)
@ -820,7 +818,6 @@ def tuned_mm(mat1, mat2, *, layout=None):
# mm-extra is a hack to keep the ah functionality alive
# while we transition to the unified kwargs retrieval
kernel_inputs,
layout,
[mm_template],
"mm-ah",
)
@ -896,12 +893,11 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
choices: list[ChoiceCaller] = []
# Create MMKernelInputs for Int MM
kernel_inputs = MMKernelInputs([mat1, mat2])
kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=torch.int32)
if use_aten_gemm_kernels():
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
layout,
[aten__int_mm],
name,
)
@ -915,9 +911,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
if is_nonzero and use_triton_template(
layout, enable_int32=True, check_max_autotune=False
):
choices.extend(
V.choices.get_mm_configs(kernel_inputs, layout, [mm_template], name)
)
choices.extend(V.choices.get_mm_configs(kernel_inputs, [mm_template], name))
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
@ -969,9 +963,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
aten_layout,
[aten_addmm],
name,
aten_layout,
)
)
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
@ -980,7 +974,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
aten_layout,
[aten_bias_addmm],
name,
)
@ -988,7 +981,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
aten_layout,
[aten_addmm],
name,
)
@ -1000,7 +992,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
layout,
[mm_template],
name,
)
@ -1011,7 +1002,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
layout,
[persistent_tma_mm_template],
name,
)
@ -1020,7 +1010,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
layout,
[addmm_contiguous_subgraph_template],
"addmm",
)
@ -1174,14 +1163,15 @@ def tuned_scaled_mm(
input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real, bias_real]
# Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1)
kernel_inputs = MMKernelInputs(input_nodes, mat1_idx=0, mat2_idx=1)
kernel_inputs = MMKernelInputs(
input_nodes, mat1_idx=0, mat2_idx=1, out_dtype=out_dtype
)
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
layout,
[aten__fp8_mm],
name,
kwarg_overrides={
@ -1209,7 +1199,6 @@ def tuned_scaled_mm(
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
layout,
[scaled_mm_device_tma_template],
name,
kwarg_overrides={scaled_mm_device_tma_template.uid: overriders},
@ -1220,7 +1209,6 @@ def tuned_scaled_mm(
choices.extend(
V.choices.get_mm_configs(
kernel_inputs,
layout,
[mm_template],
name,
kwarg_overrides={mm_template.uid: overriders},

View File

@ -157,17 +157,13 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
choices.extend(
V.choices.get_mm_configs(
kernel_inputs, layout1, [aten_mm_plus_mm], "mm_plus_mm"
)
V.choices.get_mm_configs(kernel_inputs, [aten_mm_plus_mm], "mm_plus_mm")
)
if use_triton_template(layout1, check_max_autotune=False):
# Get template choices using the new unified function
choices.extend(
V.choices.get_mm_configs(
kernel_inputs, layout1, [mm_plus_mm_template], "mm_plus_mm"
)
V.choices.get_mm_configs(kernel_inputs, [mm_plus_mm_template], "mm_plus_mm")
)
return autotune_select_algorithm(

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
@ -7,6 +8,8 @@ import torch._inductor.config
from torch._inductor import ir
from torch._inductor.virtualized import V
from .ir import FixedLayout, FlexibleLayout, Layout
if TYPE_CHECKING:
from collections.abc import Sequence
@ -14,7 +17,7 @@ if TYPE_CHECKING:
import sympy
class KernelInputs:
class KernelInputs(ABC):
"""
Class to store and provide access to input nodes for kernels.
This class takes in a tuple of input nodes and provides methods to access
@ -25,16 +28,19 @@ class KernelInputs:
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
out_dtype: Optional[torch.dtype] = None,
):
"""
Initialize with a tuple of input nodes.
Args:
input_nodes: A tuple of input nodes to store
out_dtype: Optional output dtype to store
"""
self._input_nodes = input_nodes
self._device_name: Optional[str] = None
self._scalars = scalars if scalars is not None else {}
self._out_dtype = out_dtype
assert len(input_nodes) > 0, "Expected at least one input node"
def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]:
@ -168,6 +174,15 @@ class KernelInputs:
"""
return self._input_nodes[idx].get_dtype()
@abstractmethod
def out_dtype(self) -> torch.dtype:
"""
Get the output dtype, whether passed in or inferred from the nodes
Returns:
The output dtype
"""
def get_scalar(self, name: str) -> Union[float, int]:
"""
Get the scalar value for a given name.
@ -181,6 +196,16 @@ class KernelInputs:
assert name in self._scalars, f"Scalar {name} not found, but required"
return self._scalars[name]
@abstractmethod
def output_layout(self, flexible: bool = True) -> Layout:
"""
Abstract method to handle output layout generation.
Args:
out_dtype: Optional output dtype. If not provided, infer from inputs
flexible: If True, return FlexibleLayout, otherwise FixedLayout
"""
class MMKernelInputs(KernelInputs):
"""
@ -192,6 +217,7 @@ class MMKernelInputs(KernelInputs):
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
out_dtype: Optional[torch.dtype] = None,
mat1_idx: int = -2,
mat2_idx: int = -1,
):
@ -201,7 +227,7 @@ class MMKernelInputs(KernelInputs):
By default, we assume the last 2 input nodes are mat1 and mat2, but
the caller can adjust when necessary
"""
super().__init__(input_nodes, scalars)
super().__init__(input_nodes, scalars, out_dtype)
# for mm, we need at least 2 nodes, and we need to know which nodes
# are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others
# might be (mat1, mat2, scale), etc.
@ -246,6 +272,37 @@ class MMKernelInputs(KernelInputs):
V.graph.sizevars.check_equals(k, k0)
return (m, n, k)
def out_dtype(self) -> torch.dtype:
"""
Get the output dtype, whether passed in or inferred from the nodes
Returns:
The output dtype
"""
if self._out_dtype is not None:
return self._out_dtype
return self.mat1mat2()[0].get_dtype()
def output_layout(self, flexible: bool = True) -> Layout:
"""
Handle output layout generation for matrix multiplication.
Args:
out_dtype: Optional output dtype. If not provided, infer from inputs
flexible: If True, return FlexibleLayout, otherwise FixedLayout
"""
mat1, mat2 = self.mat1mat2()
out_dtype = self.out_dtype()
# NOTE: taken from mm_common.mm_args
*b1, m, k1 = mat1.get_size()
*b2, k2, n = mat2.get_size()
b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)]
size = [*b, m, n]
if flexible:
return FlexibleLayout(self.device(), out_dtype, size)
else:
return FixedLayout(self.device(), out_dtype, size)
def mat1mat2(self) -> tuple[Any, Any]:
"""
Get the mat1 and mat2 nodes.

View File

@ -9,16 +9,14 @@ from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm,
from ..kernel.mm_plus_mm import aten_mm_plus_mm
from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
from .registry import register_template_heuristic
if TYPE_CHECKING:
from collections.abc import Generator
from ..ir import Layout
from ..kernel_inputs import KernelInputs
from .registry import register_template_heuristic
# These are all labeled as device type None to indicate that they
# are valid for all device types
@ -41,7 +39,6 @@ class ATenConfigHeuristics(TemplateConfigHeuristics):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
yield dict()
@ -55,10 +52,9 @@ class ATenAddMMConfigHeuristics(ATenConfigHeuristics):
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> dict[str, Any]:
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name)
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
alpha = kernel_inputs.get_scalar("alpha")
beta = kernel_inputs.get_scalar("beta")
return {
@ -75,7 +71,6 @@ class ATenBiasAddMMConfigHeuristics(
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
nodes = kernel_inputs.nodes()

View File

@ -6,14 +6,13 @@ from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Generator
from ..ir import Layout
from ..kernel_inputs import KernelInputs
class TemplateConfigHeuristics:
"""Base class for generating sets of configs for an associated template."""
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
def should_run(self, inputs: KernelInputs) -> bool:
"""
hookup to check whether the configs are right to run at all e.g. you can check
max-autotune specific to your heuristic here or other things
@ -21,14 +20,12 @@ class TemplateConfigHeuristics:
Args:
inputs: KernelInputs
layout: Layout
"""
return True
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
@ -37,19 +34,17 @@ class TemplateConfigHeuristics:
Prefer to override the _get_template_configs_impl method
to leverage things like should_run
"""
if not self.should_run(kernel_inputs, layout):
if not self.should_run(kernel_inputs):
return
yield from self._get_template_configs_impl(
kernel_inputs,
layout,
op_name,
)
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
@ -62,7 +57,6 @@ class TemplateConfigHeuristics:
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> dict[str, Any]:
"""

View File

@ -19,8 +19,6 @@ from .registry import register_template_heuristic
if TYPE_CHECKING:
from collections.abc import Generator
from ..ir import Layout
@register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm")
@register_template_heuristic(
@ -46,7 +44,6 @@ class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""

View File

@ -19,8 +19,6 @@ from .registry import register_template_heuristic
if TYPE_CHECKING:
from collections.abc import Generator
from ..ir import Layout
@register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm")
class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
@ -43,7 +41,6 @@ class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""

View File

@ -7,12 +7,11 @@ from .base import TemplateConfigHeuristics
if TYPE_CHECKING:
from ..ir import Layout
from ..kernel_inputs import KernelInputs
class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics):
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
def should_run(self, inputs: KernelInputs) -> bool:
"""
simple base override for GEMM family templates that run only in max-autotune
"""

View File

@ -41,8 +41,6 @@ if TYPE_CHECKING:
from triton import Config as TritonConfig
from ..ir import Layout
# Gemm Configs
@dataclasses.dataclass
@ -1451,7 +1449,6 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
@ -1479,7 +1476,11 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
# Generate and process configs
for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name):
template_kwargs = self._convert_config_to_template_kwargs(
c, m, n, k, layout
c,
m,
n,
k,
kernel_inputs.out_dtype(),
)
yield template_kwargs
@ -1489,7 +1490,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
m: sympy.Integer,
n: sympy.Integer,
k: sympy.Integer,
layout: Any,
out_dtype: torch.dtype,
) -> dict[str, Any]:
"""
Convert triton config to template kwargs.
@ -1513,7 +1514,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
EVEN_K=even_k_symbolic,
ALLOW_TF32=allow_tf32,
USE_FAST_ACCUM=False, # Option for _scaled_mm
ACC_TYPE=self._get_acc_type(layout.dtype),
ACC_TYPE=self._get_acc_type(out_dtype),
num_stages=triton_config.num_stages,
num_warps=triton_config.num_warps,
**triton_config.kwargs,
@ -1562,14 +1563,11 @@ class MMPlusMMTemplateConfigMixin(MMTemplateConfigMixin):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
assert isinstance(kernel_inputs, MMKernelInputs), "Expect MMKernelInputs"
m, n, k = kernel_inputs.mnk_symbolic()
for kwargs in super()._get_template_configs_impl(
kernel_inputs, layout, op_name
):
for kwargs in super()._get_template_configs_impl(kernel_inputs, op_name):
# Apply BLOCK_K constraint specific to mm_plus_mm
# see https://github.com/triton-lang/triton/issues/1298
# BLOCK_K = K causes llvm error
@ -1586,10 +1584,9 @@ class TMAWorkspaceMixin(MMTemplateConfigMixin):
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> dict[str, Any]:
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name)
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
kwargs["workspace_arg"] = get_tma_workspace_arg(
num_tma_descriptors=2,
device=kernel_inputs.device(),
@ -1614,7 +1611,6 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
@ -1634,7 +1630,6 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
# Get base template configs from superclass
for template_kwargs in super()._get_template_configs_impl(
kernel_inputs,
layout,
op_name,
):
yield {**template_kwargs, **tma_opts}
@ -1684,7 +1679,6 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
@ -1734,7 +1728,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
# Get base template configs from superclass
for template_kwargs in super()._get_template_configs_impl(
kernel_inputs, layout, op_name
kernel_inputs, op_name
):
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
# Override accumulator type for scaled MM
@ -1752,10 +1746,9 @@ class ScaledMMConfigMixin(BaseScaledMMConfigMixin):
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> dict[str, Any]:
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name)
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
from ..kernel.mm_common import scale_mm_epilogue
return {
@ -1792,7 +1785,6 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
@ -1801,7 +1793,6 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
# Get base scaled MM template configs from superclass
for template_kwargs in super()._get_template_configs_impl(
kernel_inputs,
layout,
op_name,
):
# Add TMA-specific options for device TMA scaled MM

View File

@ -7,7 +7,6 @@ from .base import TemplateConfigHeuristics
if TYPE_CHECKING:
from ..ir import Layout
from ..kernel_inputs import KernelInputs
@ -19,10 +18,9 @@ class AddMMConfigMixin(TemplateConfigHeuristics):
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> dict[str, Any]:
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name)
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
assert op_name in [
"addmm",
"baddbmm",
@ -31,7 +29,9 @@ class AddMMConfigMixin(TemplateConfigHeuristics):
beta = kernel_inputs.get_scalar("beta")
return {
**kwargs,
"epilogue_fn": addmm_epilogue(layout.dtype, alpha, beta),
"epilogue_fn_hash": str(["addmm_epilogue", layout.dtype, alpha, beta]),
"epilogue_fn": addmm_epilogue(kernel_inputs.out_dtype(), alpha, beta),
"epilogue_fn_hash": str(
["addmm_epilogue", kernel_inputs.out_dtype(), alpha, beta]
),
"prefix_args": 1,
}