mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-17 08:11:08 +08:00
755 lines
29 KiB
Python
755 lines
29 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import itertools
|
|
from functools import partial
|
|
from threading import Lock
|
|
from typing import Any, Callable, TYPE_CHECKING
|
|
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from . import config
|
|
from .utils import get_backend_num_stages
|
|
from .virtualized import V
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
from triton import Config as TritonConfig
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BaseConfig:
|
|
"""
|
|
Base Gemm configuration used for most backends (CPU, CUDA)
|
|
"""
|
|
|
|
block_m: int
|
|
block_n: int
|
|
block_k: int
|
|
num_stages: int
|
|
num_warps: int
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GemmConfig(BaseConfig):
|
|
"""
|
|
Gemm configuration used for most backends (CPU, CUDA)
|
|
"""
|
|
|
|
group_m: int = 8
|
|
|
|
|
|
ConvConfig = BaseConfig
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ROCmGemmConfig(GemmConfig):
|
|
"""
|
|
ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs
|
|
"""
|
|
|
|
matrix_instr_nonkdim: int = 16
|
|
waves_per_eu: int = 0
|
|
kpack: int = 2
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ROCmConvConfig(ConvConfig):
|
|
"""
|
|
ROCm subclass for Conv, with AMD backend specific tuneable kernargs
|
|
"""
|
|
|
|
matrix_instr_nonkdim: int = 16
|
|
waves_per_eu: int = 0
|
|
kpack: int = 2
|
|
|
|
|
|
class BaseHeuristicSingleton(type):
|
|
"""
|
|
Thread-safe implementation of single to be used in the config heuristic subclasses
|
|
to ensure heavy __init__ calls are not repeatedly run
|
|
"""
|
|
|
|
_instances: dict[type[Any], Any] = {}
|
|
_lock: Lock = Lock()
|
|
|
|
def __call__(
|
|
cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any
|
|
) -> BaseConfigHeuristic:
|
|
with cls._lock:
|
|
if cls not in cls._instances:
|
|
instance = super().__call__()
|
|
cls._instances[cls] = instance
|
|
return cls._instances[cls]
|
|
|
|
import os
|
|
|
|
class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|
"""
|
|
Base class for mm_configs, device specific triton kernels config inherit from here
|
|
"""
|
|
@property
|
|
def mm_configs(self) -> list[BaseConfig]:
|
|
if os.environ.get("TORCHINDUCTOR_NEW_CONFIGS", "0") == "1" or config.new_configs:
|
|
return [
|
|
# GemmConfig(16, 16, 128, 5, 1),
|
|
# GemmConfig(16, 16, 256, 4, 1),
|
|
GemmConfig(64, 16, 128, 4, 4),
|
|
GemmConfig(64, 16, 256, 4, 4),
|
|
GemmConfig(64, 32, 128, 4, 4),
|
|
GemmConfig(64, 32, 128, 5, 8),
|
|
# GemmConfig(63, 32, 256, 1, 8),
|
|
GemmConfig(64, 64, 128, 4, 4),
|
|
GemmConfig(64, 128, 64, 4, 4),
|
|
GemmConfig(64, 128, 128, 3, 4),
|
|
GemmConfig(128, 16, 128, 5, 8),
|
|
GemmConfig(128, 128, 32, 5, 8),
|
|
GemmConfig(128, 128, 64, 3, 4),
|
|
GemmConfig(128, 128, 64, 3, 8),
|
|
GemmConfig(128, 128, 64, 4, 4),
|
|
GemmConfig(128, 128, 64, 4, 8),
|
|
GemmConfig(128, 256, 32, 5, 8),
|
|
GemmConfig(128, 256, 64, 3, 8),
|
|
GemmConfig(128, 256, 64, 4, 8),
|
|
#GemmConfig(128, 256, 64, 5, 8),
|
|
GemmConfig(256, 128, 32, 5, 8)
|
|
]
|
|
else:
|
|
return [
|
|
GemmConfig(32, 32, 16, 1, 2),
|
|
GemmConfig(32, 32, 128, 2, 4),
|
|
GemmConfig(32, 64, 32, 5, 8),
|
|
GemmConfig(64, 32, 32, 5, 8),
|
|
GemmConfig(64, 32, 128, 5, 4),
|
|
GemmConfig(64, 64, 16, 2, 4),
|
|
GemmConfig(64, 64, 32, 2, 4),
|
|
GemmConfig(64, 64, 64, 3, 8),
|
|
GemmConfig(64, 64, 128, 5, 4),
|
|
GemmConfig(64, 128, 32, 3, 4),
|
|
GemmConfig(64, 128, 32, 4, 8),
|
|
GemmConfig(64, 128, 64, 3, 4),
|
|
GemmConfig(64, 128, 128, 4, 4),
|
|
GemmConfig(128, 64, 32, 3, 4),
|
|
GemmConfig(128, 64, 32, 4, 8),
|
|
GemmConfig(128, 128, 32, 2, 8),
|
|
GemmConfig(128, 128, 32, 3, 4),
|
|
GemmConfig(128, 128, 64, 3, 4),
|
|
GemmConfig(128, 128, 64, 5, 8),
|
|
]
|
|
|
|
def __init__(self) -> None:
|
|
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
|
# will be utilised on the target platform. The configs are as follows:
|
|
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
|
# self.mm_configs: list[BaseConfig] = [
|
|
# GemmConfig(32, 32, 16, 1, 2),
|
|
# GemmConfig(32, 32, 128, 2, 4),
|
|
# GemmConfig(32, 64, 32, 5, 8),
|
|
# GemmConfig(64, 32, 32, 5, 8),
|
|
# GemmConfig(64, 32, 128, 5, 4),
|
|
# GemmConfig(64, 64, 16, 2, 4),
|
|
# GemmConfig(64, 64, 32, 2, 4),
|
|
# GemmConfig(64, 64, 64, 3, 8),
|
|
# GemmConfig(64, 64, 128, 5, 4),
|
|
# GemmConfig(64, 128, 32, 3, 4),
|
|
# GemmConfig(64, 128, 32, 4, 8),
|
|
# GemmConfig(64, 128, 64, 3, 4),
|
|
# GemmConfig(64, 128, 128, 4, 4),
|
|
# GemmConfig(128, 64, 32, 3, 4),
|
|
# GemmConfig(128, 64, 32, 4, 8),
|
|
# GemmConfig(128, 128, 32, 2, 8),
|
|
# GemmConfig(128, 128, 32, 3, 4),
|
|
# GemmConfig(128, 128, 64, 3, 4),
|
|
# GemmConfig(128, 128, 64, 5, 8),
|
|
# ]
|
|
|
|
# Exhaustive search for mm configs
|
|
self.exhaustive_configs: list[BaseConfig] = [
|
|
GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m)
|
|
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
|
[16, 32, 64, 128, 256], repeat=3
|
|
)
|
|
for num_stages in [1, 2, 3, 4, 5]
|
|
for num_warps in [2, 4, 8]
|
|
for group_m in [8]
|
|
]
|
|
|
|
# these are only used in tuned_mm when AutoHeuristic is enabled
|
|
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
|
|
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
|
|
# which saves compilation time (since less configs are autotuned) and potentially increase performance
|
|
# because the learned heuristic might predict a config that is not part mm_configs
|
|
self.extra_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(16, 32, 16, 3, 2),
|
|
GemmConfig(16, 32, 32, 4, 2),
|
|
GemmConfig(16, 32, 32, 5, 2),
|
|
GemmConfig(64, 64, 128, 3, 4),
|
|
GemmConfig(128, 64, 32, 2, 2),
|
|
GemmConfig(128, 64, 64, 3, 8),
|
|
GemmConfig(128, 64, 128, 4, 8),
|
|
GemmConfig(128, 128, 32, 4, 4),
|
|
GemmConfig(128, 128, 64, 3, 8),
|
|
GemmConfig(128, 128, 64, 5, 4),
|
|
]
|
|
|
|
self.int8_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(64, 64, 32, 2, 4),
|
|
GemmConfig(64, 128, 32, 3, 4),
|
|
GemmConfig(128, 64, 32, 3, 4),
|
|
GemmConfig(64, 128, 32, 4, 8),
|
|
GemmConfig(128, 64, 32, 4, 8),
|
|
GemmConfig(64, 32, 32, 5, 8),
|
|
GemmConfig(32, 64, 32, 5, 8),
|
|
GemmConfig(128, 128, 32, 2, 8),
|
|
GemmConfig(64, 64, 64, 3, 8),
|
|
GemmConfig(128, 256, 128, 3, 8),
|
|
GemmConfig(256, 128, 128, 3, 8),
|
|
]
|
|
|
|
self.mixed_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(16, 128, 256, 3, 4),
|
|
GemmConfig(16, 128, 256, 5, 8),
|
|
]
|
|
|
|
self.persistent_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(128, 256, 64, 3, 8),
|
|
GemmConfig(128, 128, 64, 3, 8),
|
|
GemmConfig(128, 128, 128, 3, 8),
|
|
GemmConfig(128, 128, 128, 3, 4),
|
|
GemmConfig(128, 128, 64, 4, 8),
|
|
GemmConfig(128, 128, 64, 5, 8),
|
|
GemmConfig(256, 128, 64, 4, 8),
|
|
GemmConfig(128, 128, 64, 5, 4),
|
|
]
|
|
|
|
self.scaled_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(128, 256, 32, 3, 8),
|
|
GemmConfig(256, 128, 32, 3, 8),
|
|
GemmConfig(256, 64, 32, 4, 4),
|
|
GemmConfig(64, 256, 32, 4, 4),
|
|
GemmConfig(128, 128, 32, 4, 4),
|
|
GemmConfig(128, 64, 32, 4, 4),
|
|
GemmConfig(64, 128, 32, 4, 4),
|
|
GemmConfig(128, 32, 32, 4, 4),
|
|
GemmConfig(64, 32, 32, 5, 2),
|
|
GemmConfig(256, 128, 128, 3, 8),
|
|
GemmConfig(256, 64, 128, 4, 4),
|
|
GemmConfig(64, 256, 128, 4, 4),
|
|
GemmConfig(128, 128, 128, 4, 4),
|
|
GemmConfig(128, 64, 64, 4, 4),
|
|
GemmConfig(64, 128, 64, 4, 4),
|
|
GemmConfig(128, 32, 64, 4, 4),
|
|
GemmConfig(64, 32, 64, 5, 2),
|
|
GemmConfig(16, 32, 32, 2, 2),
|
|
GemmConfig(16, 64, 32, 2, 2),
|
|
GemmConfig(16, 128, 32, 2, 4),
|
|
GemmConfig(16, 256, 32, 2, 4),
|
|
GemmConfig(16, 32, 64, 2, 2),
|
|
GemmConfig(16, 64, 64, 2, 2),
|
|
GemmConfig(16, 128, 64, 2, 4),
|
|
GemmConfig(16, 256, 64, 2, 4),
|
|
GemmConfig(32, 32, 32, 2, 2),
|
|
GemmConfig(32, 64, 32, 2, 2),
|
|
GemmConfig(32, 128, 32, 2, 4),
|
|
GemmConfig(32, 256, 32, 2, 4),
|
|
GemmConfig(32, 32, 64, 2, 2),
|
|
GemmConfig(32, 64, 64, 2, 2),
|
|
GemmConfig(32, 128, 64, 2, 4),
|
|
GemmConfig(32, 256, 64, 2, 4),
|
|
GemmConfig(16, 32, 32, 3, 2),
|
|
GemmConfig(16, 64, 32, 3, 2),
|
|
GemmConfig(16, 128, 32, 3, 4),
|
|
GemmConfig(16, 256, 32, 3, 4),
|
|
GemmConfig(16, 32, 64, 3, 2),
|
|
GemmConfig(16, 64, 64, 3, 2),
|
|
GemmConfig(16, 128, 64, 3, 4),
|
|
GemmConfig(16, 256, 64, 3, 4),
|
|
GemmConfig(32, 32, 32, 3, 2),
|
|
GemmConfig(32, 64, 32, 3, 2),
|
|
GemmConfig(32, 128, 32, 3, 4),
|
|
GemmConfig(32, 256, 32, 3, 4),
|
|
GemmConfig(32, 32, 64, 3, 2),
|
|
GemmConfig(32, 64, 64, 3, 2),
|
|
GemmConfig(32, 128, 64, 3, 4),
|
|
GemmConfig(32, 256, 64, 3, 4),
|
|
GemmConfig(16, 32, 32, 4, 2),
|
|
GemmConfig(16, 64, 32, 4, 2),
|
|
GemmConfig(16, 128, 32, 4, 4),
|
|
GemmConfig(16, 256, 32, 4, 4),
|
|
GemmConfig(16, 32, 64, 4, 2),
|
|
GemmConfig(16, 64, 64, 4, 2),
|
|
GemmConfig(16, 128, 64, 4, 4),
|
|
GemmConfig(16, 256, 64, 4, 4),
|
|
GemmConfig(32, 32, 32, 4, 2),
|
|
GemmConfig(32, 64, 32, 4, 2),
|
|
GemmConfig(32, 128, 32, 4, 4),
|
|
GemmConfig(32, 256, 32, 4, 4),
|
|
GemmConfig(32, 32, 64, 4, 2),
|
|
GemmConfig(32, 64, 64, 4, 2),
|
|
GemmConfig(32, 128, 64, 4, 4),
|
|
GemmConfig(32, 256, 64, 4, 4),
|
|
GemmConfig(16, 32, 32, 5, 2),
|
|
GemmConfig(16, 64, 32, 5, 2),
|
|
GemmConfig(16, 128, 32, 5, 4),
|
|
GemmConfig(16, 256, 32, 5, 4),
|
|
GemmConfig(16, 32, 64, 5, 2),
|
|
GemmConfig(16, 64, 64, 5, 2),
|
|
GemmConfig(16, 128, 64, 5, 4),
|
|
GemmConfig(16, 256, 64, 5, 4),
|
|
GemmConfig(32, 32, 32, 5, 2),
|
|
GemmConfig(32, 64, 32, 5, 2),
|
|
GemmConfig(32, 128, 32, 5, 4),
|
|
GemmConfig(32, 256, 32, 5, 4),
|
|
GemmConfig(32, 32, 64, 5, 2),
|
|
GemmConfig(32, 64, 64, 5, 2),
|
|
GemmConfig(32, 128, 64, 5, 4),
|
|
GemmConfig(32, 256, 64, 5, 4),
|
|
GemmConfig(16, 32, 32, 6, 2),
|
|
GemmConfig(16, 64, 32, 6, 2),
|
|
GemmConfig(16, 128, 32, 6, 4),
|
|
GemmConfig(16, 256, 32, 6, 4),
|
|
GemmConfig(16, 32, 64, 6, 2),
|
|
GemmConfig(16, 64, 64, 6, 2),
|
|
GemmConfig(16, 128, 64, 6, 4),
|
|
GemmConfig(16, 256, 64, 6, 4),
|
|
GemmConfig(32, 32, 32, 6, 2),
|
|
GemmConfig(32, 64, 32, 6, 2),
|
|
GemmConfig(32, 128, 32, 6, 4),
|
|
GemmConfig(32, 256, 32, 6, 4),
|
|
GemmConfig(32, 32, 64, 6, 2),
|
|
GemmConfig(32, 64, 64, 6, 2),
|
|
GemmConfig(32, 128, 64, 6, 4),
|
|
GemmConfig(32, 256, 64, 6, 4),
|
|
]
|
|
|
|
self.scaled_persistent_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(128, 128, 64, 3, 8),
|
|
GemmConfig(128, 128, 128, 3, 8),
|
|
GemmConfig(128, 128, 128, 4, 8),
|
|
GemmConfig(128, 128, 128, 4, 4),
|
|
GemmConfig(128, 128, 128, 3, 4),
|
|
GemmConfig(128, 128, 128, 5, 4),
|
|
GemmConfig(128, 128, 128, 5, 8),
|
|
GemmConfig(128, 128, 128, 6, 8),
|
|
GemmConfig(128, 128, 64, 4, 8),
|
|
]
|
|
|
|
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
|
|
# slightly different pattern than rest
|
|
self.mm_plus_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(64, 64, 32, 2, 4),
|
|
GemmConfig(64, 64, 32, 3, 8),
|
|
GemmConfig(64, 64, 32, 4, 16),
|
|
GemmConfig(64, 32, 32, 4, 8),
|
|
GemmConfig(32, 64, 32, 4, 8),
|
|
GemmConfig(128, 128, 32, 1, 8),
|
|
GemmConfig(64, 64, 64, 1, 8),
|
|
GemmConfig(32, 32, 128, 1, 8),
|
|
GemmConfig(64, 64, 16, 2, 4),
|
|
GemmConfig(32, 32, 16, 1, 2),
|
|
]
|
|
|
|
self.conv_configs: list[BaseConfig] = [
|
|
ConvConfig(64, 256, 16, 2, 4),
|
|
ConvConfig(256, 64, 16, 2, 4),
|
|
ConvConfig(1024, 16, 16, 1, 8),
|
|
ConvConfig(128, 128, 32, 2, 8),
|
|
ConvConfig(64, 64, 32, 2, 4),
|
|
ConvConfig(64, 256, 32, 2, 8),
|
|
ConvConfig(256, 64, 32, 2, 8),
|
|
]
|
|
|
|
def _finalize_mm_configs(
|
|
self,
|
|
configs: list[BaseConfig],
|
|
) -> Generator[TritonConfig, None, None]:
|
|
"""
|
|
Finalizes configs after scaling, applying additional constraints.
|
|
"""
|
|
used: OrderedSet[tuple[int, ...]] = OrderedSet()
|
|
|
|
max_mm_configs = config.test_configs.max_mm_configs
|
|
|
|
for conf in configs:
|
|
# Each warp computes a 16x16 tile = 256 elements
|
|
num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
|
|
|
|
# Construct key for finding duplicate configs
|
|
key: tuple[int, ...] = (
|
|
conf.block_m,
|
|
conf.block_n,
|
|
conf.block_k,
|
|
conf.num_stages,
|
|
num_warps,
|
|
)
|
|
|
|
# Check if gemm specific arg exists - add to key if does
|
|
group_m = getattr(conf, "group_m", None)
|
|
if group_m is not None:
|
|
key += (group_m,)
|
|
|
|
if key not in used and (
|
|
max_mm_configs is None or len(used) < max_mm_configs
|
|
):
|
|
used.add(key)
|
|
kwargs = {
|
|
"BLOCK_M": conf.block_m,
|
|
"BLOCK_N": conf.block_n,
|
|
"BLOCK_K": conf.block_k,
|
|
"num_stages": conf.num_stages,
|
|
"num_warps": num_warps,
|
|
}
|
|
if group_m is not None:
|
|
kwargs["GROUP_M"] = group_m
|
|
yield self.triton_config(**kwargs)
|
|
|
|
def _scale_mm_configs(
|
|
self,
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
configs: list[BaseConfig],
|
|
scale: float,
|
|
has_int8_tensor: bool,
|
|
exclude: Callable[[int, int, int], bool],
|
|
) -> list[BaseConfig]:
|
|
"""
|
|
Scales and filters matrix multiplication configs based on input size.
|
|
"""
|
|
from .runtime.runtime_utils import next_power_of_2
|
|
|
|
min_block_size = 16
|
|
min_block_size_k = 32 if has_int8_tensor else 16
|
|
|
|
m = max(
|
|
next_power_of_2(
|
|
V.graph.sizevars.size_hint(
|
|
m,
|
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
|
)
|
|
),
|
|
min_block_size,
|
|
)
|
|
n = max(
|
|
next_power_of_2(
|
|
V.graph.sizevars.size_hint(
|
|
n,
|
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
|
)
|
|
),
|
|
min_block_size,
|
|
)
|
|
k = max(
|
|
next_power_of_2(
|
|
V.graph.sizevars.size_hint(
|
|
k,
|
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
|
)
|
|
),
|
|
min_block_size_k,
|
|
)
|
|
|
|
scaled_configs = []
|
|
for c in configs:
|
|
scaled_config = dataclasses.replace(
|
|
c,
|
|
block_m=max(min(int(c.block_m * scale), m), min_block_size),
|
|
block_n=max(min(int(c.block_n * scale), n), min_block_size),
|
|
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
|
|
)
|
|
|
|
if not exclude(
|
|
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
|
|
):
|
|
scaled_configs.append(scaled_config)
|
|
|
|
return scaled_configs
|
|
|
|
def preprocess_mm_configs(
|
|
self,
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
configs: list[BaseConfig],
|
|
has_int8_tensor: bool = False,
|
|
scale: int = 1,
|
|
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
|
|
) -> Generator[TritonConfig, None, None]:
|
|
scaled_configs = self._scale_mm_configs(
|
|
m, n, k, configs, scale, has_int8_tensor, exclude
|
|
)
|
|
return self._finalize_mm_configs(scaled_configs)
|
|
|
|
def triton_config(
|
|
self, num_stages: int, num_warps: int, **kwargs: Any
|
|
) -> TritonConfig:
|
|
from triton import Config as TritonConfig # type: ignore[attr-defined]
|
|
|
|
return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps)
|
|
|
|
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.mm_configs)
|
|
|
|
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
|
|
|
|
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
|
|
|
|
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
|
|
|
|
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
mm_configs = (
|
|
self.mm_configs + self.mixed_mm_configs
|
|
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
|
else self.mm_configs
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=mm_configs)
|
|
|
|
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.persistent_mm_configs)
|
|
|
|
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
|
|
|
|
def get_scaled_persistent_mm_configs(
|
|
self,
|
|
) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(
|
|
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
|
|
)
|
|
|
|
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
|
|
|
|
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
|
|
|
|
|
|
class CPUConfigHeuristic(BaseConfigHeuristic):
|
|
pass
|
|
|
|
|
|
class CUDAConfigHeuristic(BaseConfigHeuristic):
|
|
pass
|
|
|
|
|
|
class ROCmConfigHeuristic(BaseConfigHeuristic):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
self.default_num_stages = get_backend_num_stages()
|
|
|
|
self.mm_configs: list[BaseConfig] = [
|
|
ROCmGemmConfig(
|
|
16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4),
|
|
ROCmGemmConfig(
|
|
32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(
|
|
64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8),
|
|
ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4),
|
|
ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16),
|
|
ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(
|
|
64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8),
|
|
ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(
|
|
128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16),
|
|
ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(
|
|
128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16),
|
|
ROCmGemmConfig(
|
|
128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16),
|
|
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8),
|
|
ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16),
|
|
ROCmGemmConfig(
|
|
128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(
|
|
256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16),
|
|
ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4),
|
|
]
|
|
|
|
# Exhaustive search for mm configs
|
|
self.exhaustive_configs: list[BaseConfig] = [
|
|
ROCmGemmConfig(
|
|
BLOCK_M,
|
|
BLOCK_N,
|
|
BLOCK_K,
|
|
num_stages,
|
|
num_warps,
|
|
group_m,
|
|
matrix_instr_nonkdim,
|
|
waves_per_eu,
|
|
kpack,
|
|
)
|
|
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
|
[16, 32, 64, 128, 256], repeat=3
|
|
)
|
|
for num_stages in [1, self.default_num_stages]
|
|
for num_warps in [4, 8]
|
|
for group_m in [4, 8, 16]
|
|
for matrix_instr_nonkdim in [0, 16]
|
|
for waves_per_eu in [0, 2]
|
|
for kpack in [2]
|
|
]
|
|
|
|
def _filter_configs(
|
|
self, configs: list[BaseConfig], new_num_stages: int
|
|
) -> list[BaseConfig]:
|
|
# TODO: _filter_configs can be removed once backend specific configs are added
|
|
# for all methods
|
|
for c in configs:
|
|
c.num_stages = self.default_num_stages
|
|
return configs
|
|
|
|
def _finalize_mm_configs(
|
|
self,
|
|
configs: list[BaseConfig],
|
|
) -> Generator[TritonConfig, None, None]:
|
|
"""
|
|
Finalizes configs after scaling, applying additional constraints.
|
|
"""
|
|
used: OrderedSet[tuple[int, ...]] = OrderedSet()
|
|
|
|
max_mm_configs = config.test_configs.max_mm_configs
|
|
|
|
for conf in configs:
|
|
# Each warp computes a 16x16 tile = 256 elements
|
|
conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
|
|
|
|
# Defaults for AMD triton backend kern args if not set
|
|
matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16)
|
|
waves_per_eu = getattr(conf, "waves_per_eu", 0)
|
|
kpack = getattr(conf, "kpack", 2)
|
|
|
|
if matrix_instr_nonkdim != 0 and (
|
|
conf.block_m % matrix_instr_nonkdim != 0
|
|
or conf.block_n % matrix_instr_nonkdim != 0
|
|
):
|
|
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
|
continue
|
|
|
|
# Construct key for finding duplicate configs
|
|
key: tuple[int, ...] = (
|
|
conf.block_m,
|
|
conf.block_n,
|
|
conf.block_k,
|
|
conf.num_stages,
|
|
conf.num_warps,
|
|
waves_per_eu,
|
|
matrix_instr_nonkdim,
|
|
kpack,
|
|
)
|
|
|
|
# Check if gemm specific arg exists - add to key if does
|
|
group_m = getattr(conf, "group_m", None)
|
|
if group_m is not None:
|
|
key += (group_m,)
|
|
|
|
if waves_per_eu != 0:
|
|
waves_per_eu = int(8 // conf.num_warps)
|
|
|
|
if key not in used and (
|
|
max_mm_configs is None or len(used) < max_mm_configs
|
|
):
|
|
used.add(key)
|
|
kwargs = {
|
|
"BLOCK_M": conf.block_m,
|
|
"BLOCK_N": conf.block_n,
|
|
"BLOCK_K": conf.block_k,
|
|
"num_stages": conf.num_stages,
|
|
"num_warps": conf.num_warps,
|
|
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
"waves_per_eu": waves_per_eu,
|
|
"kpack": kpack,
|
|
}
|
|
if group_m is not None:
|
|
kwargs["GROUP_M"] = group_m
|
|
yield self.triton_config(**kwargs)
|
|
|
|
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.extra_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.int8_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
mm_configs = (
|
|
self.mm_configs + self.mixed_mm_configs
|
|
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
|
else self.mm_configs
|
|
)
|
|
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.persistent_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.scaled_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_scaled_persistent_mm_configs(
|
|
self,
|
|
) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.scaled_persistent_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
|
|
return partial(self._finalize_mm_configs, configs=filtered_configs)
|
|
|
|
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.conv_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
|
|
class XPUConfigHeuristic(BaseConfigHeuristic):
|
|
pass
|