diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 4b02bb1956e1..a6275ac85c11 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -163,10 +163,9 @@ class InductorChoices: heuristic = get_template_heuristic(template_name, device_type, op_name) cs = heuristic.get_template_configs( kernel_inputs, - layout, op_name, ) - extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, layout, op_name) + extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, op_name) # adjust the kernel inputs to the template-specific heuristic, if needed # default here is to just return the kernel_inputs as is inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name) @@ -184,9 +183,9 @@ class InductorChoices: def get_mm_configs( self, kernel_inputs: KernelInputs, - layout: Any, templates: list[Union[KernelTemplate, ExternKernelChoice]], op_name: str, + layout: Optional[Layout] = None, kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, ) -> list[ChoiceCaller]: """ @@ -207,7 +206,11 @@ class InductorChoices: input_tensors = kernel_inputs.nodes() if len(input_tensors) < 2: raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}") - + if layout is None: + # TODO(coconutruben): remove this once we remove the layout argument entirely + # This is just here to the brief gap between commits where we still need this + # to accommodate fixed vs flexible layout decision externally + layout = kernel_inputs.output_layout(flexible=False) # First pass: Create dict of template.uid to generator of KernelTemplateChoice objects template_choices = {} for template in templates: diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index a843c7369fb5..e882be6df0df 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -173,7 +173,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None): name = "bmm" # Create MMKernelInputs for BMM at the top - kernel_inputs = MMKernelInputs([mat1, mat2]) + kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=out_dtype) # below is for getting an overview logging info of inductor mms batch_size = mat1.get_size()[0] # Extract batch dimension @@ -201,10 +201,9 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None): choices.extend( V.choices.get_mm_configs( kernel_inputs, - layout, [aten_handler], name, - {aten_handler.uid: aten_extra_kwargs}, + kwarg_overrides={aten_handler.uid: aten_extra_kwargs}, ) ) @@ -212,9 +211,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None): # TODO: add out_dtype support for Triton Template assert out_dtype is None, "out_dtype is not supported for Triton" - choices.extend( - V.choices.get_mm_configs(kernel_inputs, layout, [bmm_template], name) - ) + choices.extend(V.choices.get_mm_configs(kernel_inputs, [bmm_template], name)) _, is_nonzero = _is_static_problem(layout) batch_stride_largest_or_zero = is_batch_stride_largest_or_zero(mat1, mat2, layout) if ( @@ -275,15 +272,12 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): # options to tune from choices: list[ChoiceCaller] = [] if use_aten_gemm_kernels(): - choices.extend( - V.choices.get_mm_configs(kernel_inputs, layout, [aten_baddbmm], name) - ) + choices.extend(V.choices.get_mm_configs(kernel_inputs, [aten_baddbmm], name)) if use_triton_template(layout, check_max_autotune=False): choices.extend( V.choices.get_mm_configs( kernel_inputs, - layout, [bmm_template], name, ) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index a597107510e7..155c461775cb 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -753,32 +753,30 @@ def tuned_mm(mat1, mat2, *, layout=None): choices: list[ChoiceCaller] = [] if use_aten_gemm_kernels(): choices.extend( - V.choices.get_mm_configs(kernel_inputs, aten_layout, [aten_mm], "mm") + V.choices.get_mm_configs(kernel_inputs, [aten_mm], "mm", aten_layout) ) static_shape, is_nonzero = _is_static_problem(layout) if is_nonzero and use_triton_template(layout, check_max_autotune=False): # Get template choices using the new unified function - choices.extend( - V.choices.get_mm_configs(kernel_inputs, layout, [mm_template], "mm") - ) + choices.extend(V.choices.get_mm_configs(kernel_inputs, [mm_template], "mm")) if use_triton_tma_template(mat1, mat2): # Get TMA template choices using the new unified function choices.extend( V.choices.get_mm_configs( - kernel_inputs, layout, [persistent_tma_mm_template], "mm" + kernel_inputs, [persistent_tma_mm_template], "mm" ) ) if use_decompose_k_choice(m, n, k): choices.extend( V.choices.get_mm_configs( - kernel_inputs, layout, [decompose_k_subgraph_template], "mm" + kernel_inputs, [decompose_k_subgraph_template], "mm" ) ) choices.extend( V.choices.get_mm_configs( - kernel_inputs, layout, [mm_contiguous_subgraph_template], "mm" + kernel_inputs, [mm_contiguous_subgraph_template], "mm" ) ) @@ -820,7 +818,6 @@ def tuned_mm(mat1, mat2, *, layout=None): # mm-extra is a hack to keep the ah functionality alive # while we transition to the unified kwargs retrieval kernel_inputs, - layout, [mm_template], "mm-ah", ) @@ -896,12 +893,11 @@ def tuned_int_mm(mat1, mat2, *, layout=None): choices: list[ChoiceCaller] = [] # Create MMKernelInputs for Int MM - kernel_inputs = MMKernelInputs([mat1, mat2]) + kernel_inputs = MMKernelInputs([mat1, mat2], out_dtype=torch.int32) if use_aten_gemm_kernels(): choices.extend( V.choices.get_mm_configs( kernel_inputs, - layout, [aten__int_mm], name, ) @@ -915,9 +911,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None): if is_nonzero and use_triton_template( layout, enable_int32=True, check_max_autotune=False ): - choices.extend( - V.choices.get_mm_configs(kernel_inputs, layout, [mm_template], name) - ) + choices.extend(V.choices.get_mm_configs(kernel_inputs, [mm_template], name)) return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) @@ -969,9 +963,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): choices.extend( V.choices.get_mm_configs( kernel_inputs, - aten_layout, [aten_addmm], name, + aten_layout, ) ) return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) @@ -980,7 +974,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): choices.extend( V.choices.get_mm_configs( kernel_inputs, - aten_layout, [aten_bias_addmm], name, ) @@ -988,7 +981,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): choices.extend( V.choices.get_mm_configs( kernel_inputs, - aten_layout, [aten_addmm], name, ) @@ -1000,7 +992,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): choices.extend( V.choices.get_mm_configs( kernel_inputs, - layout, [mm_template], name, ) @@ -1011,7 +1002,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): choices.extend( V.choices.get_mm_configs( kernel_inputs, - layout, [persistent_tma_mm_template], name, ) @@ -1020,7 +1010,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): choices.extend( V.choices.get_mm_configs( kernel_inputs, - layout, [addmm_contiguous_subgraph_template], "addmm", ) @@ -1174,14 +1163,15 @@ def tuned_scaled_mm( input_nodes = [mat_a, mat_b, scale_a_real, scale_b_real, bias_real] # Create MMKernelInputs for Scaled MM (matrices are at indices 0, 1) - kernel_inputs = MMKernelInputs(input_nodes, mat1_idx=0, mat2_idx=1) + kernel_inputs = MMKernelInputs( + input_nodes, mat1_idx=0, mat2_idx=1, out_dtype=out_dtype + ) choices: list[ChoiceCaller] = [] if use_aten_gemm_kernels(): choices.extend( V.choices.get_mm_configs( kernel_inputs, - layout, [aten__fp8_mm], name, kwarg_overrides={ @@ -1209,7 +1199,6 @@ def tuned_scaled_mm( choices.extend( V.choices.get_mm_configs( kernel_inputs, - layout, [scaled_mm_device_tma_template], name, kwarg_overrides={scaled_mm_device_tma_template.uid: overriders}, @@ -1220,7 +1209,6 @@ def tuned_scaled_mm( choices.extend( V.choices.get_mm_configs( kernel_inputs, - layout, [mm_template], name, kwarg_overrides={mm_template.uid: overriders}, diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index 60e1b01a5b03..213393181594 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -157,17 +157,13 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): choices: list[ChoiceCaller] = [] if use_aten_gemm_kernels(): choices.extend( - V.choices.get_mm_configs( - kernel_inputs, layout1, [aten_mm_plus_mm], "mm_plus_mm" - ) + V.choices.get_mm_configs(kernel_inputs, [aten_mm_plus_mm], "mm_plus_mm") ) if use_triton_template(layout1, check_max_autotune=False): # Get template choices using the new unified function choices.extend( - V.choices.get_mm_configs( - kernel_inputs, layout1, [mm_plus_mm_template], "mm_plus_mm" - ) + V.choices.get_mm_configs(kernel_inputs, [mm_plus_mm_template], "mm_plus_mm") ) return autotune_select_algorithm( diff --git a/torch/_inductor/kernel_inputs.py b/torch/_inductor/kernel_inputs.py index 83ef996831a2..c579cf756577 100644 --- a/torch/_inductor/kernel_inputs.py +++ b/torch/_inductor/kernel_inputs.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod from typing import Any, Optional, TYPE_CHECKING, Union import torch @@ -7,6 +8,8 @@ import torch._inductor.config from torch._inductor import ir from torch._inductor.virtualized import V +from .ir import FixedLayout, FlexibleLayout, Layout + if TYPE_CHECKING: from collections.abc import Sequence @@ -14,7 +17,7 @@ if TYPE_CHECKING: import sympy -class KernelInputs: +class KernelInputs(ABC): """ Class to store and provide access to input nodes for kernels. This class takes in a tuple of input nodes and provides methods to access @@ -25,16 +28,19 @@ class KernelInputs: self, input_nodes: list[Any], scalars: Optional[dict[str, Union[float, int]]] = None, + out_dtype: Optional[torch.dtype] = None, ): """ Initialize with a tuple of input nodes. Args: input_nodes: A tuple of input nodes to store + out_dtype: Optional output dtype to store """ self._input_nodes = input_nodes self._device_name: Optional[str] = None self._scalars = scalars if scalars is not None else {} + self._out_dtype = out_dtype assert len(input_nodes) > 0, "Expected at least one input node" def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]: @@ -168,6 +174,15 @@ class KernelInputs: """ return self._input_nodes[idx].get_dtype() + @abstractmethod + def out_dtype(self) -> torch.dtype: + """ + Get the output dtype, whether passed in or inferred from the nodes + + Returns: + The output dtype + """ + def get_scalar(self, name: str) -> Union[float, int]: """ Get the scalar value for a given name. @@ -181,6 +196,16 @@ class KernelInputs: assert name in self._scalars, f"Scalar {name} not found, but required" return self._scalars[name] + @abstractmethod + def output_layout(self, flexible: bool = True) -> Layout: + """ + Abstract method to handle output layout generation. + + Args: + out_dtype: Optional output dtype. If not provided, infer from inputs + flexible: If True, return FlexibleLayout, otherwise FixedLayout + """ + class MMKernelInputs(KernelInputs): """ @@ -192,6 +217,7 @@ class MMKernelInputs(KernelInputs): self, input_nodes: list[Any], scalars: Optional[dict[str, Union[float, int]]] = None, + out_dtype: Optional[torch.dtype] = None, mat1_idx: int = -2, mat2_idx: int = -1, ): @@ -201,7 +227,7 @@ class MMKernelInputs(KernelInputs): By default, we assume the last 2 input nodes are mat1 and mat2, but the caller can adjust when necessary """ - super().__init__(input_nodes, scalars) + super().__init__(input_nodes, scalars, out_dtype) # for mm, we need at least 2 nodes, and we need to know which nodes # are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others # might be (mat1, mat2, scale), etc. @@ -246,6 +272,37 @@ class MMKernelInputs(KernelInputs): V.graph.sizevars.check_equals(k, k0) return (m, n, k) + def out_dtype(self) -> torch.dtype: + """ + Get the output dtype, whether passed in or inferred from the nodes + + Returns: + The output dtype + """ + if self._out_dtype is not None: + return self._out_dtype + return self.mat1mat2()[0].get_dtype() + + def output_layout(self, flexible: bool = True) -> Layout: + """ + Handle output layout generation for matrix multiplication. + + Args: + out_dtype: Optional output dtype. If not provided, infer from inputs + flexible: If True, return FlexibleLayout, otherwise FixedLayout + """ + mat1, mat2 = self.mat1mat2() + out_dtype = self.out_dtype() + # NOTE: taken from mm_common.mm_args + *b1, m, k1 = mat1.get_size() + *b2, k2, n = mat2.get_size() + b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)] + size = [*b, m, n] + if flexible: + return FlexibleLayout(self.device(), out_dtype, size) + else: + return FixedLayout(self.device(), out_dtype, size) + def mat1mat2(self) -> tuple[Any, Any]: """ Get the mat1 and mat2 nodes. diff --git a/torch/_inductor/template_heuristics/aten.py b/torch/_inductor/template_heuristics/aten.py index 1b797319586f..72e66b1c1476 100644 --- a/torch/_inductor/template_heuristics/aten.py +++ b/torch/_inductor/template_heuristics/aten.py @@ -9,16 +9,14 @@ from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm, from ..kernel.mm_plus_mm import aten_mm_plus_mm from .base import TemplateConfigHeuristics from .gemm import GemmMaxAutotuneTemplateConfigHeuristics +from .registry import register_template_heuristic if TYPE_CHECKING: from collections.abc import Generator - from ..ir import Layout from ..kernel_inputs import KernelInputs -from .registry import register_template_heuristic - # These are all labeled as device type None to indicate that they # are valid for all device types @@ -41,7 +39,6 @@ class ATenConfigHeuristics(TemplateConfigHeuristics): def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> Generator[dict[str, Any], None, None]: yield dict() @@ -55,10 +52,9 @@ class ATenAddMMConfigHeuristics(ATenConfigHeuristics): def get_extra_kwargs( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> dict[str, Any]: - kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name) + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) alpha = kernel_inputs.get_scalar("alpha") beta = kernel_inputs.get_scalar("beta") return { @@ -75,7 +71,6 @@ class ATenBiasAddMMConfigHeuristics( def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> Generator[dict[str, Any], None, None]: nodes = kernel_inputs.nodes() diff --git a/torch/_inductor/template_heuristics/base.py b/torch/_inductor/template_heuristics/base.py index 5054de625e87..def2a2f59bee 100644 --- a/torch/_inductor/template_heuristics/base.py +++ b/torch/_inductor/template_heuristics/base.py @@ -6,14 +6,13 @@ from typing import Any, TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Generator - from ..ir import Layout from ..kernel_inputs import KernelInputs class TemplateConfigHeuristics: """Base class for generating sets of configs for an associated template.""" - def should_run(self, inputs: KernelInputs, layout: Layout) -> bool: + def should_run(self, inputs: KernelInputs) -> bool: """ hookup to check whether the configs are right to run at all e.g. you can check max-autotune specific to your heuristic here or other things @@ -21,14 +20,12 @@ class TemplateConfigHeuristics: Args: inputs: KernelInputs - layout: Layout """ return True def get_template_configs( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> Generator[dict[str, Any], None, None]: """ @@ -37,19 +34,17 @@ class TemplateConfigHeuristics: Prefer to override the _get_template_configs_impl method to leverage things like should_run """ - if not self.should_run(kernel_inputs, layout): + if not self.should_run(kernel_inputs): return yield from self._get_template_configs_impl( kernel_inputs, - layout, op_name, ) def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> Generator[dict[str, Any], None, None]: """ @@ -62,7 +57,6 @@ class TemplateConfigHeuristics: def get_extra_kwargs( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> dict[str, Any]: """ diff --git a/torch/_inductor/template_heuristics/contiguous_mm.py b/torch/_inductor/template_heuristics/contiguous_mm.py index 3c3c8c6796a9..f7b65eba9c76 100644 --- a/torch/_inductor/template_heuristics/contiguous_mm.py +++ b/torch/_inductor/template_heuristics/contiguous_mm.py @@ -19,8 +19,6 @@ from .registry import register_template_heuristic if TYPE_CHECKING: from collections.abc import Generator - from ..ir import Layout - @register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm") @register_template_heuristic( @@ -46,7 +44,6 @@ class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics): def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> Generator[dict[str, Any], None, None]: """ diff --git a/torch/_inductor/template_heuristics/decompose_k.py b/torch/_inductor/template_heuristics/decompose_k.py index 6005e421eb3b..7954396a1086 100644 --- a/torch/_inductor/template_heuristics/decompose_k.py +++ b/torch/_inductor/template_heuristics/decompose_k.py @@ -19,8 +19,6 @@ from .registry import register_template_heuristic if TYPE_CHECKING: from collections.abc import Generator - from ..ir import Layout - @register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm") class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics): @@ -43,7 +41,6 @@ class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics): def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> Generator[dict[str, Any], None, None]: """ diff --git a/torch/_inductor/template_heuristics/gemm.py b/torch/_inductor/template_heuristics/gemm.py index e1119af0d026..2d56f4c481cc 100644 --- a/torch/_inductor/template_heuristics/gemm.py +++ b/torch/_inductor/template_heuristics/gemm.py @@ -7,12 +7,11 @@ from .base import TemplateConfigHeuristics if TYPE_CHECKING: - from ..ir import Layout from ..kernel_inputs import KernelInputs class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics): - def should_run(self, inputs: KernelInputs, layout: Layout) -> bool: + def should_run(self, inputs: KernelInputs) -> bool: """ simple base override for GEMM family templates that run only in max-autotune """ diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index a7f4d9f5763f..0aaf70ae3f24 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -41,8 +41,6 @@ if TYPE_CHECKING: from triton import Config as TritonConfig - from ..ir import Layout - # Gemm Configs @dataclasses.dataclass @@ -1451,7 +1449,6 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics): def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Any, op_name: str, ) -> Generator[dict[str, Any], None, None]: """ @@ -1479,7 +1476,11 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics): # Generate and process configs for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name): template_kwargs = self._convert_config_to_template_kwargs( - c, m, n, k, layout + c, + m, + n, + k, + kernel_inputs.out_dtype(), ) yield template_kwargs @@ -1489,7 +1490,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics): m: sympy.Integer, n: sympy.Integer, k: sympy.Integer, - layout: Any, + out_dtype: torch.dtype, ) -> dict[str, Any]: """ Convert triton config to template kwargs. @@ -1513,7 +1514,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics): EVEN_K=even_k_symbolic, ALLOW_TF32=allow_tf32, USE_FAST_ACCUM=False, # Option for _scaled_mm - ACC_TYPE=self._get_acc_type(layout.dtype), + ACC_TYPE=self._get_acc_type(out_dtype), num_stages=triton_config.num_stages, num_warps=triton_config.num_warps, **triton_config.kwargs, @@ -1562,14 +1563,11 @@ class MMPlusMMTemplateConfigMixin(MMTemplateConfigMixin): def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Any, op_name: str, ) -> Generator[dict[str, Any], None, None]: assert isinstance(kernel_inputs, MMKernelInputs), "Expect MMKernelInputs" m, n, k = kernel_inputs.mnk_symbolic() - for kwargs in super()._get_template_configs_impl( - kernel_inputs, layout, op_name - ): + for kwargs in super()._get_template_configs_impl(kernel_inputs, op_name): # Apply BLOCK_K constraint specific to mm_plus_mm # see https://github.com/triton-lang/triton/issues/1298 # BLOCK_K = K causes llvm error @@ -1586,10 +1584,9 @@ class TMAWorkspaceMixin(MMTemplateConfigMixin): def get_extra_kwargs( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> dict[str, Any]: - kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name) + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) kwargs["workspace_arg"] = get_tma_workspace_arg( num_tma_descriptors=2, device=kernel_inputs.device(), @@ -1614,7 +1611,6 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin): def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Any, op_name: str, ) -> Generator[dict[str, Any], None, None]: """ @@ -1634,7 +1630,6 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin): # Get base template configs from superclass for template_kwargs in super()._get_template_configs_impl( kernel_inputs, - layout, op_name, ): yield {**template_kwargs, **tma_opts} @@ -1684,7 +1679,6 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin): def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Any, op_name: str, ) -> Generator[dict[str, Any], None, None]: """ @@ -1734,7 +1728,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin): # Get base template configs from superclass for template_kwargs in super()._get_template_configs_impl( - kernel_inputs, layout, op_name + kernel_inputs, op_name ): # Add scaled MM-specific options (moved from mm_common.scaled_mm_options) # Override accumulator type for scaled MM @@ -1752,10 +1746,9 @@ class ScaledMMConfigMixin(BaseScaledMMConfigMixin): def get_extra_kwargs( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> dict[str, Any]: - kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name) + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) from ..kernel.mm_common import scale_mm_epilogue return { @@ -1792,7 +1785,6 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): def _get_template_configs_impl( self, kernel_inputs: KernelInputs, - layout: Any, op_name: str, ) -> Generator[dict[str, Any], None, None]: """ @@ -1801,7 +1793,6 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): # Get base scaled MM template configs from superclass for template_kwargs in super()._get_template_configs_impl( kernel_inputs, - layout, op_name, ): # Add TMA-specific options for device TMA scaled MM diff --git a/torch/_inductor/template_heuristics/triton_addmm.py b/torch/_inductor/template_heuristics/triton_addmm.py index 5ce99a6049e8..a6643d1ce2a9 100644 --- a/torch/_inductor/template_heuristics/triton_addmm.py +++ b/torch/_inductor/template_heuristics/triton_addmm.py @@ -7,7 +7,6 @@ from .base import TemplateConfigHeuristics if TYPE_CHECKING: - from ..ir import Layout from ..kernel_inputs import KernelInputs @@ -19,10 +18,9 @@ class AddMMConfigMixin(TemplateConfigHeuristics): def get_extra_kwargs( self, kernel_inputs: KernelInputs, - layout: Layout, op_name: str, ) -> dict[str, Any]: - kwargs = super().get_extra_kwargs(kernel_inputs, layout, op_name) + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) assert op_name in [ "addmm", "baddbmm", @@ -31,7 +29,9 @@ class AddMMConfigMixin(TemplateConfigHeuristics): beta = kernel_inputs.get_scalar("beta") return { **kwargs, - "epilogue_fn": addmm_epilogue(layout.dtype, alpha, beta), - "epilogue_fn_hash": str(["addmm_epilogue", layout.dtype, alpha, beta]), + "epilogue_fn": addmm_epilogue(kernel_inputs.out_dtype(), alpha, beta), + "epilogue_fn_hash": str( + ["addmm_epilogue", kernel_inputs.out_dtype(), alpha, beta] + ), "prefix_args": 1, }