mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 14:15:07 +08:00
Revert "[inductor] consolidate common GEMM triton param retrieval (#159383)"
This reverts commit e7cc42df58a86bee05944f6e80c535aa1d099443. Reverted https://github.com/pytorch/pytorch/pull/159383 on behalf of https://github.com/jataylo due to sorry but rocm CI is broken due to this PR ([comment](https://github.com/pytorch/pytorch/pull/159383#issuecomment-3145604831))
This commit is contained in:
@ -3,13 +3,17 @@ import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn
|
||||
from torch._inductor.utils import sympy_product
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
from .. import config as inductor_config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import _IntLike, Layout, TensorBox
|
||||
from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -45,6 +49,96 @@ def acc_type(dtype):
|
||||
return f"tl.{dtype}".replace("torch.", "")
|
||||
|
||||
|
||||
def mm_options(config, sym_m, sym_n, sym_k, layout):
|
||||
"""
|
||||
Common options to matmul triton templates.
|
||||
"""
|
||||
even_k_symbolic = (
|
||||
# it isn't worth guarding on this
|
||||
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
|
||||
)
|
||||
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
||||
not inductor_config.force_same_precision
|
||||
or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
|
||||
)
|
||||
options_dict = dict(
|
||||
EVEN_K=even_k_symbolic,
|
||||
ALLOW_TF32=allow_tf32,
|
||||
USE_FAST_ACCUM=False, # Option for _scaled_mm
|
||||
ACC_TYPE=acc_type(layout.dtype),
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
# If GROUP_M not specified then default to 8
|
||||
if "GROUP_M" not in config.kwargs:
|
||||
group_m = config.kwargs.get("GROUP_M", 8)
|
||||
options_dict["GROUP_M"] = group_m
|
||||
|
||||
return options_dict
|
||||
|
||||
|
||||
def tma_options() -> dict[str, Any]:
|
||||
from torch.utils._triton import has_triton_stable_tma_api
|
||||
|
||||
return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()}
|
||||
|
||||
|
||||
def persistent_mm_options(mat1, mat2):
|
||||
res = {
|
||||
"A_ROW_MAJOR": not mat1.layout.is_transposed(),
|
||||
"B_ROW_MAJOR": not mat2.layout.is_transposed(),
|
||||
"NUM_SMS": get_num_sms(),
|
||||
"TMA_SIZE": TMA_DESCRIPTOR_SIZE,
|
||||
}
|
||||
res.update(tma_options())
|
||||
return res
|
||||
|
||||
|
||||
def scaled_mm_options( # type: ignore[no-untyped-def]
|
||||
config, # triton.Config
|
||||
sym_m: sympy.core.numbers.Integer,
|
||||
sym_n: sympy.core.numbers.Integer,
|
||||
sym_k: sympy.core.numbers.Integer,
|
||||
layout: Layout,
|
||||
scale_a,
|
||||
scale_b,
|
||||
use_fast_accum: bool,
|
||||
device_tma: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
def are_compatible_scales(size_a, size_b) -> bool:
|
||||
# Same sized scales are compatible
|
||||
if len(size_a) == len(size_b):
|
||||
return True
|
||||
|
||||
# Both need to be scalars or len(1) tensors
|
||||
if len(size_a) <= 1 and len(size_b) <= 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
size_a, size_b = scale_a.get_size(), scale_b.get_size()
|
||||
assert are_compatible_scales(size_a, size_b), (
|
||||
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
|
||||
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
|
||||
)
|
||||
|
||||
mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout)
|
||||
|
||||
mm_template_options["ACC_TYPE"] = "tl.float32"
|
||||
mm_template_options["USE_FAST_ACCUM"] = use_fast_accum
|
||||
mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2
|
||||
|
||||
if device_tma:
|
||||
mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE
|
||||
mm_template_options["NUM_SMS"] = get_num_sms()
|
||||
|
||||
mm_template_options.update(tma_options())
|
||||
|
||||
return mm_template_options
|
||||
|
||||
|
||||
def mm_args(
|
||||
mat1,
|
||||
mat2,
|
||||
@ -87,6 +181,20 @@ def mm_args(
|
||||
return [m, n, k, layout, mat1, mat2, *others]
|
||||
|
||||
|
||||
def mm_config_kwargs(device, exclude_condition, dtype_size=None):
|
||||
if device == "cpu":
|
||||
return {
|
||||
"scale": 0.5,
|
||||
"exclude": exclude_condition,
|
||||
}
|
||||
|
||||
if dtype_size and inductor_config.max_autotune_gemm_search_space == "EXHAUSTIVE":
|
||||
return {
|
||||
"dtype_size": dtype_size,
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def addmm_epilogue(dtype, alpha, beta):
|
||||
def epilogue(acc, bias):
|
||||
if alpha != 1:
|
||||
|
||||
Reference in New Issue
Block a user