mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][template heuristics] don't take layout to generate choices (#162238)
# why - unnecessary as we only ever need to know the dtype and maybe the device - we already take in the kernel inputs which have the device - enable us to specify the layout after finding all the configs but before generating the ChoiceCallers # what - replace all calls in template_heuristics that used to take Layout with now just taking out_dtype # testing ci Differential Revision: [D81820115](https://our.internmc.facebook.com/intern/diff/D81820115) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162238 Approved by: https://github.com/eellison ghstack dependencies: #161347, #161348, #161349
This commit is contained in:
committed by
PyTorch MergeBot
parent
24a4dae85b
commit
d91eecc9a5
@ -163,10 +163,9 @@ class InductorChoices:
|
|||||||
heuristic = get_template_heuristic(template_name, device_type, op_name)
|
heuristic = get_template_heuristic(template_name, device_type, op_name)
|
||||||
cs = heuristic.get_template_configs(
|
cs = heuristic.get_template_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
op_name,
|
op_name,
|
||||||
)
|
)
|
||||||
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, layout, op_name)
|
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, op_name)
|
||||||
# adjust the kernel inputs to the template-specific heuristic, if needed
|
# adjust the kernel inputs to the template-specific heuristic, if needed
|
||||||
# default here is to just return the kernel_inputs as is
|
# default here is to just return the kernel_inputs as is
|
||||||
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
|
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
|
||||||
@ -184,9 +183,9 @@ class InductorChoices:
|
|||||||
def get_mm_configs(
|
def get_mm_configs(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Any,
|
|
||||||
templates: list[Union[KernelTemplate, ExternKernelChoice]],
|
templates: list[Union[KernelTemplate, ExternKernelChoice]],
|
||||||
op_name: str,
|
op_name: str,
|
||||||
|
layout: Optional[Layout] = None,
|
||||||
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
|
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
|
||||||
) -> list[ChoiceCaller]:
|
) -> list[ChoiceCaller]:
|
||||||
"""
|
"""
|
||||||
@ -207,7 +206,11 @@ class InductorChoices:
|
|||||||
input_tensors = kernel_inputs.nodes()
|
input_tensors = kernel_inputs.nodes()
|
||||||
if len(input_tensors) < 2:
|
if len(input_tensors) < 2:
|
||||||
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
|
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
|
||||||
|
if layout is None:
|
||||||
|
# TODO(coconutruben): remove this once we remove the layout argument entirely
|
||||||
|
# This is just here to the brief gap between commits where we still need this
|
||||||
|
# to accommodate fixed vs flexible layout decision externally
|
||||||
|
layout = kernel_inputs.output_layout(flexible=False)
|
||||||
# First pass: Create dict of template.uid to generator of KernelTemplateChoice objects
|
# First pass: Create dict of template.uid to generator of KernelTemplateChoice objects
|
||||||
template_choices = {}
|
template_choices = {}
|
||||||
for template in templates:
|
for template in templates:
|
||||||
|
@ -173,7 +173,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
|||||||
name = "bmm"
|
name = "bmm"
|
||||||
|
|
||||||
# Create MMKernelInputs for BMM at the top
|
# Create MMKernelInputs for BMM at the top
|
||||||
kernel_inputs = MMKernelInputs([mat1, mat2])
|
kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=out_dtype)
|
||||||
|
|
||||||
# below is for getting an overview logging info of inductor mms
|
# below is for getting an overview logging info of inductor mms
|
||||||
batch_size = mat1.get_size()[0] # Extract batch dimension
|
batch_size = mat1.get_size()[0] # Extract batch dimension
|
||||||
@ -201,10 +201,9 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
|||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[aten_handler],
|
[aten_handler],
|
||||||
name,
|
name,
|
||||||
{aten_handler.uid: aten_extra_kwargs},
|
kwarg_overrides={aten_handler.uid: aten_extra_kwargs},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -212,9 +211,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
|||||||
# TODO: add out_dtype support for Triton Template
|
# TODO: add out_dtype support for Triton Template
|
||||||
assert out_dtype is None, "out_dtype is not supported for Triton"
|
assert out_dtype is None, "out_dtype is not supported for Triton"
|
||||||
|
|
||||||
choices.extend(
|
choices.extend(V.choices.get_mm_configs(kernel_inputs, [bmm_template], name))
|
||||||
V.choices.get_mm_configs(kernel_inputs, layout, [bmm_template], name)
|
|
||||||
)
|
|
||||||
_, is_nonzero = _is_static_problem(layout)
|
_, is_nonzero = _is_static_problem(layout)
|
||||||
batch_stride_largest_or_zero = is_batch_stride_largest_or_zero(mat1, mat2, layout)
|
batch_stride_largest_or_zero = is_batch_stride_largest_or_zero(mat1, mat2, layout)
|
||||||
if (
|
if (
|
||||||
@ -275,15 +272,12 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||||||
# options to tune from
|
# options to tune from
|
||||||
choices: list[ChoiceCaller] = []
|
choices: list[ChoiceCaller] = []
|
||||||
if use_aten_gemm_kernels():
|
if use_aten_gemm_kernels():
|
||||||
choices.extend(
|
choices.extend(V.choices.get_mm_configs(kernel_inputs, [aten_baddbmm], name))
|
||||||
V.choices.get_mm_configs(kernel_inputs, layout, [aten_baddbmm], name)
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_triton_template(layout, check_max_autotune=False):
|
if use_triton_template(layout, check_max_autotune=False):
|
||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[bmm_template],
|
[bmm_template],
|
||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
|
@ -753,32 +753,30 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||||||
choices: list[ChoiceCaller] = []
|
choices: list[ChoiceCaller] = []
|
||||||
if use_aten_gemm_kernels():
|
if use_aten_gemm_kernels():
|
||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(kernel_inputs, aten_layout, [aten_mm], "mm")
|
V.choices.get_mm_configs(kernel_inputs, [aten_mm], "mm", aten_layout)
|
||||||
)
|
)
|
||||||
static_shape, is_nonzero = _is_static_problem(layout)
|
static_shape, is_nonzero = _is_static_problem(layout)
|
||||||
|
|
||||||
if is_nonzero and use_triton_template(layout, check_max_autotune=False):
|
if is_nonzero and use_triton_template(layout, check_max_autotune=False):
|
||||||
# Get template choices using the new unified function
|
# Get template choices using the new unified function
|
||||||
choices.extend(
|
choices.extend(V.choices.get_mm_configs(kernel_inputs, [mm_template], "mm"))
|
||||||
V.choices.get_mm_configs(kernel_inputs, layout, [mm_template], "mm")
|
|
||||||
)
|
|
||||||
if use_triton_tma_template(mat1, mat2):
|
if use_triton_tma_template(mat1, mat2):
|
||||||
# Get TMA template choices using the new unified function
|
# Get TMA template choices using the new unified function
|
||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs, layout, [persistent_tma_mm_template], "mm"
|
kernel_inputs, [persistent_tma_mm_template], "mm"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_decompose_k_choice(m, n, k):
|
if use_decompose_k_choice(m, n, k):
|
||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs, layout, [decompose_k_subgraph_template], "mm"
|
kernel_inputs, [decompose_k_subgraph_template], "mm"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs, layout, [mm_contiguous_subgraph_template], "mm"
|
kernel_inputs, [mm_contiguous_subgraph_template], "mm"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -820,7 +818,6 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||||||
# mm-extra is a hack to keep the ah functionality alive
|
# mm-extra is a hack to keep the ah functionality alive
|
||||||
# while we transition to the unified kwargs retrieval
|
# while we transition to the unified kwargs retrieval
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[mm_template],
|
[mm_template],
|
||||||
"mm-ah",
|
"mm-ah",
|
||||||
)
|
)
|
||||||
@ -896,12 +893,11 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
|||||||
choices: list[ChoiceCaller] = []
|
choices: list[ChoiceCaller] = []
|
||||||
|
|
||||||
# Create MMKernelInputs for Int MM
|
# Create MMKernelInputs for Int MM
|
||||||
kernel_inputs = MMKernelInputs([mat1, mat2])
|
kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=torch.int32)
|
||||||
if use_aten_gemm_kernels():
|
if use_aten_gemm_kernels():
|
||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[aten__int_mm],
|
[aten__int_mm],
|
||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
@ -915,9 +911,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
|||||||
if is_nonzero and use_triton_template(
|
if is_nonzero and use_triton_template(
|
||||||
layout, enable_int32=True, check_max_autotune=False
|
layout, enable_int32=True, check_max_autotune=False
|
||||||
):
|
):
|
||||||
choices.extend(
|
choices.extend(V.choices.get_mm_configs(kernel_inputs, [mm_template], name))
|
||||||
V.choices.get_mm_configs(kernel_inputs, layout, [mm_template], name)
|
|
||||||
)
|
|
||||||
|
|
||||||
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
|
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
|
||||||
|
|
||||||
@ -969,9 +963,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
aten_layout,
|
|
||||||
[aten_addmm],
|
[aten_addmm],
|
||||||
name,
|
name,
|
||||||
|
aten_layout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
|
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
|
||||||
@ -980,7 +974,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
aten_layout,
|
|
||||||
[aten_bias_addmm],
|
[aten_bias_addmm],
|
||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
@ -988,7 +981,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
aten_layout,
|
|
||||||
[aten_addmm],
|
[aten_addmm],
|
||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
@ -1000,7 +992,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[mm_template],
|
[mm_template],
|
||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
@ -1011,7 +1002,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[persistent_tma_mm_template],
|
[persistent_tma_mm_template],
|
||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
@ -1020,7 +1010,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[addmm_contiguous_subgraph_template],
|
[addmm_contiguous_subgraph_template],
|
||||||
"addmm",
|
"addmm",
|
||||||
)
|
)
|
||||||
@ -1174,14 +1163,15 @@ def tuned_scaled_mm(
|
|||||||
input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real, bias_real]
|
input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real, bias_real]
|
||||||
|
|
||||||
# Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1)
|
# Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1)
|
||||||
kernel_inputs = MMKernelInputs(input_nodes, mat1_idx=0, mat2_idx=1)
|
kernel_inputs = MMKernelInputs(
|
||||||
|
input_nodes, mat1_idx=0, mat2_idx=1, out_dtype=out_dtype
|
||||||
|
)
|
||||||
|
|
||||||
choices: list[ChoiceCaller] = []
|
choices: list[ChoiceCaller] = []
|
||||||
if use_aten_gemm_kernels():
|
if use_aten_gemm_kernels():
|
||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[aten__fp8_mm],
|
[aten__fp8_mm],
|
||||||
name,
|
name,
|
||||||
kwarg_overrides={
|
kwarg_overrides={
|
||||||
@ -1209,7 +1199,6 @@ def tuned_scaled_mm(
|
|||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[scaled_mm_device_tma_template],
|
[scaled_mm_device_tma_template],
|
||||||
name,
|
name,
|
||||||
kwarg_overrides={scaled_mm_device_tma_template.uid: overriders},
|
kwarg_overrides={scaled_mm_device_tma_template.uid: overriders},
|
||||||
@ -1220,7 +1209,6 @@ def tuned_scaled_mm(
|
|||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
[mm_template],
|
[mm_template],
|
||||||
name,
|
name,
|
||||||
kwarg_overrides={mm_template.uid: overriders},
|
kwarg_overrides={mm_template.uid: overriders},
|
||||||
|
@ -157,17 +157,13 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
|||||||
choices: list[ChoiceCaller] = []
|
choices: list[ChoiceCaller] = []
|
||||||
if use_aten_gemm_kernels():
|
if use_aten_gemm_kernels():
|
||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(kernel_inputs, [aten_mm_plus_mm], "mm_plus_mm")
|
||||||
kernel_inputs, layout1, [aten_mm_plus_mm], "mm_plus_mm"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_triton_template(layout1, check_max_autotune=False):
|
if use_triton_template(layout1, check_max_autotune=False):
|
||||||
# Get template choices using the new unified function
|
# Get template choices using the new unified function
|
||||||
choices.extend(
|
choices.extend(
|
||||||
V.choices.get_mm_configs(
|
V.choices.get_mm_configs(kernel_inputs, [mm_plus_mm_template], "mm_plus_mm")
|
||||||
kernel_inputs, layout1, [mm_plus_mm_template], "mm_plus_mm"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return autotune_select_algorithm(
|
return autotune_select_algorithm(
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -7,6 +8,8 @@ import torch._inductor.config
|
|||||||
from torch._inductor import ir
|
from torch._inductor import ir
|
||||||
from torch._inductor.virtualized import V
|
from torch._inductor.virtualized import V
|
||||||
|
|
||||||
|
from .ir import FixedLayout, FlexibleLayout, Layout
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@ -14,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
|
|
||||||
class KernelInputs:
|
class KernelInputs(ABC):
|
||||||
"""
|
"""
|
||||||
Class to store and provide access to input nodes for kernels.
|
Class to store and provide access to input nodes for kernels.
|
||||||
This class takes in a tuple of input nodes and provides methods to access
|
This class takes in a tuple of input nodes and provides methods to access
|
||||||
@ -25,16 +28,19 @@ class KernelInputs:
|
|||||||
self,
|
self,
|
||||||
input_nodes: list[Any],
|
input_nodes: list[Any],
|
||||||
scalars: Optional[dict[str, Union[float, int]]] = None,
|
scalars: Optional[dict[str, Union[float, int]]] = None,
|
||||||
|
out_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with a tuple of input nodes.
|
Initialize with a tuple of input nodes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_nodes: A tuple of input nodes to store
|
input_nodes: A tuple of input nodes to store
|
||||||
|
out_dtype: Optional output dtype to store
|
||||||
"""
|
"""
|
||||||
self._input_nodes = input_nodes
|
self._input_nodes = input_nodes
|
||||||
self._device_name: Optional[str] = None
|
self._device_name: Optional[str] = None
|
||||||
self._scalars = scalars if scalars is not None else {}
|
self._scalars = scalars if scalars is not None else {}
|
||||||
|
self._out_dtype = out_dtype
|
||||||
assert len(input_nodes) > 0, "Expected at least one input node"
|
assert len(input_nodes) > 0, "Expected at least one input node"
|
||||||
|
|
||||||
def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]:
|
def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]:
|
||||||
@ -168,6 +174,15 @@ class KernelInputs:
|
|||||||
"""
|
"""
|
||||||
return self._input_nodes[idx].get_dtype()
|
return self._input_nodes[idx].get_dtype()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def out_dtype(self) -> torch.dtype:
|
||||||
|
"""
|
||||||
|
Get the output dtype, whether passed in or inferred from the nodes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output dtype
|
||||||
|
"""
|
||||||
|
|
||||||
def get_scalar(self, name: str) -> Union[float, int]:
|
def get_scalar(self, name: str) -> Union[float, int]:
|
||||||
"""
|
"""
|
||||||
Get the scalar value for a given name.
|
Get the scalar value for a given name.
|
||||||
@ -181,6 +196,16 @@ class KernelInputs:
|
|||||||
assert name in self._scalars, f"Scalar {name} not found, but required"
|
assert name in self._scalars, f"Scalar {name} not found, but required"
|
||||||
return self._scalars[name]
|
return self._scalars[name]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def output_layout(self, flexible: bool = True) -> Layout:
|
||||||
|
"""
|
||||||
|
Abstract method to handle output layout generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dtype: Optional output dtype. If not provided, infer from inputs
|
||||||
|
flexible: If True, return FlexibleLayout, otherwise FixedLayout
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MMKernelInputs(KernelInputs):
|
class MMKernelInputs(KernelInputs):
|
||||||
"""
|
"""
|
||||||
@ -192,6 +217,7 @@ class MMKernelInputs(KernelInputs):
|
|||||||
self,
|
self,
|
||||||
input_nodes: list[Any],
|
input_nodes: list[Any],
|
||||||
scalars: Optional[dict[str, Union[float, int]]] = None,
|
scalars: Optional[dict[str, Union[float, int]]] = None,
|
||||||
|
out_dtype: Optional[torch.dtype] = None,
|
||||||
mat1_idx: int = -2,
|
mat1_idx: int = -2,
|
||||||
mat2_idx: int = -1,
|
mat2_idx: int = -1,
|
||||||
):
|
):
|
||||||
@ -201,7 +227,7 @@ class MMKernelInputs(KernelInputs):
|
|||||||
By default, we assume the last 2 input nodes are mat1 and mat2, but
|
By default, we assume the last 2 input nodes are mat1 and mat2, but
|
||||||
the caller can adjust when necessary
|
the caller can adjust when necessary
|
||||||
"""
|
"""
|
||||||
super().__init__(input_nodes, scalars)
|
super().__init__(input_nodes, scalars, out_dtype)
|
||||||
# for mm, we need at least 2 nodes, and we need to know which nodes
|
# for mm, we need at least 2 nodes, and we need to know which nodes
|
||||||
# are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others
|
# are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others
|
||||||
# might be (mat1, mat2, scale), etc.
|
# might be (mat1, mat2, scale), etc.
|
||||||
@ -246,6 +272,37 @@ class MMKernelInputs(KernelInputs):
|
|||||||
V.graph.sizevars.check_equals(k, k0)
|
V.graph.sizevars.check_equals(k, k0)
|
||||||
return (m, n, k)
|
return (m, n, k)
|
||||||
|
|
||||||
|
def out_dtype(self) -> torch.dtype:
|
||||||
|
"""
|
||||||
|
Get the output dtype, whether passed in or inferred from the nodes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output dtype
|
||||||
|
"""
|
||||||
|
if self._out_dtype is not None:
|
||||||
|
return self._out_dtype
|
||||||
|
return self.mat1mat2()[0].get_dtype()
|
||||||
|
|
||||||
|
def output_layout(self, flexible: bool = True) -> Layout:
|
||||||
|
"""
|
||||||
|
Handle output layout generation for matrix multiplication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dtype: Optional output dtype. If not provided, infer from inputs
|
||||||
|
flexible: If True, return FlexibleLayout, otherwise FixedLayout
|
||||||
|
"""
|
||||||
|
mat1, mat2 = self.mat1mat2()
|
||||||
|
out_dtype = self.out_dtype()
|
||||||
|
# NOTE: taken from mm_common.mm_args
|
||||||
|
*b1, m, k1 = mat1.get_size()
|
||||||
|
*b2, k2, n = mat2.get_size()
|
||||||
|
b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)]
|
||||||
|
size = [*b, m, n]
|
||||||
|
if flexible:
|
||||||
|
return FlexibleLayout(self.device(), out_dtype, size)
|
||||||
|
else:
|
||||||
|
return FixedLayout(self.device(), out_dtype, size)
|
||||||
|
|
||||||
def mat1mat2(self) -> tuple[Any, Any]:
|
def mat1mat2(self) -> tuple[Any, Any]:
|
||||||
"""
|
"""
|
||||||
Get the mat1 and mat2 nodes.
|
Get the mat1 and mat2 nodes.
|
||||||
|
@ -9,16 +9,14 @@ from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm,
|
|||||||
from ..kernel.mm_plus_mm import aten_mm_plus_mm
|
from ..kernel.mm_plus_mm import aten_mm_plus_mm
|
||||||
from .base import TemplateConfigHeuristics
|
from .base import TemplateConfigHeuristics
|
||||||
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
|
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
|
||||||
|
from .registry import register_template_heuristic
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from ..ir import Layout
|
|
||||||
from ..kernel_inputs import KernelInputs
|
from ..kernel_inputs import KernelInputs
|
||||||
|
|
||||||
from .registry import register_template_heuristic
|
|
||||||
|
|
||||||
|
|
||||||
# These are all labeled as device type None to indicate that they
|
# These are all labeled as device type None to indicate that they
|
||||||
# are valid for all device types
|
# are valid for all device types
|
||||||
@ -41,7 +39,6 @@ class ATenConfigHeuristics(TemplateConfigHeuristics):
|
|||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
yield dict()
|
yield dict()
|
||||||
@ -55,10 +52,9 @@ class ATenAddMMConfigHeuristics(ATenConfigHeuristics):
|
|||||||
def get_extra_kwargs(
|
def get_extra_kwargs(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name)
|
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
|
||||||
alpha = kernel_inputs.get_scalar("alpha")
|
alpha = kernel_inputs.get_scalar("alpha")
|
||||||
beta = kernel_inputs.get_scalar("beta")
|
beta = kernel_inputs.get_scalar("beta")
|
||||||
return {
|
return {
|
||||||
@ -75,7 +71,6 @@ class ATenBiasAddMMConfigHeuristics(
|
|||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
nodes = kernel_inputs.nodes()
|
nodes = kernel_inputs.nodes()
|
||||||
|
@ -6,14 +6,13 @@ from typing import Any, TYPE_CHECKING
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from ..ir import Layout
|
|
||||||
from ..kernel_inputs import KernelInputs
|
from ..kernel_inputs import KernelInputs
|
||||||
|
|
||||||
|
|
||||||
class TemplateConfigHeuristics:
|
class TemplateConfigHeuristics:
|
||||||
"""Base class for generating sets of configs for an associated template."""
|
"""Base class for generating sets of configs for an associated template."""
|
||||||
|
|
||||||
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
|
def should_run(self, inputs: KernelInputs) -> bool:
|
||||||
"""
|
"""
|
||||||
hookup to check whether the configs are right to run at all e.g. you can check
|
hookup to check whether the configs are right to run at all e.g. you can check
|
||||||
max-autotune specific to your heuristic here or other things
|
max-autotune specific to your heuristic here or other things
|
||||||
@ -21,14 +20,12 @@ class TemplateConfigHeuristics:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: KernelInputs
|
inputs: KernelInputs
|
||||||
layout: Layout
|
|
||||||
"""
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_template_configs(
|
def get_template_configs(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
@ -37,19 +34,17 @@ class TemplateConfigHeuristics:
|
|||||||
Prefer to override the _get_template_configs_impl method
|
Prefer to override the _get_template_configs_impl method
|
||||||
to leverage things like should_run
|
to leverage things like should_run
|
||||||
"""
|
"""
|
||||||
if not self.should_run(kernel_inputs, layout):
|
if not self.should_run(kernel_inputs):
|
||||||
return
|
return
|
||||||
|
|
||||||
yield from self._get_template_configs_impl(
|
yield from self._get_template_configs_impl(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
op_name,
|
op_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
@ -62,7 +57,6 @@ class TemplateConfigHeuristics:
|
|||||||
def get_extra_kwargs(
|
def get_extra_kwargs(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
@ -19,8 +19,6 @@ from .registry import register_template_heuristic
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from ..ir import Layout
|
|
||||||
|
|
||||||
|
|
||||||
@register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm")
|
@register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm")
|
||||||
@register_template_heuristic(
|
@register_template_heuristic(
|
||||||
@ -46,7 +44,6 @@ class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
|
|||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
|
@ -19,8 +19,6 @@ from .registry import register_template_heuristic
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from ..ir import Layout
|
|
||||||
|
|
||||||
|
|
||||||
@register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm")
|
@register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm")
|
||||||
class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
|
class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
|
||||||
@ -43,7 +41,6 @@ class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics):
|
|||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
|
@ -7,12 +7,11 @@ from .base import TemplateConfigHeuristics
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..ir import Layout
|
|
||||||
from ..kernel_inputs import KernelInputs
|
from ..kernel_inputs import KernelInputs
|
||||||
|
|
||||||
|
|
||||||
class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics):
|
class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics):
|
||||||
def should_run(self, inputs: KernelInputs, layout: Layout) -> bool:
|
def should_run(self, inputs: KernelInputs) -> bool:
|
||||||
"""
|
"""
|
||||||
simple base override for GEMM family templates that run only in max-autotune
|
simple base override for GEMM family templates that run only in max-autotune
|
||||||
"""
|
"""
|
||||||
|
@ -41,8 +41,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from triton import Config as TritonConfig
|
from triton import Config as TritonConfig
|
||||||
|
|
||||||
from ..ir import Layout
|
|
||||||
|
|
||||||
|
|
||||||
# Gemm Configs
|
# Gemm Configs
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -1451,7 +1449,6 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
|
|||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Any,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
@ -1479,7 +1476,11 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
|
|||||||
# Generate and process configs
|
# Generate and process configs
|
||||||
for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name):
|
for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name):
|
||||||
template_kwargs = self._convert_config_to_template_kwargs(
|
template_kwargs = self._convert_config_to_template_kwargs(
|
||||||
c, m, n, k, layout
|
c,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
kernel_inputs.out_dtype(),
|
||||||
)
|
)
|
||||||
yield template_kwargs
|
yield template_kwargs
|
||||||
|
|
||||||
@ -1489,7 +1490,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
|
|||||||
m: sympy.Integer,
|
m: sympy.Integer,
|
||||||
n: sympy.Integer,
|
n: sympy.Integer,
|
||||||
k: sympy.Integer,
|
k: sympy.Integer,
|
||||||
layout: Any,
|
out_dtype: torch.dtype,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert triton config to template kwargs.
|
Convert triton config to template kwargs.
|
||||||
@ -1513,7 +1514,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
|
|||||||
EVEN_K=even_k_symbolic,
|
EVEN_K=even_k_symbolic,
|
||||||
ALLOW_TF32=allow_tf32,
|
ALLOW_TF32=allow_tf32,
|
||||||
USE_FAST_ACCUM=False, # Option for _scaled_mm
|
USE_FAST_ACCUM=False, # Option for _scaled_mm
|
||||||
ACC_TYPE=self._get_acc_type(layout.dtype),
|
ACC_TYPE=self._get_acc_type(out_dtype),
|
||||||
num_stages=triton_config.num_stages,
|
num_stages=triton_config.num_stages,
|
||||||
num_warps=triton_config.num_warps,
|
num_warps=triton_config.num_warps,
|
||||||
**triton_config.kwargs,
|
**triton_config.kwargs,
|
||||||
@ -1562,14 +1563,11 @@ class MMPlusMMTemplateConfigMixin(MMTemplateConfigMixin):
|
|||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Any,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
assert isinstance(kernel_inputs, MMKernelInputs), "Expect MMKernelInputs"
|
assert isinstance(kernel_inputs, MMKernelInputs), "Expect MMKernelInputs"
|
||||||
m, n, k = kernel_inputs.mnk_symbolic()
|
m, n, k = kernel_inputs.mnk_symbolic()
|
||||||
for kwargs in super()._get_template_configs_impl(
|
for kwargs in super()._get_template_configs_impl(kernel_inputs, op_name):
|
||||||
kernel_inputs, layout, op_name
|
|
||||||
):
|
|
||||||
# Apply BLOCK_K constraint specific to mm_plus_mm
|
# Apply BLOCK_K constraint specific to mm_plus_mm
|
||||||
# see https://github.com/triton-lang/triton/issues/1298
|
# see https://github.com/triton-lang/triton/issues/1298
|
||||||
# BLOCK_K = K causes llvm error
|
# BLOCK_K = K causes llvm error
|
||||||
@ -1586,10 +1584,9 @@ class TMAWorkspaceMixin(MMTemplateConfigMixin):
|
|||||||
def get_extra_kwargs(
|
def get_extra_kwargs(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name)
|
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
|
||||||
kwargs["workspace_arg"] = get_tma_workspace_arg(
|
kwargs["workspace_arg"] = get_tma_workspace_arg(
|
||||||
num_tma_descriptors=2,
|
num_tma_descriptors=2,
|
||||||
device=kernel_inputs.device(),
|
device=kernel_inputs.device(),
|
||||||
@ -1614,7 +1611,6 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
|
|||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Any,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
@ -1634,7 +1630,6 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin):
|
|||||||
# Get base template configs from superclass
|
# Get base template configs from superclass
|
||||||
for template_kwargs in super()._get_template_configs_impl(
|
for template_kwargs in super()._get_template_configs_impl(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
op_name,
|
op_name,
|
||||||
):
|
):
|
||||||
yield {**template_kwargs, **tma_opts}
|
yield {**template_kwargs, **tma_opts}
|
||||||
@ -1684,7 +1679,6 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
|
|||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Any,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
@ -1734,7 +1728,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin):
|
|||||||
|
|
||||||
# Get base template configs from superclass
|
# Get base template configs from superclass
|
||||||
for template_kwargs in super()._get_template_configs_impl(
|
for template_kwargs in super()._get_template_configs_impl(
|
||||||
kernel_inputs, layout, op_name
|
kernel_inputs, op_name
|
||||||
):
|
):
|
||||||
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
|
# Add scaled MM-specific options (moved from mm_common.scaled_mm_options)
|
||||||
# Override accumulator type for scaled MM
|
# Override accumulator type for scaled MM
|
||||||
@ -1752,10 +1746,9 @@ class ScaledMMConfigMixin(BaseScaledMMConfigMixin):
|
|||||||
def get_extra_kwargs(
|
def get_extra_kwargs(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name)
|
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
|
||||||
from ..kernel.mm_common import scale_mm_epilogue
|
from ..kernel.mm_common import scale_mm_epilogue
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -1792,7 +1785,6 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
|
|||||||
def _get_template_configs_impl(
|
def _get_template_configs_impl(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Any,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
"""
|
"""
|
||||||
@ -1801,7 +1793,6 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin):
|
|||||||
# Get base scaled MM template configs from superclass
|
# Get base scaled MM template configs from superclass
|
||||||
for template_kwargs in super()._get_template_configs_impl(
|
for template_kwargs in super()._get_template_configs_impl(
|
||||||
kernel_inputs,
|
kernel_inputs,
|
||||||
layout,
|
|
||||||
op_name,
|
op_name,
|
||||||
):
|
):
|
||||||
# Add TMA-specific options for device TMA scaled MM
|
# Add TMA-specific options for device TMA scaled MM
|
||||||
|
@ -7,7 +7,6 @@ from .base import TemplateConfigHeuristics
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..ir import Layout
|
|
||||||
from ..kernel_inputs import KernelInputs
|
from ..kernel_inputs import KernelInputs
|
||||||
|
|
||||||
|
|
||||||
@ -19,10 +18,9 @@ class AddMMConfigMixin(TemplateConfigHeuristics):
|
|||||||
def get_extra_kwargs(
|
def get_extra_kwargs(
|
||||||
self,
|
self,
|
||||||
kernel_inputs: KernelInputs,
|
kernel_inputs: KernelInputs,
|
||||||
layout: Layout,
|
|
||||||
op_name: str,
|
op_name: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name)
|
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
|
||||||
assert op_name in [
|
assert op_name in [
|
||||||
"addmm",
|
"addmm",
|
||||||
"baddbmm",
|
"baddbmm",
|
||||||
@ -31,7 +29,9 @@ class AddMMConfigMixin(TemplateConfigHeuristics):
|
|||||||
beta = kernel_inputs.get_scalar("beta")
|
beta = kernel_inputs.get_scalar("beta")
|
||||||
return {
|
return {
|
||||||
**kwargs,
|
**kwargs,
|
||||||
"epilogue_fn": addmm_epilogue(layout.dtype, alpha, beta),
|
"epilogue_fn": addmm_epilogue(kernel_inputs.out_dtype(), alpha, beta),
|
||||||
"epilogue_fn_hash": str(["addmm_epilogue", layout.dtype, alpha, beta]),
|
"epilogue_fn_hash": str(
|
||||||
|
["addmm_epilogue", kernel_inputs.out_dtype(), alpha, beta]
|
||||||
|
),
|
||||||
"prefix_args": 1,
|
"prefix_args": 1,
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user