mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[inductor][ez] pass template rather than template.uid (#161343)
# why - simpler interface - enables future of extracting more things out of the template e.g. a hash # what V.choices.get_mm_configs now takes the whole template rather than just the template.uid # testing ``` python3 -bb -m pytest test/inductor/test_max_autotune.py -v ``` Differential Revision: [D81520576](https://our.internmc.facebook.com/intern/diff/D81520576) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161343 Approved by: https://github.com/jansel ghstack dependencies: #162075, #161340, #161341, #161342
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							af590cb729
						
					
				
				
					commit
					a301dc3b60
				
			@ -1,7 +1,7 @@
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import typing
 | 
			
		||||
from typing import Any, Optional, TYPE_CHECKING
 | 
			
		||||
from typing import Any, Optional, TYPE_CHECKING, Union
 | 
			
		||||
 | 
			
		||||
import sympy
 | 
			
		||||
 | 
			
		||||
@ -33,8 +33,10 @@ if TYPE_CHECKING:
 | 
			
		||||
 | 
			
		||||
    from torch.utils._ordered_set import OrderedSet
 | 
			
		||||
 | 
			
		||||
    from .codegen.common import KernelTemplate
 | 
			
		||||
    from .codegen.simd_kernel_features import SIMDKernelFeatures
 | 
			
		||||
    from .codegen.triton import TritonKernel
 | 
			
		||||
    from .select_algorithm import ExternKernelChoice
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Sortable(typing.Protocol):
 | 
			
		||||
@ -104,7 +106,7 @@ class InductorChoices:
 | 
			
		||||
        self,
 | 
			
		||||
        kernel_inputs: KernelInputs,
 | 
			
		||||
        layout: Any,
 | 
			
		||||
        template_name: str,
 | 
			
		||||
        template: Union[KernelTemplate, ExternKernelChoice],
 | 
			
		||||
        op_name: str,
 | 
			
		||||
        kwarg_overrides: Optional[dict[str, Any]] = None,
 | 
			
		||||
    ) -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]:
 | 
			
		||||
@ -114,7 +116,7 @@ class InductorChoices:
 | 
			
		||||
        Args:
 | 
			
		||||
            kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
 | 
			
		||||
            layout: Output layout
 | 
			
		||||
            template_name: Template name (e.g., "bmm", "mm", "mm_persistent_tma")
 | 
			
		||||
            template: Template object (KernelTemplate or ExternKernelChoice)
 | 
			
		||||
            op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm")
 | 
			
		||||
            kwarg_overrides: Optional dict of kwargs to override for the template heuristic
 | 
			
		||||
                             these only override the per config kwargs, not the extra kwargs
 | 
			
		||||
@ -125,6 +127,9 @@ class InductorChoices:
 | 
			
		||||
        if len(input_tensors) < 2:
 | 
			
		||||
            raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
 | 
			
		||||
 | 
			
		||||
        # Extract template_name from the template object
 | 
			
		||||
        template_name = template.uid
 | 
			
		||||
 | 
			
		||||
        # Extract device_type from kernel_inputs
 | 
			
		||||
        device_type = kernel_inputs.device_type
 | 
			
		||||
        assert device_type is not None, "get_mm_configs requires a valid device type"
 | 
			
		||||
 | 
			
		||||
@ -199,7 +199,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
 | 
			
		||||
    choices: list[ChoiceCaller] = []
 | 
			
		||||
    if use_aten_gemm_kernels():
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, layout, aten_handler.uid, name, aten_extra_kwargs
 | 
			
		||||
            kernel_inputs, layout, aten_handler, name, aten_extra_kwargs
 | 
			
		||||
        ):
 | 
			
		||||
            aten_handler.maybe_append_choice(
 | 
			
		||||
                choices,
 | 
			
		||||
@ -212,7 +212,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
 | 
			
		||||
        assert out_dtype is None, "out_dtype is not supported for Triton"
 | 
			
		||||
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, layout, bmm_template.uid, name
 | 
			
		||||
            kernel_inputs, layout, bmm_template, name
 | 
			
		||||
        ):
 | 
			
		||||
            bmm_template.maybe_append_choice(
 | 
			
		||||
                choices,
 | 
			
		||||
@ -280,7 +280,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
 | 
			
		||||
    choices: list[ChoiceCaller] = []
 | 
			
		||||
    if use_aten_gemm_kernels():
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, layout, aten_baddbmm.uid, name
 | 
			
		||||
            kernel_inputs, layout, aten_baddbmm, name
 | 
			
		||||
        ):
 | 
			
		||||
            aten_baddbmm.maybe_append_choice(
 | 
			
		||||
                choices,
 | 
			
		||||
@ -292,7 +292,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            layout,
 | 
			
		||||
            bmm_template.uid,
 | 
			
		||||
            bmm_template,
 | 
			
		||||
            name,
 | 
			
		||||
        ):
 | 
			
		||||
            bmm_template.maybe_append_choice(
 | 
			
		||||
 | 
			
		||||
@ -753,7 +753,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
    choices: list[ChoiceCaller] = []
 | 
			
		||||
    if use_aten_gemm_kernels():
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, aten_layout, aten_mm.uid, "mm"
 | 
			
		||||
            kernel_inputs, aten_layout, aten_mm, "mm"
 | 
			
		||||
        ):
 | 
			
		||||
            aten_mm.maybe_append_choice(
 | 
			
		||||
                choices=choices,
 | 
			
		||||
@ -765,7 +765,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
    if is_nonzero and use_triton_template(layout):
 | 
			
		||||
        # Get template params using the new unified function
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, layout, mm_template.uid, "mm"
 | 
			
		||||
            kernel_inputs, layout, mm_template, "mm"
 | 
			
		||||
        ):
 | 
			
		||||
            mm_template.maybe_append_choice(
 | 
			
		||||
                choices,
 | 
			
		||||
@ -776,7 +776,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
        if use_triton_tma_template(mat1, mat2):
 | 
			
		||||
            # Get TMA template params using the new unified function
 | 
			
		||||
            for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
                kernel_inputs, layout, persistent_tma_mm_template.uid, "mm"
 | 
			
		||||
                kernel_inputs, layout, persistent_tma_mm_template, "mm"
 | 
			
		||||
            ):
 | 
			
		||||
                persistent_tma_mm_template.maybe_append_choice(
 | 
			
		||||
                    choices,
 | 
			
		||||
@ -787,7 +787,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
        # Only do split-k optimization if K is much larger than m, n and m, n are small
 | 
			
		||||
        if use_decompose_k_choice(m, n, k):
 | 
			
		||||
            for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
                kernel_inputs, layout, decompose_k_subgraph_template.uid, "mm"
 | 
			
		||||
                kernel_inputs, layout, decompose_k_subgraph_template, "mm"
 | 
			
		||||
            ):
 | 
			
		||||
                decompose_k_subgraph_template.maybe_append_choice(
 | 
			
		||||
                    choices,
 | 
			
		||||
@ -795,7 +795,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
                    **extra_kwargs,
 | 
			
		||||
                )
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, layout, mm_contiguous_subgraph_template.name, "mm"
 | 
			
		||||
            kernel_inputs, layout, mm_contiguous_subgraph_template, "mm"
 | 
			
		||||
        ):
 | 
			
		||||
            mm_contiguous_subgraph_template.maybe_append_choice(
 | 
			
		||||
                choices,
 | 
			
		||||
@ -841,7 +841,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
            # while we transition to the unified kwargs retrieval
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            layout,
 | 
			
		||||
            mm_template.uid,
 | 
			
		||||
            mm_template,
 | 
			
		||||
            "mm-ah",
 | 
			
		||||
        ):
 | 
			
		||||
            assert not kwargs, "mm-ah should not have any extra kwargs"
 | 
			
		||||
@ -904,7 +904,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
    m, n, k, layout, mat1, mat2 = mm_args(
 | 
			
		||||
        mat1, mat2, layout=layout, out_dtype=torch.int32
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    name = "int_mm"
 | 
			
		||||
    # below is for getting an overview logging info of inductor mms
 | 
			
		||||
    counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1
 | 
			
		||||
    log.info(
 | 
			
		||||
@ -925,7 +925,10 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
    kernel_inputs = MMKernelInputs([mat1, mat2])
 | 
			
		||||
    if use_aten_gemm_kernels():
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, layout, aten__int_mm.uid, "int_mm"
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            layout,
 | 
			
		||||
            aten__int_mm,
 | 
			
		||||
            name,
 | 
			
		||||
        ):
 | 
			
		||||
            aten__int_mm.maybe_append_choice(
 | 
			
		||||
                choices=choices,
 | 
			
		||||
@ -933,14 +936,14 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
                **extra_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    if use_cutlass and _use_cutlass_for_op("int_mm"):
 | 
			
		||||
    if use_cutlass and _use_cutlass_for_op(name):
 | 
			
		||||
        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
 | 
			
		||||
            choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if is_nonzero and use_triton_template(layout, enable_int32=True):
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, layout, mm_template.uid, "int_mm"
 | 
			
		||||
            kernel_inputs, layout, mm_template, name
 | 
			
		||||
        ):
 | 
			
		||||
            mm_template.maybe_append_choice(
 | 
			
		||||
                choices,
 | 
			
		||||
@ -948,7 +951,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
 | 
			
		||||
                **extra_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    return autotune_select_algorithm("int_mm", choices, kernel_inputs.nodes(), layout)
 | 
			
		||||
    return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@register_lowering(aten.addmm, type_promotion_kind=None)
 | 
			
		||||
@ -998,7 +1001,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            aten_layout,
 | 
			
		||||
            aten_addmm.uid,
 | 
			
		||||
            aten_addmm,
 | 
			
		||||
            name,
 | 
			
		||||
        ):
 | 
			
		||||
            aten_addmm.maybe_append_choice(
 | 
			
		||||
@ -1012,7 +1015,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            aten_layout,
 | 
			
		||||
            aten_addmm.uid,
 | 
			
		||||
            aten_addmm,
 | 
			
		||||
            name,
 | 
			
		||||
        ):
 | 
			
		||||
            aten_addmm.maybe_append_choice(
 | 
			
		||||
@ -1023,7 +1026,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            aten_layout,
 | 
			
		||||
            aten_bias_addmm.uid,
 | 
			
		||||
            aten_bias_addmm,
 | 
			
		||||
            name,
 | 
			
		||||
        ):
 | 
			
		||||
            aten_bias_addmm.maybe_append_choice(
 | 
			
		||||
@ -1038,7 +1041,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            layout,
 | 
			
		||||
            mm_template.uid,
 | 
			
		||||
            mm_template,
 | 
			
		||||
            name,
 | 
			
		||||
        ):
 | 
			
		||||
            mm_template.maybe_append_choice(
 | 
			
		||||
@ -1052,7 +1055,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
 | 
			
		||||
            for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
                kernel_inputs,
 | 
			
		||||
                layout,
 | 
			
		||||
                persistent_tma_mm_template.uid,
 | 
			
		||||
                persistent_tma_mm_template,
 | 
			
		||||
                name,
 | 
			
		||||
            ):
 | 
			
		||||
                persistent_tma_mm_template.maybe_append_choice(
 | 
			
		||||
@ -1064,7 +1067,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            layout,
 | 
			
		||||
            addmm_contiguous_subgraph_template.name,
 | 
			
		||||
            addmm_contiguous_subgraph_template,
 | 
			
		||||
            "addmm",
 | 
			
		||||
        ):
 | 
			
		||||
            addmm_contiguous_subgraph_template.maybe_append_choice(
 | 
			
		||||
@ -1229,7 +1232,7 @@ def tuned_scaled_mm(
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            layout,
 | 
			
		||||
            aten__fp8_mm.uid,
 | 
			
		||||
            aten__fp8_mm,
 | 
			
		||||
            name,
 | 
			
		||||
            kwarg_overrides=aten_extra_kwargs,
 | 
			
		||||
        ):
 | 
			
		||||
@ -1254,7 +1257,7 @@ def tuned_scaled_mm(
 | 
			
		||||
            for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
                kernel_inputs,
 | 
			
		||||
                layout,
 | 
			
		||||
                scaled_mm_device_tma_template.uid,
 | 
			
		||||
                scaled_mm_device_tma_template,
 | 
			
		||||
                name,
 | 
			
		||||
                overriders,
 | 
			
		||||
            ):
 | 
			
		||||
@ -1268,7 +1271,7 @@ def tuned_scaled_mm(
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs,
 | 
			
		||||
            layout,
 | 
			
		||||
            mm_template.uid,
 | 
			
		||||
            mm_template,
 | 
			
		||||
            name,
 | 
			
		||||
            overriders,
 | 
			
		||||
        ):
 | 
			
		||||
 | 
			
		||||
@ -130,7 +130,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
 | 
			
		||||
    # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
 | 
			
		||||
    m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
 | 
			
		||||
    m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
 | 
			
		||||
    name = "mm_plus_mm"
 | 
			
		||||
 | 
			
		||||
    # Optimization is optional, because we can always just not do the fusion
 | 
			
		||||
    if (
 | 
			
		||||
        m1 * n1 == 0
 | 
			
		||||
@ -157,7 +157,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
 | 
			
		||||
    choices: list[ChoiceCaller] = []
 | 
			
		||||
    if use_aten_gemm_kernels():
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, layout1, aten_mm_plus_mm.uid, name
 | 
			
		||||
            kernel_inputs, layout1, aten_mm_plus_mm, "mm_plus_mm"
 | 
			
		||||
        ):
 | 
			
		||||
            aten_mm_plus_mm.maybe_append_choice(
 | 
			
		||||
                choices,
 | 
			
		||||
@ -168,7 +168,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
 | 
			
		||||
    if use_triton_template(layout1):
 | 
			
		||||
        # Get template params using the new unified function
 | 
			
		||||
        for kwargs, extra_kwargs in V.choices.get_mm_configs(
 | 
			
		||||
            kernel_inputs, layout1, mm_plus_mm_template.uid, name
 | 
			
		||||
            kernel_inputs, layout1, mm_plus_mm_template, "mm_plus_mm"
 | 
			
		||||
        ):
 | 
			
		||||
            # Apply BLOCK_K constraint specific to mm_plus_mm
 | 
			
		||||
            # see https://github.com/triton-lang/triton/issues/1298
 | 
			
		||||
@ -180,4 +180,6 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
 | 
			
		||||
                    **extra_kwargs,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout1)
 | 
			
		||||
    return autotune_select_algorithm(
 | 
			
		||||
        "mm_plus_mm", choices, kernel_inputs.nodes(), layout1
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user