mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][ez] add template/externchoice uid (#161341)
# why - to have a central registry of templates/externkernelchoice to match them to heuristics etc, they need unique names - mm is both the triton template name and the aten_mm name # what - add a uid() to KernelTemplate/ExternKernelChoice that returns name - override in ExternKernel to prepend "aten::" - override in TritonTemplate to prepend "triton::" This id is just use to find template heuristics, so it has no other impact # testing ``` python3 -bb -m pytest test/inductor/test_max_autotune.py -v ``` Differential Revision: [D81520579](https://our.internmc.facebook.com/intern/diff/D81520579) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161341 Approved by: https://github.com/jansel, https://github.com/eellison ghstack dependencies: #162075, #161340
This commit is contained in:
committed by
PyTorch MergeBot
parent
9602590b15
commit
4902c76c65
@ -2391,6 +2391,17 @@ class KernelTemplate:
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def uid(self) -> str:
|
||||
"""
|
||||
entry point to override for templates to ensure a uid e.g. through a prefix
|
||||
|
||||
the purpose of this is that every KernelTemplate/ExternKernelChoice is unique
|
||||
in the system, but reproducible e.g. restarting pytorch should yield the same id
|
||||
"""
|
||||
# TODO(coconutruben): add some central registration to assert on global uniqueness
|
||||
return self.name
|
||||
|
||||
def maybe_append_choice(
|
||||
self, choices: list[Any], **kwargs: Any
|
||||
) -> Optional[NotImplementedError]:
|
||||
|
||||
@ -205,7 +205,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
assert out_dtype is None, "out_dtype is not supported for Triton"
|
||||
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, bmm_template.name, name
|
||||
kernel_inputs, layout, bmm_template.uid, name
|
||||
):
|
||||
bmm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -284,7 +284,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
bmm_template.name,
|
||||
bmm_template.uid,
|
||||
name,
|
||||
):
|
||||
bmm_template.maybe_append_choice(
|
||||
|
||||
@ -762,7 +762,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
# Get template params using the new unified function
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template.name, "mm"
|
||||
kernel_inputs, layout, mm_template.uid, "mm"
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -773,7 +773,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
if use_triton_tma_template(mat1, mat2):
|
||||
# Get TMA template params using the new unified function
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, persistent_tma_mm_template.name, "mm"
|
||||
kernel_inputs, layout, persistent_tma_mm_template.uid, "mm"
|
||||
):
|
||||
persistent_tma_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -784,7 +784,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
# Only do split-k optimization if K is much larger than m, n and m, n are small
|
||||
if use_decompose_k_choice(m, n, k):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, decompose_k_subgraph_template.name, "mm"
|
||||
kernel_inputs, layout, decompose_k_subgraph_template.uid, "mm"
|
||||
):
|
||||
decompose_k_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -931,7 +931,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||
|
||||
if is_nonzero and use_triton_template(layout, enable_int32=True):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template.name, "int_mm"
|
||||
kernel_inputs, layout, mm_template.uid, "int_mm"
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -1038,7 +1038,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
mm_template.name,
|
||||
mm_template.uid,
|
||||
"addmm",
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
@ -1052,7 +1052,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
persistent_tma_mm_template.name,
|
||||
persistent_tma_mm_template.uid,
|
||||
"addmm",
|
||||
):
|
||||
persistent_tma_mm_template.maybe_append_choice(
|
||||
@ -1246,7 +1246,7 @@ def tuned_scaled_mm(
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
scaled_mm_device_tma_template.name,
|
||||
scaled_mm_device_tma_template.uid,
|
||||
"scaled_mm",
|
||||
overriders,
|
||||
):
|
||||
@ -1260,7 +1260,7 @@ def tuned_scaled_mm(
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
mm_template.name,
|
||||
mm_template.uid,
|
||||
"scaled_mm",
|
||||
overriders,
|
||||
):
|
||||
|
||||
@ -159,7 +159,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
||||
if use_triton_template(layout1):
|
||||
# Get template params using the new unified function
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout1, mm_plus_mm_template.name, "mm_plus_mm"
|
||||
kernel_inputs, layout1, mm_plus_mm_template.uid, "mm_plus_mm"
|
||||
):
|
||||
# Apply BLOCK_K constraint specific to mm_plus_mm
|
||||
# see https://github.com/triton-lang/triton/issues/1298
|
||||
|
||||
@ -1443,6 +1443,11 @@ class TritonTemplate(KernelTemplate):
|
||||
# was not used are the same.
|
||||
test_cache = False
|
||||
|
||||
@property
|
||||
def uid(self) -> str:
|
||||
# unique by prefixing with triton
|
||||
return f"triton::{self.name}"
|
||||
|
||||
def maybe_append_choice(
|
||||
self, choices: list[Any], **kwargs: Any
|
||||
) -> Optional[NotImplementedError]:
|
||||
@ -1909,6 +1914,11 @@ class ExternKernelChoice:
|
||||
self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant
|
||||
)
|
||||
|
||||
@property
|
||||
def uid(self) -> str:
|
||||
# unique by prefixing with aten
|
||||
return f"aten::{self.name}"
|
||||
|
||||
def maybe_append_choice(
|
||||
self, choices: list[Any], **kwargs: Any
|
||||
) -> Optional[NotImplementedError]:
|
||||
|
||||
@ -5,6 +5,10 @@ from typing import Any, TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from ..ir import get_free_symbols
|
||||
from ..kernel.mm import (
|
||||
addmm_contiguous_subgraph_template,
|
||||
mm_contiguous_subgraph_template,
|
||||
)
|
||||
from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from ..utils import use_contiguous
|
||||
from .base import TemplateConfigHeuristics
|
||||
@ -17,17 +21,25 @@ if TYPE_CHECKING:
|
||||
from ..ir import Layout
|
||||
|
||||
|
||||
@register_template_heuristic("contiguous_mm", None, op_name="mm")
|
||||
@register_template_heuristic("contiguous_addmm", None, op_name="addmm")
|
||||
@register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm")
|
||||
@register_template_heuristic(
|
||||
addmm_contiguous_subgraph_template.uid, None, op_name="addmm"
|
||||
)
|
||||
class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics):
|
||||
"""empty heuristics to skip contiguous mm on not cuda"""
|
||||
|
||||
|
||||
@register_template_heuristic(
|
||||
"contiguous_mm", "cuda", register=torch.version.hip is not None, op_name="mm"
|
||||
mm_contiguous_subgraph_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
op_name="mm",
|
||||
)
|
||||
@register_template_heuristic(
|
||||
"contiguous_addmm", "cuda", register=torch.version.hip is not None, op_name="addmm"
|
||||
addmm_contiguous_subgraph_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
op_name="addmm",
|
||||
)
|
||||
class ContiguousMMHeuristics(TemplateConfigHeuristics):
|
||||
def get_template_configs(
|
||||
|
||||
@ -7,6 +7,7 @@ import sympy
|
||||
import torch
|
||||
|
||||
from ..ir import get_free_symbols
|
||||
from ..kernel.mm import decompose_k_subgraph_template
|
||||
from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from ..utils import get_k_splits
|
||||
from ..virtualized import V
|
||||
@ -20,14 +21,17 @@ if TYPE_CHECKING:
|
||||
from ..ir import Layout
|
||||
|
||||
|
||||
@register_template_heuristic("decompose_k", None, op_name="mm")
|
||||
@register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm")
|
||||
class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
|
||||
"""empty heuristics to skip decompose k on anything not cuda"""
|
||||
|
||||
|
||||
# on CUDA, we don't support hip for decompose_k yet
|
||||
@register_template_heuristic(
|
||||
"decompose_k", "cuda", register=torch.version.hip is None, op_name="mm"
|
||||
decompose_k_subgraph_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
op_name="mm",
|
||||
)
|
||||
# TODO(coconutruben): enable decompose k on AMD by removing the register bool
|
||||
# and benchmarking it for performance and stability
|
||||
|
||||
@ -16,6 +16,13 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._triton import has_triton_stable_tma_api
|
||||
|
||||
from .. import config, config as inductor_config
|
||||
from ..kernel.bmm import bmm_template
|
||||
from ..kernel.mm import (
|
||||
mm_template,
|
||||
persistent_tma_mm_template,
|
||||
scaled_mm_device_tma_template,
|
||||
)
|
||||
from ..kernel.mm_plus_mm import mm_plus_mm_template
|
||||
from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from ..utils import (
|
||||
get_backend_num_stages,
|
||||
@ -1786,28 +1793,31 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
|
||||
# Template-specific heuristic classes using multiple inheritance
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm", "cuda", register=torch.version.hip is None)
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("bmm", "cuda", register=torch.version.hip is None)
|
||||
@register_template_heuristic(
|
||||
mm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
)
|
||||
@register_template_heuristic(
|
||||
bmm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
)
|
||||
class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
"""Standard MM template heuristic for CUDA"""
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is None, op_name="addmm"
|
||||
mm_template.uid, "cuda", register=torch.version.hip is None, op_name="addmm"
|
||||
)
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"bmm", "cuda", register=torch.version.hip is None, op_name="baddbmm"
|
||||
bmm_template.uid, "cuda", register=torch.version.hip is None, op_name="baddbmm"
|
||||
)
|
||||
class CUDAAddMMTemplateConfigHeuristic(AddMMConfigMixin, CUDAMMTemplateConfigHeuristic):
|
||||
"""Addmm specific mixin for CUDA"""
|
||||
|
||||
|
||||
# TODO(coconutruben): deprecate once autoheuristic is deprecated
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is None)
|
||||
class CUDAMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
"""Standard MM template heuristic for CUDA using the extra mm configs only (for autoheuristic)"""
|
||||
@ -1819,9 +1829,10 @@ class CUDAMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic
|
||||
self.exhaustive_configs = self.extra_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm_persistent_tma", "cuda", register=torch.version.hip is None
|
||||
persistent_tma_mm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
)
|
||||
class CUDAPersistentTMATemplateConfigHeuristic(
|
||||
TMATemplateConfigMixin, CUDAConfigHeuristic
|
||||
@ -1834,9 +1845,11 @@ class CUDAPersistentTMATemplateConfigHeuristic(
|
||||
self.mm_configs = self.persistent_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm_persistent_tma", "cuda", register=torch.version.hip is None, op_name="addmm"
|
||||
persistent_tma_mm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
op_name="addmm",
|
||||
)
|
||||
class CUDAAddmmPersistentTMATemplateConfigHeuristic(
|
||||
AddMMConfigMixin, CUDAPersistentTMATemplateConfigHeuristic
|
||||
@ -1844,9 +1857,8 @@ class CUDAAddmmPersistentTMATemplateConfigHeuristic(
|
||||
"""Addmm specific mixin for CUDA"""
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is None, op_name="scaled_mm"
|
||||
mm_template.uid, "cuda", register=torch.version.hip is None, op_name="scaled_mm"
|
||||
)
|
||||
class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeuristic):
|
||||
"""Scaled MM template heuristic for CUDA"""
|
||||
@ -1862,9 +1874,10 @@ class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeurist
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"scaled_mm_device_tma", "cuda", register=torch.version.hip is None
|
||||
scaled_mm_device_tma_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
)
|
||||
class CUDAScaledTMATemplateConfigHeuristic(ScaledTMAConfigMixin, CUDAConfigHeuristic):
|
||||
"""Scaled TMA template heuristic for CUDA"""
|
||||
@ -1880,8 +1893,11 @@ class CUDAScaledTMATemplateConfigHeuristic(ScaledTMAConfigMixin, CUDAConfigHeuri
|
||||
self.exhaustive_configs = self.scaled_persistent_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm_plus_mm", "cuda", register=torch.version.hip is None)
|
||||
@register_template_heuristic(
|
||||
mm_plus_mm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
)
|
||||
class CUDAMMPlusMMTemplateConfigHeuristic(
|
||||
MMPlusMMTemplateConfigMixin, CUDAConfigHeuristic
|
||||
):
|
||||
@ -1898,9 +1914,11 @@ class CUDAMMPlusMMTemplateConfigHeuristic(
|
||||
self.exhaustive_configs = self.mm_plus_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is None, op_name="int_mm"
|
||||
mm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
op_name="int_mm",
|
||||
)
|
||||
class CUDAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
"""Int8 MM template heuristic for CUDA"""
|
||||
@ -1919,28 +1937,33 @@ class CUDAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CUDAConfigHeu
|
||||
# ROCm template-specific classes
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm", "cuda", register=torch.version.hip is not None)
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("bmm", "cuda", register=torch.version.hip is not None)
|
||||
@register_template_heuristic(
|
||||
mm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
)
|
||||
@register_template_heuristic(
|
||||
bmm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
)
|
||||
class ROCmMMTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
|
||||
"""Standard MM template heuristic for ROCm"""
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is not None, op_name="addmm"
|
||||
mm_template.uid, "cuda", register=torch.version.hip is not None, op_name="addmm"
|
||||
)
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"bmm", "cuda", register=torch.version.hip is not None, op_name="baddbmm"
|
||||
bmm_template.uid, "cuda", register=torch.version.hip is not None, op_name="baddbmm"
|
||||
)
|
||||
class ROCmAddMMTemplateConfigHeuristic(AddMMConfigMixin, ROCmMMTemplateConfigHeuristic):
|
||||
"""Addmm specific mixin for ROCm"""
|
||||
|
||||
|
||||
# TODO(coconutruben): deprecate once autoheuristic is deprecated
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is not None)
|
||||
class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic):
|
||||
"""Standard MM template heuristic for ROCm using the extra mm configs only (for autoheuristic)"""
|
||||
@ -1952,9 +1975,11 @@ class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic
|
||||
self.exhaustive_configs = self.extra_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is not None, op_name="scaled_mm"
|
||||
mm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
op_name="scaled_mm",
|
||||
)
|
||||
class ROCmScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, ROCmConfigHeuristic):
|
||||
"""Scaled MM template heuristic for ROCm (non-TMA)"""
|
||||
@ -1970,9 +1995,11 @@ class ROCmScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, ROCmConfigHeurist
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm", "cuda", register=torch.version.hip is not None, op_name="int_mm"
|
||||
mm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
op_name="int_mm",
|
||||
)
|
||||
class ROCmInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, ROCmConfigHeuristic):
|
||||
"""Int8 MM template heuristic for ROCm"""
|
||||
@ -1988,9 +2015,10 @@ class ROCmInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, ROCmConfigHeu
|
||||
self.exhaustive_configs = self.int8_mm_configs
|
||||
|
||||
|
||||
# TODO(coconutruben): replace with template.name once templates are importable
|
||||
@register_template_heuristic(
|
||||
"mm_plus_mm", "cuda", register=torch.version.hip is not None
|
||||
mm_plus_mm_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
)
|
||||
class ROCmMMPlusMMTemplateConfigHeuristic(
|
||||
MMPlusMMTemplateConfigMixin, ROCmConfigHeuristic
|
||||
@ -2014,19 +2042,19 @@ class ROCmMMPlusMMTemplateConfigHeuristic(
|
||||
# CPU template-specific classes
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "cpu")
|
||||
@register_template_heuristic("bmm", "cpu")
|
||||
@register_template_heuristic(mm_template.uid, "cpu")
|
||||
@register_template_heuristic(bmm_template.uid, "cpu")
|
||||
class CPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, CPUConfigHeuristic):
|
||||
"""Standard MM template heuristic for CPU"""
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "cpu", op_name="addmm")
|
||||
@register_template_heuristic("bmm", "cpu", op_name="baddbmm")
|
||||
@register_template_heuristic(mm_template.uid, "cpu", op_name="addmm")
|
||||
@register_template_heuristic(bmm_template.uid, "cpu", op_name="baddbmm")
|
||||
class CPUAddmmTemplateConfigHeuristic(AddMMConfigMixin, CPUMMTemplateConfigHeuristic):
|
||||
"""Addmm specific mixin for CPU"""
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "cpu", op_name="scaled_mm")
|
||||
@register_template_heuristic(mm_template.uid, "cpu", op_name="scaled_mm")
|
||||
class CPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CPUConfigHeuristic):
|
||||
"""Scaled MM template heuristic for CPU (non-TMA)"""
|
||||
|
||||
@ -2041,7 +2069,7 @@ class CPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CPUConfigHeuristic
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "cpu", op_name="int_mm")
|
||||
@register_template_heuristic(mm_template.uid, "cpu", op_name="int_mm")
|
||||
class CPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CPUConfigHeuristic):
|
||||
"""Int8 MM template heuristic for CPU"""
|
||||
|
||||
@ -2056,7 +2084,7 @@ class CPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CPUConfigHeuri
|
||||
self.exhaustive_configs = self.int8_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm_plus_mm", "cpu")
|
||||
@register_template_heuristic(mm_plus_mm_template.uid, "cpu")
|
||||
class CPUMMPlusMMTemplateConfigHeuristic(
|
||||
MMPlusMMTemplateConfigMixin, CPUConfigHeuristic
|
||||
):
|
||||
@ -2076,20 +2104,20 @@ class CPUMMPlusMMTemplateConfigHeuristic(
|
||||
# XPU template-specific classes
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "xpu")
|
||||
@register_template_heuristic("bmm", "xpu")
|
||||
@register_template_heuristic(mm_template.uid, "xpu")
|
||||
@register_template_heuristic(bmm_template.uid, "xpu")
|
||||
class XPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic):
|
||||
"""Standard MM template heuristic for XPU"""
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "xpu", op_name="addmm")
|
||||
@register_template_heuristic("bmm", "xpu", op_name="baddbmm")
|
||||
@register_template_heuristic(mm_template.uid, "xpu", op_name="addmm")
|
||||
@register_template_heuristic(bmm_template.uid, "xpu", op_name="baddbmm")
|
||||
class XPUAddmmTemplateConfigHeuristic(AddMMConfigMixin, XPUMMTemplateConfigHeuristic):
|
||||
"""Addmm specific mixin for XPU"""
|
||||
|
||||
|
||||
@register_template_heuristic(
|
||||
"mm_persistent_tma",
|
||||
persistent_tma_mm_template.uid,
|
||||
"xpu",
|
||||
)
|
||||
class XPUPersistentTMATemplateConfigHeuristic(
|
||||
@ -2103,14 +2131,14 @@ class XPUPersistentTMATemplateConfigHeuristic(
|
||||
self.mm_configs = self.persistent_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm_persistent_tma", "xpu", op_name="addmm")
|
||||
@register_template_heuristic(persistent_tma_mm_template.uid, "xpu", op_name="addmm")
|
||||
class XPUAddmmPersistentTMATemplateConfigHeuristic(
|
||||
AddMMConfigMixin, XPUPersistentTMATemplateConfigHeuristic
|
||||
):
|
||||
"""Addmm specific mixin for XPU"""
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "xpu", op_name="scaled_mm")
|
||||
@register_template_heuristic(mm_template.uid, "xpu", op_name="scaled_mm")
|
||||
class XPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, XPUConfigHeuristic):
|
||||
"""Scaled MM template heuristic for XPU (non-TMA)"""
|
||||
|
||||
@ -2125,7 +2153,7 @@ class XPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, XPUConfigHeuristic
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "xpu", op_name="int_mm")
|
||||
@register_template_heuristic(mm_template.uid, "xpu", op_name="int_mm")
|
||||
class XPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, XPUConfigHeuristic):
|
||||
"""Int8 MM template heuristic for XPU"""
|
||||
|
||||
@ -2140,7 +2168,7 @@ class XPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, XPUConfigHeuri
|
||||
self.exhaustive_configs = self.int8_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm_plus_mm", "xpu")
|
||||
@register_template_heuristic(mm_plus_mm_template.uid, "xpu")
|
||||
class XPUMMPlusMMTemplateConfigHeuristic(
|
||||
MMPlusMMTemplateConfigMixin, XPUConfigHeuristic
|
||||
):
|
||||
@ -2160,19 +2188,19 @@ class XPUMMPlusMMTemplateConfigHeuristic(
|
||||
# MTIA template-specific classes
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "mtia")
|
||||
@register_template_heuristic("bmm", "mtia")
|
||||
@register_template_heuristic(mm_template.uid, "mtia")
|
||||
@register_template_heuristic(bmm_template.uid, "mtia")
|
||||
class MTIAMMTemplateConfigHeuristic(MMTemplateConfigMixin, MTIAConfigHeuristic):
|
||||
"""Standard MM template heuristic for MTIA"""
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "mtia", op_name="addmm")
|
||||
@register_template_heuristic("bmm", "mtia", op_name="baddbmm")
|
||||
@register_template_heuristic(mm_template.uid, "mtia", op_name="addmm")
|
||||
@register_template_heuristic(bmm_template.uid, "mtia", op_name="baddbmm")
|
||||
class MTIAAddMMTemplateConfigHeuristic(AddMMConfigMixin, MTIAMMTemplateConfigHeuristic):
|
||||
"""Addmm specific mixin for MTIA"""
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "mtia", op_name="scaled_mm")
|
||||
@register_template_heuristic(mm_template.uid, "mtia", op_name="scaled_mm")
|
||||
class MTIAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, MTIAConfigHeuristic):
|
||||
"""Scaled MM template heuristic for MTIA (non-TMA)"""
|
||||
|
||||
@ -2187,7 +2215,7 @@ class MTIAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, MTIAConfigHeurist
|
||||
self.exhaustive_configs = self.scaled_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm", "mtia", op_name="int_mm")
|
||||
@register_template_heuristic(mm_template.uid, "mtia", op_name="int_mm")
|
||||
class MTIAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, MTIAConfigHeuristic):
|
||||
"""Int8 MM template heuristic for MTIA"""
|
||||
|
||||
@ -2202,7 +2230,7 @@ class MTIAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, MTIAConfigHeu
|
||||
self.exhaustive_configs = self.int8_mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic("mm_plus_mm", "mtia")
|
||||
@register_template_heuristic(mm_plus_mm_template.uid, "mtia")
|
||||
class MTIAMMPlusMMTemplateConfigHeuristic(
|
||||
MMPlusMMTemplateConfigMixin, MTIAConfigHeuristic
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user