mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# why - unnecessary as we only ever need to know the dtype and maybe the device - we already take in the kernel inputs which have the device - enable us to specify the layout after finding all the configs but before generating the ChoiceCallers # what - replace all calls in template_heuristics that used to take Layout with now just taking out_dtype # testing ci Differential Revision: [D81820115](https://our.internmc.facebook.com/intern/diff/D81820115) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162238 Approved by: https://github.com/eellison ghstack dependencies: #161347, #161348, #161349
81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
from torch._inductor import config as inductor_config
|
|
|
|
from ..kernel.bmm import aten_baddbmm, aten_bmm, aten_bmm_dtype
|
|
from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm, aten_mm
|
|
from ..kernel.mm_plus_mm import aten_mm_plus_mm
|
|
from .base import TemplateConfigHeuristics
|
|
from .gemm import GemmMaxAutotuneTemplateConfigHeuristics
|
|
from .registry import register_template_heuristic
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
from ..kernel_inputs import KernelInputs
|
|
|
|
|
|
# These are all labeled as device type None to indicate that they
|
|
# are valid for all device types
|
|
@register_template_heuristic(aten_mm.uid, None)
|
|
@register_template_heuristic(aten__fp8_mm.uid, None)
|
|
@register_template_heuristic(aten__int_mm.uid, None)
|
|
@register_template_heuristic(aten_bmm.uid, None)
|
|
@register_template_heuristic(aten_mm_plus_mm.uid, None)
|
|
# bmm dtype is only valid on cuda
|
|
@register_template_heuristic(aten_bmm_dtype.uid, "cuda")
|
|
class ATenConfigHeuristics(TemplateConfigHeuristics):
|
|
"""
|
|
Pseudo heuristic to make ATen choices go through the same flow as other templates
|
|
|
|
This is a single choice without kwargs
|
|
|
|
If you want to use this with an ATen choice that has kwargs, just subclass
|
|
"""
|
|
|
|
def _get_template_configs_impl(
|
|
self,
|
|
kernel_inputs: KernelInputs,
|
|
op_name: str,
|
|
) -> Generator[dict[str, Any], None, None]:
|
|
yield dict()
|
|
|
|
|
|
# None here indicates that this is valid for all device types on that op
|
|
# Note (None, op) takes precedence over (device_type, None)
|
|
@register_template_heuristic(aten_addmm.uid, None, op_name="addmm")
|
|
@register_template_heuristic(aten_baddbmm.uid, None, op_name="baddbmm")
|
|
class ATenAddMMConfigHeuristics(ATenConfigHeuristics):
|
|
def get_extra_kwargs(
|
|
self,
|
|
kernel_inputs: KernelInputs,
|
|
op_name: str,
|
|
) -> dict[str, Any]:
|
|
kwargs = super().get_extra_kwargs(kernel_inputs, op_name)
|
|
alpha = kernel_inputs.get_scalar("alpha")
|
|
beta = kernel_inputs.get_scalar("beta")
|
|
return {
|
|
**kwargs,
|
|
"alpha": alpha,
|
|
"beta": beta,
|
|
}
|
|
|
|
|
|
@register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm")
|
|
class ATenBiasAddMMConfigHeuristics(
|
|
ATenAddMMConfigHeuristics, GemmMaxAutotuneTemplateConfigHeuristics
|
|
):
|
|
def _get_template_configs_impl(
|
|
self,
|
|
kernel_inputs: KernelInputs,
|
|
op_name: str,
|
|
) -> Generator[dict[str, Any], None, None]:
|
|
nodes = kernel_inputs.nodes()
|
|
# for addmm, bias is the first input
|
|
bias = nodes[0]
|
|
if bias.get_stride()[0] == 0 and inductor_config.triton.autotune_cublasLt:
|
|
yield dict()
|