Files
pytorch/torch/_inductor/template_heuristics/aten.py
Ruben Rodriguez Buchillon d91eecc9a5 [inductor][template heuristics] don't take layout to generate choices (#162238)
# why

- unnecessary as we only ever need to know the dtype and maybe the
  device
- we already take in the kernel inputs which have the device
- enable us to specify the layout after finding all the configs
  but before generating the ChoiceCallers

# what

- replace all calls in template_heuristics that used to take Layout
  with now just taking out_dtype

# testing

ci

Differential Revision: [D81820115](https://our.internmc.facebook.com/intern/diff/D81820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162238
Approved by: https://github.com/eellison
ghstack dependencies: #161347, #161348, #161349
2025-09-09 17:17:04 +00:00

81 lines
2.7 KiB
Python

from __future__ import annotations
from typing import Any, TYPE_CHECKING
from torch._inductor import config as inductor_config
from ..kernel.bmm import aten_baddbmm, aten_bmm, aten_bmm_dtype
from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm, aten_mm
from ..kernel.mm_plus_mm import aten_mm_plus_mm
from .base import TemplateConfigHeuristics
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
from .registry import register_template_heuristic
if TYPE_CHECKING:
from collections.abc import Generator
from ..kernel_inputs import KernelInputs
# These are all labeled as device type None to indicate that they
# are valid for all device types
@register_template_heuristic(aten_mm.uid, None)
@register_template_heuristic(aten__fp8_mm.uid, None)
@register_template_heuristic(aten__int_mm.uid, None)
@register_template_heuristic(aten_bmm.uid, None)
@register_template_heuristic(aten_mm_plus_mm.uid, None)
# bmm dtype is only valid on cuda
@register_template_heuristic(aten_bmm_dtype.uid, "cuda")
class ATenConfigHeuristics(TemplateConfigHeuristics):
"""
Pseudo heuristic to make ATen choices go through the same flow as other templates
This is a single choice without kwargs
If you want to use this with an ATen choice that has kwargs, just subclass
"""
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
yield dict()
# None here indicates that this is valid for all device types on that op
# Note (None, op) takes precedence over (device_type, None)
@register_template_heuristic(aten_addmm.uid, None, op_name="addmm")
@register_template_heuristic(aten_baddbmm.uid, None, op_name="baddbmm")
class ATenAddMMConfigHeuristics(ATenConfigHeuristics):
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> dict[str, Any]:
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
alpha = kernel_inputs.get_scalar("alpha")
beta = kernel_inputs.get_scalar("beta")
return {
**kwargs,
"alpha": alpha,
"beta": beta,
}
@register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm")
class ATenBiasAddMMConfigHeuristics(
ATenAddMMConfigHeuristics, GemmMaxAutotuneTemplateConfigHeuristics
):
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
nodes = kernel_inputs.nodes()
# for addmm, bias is the first input
bias = nodes[0]
if bias.get_stride()[0] == 0 and inductor_config.triton.autotune_cublasLt:
yield dict()