mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
[inductor][ez] pass template rather than template.uid (#161343)
# why - simpler interface - enables future of extracting more things out of the template e.g. a hash # what V.choices.get_mm_configs now takes the whole template rather than just the template.uid # testing ``` python3 -bb -m pytest test/inductor/test_max_autotune.py -v ``` Differential Revision: [D81520576](https://our.internmc.facebook.com/intern/diff/D81520576) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161343 Approved by: https://github.com/jansel ghstack dependencies: #162075, #161340, #161341, #161342
This commit is contained in:
committed by
PyTorch MergeBot
parent
af590cb729
commit
a301dc3b60
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -33,8 +33,10 @@ if TYPE_CHECKING:
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .codegen.common import KernelTemplate
|
||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||
from .codegen.triton import TritonKernel
|
||||
from .select_algorithm import ExternKernelChoice
|
||||
|
||||
|
||||
class Sortable(typing.Protocol):
|
||||
@ -104,7 +106,7 @@ class InductorChoices:
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
template_name: str,
|
||||
template: Union[KernelTemplate, ExternKernelChoice],
|
||||
op_name: str,
|
||||
kwarg_overrides: Optional[dict[str, Any]] = None,
|
||||
) -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]:
|
||||
@ -114,7 +116,7 @@ class InductorChoices:
|
||||
Args:
|
||||
kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
|
||||
layout: Output layout
|
||||
template_name: Template name (e.g., "bmm", "mm", "mm_persistent_tma")
|
||||
template: Template object (KernelTemplate or ExternKernelChoice)
|
||||
op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm")
|
||||
kwarg_overrides: Optional dict of kwargs to override for the template heuristic
|
||||
these only override the per config kwargs, not the extra kwargs
|
||||
@ -125,6 +127,9 @@ class InductorChoices:
|
||||
if len(input_tensors) < 2:
|
||||
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
|
||||
|
||||
# Extract template_name from the template object
|
||||
template_name = template.uid
|
||||
|
||||
# Extract device_type from kernel_inputs
|
||||
device_type = kernel_inputs.device_type
|
||||
assert device_type is not None, "get_mm_configs requires a valid device type"
|
||||
|
||||
@ -199,7 +199,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_aten_gemm_kernels():
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, aten_handler.uid, name, aten_extra_kwargs
|
||||
kernel_inputs, layout, aten_handler, name, aten_extra_kwargs
|
||||
):
|
||||
aten_handler.maybe_append_choice(
|
||||
choices,
|
||||
@ -212,7 +212,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
assert out_dtype is None, "out_dtype is not supported for Triton"
|
||||
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, bmm_template.uid, name
|
||||
kernel_inputs, layout, bmm_template, name
|
||||
):
|
||||
bmm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -280,7 +280,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_aten_gemm_kernels():
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, aten_baddbmm.uid, name
|
||||
kernel_inputs, layout, aten_baddbmm, name
|
||||
):
|
||||
aten_baddbmm.maybe_append_choice(
|
||||
choices,
|
||||
@ -292,7 +292,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
bmm_template.uid,
|
||||
bmm_template,
|
||||
name,
|
||||
):
|
||||
bmm_template.maybe_append_choice(
|
||||
|
||||
@ -753,7 +753,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_aten_gemm_kernels():
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, aten_layout, aten_mm.uid, "mm"
|
||||
kernel_inputs, aten_layout, aten_mm, "mm"
|
||||
):
|
||||
aten_mm.maybe_append_choice(
|
||||
choices=choices,
|
||||
@ -765,7 +765,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
if is_nonzero and use_triton_template(layout):
|
||||
# Get template params using the new unified function
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template.uid, "mm"
|
||||
kernel_inputs, layout, mm_template, "mm"
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -776,7 +776,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
if use_triton_tma_template(mat1, mat2):
|
||||
# Get TMA template params using the new unified function
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, persistent_tma_mm_template.uid, "mm"
|
||||
kernel_inputs, layout, persistent_tma_mm_template, "mm"
|
||||
):
|
||||
persistent_tma_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -787,7 +787,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
# Only do split-k optimization if K is much larger than m, n and m, n are small
|
||||
if use_decompose_k_choice(m, n, k):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, decompose_k_subgraph_template.uid, "mm"
|
||||
kernel_inputs, layout, decompose_k_subgraph_template, "mm"
|
||||
):
|
||||
decompose_k_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -795,7 +795,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
**extra_kwargs,
|
||||
)
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_contiguous_subgraph_template.name, "mm"
|
||||
kernel_inputs, layout, mm_contiguous_subgraph_template, "mm"
|
||||
):
|
||||
mm_contiguous_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -841,7 +841,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
# while we transition to the unified kwargs retrieval
|
||||
kernel_inputs,
|
||||
layout,
|
||||
mm_template.uid,
|
||||
mm_template,
|
||||
"mm-ah",
|
||||
):
|
||||
assert not kwargs, "mm-ah should not have any extra kwargs"
|
||||
@ -904,7 +904,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||
m, n, k, layout, mat1, mat2 = mm_args(
|
||||
mat1, mat2, layout=layout, out_dtype=torch.int32
|
||||
)
|
||||
|
||||
name = "int_mm"
|
||||
# below is for getting an overview logging info of inductor mms
|
||||
counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1
|
||||
log.info(
|
||||
@ -925,7 +925,10 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||
kernel_inputs = MMKernelInputs([mat1, mat2])
|
||||
if use_aten_gemm_kernels():
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, aten__int_mm.uid, "int_mm"
|
||||
kernel_inputs,
|
||||
layout,
|
||||
aten__int_mm,
|
||||
name,
|
||||
):
|
||||
aten__int_mm.maybe_append_choice(
|
||||
choices=choices,
|
||||
@ -933,14 +936,14 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if use_cutlass and _use_cutlass_for_op("int_mm"):
|
||||
if use_cutlass and _use_cutlass_for_op(name):
|
||||
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
||||
choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True
|
||||
)
|
||||
|
||||
if is_nonzero and use_triton_template(layout, enable_int32=True):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, mm_template.uid, "int_mm"
|
||||
kernel_inputs, layout, mm_template, name
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
choices,
|
||||
@ -948,7 +951,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
return autotune_select_algorithm("int_mm", choices, kernel_inputs.nodes(), layout)
|
||||
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
|
||||
|
||||
|
||||
@register_lowering(aten.addmm, type_promotion_kind=None)
|
||||
@ -998,7 +1001,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
aten_layout,
|
||||
aten_addmm.uid,
|
||||
aten_addmm,
|
||||
name,
|
||||
):
|
||||
aten_addmm.maybe_append_choice(
|
||||
@ -1012,7 +1015,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
aten_layout,
|
||||
aten_addmm.uid,
|
||||
aten_addmm,
|
||||
name,
|
||||
):
|
||||
aten_addmm.maybe_append_choice(
|
||||
@ -1023,7 +1026,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
aten_layout,
|
||||
aten_bias_addmm.uid,
|
||||
aten_bias_addmm,
|
||||
name,
|
||||
):
|
||||
aten_bias_addmm.maybe_append_choice(
|
||||
@ -1038,7 +1041,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
mm_template.uid,
|
||||
mm_template,
|
||||
name,
|
||||
):
|
||||
mm_template.maybe_append_choice(
|
||||
@ -1052,7 +1055,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
persistent_tma_mm_template.uid,
|
||||
persistent_tma_mm_template,
|
||||
name,
|
||||
):
|
||||
persistent_tma_mm_template.maybe_append_choice(
|
||||
@ -1064,7 +1067,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
addmm_contiguous_subgraph_template.name,
|
||||
addmm_contiguous_subgraph_template,
|
||||
"addmm",
|
||||
):
|
||||
addmm_contiguous_subgraph_template.maybe_append_choice(
|
||||
@ -1229,7 +1232,7 @@ def tuned_scaled_mm(
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
aten__fp8_mm.uid,
|
||||
aten__fp8_mm,
|
||||
name,
|
||||
kwarg_overrides=aten_extra_kwargs,
|
||||
):
|
||||
@ -1254,7 +1257,7 @@ def tuned_scaled_mm(
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
scaled_mm_device_tma_template.uid,
|
||||
scaled_mm_device_tma_template,
|
||||
name,
|
||||
overriders,
|
||||
):
|
||||
@ -1268,7 +1271,7 @@ def tuned_scaled_mm(
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
mm_template.uid,
|
||||
mm_template,
|
||||
name,
|
||||
overriders,
|
||||
):
|
||||
|
||||
@ -130,7 +130,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
||||
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
|
||||
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
||||
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
|
||||
name = "mm_plus_mm"
|
||||
|
||||
# Optimization is optional, because we can always just not do the fusion
|
||||
if (
|
||||
m1 * n1 == 0
|
||||
@ -157,7 +157,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_aten_gemm_kernels():
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout1, aten_mm_plus_mm.uid, name
|
||||
kernel_inputs, layout1, aten_mm_plus_mm, "mm_plus_mm"
|
||||
):
|
||||
aten_mm_plus_mm.maybe_append_choice(
|
||||
choices,
|
||||
@ -168,7 +168,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
||||
if use_triton_template(layout1):
|
||||
# Get template params using the new unified function
|
||||
for kwargs, extra_kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout1, mm_plus_mm_template.uid, name
|
||||
kernel_inputs, layout1, mm_plus_mm_template, "mm_plus_mm"
|
||||
):
|
||||
# Apply BLOCK_K constraint specific to mm_plus_mm
|
||||
# see https://github.com/triton-lang/triton/issues/1298
|
||||
@ -180,4 +180,6 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout1)
|
||||
return autotune_select_algorithm(
|
||||
"mm_plus_mm", choices, kernel_inputs.nodes(), layout1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user