[inductor][ez] add template/externchoice uid (#161341)

# why

- to have a central registry of templates/externkernelchoice
  to match them to heuristics etc, they need unique names
- mm is both the triton template name and the aten_mm name

# what

- add a uid() to KernelTemplate/ExternKernelChoice that returns name
- override in ExternKernel to prepend "aten::"
- override in TritonTemplate to prepend "triton::"

This id is just use to find template heuristics, so it has no other
impact

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```

Differential Revision: [D81520579](https://our.internmc.facebook.com/intern/diff/D81520579)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161341
Approved by: https://github.com/jansel, https://github.com/eellison
ghstack dependencies: #162075, #161340
This commit is contained in:
Ruben Rodriguez Buchillon
2025-09-05 07:38:40 -07:00
committed by PyTorch MergeBot
parent 9602590b15
commit 4902c76c65
8 changed files with 139 additions and 74 deletions

View File

@ -2391,6 +2391,17 @@ class KernelTemplate:
def __init__(self, name: str) -> None:
self.name = name
@property
def uid(self) -> str:
"""
entry point to override for templates to ensure a uid e.g. through a prefix
the purpose of this is that every KernelTemplate/ExternKernelChoice is unique
in the system, but reproducible e.g. restarting pytorch should yield the same id
"""
# TODO(coconutruben): add some central registration to assert on global uniqueness
return self.name
def maybe_append_choice(
self, choices: list[Any], **kwargs: Any
) -> Optional[NotImplementedError]:

View File

@ -205,7 +205,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
assert out_dtype is None, "out_dtype is not supported for Triton"
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, bmm_template.name, name
kernel_inputs, layout, bmm_template.uid, name
):
bmm_template.maybe_append_choice(
choices,
@ -284,7 +284,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
bmm_template.name,
bmm_template.uid,
name,
):
bmm_template.maybe_append_choice(

View File

@ -762,7 +762,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
if is_nonzero and use_triton_template(layout):
# Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template.name, "mm"
kernel_inputs, layout, mm_template.uid, "mm"
):
mm_template.maybe_append_choice(
choices,
@ -773,7 +773,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
if use_triton_tma_template(mat1, mat2):
# Get TMA template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, persistent_tma_mm_template.name, "mm"
kernel_inputs, layout, persistent_tma_mm_template.uid, "mm"
):
persistent_tma_mm_template.maybe_append_choice(
choices,
@ -784,7 +784,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
# Only do split-k optimization if K is much larger than m, n and m, n are small
if use_decompose_k_choice(m, n, k):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, decompose_k_subgraph_template.name, "mm"
kernel_inputs, layout, decompose_k_subgraph_template.uid, "mm"
):
decompose_k_subgraph_template.maybe_append_choice(
choices,
@ -931,7 +931,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
if is_nonzero and use_triton_template(layout, enable_int32=True):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template.name, "int_mm"
kernel_inputs, layout, mm_template.uid, "int_mm"
):
mm_template.maybe_append_choice(
choices,
@ -1038,7 +1038,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
mm_template.name,
mm_template.uid,
"addmm",
):
mm_template.maybe_append_choice(
@ -1052,7 +1052,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
persistent_tma_mm_template.name,
persistent_tma_mm_template.uid,
"addmm",
):
persistent_tma_mm_template.maybe_append_choice(
@ -1246,7 +1246,7 @@ def tuned_scaled_mm(
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
scaled_mm_device_tma_template.name,
scaled_mm_device_tma_template.uid,
"scaled_mm",
overriders,
):
@ -1260,7 +1260,7 @@ def tuned_scaled_mm(
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
mm_template.name,
mm_template.uid,
"scaled_mm",
overriders,
):

View File

@ -159,7 +159,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
if use_triton_template(layout1):
# Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout1, mm_plus_mm_template.name, "mm_plus_mm"
kernel_inputs, layout1, mm_plus_mm_template.uid, "mm_plus_mm"
):
# Apply BLOCK_K constraint specific to mm_plus_mm
# see https://github.com/triton-lang/triton/issues/1298

View File

@ -1443,6 +1443,11 @@ class TritonTemplate(KernelTemplate):
# was not used are the same.
test_cache = False
@property
def uid(self) -> str:
# unique by prefixing with triton
return f"triton::{self.name}"
def maybe_append_choice(
self, choices: list[Any], **kwargs: Any
) -> Optional[NotImplementedError]:
@ -1909,6 +1914,11 @@ class ExternKernelChoice:
self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant
)
@property
def uid(self) -> str:
# unique by prefixing with aten
return f"aten::{self.name}"
def maybe_append_choice(
self, choices: list[Any], **kwargs: Any
) -> Optional[NotImplementedError]:

View File

@ -5,6 +5,10 @@ from typing import Any, TYPE_CHECKING
import torch
from ..ir import get_free_symbols
from ..kernel.mm import (
addmm_contiguous_subgraph_template,
mm_contiguous_subgraph_template,
)
from ..kernel_inputs import KernelInputs, MMKernelInputs
from ..utils import use_contiguous
from .base import TemplateConfigHeuristics
@ -17,17 +21,25 @@ if TYPE_CHECKING:
from ..ir import Layout
@register_template_heuristic("contiguous_mm", None, op_name="mm")
@register_template_heuristic("contiguous_addmm", None, op_name="addmm")
@register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm")
@register_template_heuristic(
addmm_contiguous_subgraph_template.uid, None, op_name="addmm"
)
class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics):
"""empty heuristics to skip contiguous mm on not cuda"""
@register_template_heuristic(
"contiguous_mm", "cuda", register=torch.version.hip is not None, op_name="mm"
mm_contiguous_subgraph_template.uid,
"cuda",
register=torch.version.hip is not None,
op_name="mm",
)
@register_template_heuristic(
"contiguous_addmm", "cuda", register=torch.version.hip is not None, op_name="addmm"
addmm_contiguous_subgraph_template.uid,
"cuda",
register=torch.version.hip is not None,
op_name="addmm",
)
class ContiguousMMHeuristics(TemplateConfigHeuristics):
def get_template_configs(

View File

@ -7,6 +7,7 @@ import sympy
import torch
from ..ir import get_free_symbols
from ..kernel.mm import decompose_k_subgraph_template
from ..kernel_inputs import KernelInputs, MMKernelInputs
from ..utils import get_k_splits
from ..virtualized import V
@ -20,14 +21,17 @@ if TYPE_CHECKING:
from ..ir import Layout
@register_template_heuristic("decompose_k", None, op_name="mm")
@register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm")
class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
"""empty heuristics to skip decompose k on anything not cuda"""
# on CUDA, we don't support hip for decompose_k yet
@register_template_heuristic(
"decompose_k", "cuda", register=torch.version.hip is None, op_name="mm"
decompose_k_subgraph_template.uid,
"cuda",
register=torch.version.hip is None,
op_name="mm",
)
# TODO(coconutruben): enable decompose k on AMD by removing the register bool
# and benchmarking it for performance and stability

View File

@ -16,6 +16,13 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._triton import has_triton_stable_tma_api
from .. import config, config as inductor_config
from ..kernel.bmm import bmm_template
from ..kernel.mm import (
mm_template,
persistent_tma_mm_template,
scaled_mm_device_tma_template,
)
from ..kernel.mm_plus_mm import mm_plus_mm_template
from ..kernel_inputs import KernelInputs, MMKernelInputs
from ..utils import (
get_backend_num_stages,
@ -1786,28 +1793,31 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
# Template-specific heuristic classes using multiple inheritance
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm", "cuda", register=torch.version.hip is None)
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("bmm", "cuda", register=torch.version.hip is None)
@register_template_heuristic(
mm_template.uid,
"cuda",
register=torch.version.hip is None,
)
@register_template_heuristic(
bmm_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
"""Standard MM template heuristic for CUDA"""
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is None, op_name="addmm"
mm_template.uid, "cuda", register=torch.version.hip is None, op_name="addmm"
)
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"bmm", "cuda", register=torch.version.hip is None, op_name="baddbmm"
bmm_template.uid, "cuda", register=torch.version.hip is None, op_name="baddbmm"
)
class CUDAAddMMTemplateConfigHeuristic(AddMMConfigMixin, CUDAMMTemplateConfigHeuristic):
"""Addmm specific mixin for CUDA"""
# TODO(coconutruben): deprecate once autoheuristic is deprecated
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is None)
class CUDAMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
"""Standard MM template heuristic for CUDA using the extra mm configs only (for autoheuristic)"""
@ -1819,9 +1829,10 @@ class CUDAMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic
self.exhaustive_configs = self.extra_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm_persistent_tma", "cuda", register=torch.version.hip is None
persistent_tma_mm_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDAPersistentTMATemplateConfigHeuristic(
TMATemplateConfigMixin, CUDAConfigHeuristic
@ -1834,9 +1845,11 @@ class CUDAPersistentTMATemplateConfigHeuristic(
self.mm_configs = self.persistent_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm_persistent_tma", "cuda", register=torch.version.hip is None, op_name="addmm"
persistent_tma_mm_template.uid,
"cuda",
register=torch.version.hip is None,
op_name="addmm",
)
class CUDAAddmmPersistentTMATemplateConfigHeuristic(
AddMMConfigMixin, CUDAPersistentTMATemplateConfigHeuristic
@ -1844,9 +1857,8 @@ class CUDAAddmmPersistentTMATemplateConfigHeuristic(
"""Addmm specific mixin for CUDA"""
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is None, op_name="scaled_mm"
mm_template.uid, "cuda", register=torch.version.hip is None, op_name="scaled_mm"
)
class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeuristic):
"""Scaled MM template heuristic for CUDA"""
@ -1862,9 +1874,10 @@ class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeurist
self.exhaustive_configs = self.scaled_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"scaled_mm_device_tma", "cuda", register=torch.version.hip is None
scaled_mm_device_tma_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDAScaledTMATemplateConfigHeuristic(ScaledTMAConfigMixin, CUDAConfigHeuristic):
"""Scaled TMA template heuristic for CUDA"""
@ -1880,8 +1893,11 @@ class CUDAScaledTMATemplateConfigHeuristic(ScaledTMAConfigMixin, CUDAConfigHeuri
self.exhaustive_configs = self.scaled_persistent_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm_plus_mm", "cuda", register=torch.version.hip is None)
@register_template_heuristic(
mm_plus_mm_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDAMMPlusMMTemplateConfigHeuristic(
MMPlusMMTemplateConfigMixin, CUDAConfigHeuristic
):
@ -1898,9 +1914,11 @@ class CUDAMMPlusMMTemplateConfigHeuristic(
self.exhaustive_configs = self.mm_plus_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is None, op_name="int_mm"
mm_template.uid,
"cuda",
register=torch.version.hip is None,
op_name="int_mm",
)
class CUDAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CUDAConfigHeuristic):
"""Int8 MM template heuristic for CUDA"""
@ -1919,28 +1937,33 @@ class CUDAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CUDAConfigHeu
# ROCm template-specific classes
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm", "cuda", register=torch.version.hip is not None)
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("bmm", "cuda", register=torch.version.hip is not None)
@register_template_heuristic(
mm_template.uid,
"cuda",
register=torch.version.hip is not None,
)
@register_template_heuristic(
bmm_template.uid,
"cuda",
register=torch.version.hip is not None,
)
class ROCmMMTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
"""Standard MM template heuristic for ROCm"""
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is not None, op_name="addmm"
mm_template.uid, "cuda", register=torch.version.hip is not None, op_name="addmm"
)
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"bmm", "cuda", register=torch.version.hip is not None, op_name="baddbmm"
bmm_template.uid, "cuda", register=torch.version.hip is not None, op_name="baddbmm"
)
class ROCmAddMMTemplateConfigHeuristic(AddMMConfigMixin, ROCmMMTemplateConfigHeuristic):
"""Addmm specific mixin for ROCm"""
# TODO(coconutruben): deprecate once autoheuristic is deprecated
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is not None)
class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
"""Standard MM template heuristic for ROCm using the extra mm configs only (for autoheuristic)"""
@ -1952,9 +1975,11 @@ class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic
self.exhaustive_configs = self.extra_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is not None, op_name="scaled_mm"
mm_template.uid,
"cuda",
register=torch.version.hip is not None,
op_name="scaled_mm",
)
class ROCmScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, ROCmConfigHeuristic):
"""Scaled MM template heuristic for ROCm (non-TMA)"""
@ -1970,9 +1995,11 @@ class ROCmScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, ROCmConfigHeurist
self.exhaustive_configs = self.scaled_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm", "cuda", register=torch.version.hip is not None, op_name="int_mm"
mm_template.uid,
"cuda",
register=torch.version.hip is not None,
op_name="int_mm",
)
class ROCmInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, ROCmConfigHeuristic):
"""Int8 MM template heuristic for ROCm"""
@ -1988,9 +2015,10 @@ class ROCmInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, ROCmConfigHeu
self.exhaustive_configs = self.int8_mm_configs
# TODO(coconutruben): replace with template.name once templates are importable
@register_template_heuristic(
"mm_plus_mm", "cuda", register=torch.version.hip is not None
mm_plus_mm_template.uid,
"cuda",
register=torch.version.hip is not None,
)
class ROCmMMPlusMMTemplateConfigHeuristic(
MMPlusMMTemplateConfigMixin, ROCmConfigHeuristic
@ -2014,19 +2042,19 @@ class ROCmMMPlusMMTemplateConfigHeuristic(
# CPU template-specific classes
@register_template_heuristic("mm", "cpu")
@register_template_heuristic("bmm", "cpu")
@register_template_heuristic(mm_template.uid, "cpu")
@register_template_heuristic(bmm_template.uid, "cpu")
class CPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, CPUConfigHeuristic):
"""Standard MM template heuristic for CPU"""
@register_template_heuristic("mm", "cpu", op_name="addmm")
@register_template_heuristic("bmm", "cpu", op_name="baddbmm")
@register_template_heuristic(mm_template.uid, "cpu", op_name="addmm")
@register_template_heuristic(bmm_template.uid, "cpu", op_name="baddbmm")
class CPUAddmmTemplateConfigHeuristic(AddMMConfigMixin, CPUMMTemplateConfigHeuristic):
"""Addmm specific mixin for CPU"""
@register_template_heuristic("mm", "cpu", op_name="scaled_mm")
@register_template_heuristic(mm_template.uid, "cpu", op_name="scaled_mm")
class CPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CPUConfigHeuristic):
"""Scaled MM template heuristic for CPU (non-TMA)"""
@ -2041,7 +2069,7 @@ class CPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CPUConfigHeuristic
self.exhaustive_configs = self.scaled_mm_configs
@register_template_heuristic("mm", "cpu", op_name="int_mm")
@register_template_heuristic(mm_template.uid, "cpu", op_name="int_mm")
class CPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CPUConfigHeuristic):
"""Int8 MM template heuristic for CPU"""
@ -2056,7 +2084,7 @@ class CPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CPUConfigHeuri
self.exhaustive_configs = self.int8_mm_configs
@register_template_heuristic("mm_plus_mm", "cpu")
@register_template_heuristic(mm_plus_mm_template.uid, "cpu")
class CPUMMPlusMMTemplateConfigHeuristic(
MMPlusMMTemplateConfigMixin, CPUConfigHeuristic
):
@ -2076,20 +2104,20 @@ class CPUMMPlusMMTemplateConfigHeuristic(
# XPU template-specific classes
@register_template_heuristic("mm", "xpu")
@register_template_heuristic("bmm", "xpu")
@register_template_heuristic(mm_template.uid, "xpu")
@register_template_heuristic(bmm_template.uid, "xpu")
class XPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic):
"""Standard MM template heuristic for XPU"""
@register_template_heuristic("mm", "xpu", op_name="addmm")
@register_template_heuristic("bmm", "xpu", op_name="baddbmm")
@register_template_heuristic(mm_template.uid, "xpu", op_name="addmm")
@register_template_heuristic(bmm_template.uid, "xpu", op_name="baddbmm")
class XPUAddmmTemplateConfigHeuristic(AddMMConfigMixin, XPUMMTemplateConfigHeuristic):
"""Addmm specific mixin for XPU"""
@register_template_heuristic(
"mm_persistent_tma",
persistent_tma_mm_template.uid,
"xpu",
)
class XPUPersistentTMATemplateConfigHeuristic(
@ -2103,14 +2131,14 @@ class XPUPersistentTMATemplateConfigHeuristic(
self.mm_configs = self.persistent_mm_configs
@register_template_heuristic("mm_persistent_tma", "xpu", op_name="addmm")
@register_template_heuristic(persistent_tma_mm_template.uid, "xpu", op_name="addmm")
class XPUAddmmPersistentTMATemplateConfigHeuristic(
AddMMConfigMixin, XPUPersistentTMATemplateConfigHeuristic
):
"""Addmm specific mixin for XPU"""
@register_template_heuristic("mm", "xpu", op_name="scaled_mm")
@register_template_heuristic(mm_template.uid, "xpu", op_name="scaled_mm")
class XPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, XPUConfigHeuristic):
"""Scaled MM template heuristic for XPU (non-TMA)"""
@ -2125,7 +2153,7 @@ class XPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, XPUConfigHeuristic
self.exhaustive_configs = self.scaled_mm_configs
@register_template_heuristic("mm", "xpu", op_name="int_mm")
@register_template_heuristic(mm_template.uid, "xpu", op_name="int_mm")
class XPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, XPUConfigHeuristic):
"""Int8 MM template heuristic for XPU"""
@ -2140,7 +2168,7 @@ class XPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, XPUConfigHeuri
self.exhaustive_configs = self.int8_mm_configs
@register_template_heuristic("mm_plus_mm", "xpu")
@register_template_heuristic(mm_plus_mm_template.uid, "xpu")
class XPUMMPlusMMTemplateConfigHeuristic(
MMPlusMMTemplateConfigMixin, XPUConfigHeuristic
):
@ -2160,19 +2188,19 @@ class XPUMMPlusMMTemplateConfigHeuristic(
# MTIA template-specific classes
@register_template_heuristic("mm", "mtia")
@register_template_heuristic("bmm", "mtia")
@register_template_heuristic(mm_template.uid, "mtia")
@register_template_heuristic(bmm_template.uid, "mtia")
class MTIAMMTemplateConfigHeuristic(MMTemplateConfigMixin, MTIAConfigHeuristic):
"""Standard MM template heuristic for MTIA"""
@register_template_heuristic("mm", "mtia", op_name="addmm")
@register_template_heuristic("bmm", "mtia", op_name="baddbmm")
@register_template_heuristic(mm_template.uid, "mtia", op_name="addmm")
@register_template_heuristic(bmm_template.uid, "mtia", op_name="baddbmm")
class MTIAAddMMTemplateConfigHeuristic(AddMMConfigMixin, MTIAMMTemplateConfigHeuristic):
"""Addmm specific mixin for MTIA"""
@register_template_heuristic("mm", "mtia", op_name="scaled_mm")
@register_template_heuristic(mm_template.uid, "mtia", op_name="scaled_mm")
class MTIAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, MTIAConfigHeuristic):
"""Scaled MM template heuristic for MTIA (non-TMA)"""
@ -2187,7 +2215,7 @@ class MTIAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, MTIAConfigHeurist
self.exhaustive_configs = self.scaled_mm_configs
@register_template_heuristic("mm", "mtia", op_name="int_mm")
@register_template_heuristic(mm_template.uid, "mtia", op_name="int_mm")
class MTIAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, MTIAConfigHeuristic):
"""Int8 MM template heuristic for MTIA"""
@ -2202,7 +2230,7 @@ class MTIAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, MTIAConfigHeu
self.exhaustive_configs = self.int8_mm_configs
@register_template_heuristic("mm_plus_mm", "mtia")
@register_template_heuristic(mm_plus_mm_template.uid, "mtia")
class MTIAMMPlusMMTemplateConfigHeuristic(
MMPlusMMTemplateConfigMixin, MTIAConfigHeuristic
):