[ROCm][inductor] heuristic improvements for pointwise kernels (#163197)

Heuristic improvements for pointwise kernels for MI350.

Contributions from several members of the AMD Inductor and Triton teams:
@jataylo @AmdSampsa @iupaikov-amd @@xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163197
Approved by: https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jansel

Co-authored-by: AmdSampsa <sampsa.riikonen@amd.com>
Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
This commit is contained in:
Nichols A. Romero
2025-10-18 07:23:37 +00:00
committed by PyTorch MergeBot
parent 24520b8386
commit 0bbdd6b8db
2 changed files with 60 additions and 6 deletions

View File

@ -6,13 +6,14 @@ import functools
import typing
from enum import auto, Enum
import torch
from torch.utils._triton import has_triton_package
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
"X": 8192 if torch.version.hip else 4096,
"Y": 1024,
"Z": 1024,
"R0_": 4096 * 16, # * 16 is multi-kernel only

View File

@ -2244,6 +2244,9 @@ def triton_config(
num_stages=1,
num_elements_per_warp=256,
min_elem_per_thread=0,
num_warps=None,
matrix_instr=None,
waves_per_eu=None,
) -> Config:
"""
Construct a pointwise triton config with some adjustment heuristics
@ -2300,9 +2303,11 @@ def triton_config(
):
z *= 2
num_warps = _num_warps(
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
)
# Calculate num_warps if they are not hard passed to config
if num_warps is None:
num_warps = _num_warps(
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
)
# we are going to arrive at 2 warps only if bs was too small due to
# numel being too small. However to workaround some ptx bugs we still
# want at least 4 warps if there's enough elements per thread
@ -2332,7 +2337,15 @@ def triton_config(
cfg["ZBLOCK"] = z
check_max_block(cfg)
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
if torch.version.hip:
if matrix_instr is not None:
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu
return config
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
@ -2578,10 +2591,32 @@ def pointwise(
),
*hinted_configs,
]
# Additional configs appended for ROCm builds
if torch.version.hip:
configs.extend(
[
triton_config_with_settings(
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
),
triton_config_with_settings(
size_hints,
4096, # wrt: better than the max_block for some kernel
),
triton_config_with_settings(
size_hints,
2048,
num_warps=8,
num_stages=2,
waves_per_eu=1, # 20% improvement
),
]
)
if len(size_hints) == 2:
# Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
# ROCm has observed improvement by diverging here
if (
not inductor_meta.get("autotune_pointwise", True)
or tile_hint == TileHint.SQUARE
or (torch.version.hip is None and tile_hint == TileHint.SQUARE)
) and not (
inductor_meta.get("max_autotune")
or inductor_meta.get("max_autotune_pointwise")
@ -2597,6 +2632,24 @@ def pointwise(
triton_config_with_settings(size_hints, 1, bs),
*hinted_configs,
]
# Additional configs appended for ROCm builds
if torch.version.hip:
configs.extend(
[
triton_config_with_settings(
size_hints, 64, 32
), # better for some kernels
triton_config_with_settings(
size_hints, 128, 16
), # +10% for some kernels
triton_config_with_settings(
size_hints, 128, 32
), # additional 10% more
triton_config_with_settings(
size_hints, 32, 512
), # +30% for some kernels
]
)
if len(size_hints) == 3:
if not inductor_meta.get("autotune_pointwise", True):
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]