mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Differential Revision: D76843916 Exhaustive autotuning is meant to autotune GEMM configs across the entire search space of possible configs. Some of these configs can cause extremely long compilation times and OOMs, especially with configs of the following nature: Excessive register spillage Using much larger amounts of shared memory than available on the hardware This diff prunes out those configs to make exhaustive autotuning more viable, along with supporting exhaustive autotuning for persistent+tma template and decompose_k. Previously, exhaustive autotuning would hang, now we are able to tune shapes in ~5 minutes. Below is a sample log for autotuning with exhaustive: ``` AUTOTUNE mm(1152x21504, 21504x1024) strides: [21504, 1], [1, 21504] dtypes: torch.bfloat16, torch.bfloat16 mm 0.1167 ms 100.0% triton_mm_6270 0.1172 ms 99.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=256, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_6522 0.1183 ms 98.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_persistent_tma_7482 0.1190 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, A_ROW_MAJOR=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_ROW_MAJOR=False, EVEN_K=True, GROUP_M=8, NUM_SMS=132, TMA_SIZE=128, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_persistent_tma_7483 0.1195 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, A_ROW_MAJOR=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_ROW_MAJOR=False, EVEN_K=True, GROUP_M=8, NUM_SMS=132, TMA_SIZE=128, USE_FAST_ACCUM=False, num_stages=5, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_6523 0.1274 ms 91.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_6267 0.1285 ms 90.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=256, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_6519 0.1287 ms 90.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_persistent_tma_7480 0.1298 ms 89.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, A_ROW_MAJOR=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_ROW_MAJOR=False, EVEN_K=True, GROUP_M=8, NUM_SMS=132, TMA_SIZE=128, USE_FAST_ACCUM=False, num_stages=4, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 triton_mm_persistent_tma_7312 0.1302 ms 89.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, A_ROW_MAJOR=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=256, B_ROW_MAJOR=False, EVEN_K=True, GROUP_M=8, NUM_SMS=132, TMA_SIZE=128, USE_FAST_ACCUM=False, num_stages=4, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 SingleProcess AUTOTUNE benchmarking takes 298.7185 seconds and 21.2569 seconds precompiling for 2210 choices INFO:tritonbench.utils.triton_op:Took 333894.46ms to get benchmark function for pt2_matmul_maxautotune ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156610 Approved by: https://github.com/jansel
1181 lines
44 KiB
Python
1181 lines
44 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import itertools
|
|
import math
|
|
from functools import partial
|
|
from threading import Lock
|
|
from typing import Any, Callable, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from . import config
|
|
from .utils import get_backend_num_stages
|
|
from .virtualized import V
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
from triton import Config as TritonConfig
|
|
|
|
|
|
# Gemm Configs
|
|
@dataclasses.dataclass
|
|
class BaseConfig:
|
|
"""
|
|
Base Gemm configuration used for most backends (CPU, CUDA)
|
|
"""
|
|
|
|
block_m: int
|
|
block_n: int
|
|
block_k: int
|
|
num_stages: int
|
|
num_warps: int
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GemmConfig(BaseConfig):
|
|
"""
|
|
Gemm configuration used for most backends (CPU, CUDA)
|
|
"""
|
|
|
|
group_m: int = 8
|
|
|
|
|
|
ConvConfig = BaseConfig
|
|
|
|
|
|
# FlexAttention Configs
|
|
@dataclasses.dataclass
|
|
class FlexConfig:
|
|
"""
|
|
Base Config class for flex attention
|
|
- FlexAttn forward, backward and flex decode will use this
|
|
|
|
NOTE:
|
|
For flex_attn bwd block_m and block_n are reused for block_m1, block_m2, block_n1, block_n2
|
|
|
|
"""
|
|
|
|
block_m: int
|
|
block_n: int
|
|
num_stages: int
|
|
num_warps: int
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class FlexDecodeConfig:
|
|
"""
|
|
Config class for flex decoding
|
|
"""
|
|
|
|
block_n: int
|
|
num_stages: int
|
|
num_warps: int
|
|
|
|
|
|
# ROCm classes
|
|
@dataclasses.dataclass
|
|
class ROCmGemmConfig(GemmConfig):
|
|
"""
|
|
ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs
|
|
"""
|
|
|
|
matrix_instr_nonkdim: int = 16
|
|
waves_per_eu: int = 0
|
|
kpack: int = 2
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ROCmConvConfig(ConvConfig):
|
|
"""
|
|
ROCm subclass for Conv, with AMD backend specific tuneable kernargs
|
|
"""
|
|
|
|
matrix_instr_nonkdim: int = 16
|
|
waves_per_eu: int = 0
|
|
kpack: int = 2
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ROCmFlexConfig(FlexConfig):
|
|
"""
|
|
ROCm subclass for FlexAttn, with AMD backend specific tuneable kernargs
|
|
"""
|
|
|
|
matrix_instr_nonkdim: int = 0
|
|
waves_per_eu: int = 0
|
|
kpack: int = 2
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ROCmFlexDecodeConfig(FlexDecodeConfig):
|
|
"""
|
|
ROCm subclass for FlexDecode, with AMD backend specific tuneable kernargs
|
|
"""
|
|
|
|
matrix_instr_nonkdim: int = 0
|
|
waves_per_eu: int = 0
|
|
kpack: int = 2
|
|
|
|
|
|
class BaseHeuristicSingleton(type):
|
|
"""
|
|
Thread-safe implementation of single to be used in the config heuristic subclasses
|
|
to ensure heavy __init__ calls are not repeatedly run
|
|
"""
|
|
|
|
_instances: dict[type[Any], Any] = {}
|
|
_lock: Lock = Lock()
|
|
|
|
def __call__(
|
|
cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any
|
|
) -> BaseConfigHeuristic:
|
|
with cls._lock:
|
|
if cls not in cls._instances:
|
|
instance = super().__call__()
|
|
cls._instances[cls] = instance
|
|
return cls._instances[cls]
|
|
|
|
|
|
class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|
"""
|
|
Base class for mm_configs, device specific triton kernels config inherit from here
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
|
# will be utilised on the target platform. The configs are as follows:
|
|
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
|
self.mm_configs: list[BaseConfig] = [
|
|
GemmConfig(32, 32, 16, 1, 2),
|
|
GemmConfig(32, 32, 128, 2, 4),
|
|
GemmConfig(32, 64, 32, 5, 8),
|
|
GemmConfig(64, 32, 32, 5, 8),
|
|
GemmConfig(64, 32, 128, 5, 4),
|
|
GemmConfig(64, 64, 16, 2, 4),
|
|
GemmConfig(64, 64, 32, 2, 4),
|
|
GemmConfig(64, 64, 64, 3, 8),
|
|
GemmConfig(64, 64, 128, 5, 4),
|
|
GemmConfig(64, 128, 32, 3, 4),
|
|
GemmConfig(64, 128, 32, 4, 8),
|
|
GemmConfig(64, 128, 64, 3, 4),
|
|
GemmConfig(64, 128, 128, 4, 4),
|
|
GemmConfig(128, 64, 32, 3, 4),
|
|
GemmConfig(128, 64, 32, 4, 8),
|
|
GemmConfig(128, 128, 32, 2, 8),
|
|
GemmConfig(128, 128, 32, 3, 4),
|
|
GemmConfig(128, 128, 64, 3, 4),
|
|
GemmConfig(128, 128, 64, 5, 8),
|
|
]
|
|
|
|
# Exhaustive search for mm configs
|
|
self.exhaustive_configs: list[BaseConfig] = [
|
|
GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m)
|
|
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
|
[16, 32, 64, 128, 256], repeat=3
|
|
)
|
|
for num_stages in [1, 2, 3, 4, 5]
|
|
for num_warps in [2, 4, 8]
|
|
for group_m in [8]
|
|
]
|
|
|
|
# these are only used in tuned_mm when AutoHeuristic is enabled
|
|
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
|
|
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
|
|
# which saves compilation time (since less configs are autotuned) and potentially increase performance
|
|
# because the learned heuristic might predict a config that is not part mm_configs
|
|
self.extra_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(16, 32, 16, 3, 2),
|
|
GemmConfig(16, 32, 32, 4, 2),
|
|
GemmConfig(16, 32, 32, 5, 2),
|
|
GemmConfig(64, 64, 128, 3, 4),
|
|
GemmConfig(128, 64, 32, 2, 2),
|
|
GemmConfig(128, 64, 64, 3, 8),
|
|
GemmConfig(128, 64, 128, 4, 8),
|
|
GemmConfig(128, 128, 32, 4, 4),
|
|
GemmConfig(128, 128, 64, 3, 8),
|
|
GemmConfig(128, 128, 64, 5, 4),
|
|
]
|
|
|
|
self.int8_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(64, 64, 32, 2, 4),
|
|
GemmConfig(64, 128, 32, 3, 4),
|
|
GemmConfig(128, 64, 32, 3, 4),
|
|
GemmConfig(64, 128, 32, 4, 8),
|
|
GemmConfig(128, 64, 32, 4, 8),
|
|
GemmConfig(64, 32, 32, 5, 8),
|
|
GemmConfig(32, 64, 32, 5, 8),
|
|
GemmConfig(128, 128, 32, 2, 8),
|
|
GemmConfig(64, 64, 64, 3, 8),
|
|
GemmConfig(128, 256, 128, 3, 8),
|
|
GemmConfig(256, 128, 128, 3, 8),
|
|
]
|
|
|
|
self.mixed_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(16, 128, 256, 3, 4),
|
|
GemmConfig(16, 128, 256, 5, 8),
|
|
]
|
|
|
|
self.persistent_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(128, 256, 64, 3, 8),
|
|
GemmConfig(128, 128, 64, 3, 8),
|
|
GemmConfig(128, 128, 128, 3, 8),
|
|
GemmConfig(128, 128, 128, 3, 4),
|
|
GemmConfig(128, 128, 64, 4, 8),
|
|
GemmConfig(128, 128, 64, 5, 8),
|
|
GemmConfig(256, 128, 64, 4, 8),
|
|
GemmConfig(128, 128, 64, 5, 4),
|
|
]
|
|
|
|
self.scaled_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(128, 256, 32, 3, 8),
|
|
GemmConfig(256, 128, 32, 3, 8),
|
|
GemmConfig(256, 64, 32, 4, 4),
|
|
GemmConfig(64, 256, 32, 4, 4),
|
|
GemmConfig(128, 128, 32, 4, 4),
|
|
GemmConfig(128, 64, 32, 4, 4),
|
|
GemmConfig(64, 128, 32, 4, 4),
|
|
GemmConfig(128, 32, 32, 4, 4),
|
|
GemmConfig(64, 32, 32, 5, 2),
|
|
GemmConfig(256, 128, 128, 3, 8),
|
|
GemmConfig(256, 64, 128, 4, 4),
|
|
GemmConfig(64, 256, 128, 4, 4),
|
|
GemmConfig(128, 128, 128, 4, 4),
|
|
GemmConfig(128, 64, 64, 4, 4),
|
|
GemmConfig(64, 128, 64, 4, 4),
|
|
GemmConfig(128, 32, 64, 4, 4),
|
|
GemmConfig(64, 32, 64, 5, 2),
|
|
GemmConfig(16, 32, 32, 2, 2),
|
|
GemmConfig(16, 64, 32, 2, 2),
|
|
GemmConfig(16, 128, 32, 2, 4),
|
|
GemmConfig(16, 256, 32, 2, 4),
|
|
GemmConfig(16, 32, 64, 2, 2),
|
|
GemmConfig(16, 64, 64, 2, 2),
|
|
GemmConfig(16, 128, 64, 2, 4),
|
|
GemmConfig(16, 256, 64, 2, 4),
|
|
GemmConfig(32, 32, 32, 2, 2),
|
|
GemmConfig(32, 64, 32, 2, 2),
|
|
GemmConfig(32, 128, 32, 2, 4),
|
|
GemmConfig(32, 256, 32, 2, 4),
|
|
GemmConfig(32, 32, 64, 2, 2),
|
|
GemmConfig(32, 64, 64, 2, 2),
|
|
GemmConfig(32, 128, 64, 2, 4),
|
|
GemmConfig(32, 256, 64, 2, 4),
|
|
GemmConfig(16, 32, 32, 3, 2),
|
|
GemmConfig(16, 64, 32, 3, 2),
|
|
GemmConfig(16, 128, 32, 3, 4),
|
|
GemmConfig(16, 256, 32, 3, 4),
|
|
GemmConfig(16, 32, 64, 3, 2),
|
|
GemmConfig(16, 64, 64, 3, 2),
|
|
GemmConfig(16, 128, 64, 3, 4),
|
|
GemmConfig(16, 256, 64, 3, 4),
|
|
GemmConfig(32, 32, 32, 3, 2),
|
|
GemmConfig(32, 64, 32, 3, 2),
|
|
GemmConfig(32, 128, 32, 3, 4),
|
|
GemmConfig(32, 256, 32, 3, 4),
|
|
GemmConfig(32, 32, 64, 3, 2),
|
|
GemmConfig(32, 64, 64, 3, 2),
|
|
GemmConfig(32, 128, 64, 3, 4),
|
|
GemmConfig(32, 256, 64, 3, 4),
|
|
GemmConfig(16, 32, 32, 4, 2),
|
|
GemmConfig(16, 64, 32, 4, 2),
|
|
GemmConfig(16, 128, 32, 4, 4),
|
|
GemmConfig(16, 256, 32, 4, 4),
|
|
GemmConfig(16, 32, 64, 4, 2),
|
|
GemmConfig(16, 64, 64, 4, 2),
|
|
GemmConfig(16, 128, 64, 4, 4),
|
|
GemmConfig(16, 256, 64, 4, 4),
|
|
GemmConfig(32, 32, 32, 4, 2),
|
|
GemmConfig(32, 64, 32, 4, 2),
|
|
GemmConfig(32, 128, 32, 4, 4),
|
|
GemmConfig(32, 256, 32, 4, 4),
|
|
GemmConfig(32, 32, 64, 4, 2),
|
|
GemmConfig(32, 64, 64, 4, 2),
|
|
GemmConfig(32, 128, 64, 4, 4),
|
|
GemmConfig(32, 256, 64, 4, 4),
|
|
GemmConfig(16, 32, 32, 5, 2),
|
|
GemmConfig(16, 64, 32, 5, 2),
|
|
GemmConfig(16, 128, 32, 5, 4),
|
|
GemmConfig(16, 256, 32, 5, 4),
|
|
GemmConfig(16, 32, 64, 5, 2),
|
|
GemmConfig(16, 64, 64, 5, 2),
|
|
GemmConfig(16, 128, 64, 5, 4),
|
|
GemmConfig(16, 256, 64, 5, 4),
|
|
GemmConfig(32, 32, 32, 5, 2),
|
|
GemmConfig(32, 64, 32, 5, 2),
|
|
GemmConfig(32, 128, 32, 5, 4),
|
|
GemmConfig(32, 256, 32, 5, 4),
|
|
GemmConfig(32, 32, 64, 5, 2),
|
|
GemmConfig(32, 64, 64, 5, 2),
|
|
GemmConfig(32, 128, 64, 5, 4),
|
|
GemmConfig(32, 256, 64, 5, 4),
|
|
GemmConfig(16, 32, 32, 6, 2),
|
|
GemmConfig(16, 64, 32, 6, 2),
|
|
GemmConfig(16, 128, 32, 6, 4),
|
|
GemmConfig(16, 256, 32, 6, 4),
|
|
GemmConfig(16, 32, 64, 6, 2),
|
|
GemmConfig(16, 64, 64, 6, 2),
|
|
GemmConfig(16, 128, 64, 6, 4),
|
|
GemmConfig(16, 256, 64, 6, 4),
|
|
GemmConfig(32, 32, 32, 6, 2),
|
|
GemmConfig(32, 64, 32, 6, 2),
|
|
GemmConfig(32, 128, 32, 6, 4),
|
|
GemmConfig(32, 256, 32, 6, 4),
|
|
GemmConfig(32, 32, 64, 6, 2),
|
|
GemmConfig(32, 64, 64, 6, 2),
|
|
GemmConfig(32, 128, 64, 6, 4),
|
|
GemmConfig(32, 256, 64, 6, 4),
|
|
]
|
|
|
|
self.scaled_persistent_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(128, 128, 64, 3, 8),
|
|
GemmConfig(128, 128, 128, 3, 8),
|
|
GemmConfig(128, 128, 128, 4, 8),
|
|
GemmConfig(128, 128, 128, 4, 4),
|
|
GemmConfig(128, 128, 128, 3, 4),
|
|
GemmConfig(128, 128, 128, 5, 4),
|
|
GemmConfig(128, 128, 128, 5, 8),
|
|
GemmConfig(128, 128, 128, 6, 8),
|
|
GemmConfig(128, 128, 64, 4, 8),
|
|
]
|
|
|
|
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows
|
|
# slightly different pattern than rest
|
|
self.mm_plus_mm_configs: list[BaseConfig] = [
|
|
GemmConfig(64, 64, 32, 2, 4),
|
|
GemmConfig(64, 64, 32, 3, 8),
|
|
GemmConfig(64, 64, 32, 4, 16),
|
|
GemmConfig(64, 32, 32, 4, 8),
|
|
GemmConfig(32, 64, 32, 4, 8),
|
|
GemmConfig(128, 128, 32, 1, 8),
|
|
GemmConfig(64, 64, 64, 1, 8),
|
|
GemmConfig(32, 32, 128, 1, 8),
|
|
GemmConfig(64, 64, 16, 2, 4),
|
|
GemmConfig(32, 32, 16, 1, 2),
|
|
]
|
|
|
|
self.conv_configs: list[BaseConfig] = [
|
|
ConvConfig(64, 256, 16, 2, 4),
|
|
ConvConfig(256, 64, 16, 2, 4),
|
|
ConvConfig(1024, 16, 16, 1, 8),
|
|
ConvConfig(128, 128, 32, 2, 8),
|
|
ConvConfig(64, 64, 32, 2, 4),
|
|
ConvConfig(64, 256, 32, 2, 8),
|
|
ConvConfig(256, 64, 32, 2, 8),
|
|
]
|
|
|
|
self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [
|
|
FlexConfig(128, 64, 3, 4),
|
|
FlexConfig(128, 128, 3, 4),
|
|
FlexConfig(128, 128, 2, 8),
|
|
FlexConfig(64, 128, 3, 4),
|
|
FlexConfig(64, 64, 3, 4),
|
|
]
|
|
|
|
self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [
|
|
FlexConfig(BLOCK1, BLOCK2, s, w)
|
|
for BLOCK1 in [32, 64]
|
|
for BLOCK2 in [32, 64, 128]
|
|
for s in [1, 3, 4, 5] # num_stages
|
|
for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4])
|
|
if BLOCK2 % BLOCK1 == 0
|
|
]
|
|
|
|
self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [
|
|
FlexDecodeConfig(64, 3, 2),
|
|
FlexDecodeConfig(32, 3, 2),
|
|
FlexDecodeConfig(128, 3, 2),
|
|
]
|
|
|
|
self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [
|
|
FlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps)
|
|
for BLOCK_M in [16, 32, 64, 128]
|
|
for BLOCK_N in [32, 64, 128]
|
|
for num_stages in [1, 3, 4, 5]
|
|
for num_warps in [2, 4, 8]
|
|
]
|
|
|
|
self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [
|
|
FlexConfig(BLOCK1, BLOCK2, num_stages, num_warps)
|
|
for BLOCK1 in [16, 32, 64, 128]
|
|
for BLOCK2 in [16, 32, 64, 128]
|
|
for num_stages in [1, 3, 4, 5]
|
|
for num_warps in [2, 4, 8]
|
|
if BLOCK2 % BLOCK1 == 0
|
|
]
|
|
|
|
self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [
|
|
FlexDecodeConfig(block_n, num_stages, num_warps)
|
|
for block_n in [16, 32, 64, 128]
|
|
for num_stages in [1, 3, 4, 5]
|
|
for num_warps in [2, 4, 8]
|
|
]
|
|
|
|
def _finalize_mm_configs(
|
|
self,
|
|
configs: list[BaseConfig],
|
|
) -> Generator[TritonConfig, None, None]:
|
|
"""
|
|
Finalizes configs after scaling, applying additional constraints.
|
|
"""
|
|
used: OrderedSet[tuple[int, ...]] = OrderedSet()
|
|
|
|
max_mm_configs = config.test_configs.max_mm_configs
|
|
|
|
for conf in configs:
|
|
# Each warp computes a 16x16 tile = 256 elements
|
|
num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
|
|
|
|
# Construct key for finding duplicate configs
|
|
key: tuple[int, ...] = (
|
|
conf.block_m,
|
|
conf.block_n,
|
|
conf.block_k,
|
|
conf.num_stages,
|
|
num_warps,
|
|
)
|
|
|
|
# Check if gemm specific arg exists - add to key if does
|
|
group_m = getattr(conf, "group_m", None)
|
|
if group_m is not None:
|
|
key += (group_m,)
|
|
|
|
if key not in used and (
|
|
max_mm_configs is None or len(used) < max_mm_configs
|
|
):
|
|
used.add(key)
|
|
kwargs = {
|
|
"BLOCK_M": conf.block_m,
|
|
"BLOCK_N": conf.block_n,
|
|
"BLOCK_K": conf.block_k,
|
|
"num_stages": conf.num_stages,
|
|
"num_warps": num_warps,
|
|
}
|
|
if group_m is not None:
|
|
kwargs["GROUP_M"] = group_m
|
|
yield self.triton_config(**kwargs)
|
|
|
|
def _scale_mm_configs(
|
|
self,
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
configs: list[BaseConfig],
|
|
scale: float,
|
|
has_int8_tensor: bool,
|
|
exclude: Callable[[int, int, int], bool],
|
|
) -> list[BaseConfig]:
|
|
"""
|
|
Scales and filters matrix multiplication configs based on input size.
|
|
"""
|
|
from .runtime.runtime_utils import next_power_of_2
|
|
|
|
min_block_size = 16
|
|
min_block_size_k = 32 if has_int8_tensor else 16
|
|
|
|
m = max(
|
|
next_power_of_2(
|
|
V.graph.sizevars.size_hint(
|
|
m,
|
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
|
)
|
|
),
|
|
min_block_size,
|
|
)
|
|
n = max(
|
|
next_power_of_2(
|
|
V.graph.sizevars.size_hint(
|
|
n,
|
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
|
)
|
|
),
|
|
min_block_size,
|
|
)
|
|
k = max(
|
|
next_power_of_2(
|
|
V.graph.sizevars.size_hint(
|
|
k,
|
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
|
)
|
|
),
|
|
min_block_size_k,
|
|
)
|
|
|
|
scaled_configs = []
|
|
for c in configs:
|
|
scaled_config = dataclasses.replace(
|
|
c,
|
|
block_m=max(min(int(c.block_m * scale), m), min_block_size),
|
|
block_n=max(min(int(c.block_n * scale), n), min_block_size),
|
|
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
|
|
)
|
|
|
|
if not exclude(
|
|
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
|
|
):
|
|
scaled_configs.append(scaled_config)
|
|
|
|
return scaled_configs
|
|
|
|
def _prune_exhaustive_configs(
|
|
self,
|
|
configs: list[BaseConfig],
|
|
dtype_size: int,
|
|
) -> list[BaseConfig]:
|
|
import torch
|
|
|
|
pruned_configs = []
|
|
for gemm_config in configs:
|
|
device = torch.cuda.current_device()
|
|
props = torch.cuda.get_device_properties(device)
|
|
sm_available = props.shared_memory_per_block_optin # type: ignore[attr-defined]
|
|
NUM_REG = 255
|
|
|
|
acc_regs = math.ceil(
|
|
gemm_config.block_m * gemm_config.block_n / (gemm_config.num_warps * 32)
|
|
)
|
|
|
|
shared_mem_accum = dtype_size * (
|
|
gemm_config.block_m * gemm_config.block_k
|
|
+ gemm_config.block_n * gemm_config.block_k
|
|
)
|
|
|
|
# Will use more shared memory than available
|
|
if shared_mem_accum * gemm_config.num_stages > sm_available:
|
|
continue
|
|
# Lower bound for register spillage, if exceeds the kernel will certainly spill
|
|
elif acc_regs > NUM_REG:
|
|
continue
|
|
|
|
pruned_configs.append(gemm_config)
|
|
|
|
return pruned_configs
|
|
|
|
def preprocess_mm_configs(
|
|
self,
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
configs: list[BaseConfig],
|
|
has_int8_tensor: bool = False,
|
|
scale: int = 1,
|
|
exclude: Callable[[int, int, int], bool] = lambda m, n, k: False,
|
|
dtype_size: int = 0,
|
|
) -> Generator[TritonConfig, None, None]:
|
|
scaled_configs = self._scale_mm_configs(
|
|
m, n, k, configs, scale, has_int8_tensor, exclude
|
|
)
|
|
|
|
if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
|
|
assert dtype_size > 0, "dtype_size must be provided for exhaustive search"
|
|
scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size)
|
|
return self._finalize_mm_configs(scaled_configs)
|
|
|
|
def triton_config(
|
|
self, num_stages: int, num_warps: int, **kwargs: Any
|
|
) -> TritonConfig:
|
|
from triton import Config as TritonConfig # type: ignore[attr-defined]
|
|
|
|
return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps)
|
|
|
|
def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.mm_configs)
|
|
|
|
def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs)
|
|
|
|
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs)
|
|
|
|
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs)
|
|
|
|
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
mm_configs = (
|
|
self.mm_configs + self.mixed_mm_configs
|
|
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
|
else self.mm_configs
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=mm_configs)
|
|
|
|
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
persistent_mm_configs = (
|
|
self.exhaustive_configs
|
|
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
|
else self.persistent_mm_configs
|
|
)
|
|
|
|
# num_warps=2 not safe for TMA
|
|
persistent_mm_configs = [
|
|
config for config in persistent_mm_configs if config.num_warps != 2
|
|
]
|
|
return partial(self.preprocess_mm_configs, configs=persistent_mm_configs)
|
|
|
|
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs)
|
|
|
|
def get_scaled_persistent_mm_configs(
|
|
self,
|
|
) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(
|
|
self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs
|
|
)
|
|
|
|
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs)
|
|
|
|
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
return partial(self.preprocess_mm_configs, configs=self.conv_configs)
|
|
|
|
# Flex attn helpers
|
|
def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
|
|
flex_attn_fwd_configs: list[FlexConfig] = []
|
|
|
|
if config.max_autotune:
|
|
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
|
|
return self.exhaustive_flex_attn_fwd_configs
|
|
flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs
|
|
|
|
if head_dim <= 256:
|
|
if dtype == torch.float32:
|
|
default_config = FlexConfig(64, 64, 3, 4)
|
|
else:
|
|
default_config = FlexConfig(128, 64, 3, 4)
|
|
else:
|
|
if dtype == torch.float32:
|
|
default_config = FlexConfig(32, 16, 3, 4)
|
|
else:
|
|
default_config = FlexConfig(64, 32, 3, 4)
|
|
|
|
if default_config not in flex_attn_fwd_configs:
|
|
flex_attn_fwd_configs.append(default_config)
|
|
|
|
return flex_attn_fwd_configs
|
|
|
|
def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
|
|
flex_attn_bwd_configs: list[FlexConfig] = []
|
|
|
|
if config.max_autotune:
|
|
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
|
|
return self.exhaustive_flex_attn_bwd_configs
|
|
flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs
|
|
|
|
default_config = FlexConfig(16, 16, 1, 4)
|
|
|
|
if default_config not in flex_attn_bwd_configs:
|
|
flex_attn_bwd_configs.append(default_config)
|
|
|
|
return flex_attn_bwd_configs
|
|
|
|
def get_flex_decode_configs(
|
|
self, head_dim: int, dtype: Any
|
|
) -> list[FlexDecodeConfig]:
|
|
flex_decode_configs: list[FlexDecodeConfig] = []
|
|
|
|
if config.max_autotune:
|
|
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
|
|
return self.exhaustive_flex_decode_configs
|
|
flex_decode_configs += self.flex_decode_autotune_configs
|
|
|
|
default_config = FlexDecodeConfig(block_n=64, num_stages=1, num_warps=2)
|
|
|
|
if default_config not in flex_decode_configs:
|
|
flex_decode_configs.append(default_config)
|
|
|
|
return flex_decode_configs
|
|
|
|
|
|
class CPUConfigHeuristic(BaseConfigHeuristic):
|
|
pass
|
|
|
|
|
|
class CUDAConfigHeuristic(BaseConfigHeuristic):
|
|
"""
|
|
Child class for CUDA device specific gemm/flex attention/conv/ configs.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
self.h100_default_flex_config = {
|
|
(torch.float32, 64): FlexConfig(128, 32, 3, 4),
|
|
(torch.float32, 128): FlexConfig(32, 64, 3, 4),
|
|
(torch.float32, 256): FlexConfig(32, 32, 3, 4),
|
|
(torch.bfloat16, 64): FlexConfig(128, 128, 3, 4),
|
|
(torch.bfloat16, 128): FlexConfig(128, 64, 3, 8),
|
|
(torch.bfloat16, 256): FlexConfig(64, 32, 3, 4),
|
|
(torch.float16, 64): FlexConfig(128, 128, 3, 4),
|
|
(torch.float16, 128): FlexConfig(128, 128, 3, 8),
|
|
(torch.float16, 256): FlexConfig(64, 32, 3, 4),
|
|
}
|
|
|
|
self.a100_default_flex_config = {
|
|
(torch.float32, 64): FlexConfig(128, 32, 3, 4),
|
|
(torch.float32, 128): FlexConfig(128, 32, 3, 4),
|
|
(torch.float32, 256): FlexConfig(64, 16, 3, 4),
|
|
(torch.bfloat16, 64): FlexConfig(128, 64, 3, 4),
|
|
(torch.bfloat16, 128): FlexConfig(128, 64, 3, 8),
|
|
(torch.bfloat16, 256): FlexConfig(32, 64, 3, 4),
|
|
(torch.float16, 64): FlexConfig(128, 64, 3, 4),
|
|
(torch.float16, 128): FlexConfig(128, 64, 3, 8),
|
|
(torch.float16, 256): FlexConfig(32, 64, 3, 4),
|
|
}
|
|
|
|
def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
|
|
capability = torch.cuda.get_device_capability()
|
|
flex_attn_fwd_configs: list[FlexConfig] = []
|
|
|
|
if config.max_autotune:
|
|
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
|
|
return self.exhaustive_flex_attn_fwd_configs
|
|
flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs
|
|
|
|
if head_dim <= 256:
|
|
if dtype == torch.float32:
|
|
default_config = FlexConfig(64, 64, 3, 4)
|
|
else:
|
|
default_config = FlexConfig(128, 64, 3, 4)
|
|
if capability >= (9, 0):
|
|
default_config = self.h100_default_flex_config.get(
|
|
(dtype, head_dim), default_config
|
|
)
|
|
elif capability >= (8, 0):
|
|
default_config = self.a100_default_flex_config.get(
|
|
(dtype, head_dim), default_config
|
|
)
|
|
else:
|
|
if dtype == torch.float32:
|
|
default_config = FlexConfig(32, 16, 3, 4)
|
|
else:
|
|
default_config = FlexConfig(64, 32, 3, 4)
|
|
|
|
if default_config not in flex_attn_fwd_configs:
|
|
flex_attn_fwd_configs.append(default_config)
|
|
|
|
return flex_attn_fwd_configs
|
|
|
|
def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
|
|
capability = torch.cuda.get_device_capability()
|
|
|
|
flex_attn_bwd_configs: list[FlexConfig] = []
|
|
|
|
if config.max_autotune:
|
|
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
|
|
return self.exhaustive_flex_attn_bwd_configs
|
|
flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs
|
|
|
|
if dtype == torch.float32:
|
|
default_config = FlexConfig(16, 16, 1, 4)
|
|
elif head_dim <= 256 and capability >= (9, 0): # H100
|
|
if head_dim == 64:
|
|
default_config = FlexConfig(64, 64, 3, 4)
|
|
elif head_dim == 128:
|
|
default_config = FlexConfig(64, 128, 3, 8)
|
|
else:
|
|
default_config = FlexConfig(64, 64, 2, 4)
|
|
elif capability >= (8, 0): # A100
|
|
if head_dim == 64:
|
|
default_config = FlexConfig(32, 128, 3, 4)
|
|
elif head_dim == 128:
|
|
# SM86/89 have smaller shared memory sizes
|
|
num_stages = 3 if capability[1] == 0 else 2
|
|
default_config = FlexConfig(64, 64, num_stages, 4)
|
|
else:
|
|
default_config = FlexConfig(64, 64, 2, 4)
|
|
else: # modest hardware or extremely large head_dim
|
|
default_config = FlexConfig(16, 16, 1, 4)
|
|
|
|
if default_config not in flex_attn_bwd_configs:
|
|
flex_attn_bwd_configs.append(default_config)
|
|
|
|
return flex_attn_bwd_configs
|
|
|
|
def get_flex_decode_configs(
|
|
self, head_dim: int, dtype: Any
|
|
) -> list[FlexDecodeConfig]:
|
|
capability = torch.cuda.get_device_capability()
|
|
|
|
default_config = FlexDecodeConfig(64, 1, 2)
|
|
|
|
flex_decode_configs: list[FlexDecodeConfig] = []
|
|
|
|
if config.max_autotune:
|
|
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
|
|
return self.exhaustive_flex_decode_configs
|
|
flex_decode_configs += self.flex_decode_autotune_configs
|
|
|
|
if capability >= (9, 0): # sm_90+
|
|
if head_dim > 128 and dtype == torch.float32:
|
|
default_config = FlexDecodeConfig(64, 1, 2)
|
|
else:
|
|
default_config = FlexDecodeConfig(64, 3, 2)
|
|
else:
|
|
default_config = FlexDecodeConfig(64, 1, 2)
|
|
|
|
if default_config not in flex_decode_configs:
|
|
flex_decode_configs.append(default_config)
|
|
|
|
return flex_decode_configs
|
|
|
|
|
|
class ROCmConfigHeuristic(BaseConfigHeuristic):
|
|
"""
|
|
Child class for ROCm specific gemm/flex attention/conv/ configs.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
self.default_num_stages = get_backend_num_stages()
|
|
|
|
self.mm_configs: list[BaseConfig] = [
|
|
ROCmGemmConfig(
|
|
16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4),
|
|
ROCmGemmConfig(
|
|
32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(
|
|
64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8),
|
|
ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4),
|
|
ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16),
|
|
ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(
|
|
64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8),
|
|
ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8),
|
|
ROCmGemmConfig(
|
|
128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16),
|
|
ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(
|
|
128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16),
|
|
ROCmGemmConfig(
|
|
128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16),
|
|
ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8),
|
|
ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16),
|
|
ROCmGemmConfig(
|
|
128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(
|
|
256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2
|
|
),
|
|
ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16),
|
|
ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4),
|
|
ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4),
|
|
]
|
|
|
|
# Exhaustive search for mm configs
|
|
self.exhaustive_configs: list[BaseConfig] = [
|
|
ROCmGemmConfig(
|
|
BLOCK_M,
|
|
BLOCK_N,
|
|
BLOCK_K,
|
|
num_stages,
|
|
num_warps,
|
|
group_m,
|
|
matrix_instr_nonkdim,
|
|
waves_per_eu,
|
|
kpack,
|
|
)
|
|
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
|
[16, 32, 64, 128, 256], repeat=3
|
|
)
|
|
for num_stages in [1, self.default_num_stages]
|
|
for num_warps in [4, 8]
|
|
for group_m in [4, 8, 16]
|
|
for matrix_instr_nonkdim in [0, 16]
|
|
for waves_per_eu in [0, 2]
|
|
for kpack in [2]
|
|
]
|
|
|
|
self.default_flex_config = {
|
|
(torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4),
|
|
(torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4),
|
|
(torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4),
|
|
(torch.bfloat16, 64): ROCmFlexConfig(128, 64, 1, 8),
|
|
(torch.bfloat16, 128): ROCmFlexConfig(128, 64, 1, 8),
|
|
(torch.bfloat16, 256): ROCmFlexConfig(32, 64, 1, 8),
|
|
(torch.float16, 64): ROCmFlexConfig(128, 64, 1, 8),
|
|
(torch.float16, 128): ROCmFlexConfig(128, 64, 1, 8),
|
|
(torch.float16, 256): ROCmFlexConfig(32, 64, 1, 4),
|
|
}
|
|
|
|
self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [
|
|
ROCmFlexConfig(BLOCK1, BLOCK2, 1, w)
|
|
for BLOCK1 in [16, 64, 128]
|
|
for BLOCK2 in [16, 32, 64, 128]
|
|
for w in [4, 8]
|
|
]
|
|
|
|
self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [
|
|
ROCmFlexConfig(BLOCK1, BLOCK2, 1, w, mfma)
|
|
for BLOCK1 in [16, 32, 64]
|
|
for BLOCK2 in [32, 64, 128]
|
|
for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4])
|
|
for mfma in [0, 16]
|
|
if BLOCK2 % BLOCK1 == 0
|
|
]
|
|
|
|
self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [
|
|
ROCmFlexDecodeConfig(32, 1, 4),
|
|
ROCmFlexDecodeConfig(64, 1, 4),
|
|
ROCmFlexDecodeConfig(128, 1, 4),
|
|
ROCmFlexDecodeConfig(32, 1, 8),
|
|
ROCmFlexDecodeConfig(64, 1, 8),
|
|
ROCmFlexDecodeConfig(128, 1, 8),
|
|
]
|
|
|
|
self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [
|
|
ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu)
|
|
for BLOCK_M in [16, 32, 64, 128]
|
|
for BLOCK_N in [32, 64, 128]
|
|
for num_stages in [1, 2]
|
|
for num_warps in [2, 4, 8]
|
|
for mfma in [0, 16]
|
|
for wpeu in [0, int(8 // num_warps)]
|
|
]
|
|
|
|
self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [
|
|
ROCmFlexConfig(BLOCK1, BLOCK2, num_stages, num_warps, mfma, wpeu)
|
|
for BLOCK1 in [16, 32, 64, 128]
|
|
for BLOCK2 in [16, 32, 64, 128]
|
|
for num_stages in [1, 2]
|
|
for num_warps in [2, 4, 8]
|
|
for mfma in [0, 16]
|
|
for wpeu in [0, int(8 // num_warps)]
|
|
if BLOCK2 % BLOCK1 == 0
|
|
]
|
|
|
|
self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [
|
|
ROCmFlexDecodeConfig(block_n, num_stages, num_warps, mfma, wpeu, kpack=2)
|
|
for block_n in [16, 32, 64, 128]
|
|
for num_stages in [1, 2]
|
|
for num_warps in [2, 4, 8]
|
|
for mfma in [0, 16]
|
|
for wpeu in [0, int(8 // num_warps)]
|
|
]
|
|
|
|
def _filter_configs(
|
|
self, configs: list[BaseConfig], new_num_stages: int
|
|
) -> list[BaseConfig]:
|
|
# TODO: _filter_configs can be removed once backend specific configs are added
|
|
# for all methods
|
|
for c in configs:
|
|
c.num_stages = self.default_num_stages
|
|
return configs
|
|
|
|
def _finalize_mm_configs(
|
|
self,
|
|
configs: list[BaseConfig],
|
|
) -> Generator[TritonConfig, None, None]:
|
|
"""
|
|
Finalizes configs after scaling, applying additional constraints.
|
|
"""
|
|
used: OrderedSet[tuple[int, ...]] = OrderedSet()
|
|
|
|
max_mm_configs = config.test_configs.max_mm_configs
|
|
|
|
for conf in configs:
|
|
# Each warp computes a 16x16 tile = 256 elements
|
|
conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
|
|
|
|
# Defaults for AMD triton backend kern args if not set
|
|
matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16)
|
|
waves_per_eu = getattr(conf, "waves_per_eu", 0)
|
|
kpack = getattr(conf, "kpack", 2)
|
|
|
|
if matrix_instr_nonkdim != 0 and (
|
|
conf.block_m % matrix_instr_nonkdim != 0
|
|
or conf.block_n % matrix_instr_nonkdim != 0
|
|
):
|
|
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
|
continue
|
|
|
|
# Construct key for finding duplicate configs
|
|
key: tuple[int, ...] = (
|
|
conf.block_m,
|
|
conf.block_n,
|
|
conf.block_k,
|
|
conf.num_stages,
|
|
conf.num_warps,
|
|
waves_per_eu,
|
|
matrix_instr_nonkdim,
|
|
kpack,
|
|
)
|
|
|
|
# Check if gemm specific arg exists - add to key if does
|
|
group_m = getattr(conf, "group_m", None)
|
|
if group_m is not None:
|
|
key += (group_m,)
|
|
|
|
if waves_per_eu != 0:
|
|
waves_per_eu = int(8 // conf.num_warps)
|
|
|
|
if key not in used and (
|
|
max_mm_configs is None or len(used) < max_mm_configs
|
|
):
|
|
used.add(key)
|
|
kwargs = {
|
|
"BLOCK_M": conf.block_m,
|
|
"BLOCK_N": conf.block_n,
|
|
"BLOCK_K": conf.block_k,
|
|
"num_stages": conf.num_stages,
|
|
"num_warps": conf.num_warps,
|
|
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
"waves_per_eu": waves_per_eu,
|
|
"kpack": kpack,
|
|
}
|
|
if group_m is not None:
|
|
kwargs["GROUP_M"] = group_m
|
|
yield self.triton_config(**kwargs)
|
|
|
|
def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.extra_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.int8_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
mm_configs = (
|
|
self.mm_configs + self.mixed_mm_configs
|
|
if config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
|
else self.mm_configs
|
|
)
|
|
filtered_configs = self._filter_configs(mm_configs, self.default_num_stages)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.persistent_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.scaled_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_scaled_persistent_mm_configs(
|
|
self,
|
|
) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.scaled_persistent_mm_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1)
|
|
return partial(self._finalize_mm_configs, configs=filtered_configs)
|
|
|
|
def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]:
|
|
filtered_configs = self._filter_configs(
|
|
self.conv_configs, self.default_num_stages
|
|
)
|
|
return partial(self.preprocess_mm_configs, configs=filtered_configs)
|
|
|
|
def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
|
|
flex_attn_fwd_configs: list[FlexConfig] = []
|
|
|
|
if config.max_autotune:
|
|
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
|
|
return self.exhaustive_flex_attn_fwd_configs
|
|
flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs
|
|
|
|
if head_dim <= 256:
|
|
if dtype == torch.float32:
|
|
default_config = ROCmFlexConfig(64, 64, 1, 4)
|
|
else:
|
|
default_config = ROCmFlexConfig(128, 64, 1, 8)
|
|
default_config = self.default_flex_config.get(
|
|
(dtype, head_dim), default_config
|
|
)
|
|
else:
|
|
if dtype == torch.float32:
|
|
default_config = ROCmFlexConfig(32, 16, 1, 4)
|
|
else:
|
|
default_config = ROCmFlexConfig(64, 32, 1, 4)
|
|
|
|
if default_config not in flex_attn_fwd_configs:
|
|
flex_attn_fwd_configs.append(default_config)
|
|
|
|
return flex_attn_fwd_configs
|
|
|
|
def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
|
|
flex_attn_bwd_configs: list[FlexConfig] = []
|
|
|
|
if config.max_autotune:
|
|
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
|
|
return self.exhaustive_flex_attn_bwd_configs
|
|
flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs
|
|
|
|
if dtype == torch.float32:
|
|
default_config = ROCmFlexConfig(16, 16, 1, 4)
|
|
elif head_dim <= 256:
|
|
if head_dim == 64:
|
|
default_config = ROCmFlexConfig(64, 64, 1, 4)
|
|
elif head_dim == 128:
|
|
default_config = ROCmFlexConfig(64, 128, 1, 8)
|
|
else:
|
|
default_config = ROCmFlexConfig(64, 64, 1, 4)
|
|
else:
|
|
default_config = ROCmFlexConfig(16, 16, 1, 4)
|
|
|
|
if default_config not in flex_attn_bwd_configs:
|
|
flex_attn_bwd_configs.append(default_config)
|
|
|
|
return flex_attn_bwd_configs
|
|
|
|
def get_flex_decode_configs(
|
|
self, head_dim: int, dtype: Any
|
|
) -> list[FlexDecodeConfig]:
|
|
flex_decode_configs: list[FlexDecodeConfig] = []
|
|
|
|
if config.max_autotune:
|
|
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
|
|
return self.exhaustive_flex_decode_configs
|
|
flex_decode_configs += self.flex_decode_autotune_configs
|
|
|
|
default_config = ROCmFlexDecodeConfig(64, 1, 4)
|
|
|
|
if default_config not in flex_decode_configs:
|
|
flex_decode_configs.append(default_config)
|
|
|
|
return flex_decode_configs
|
|
|
|
|
|
class XPUConfigHeuristic(BaseConfigHeuristic):
|
|
"""
|
|
Placeholder child class for XPU specific overrides.
|
|
"""
|