[inductor][ez] pass template rather than template.uid (#161343)

# why

- simpler interface
- enables future of extracting more things out of the template e.g. a
  hash

# what

V.choices.get_mm_configs now takes the whole template rather than just
the template.uid

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```

Differential Revision: [D81520576](https://our.internmc.facebook.com/intern/diff/D81520576)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161343
Approved by: https://github.com/jansel
ghstack dependencies: #162075, #161340, #161341, #161342
This commit is contained in:
Ruben Rodriguez Buchillon
2025-09-05 07:38:41 -07:00
committed by PyTorch MergeBot
parent af590cb729
commit a301dc3b60
4 changed files with 41 additions and 31 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import typing
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING, Union
import sympy
@ -33,8 +33,10 @@ if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
from .codegen.common import KernelTemplate
from .codegen.simd_kernel_features import SIMDKernelFeatures
from .codegen.triton import TritonKernel
from .select_algorithm import ExternKernelChoice
class Sortable(typing.Protocol):
@ -104,7 +106,7 @@ class InductorChoices:
self,
kernel_inputs: KernelInputs,
layout: Any,
template_name: str,
template: Union[KernelTemplate, ExternKernelChoice],
op_name: str,
kwarg_overrides: Optional[dict[str, Any]] = None,
) -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]:
@ -114,7 +116,7 @@ class InductorChoices:
Args:
kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
layout: Output layout
template_name: Template name (e.g., "bmm", "mm", "mm_persistent_tma")
template: Template object (KernelTemplate or ExternKernelChoice)
op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm")
kwarg_overrides: Optional dict of kwargs to override for the template heuristic
these only override the per config kwargs, not the extra kwargs
@ -125,6 +127,9 @@ class InductorChoices:
if len(input_tensors) < 2:
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
# Extract template_name from the template object
template_name = template.uid
# Extract device_type from kernel_inputs
device_type = kernel_inputs.device_type
assert device_type is not None, "get_mm_configs requires a valid device type"

View File

@ -199,7 +199,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, aten_handler.uid, name, aten_extra_kwargs
kernel_inputs, layout, aten_handler, name, aten_extra_kwargs
):
aten_handler.maybe_append_choice(
choices,
@ -212,7 +212,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
assert out_dtype is None, "out_dtype is not supported for Triton"
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, bmm_template.uid, name
kernel_inputs, layout, bmm_template, name
):
bmm_template.maybe_append_choice(
choices,
@ -280,7 +280,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, aten_baddbmm.uid, name
kernel_inputs, layout, aten_baddbmm, name
):
aten_baddbmm.maybe_append_choice(
choices,
@ -292,7 +292,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
bmm_template.uid,
bmm_template,
name,
):
bmm_template.maybe_append_choice(

View File

@ -753,7 +753,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, aten_layout, aten_mm.uid, "mm"
kernel_inputs, aten_layout, aten_mm, "mm"
):
aten_mm.maybe_append_choice(
choices=choices,
@ -765,7 +765,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
if is_nonzero and use_triton_template(layout):
# Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template.uid, "mm"
kernel_inputs, layout, mm_template, "mm"
):
mm_template.maybe_append_choice(
choices,
@ -776,7 +776,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
if use_triton_tma_template(mat1, mat2):
# Get TMA template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, persistent_tma_mm_template.uid, "mm"
kernel_inputs, layout, persistent_tma_mm_template, "mm"
):
persistent_tma_mm_template.maybe_append_choice(
choices,
@ -787,7 +787,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
# Only do split-k optimization if K is much larger than m, n and m, n are small
if use_decompose_k_choice(m, n, k):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, decompose_k_subgraph_template.uid, "mm"
kernel_inputs, layout, decompose_k_subgraph_template, "mm"
):
decompose_k_subgraph_template.maybe_append_choice(
choices,
@ -795,7 +795,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
**extra_kwargs,
)
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_contiguous_subgraph_template.name, "mm"
kernel_inputs, layout, mm_contiguous_subgraph_template, "mm"
):
mm_contiguous_subgraph_template.maybe_append_choice(
choices,
@ -841,7 +841,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
# while we transition to the unified kwargs retrieval
kernel_inputs,
layout,
mm_template.uid,
mm_template,
"mm-ah",
):
assert not kwargs, "mm-ah should not have any extra kwargs"
@ -904,7 +904,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
m, n, k, layout, mat1, mat2 = mm_args(
mat1, mat2, layout=layout, out_dtype=torch.int32
)
name = "int_mm"
# below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1
log.info(
@ -925,7 +925,10 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
kernel_inputs = MMKernelInputs([mat1, mat2])
if use_aten_gemm_kernels():
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, aten__int_mm.uid, "int_mm"
kernel_inputs,
layout,
aten__int_mm,
name,
):
aten__int_mm.maybe_append_choice(
choices=choices,
@ -933,14 +936,14 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
**extra_kwargs,
)
if use_cutlass and _use_cutlass_for_op("int_mm"):
if use_cutlass and _use_cutlass_for_op(name):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True
)
if is_nonzero and use_triton_template(layout, enable_int32=True):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, mm_template.uid, "int_mm"
kernel_inputs, layout, mm_template, name
):
mm_template.maybe_append_choice(
choices,
@ -948,7 +951,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
**extra_kwargs,
)
return autotune_select_algorithm("int_mm", choices, kernel_inputs.nodes(), layout)
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
@register_lowering(aten.addmm, type_promotion_kind=None)
@ -998,7 +1001,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
aten_layout,
aten_addmm.uid,
aten_addmm,
name,
):
aten_addmm.maybe_append_choice(
@ -1012,7 +1015,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
aten_layout,
aten_addmm.uid,
aten_addmm,
name,
):
aten_addmm.maybe_append_choice(
@ -1023,7 +1026,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
aten_layout,
aten_bias_addmm.uid,
aten_bias_addmm,
name,
):
aten_bias_addmm.maybe_append_choice(
@ -1038,7 +1041,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
mm_template.uid,
mm_template,
name,
):
mm_template.maybe_append_choice(
@ -1052,7 +1055,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
persistent_tma_mm_template.uid,
persistent_tma_mm_template,
name,
):
persistent_tma_mm_template.maybe_append_choice(
@ -1064,7 +1067,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
addmm_contiguous_subgraph_template.name,
addmm_contiguous_subgraph_template,
"addmm",
):
addmm_contiguous_subgraph_template.maybe_append_choice(
@ -1229,7 +1232,7 @@ def tuned_scaled_mm(
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
aten__fp8_mm.uid,
aten__fp8_mm,
name,
kwarg_overrides=aten_extra_kwargs,
):
@ -1254,7 +1257,7 @@ def tuned_scaled_mm(
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
scaled_mm_device_tma_template.uid,
scaled_mm_device_tma_template,
name,
overriders,
):
@ -1268,7 +1271,7 @@ def tuned_scaled_mm(
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs,
layout,
mm_template.uid,
mm_template,
name,
overriders,
):

View File

@ -130,7 +130,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
name = "mm_plus_mm"
# Optimization is optional, because we can always just not do the fusion
if (
m1 * n1 == 0
@ -157,7 +157,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout1, aten_mm_plus_mm.uid, name
kernel_inputs, layout1, aten_mm_plus_mm, "mm_plus_mm"
):
aten_mm_plus_mm.maybe_append_choice(
choices,
@ -168,7 +168,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
if use_triton_template(layout1):
# Get template params using the new unified function
for kwargs, extra_kwargs in V.choices.get_mm_configs(
kernel_inputs, layout1, mm_plus_mm_template.uid, name
kernel_inputs, layout1, mm_plus_mm_template, "mm_plus_mm"
):
# Apply BLOCK_K constraint specific to mm_plus_mm
# see https://github.com/triton-lang/triton/issues/1298
@ -180,4 +180,6 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
**extra_kwargs,
)
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout1)
return autotune_select_algorithm(
"mm_plus_mm", choices, kernel_inputs.nodes(), layout1
)