[ROCm][inductor] heuristic improvements for pointwise kernels (#163197)

Heuristic improvements for pointwise kernels for MI350.

Contributions from several members of the AMD Inductor and Triton teams:
@jataylo @AmdSampsa @iupaikov-amd @@xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163197
Approved by: https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jansel

Co-authored-by: AmdSampsa <sampsa.riikonen@amd.com>
Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
This commit is contained in:
Nichols A. Romero
2025-10-18 07:23:37 +00:00
committed by PyTorch MergeBot
parent 24520b8386
commit 0bbdd6b8db
2 changed files with 60 additions and 6 deletions

View File

@ -6,13 +6,14 @@ import functools
import typing import typing
from enum import auto, Enum from enum import auto, Enum
import torch
from torch.utils._triton import has_triton_package from torch.utils._triton import has_triton_package
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them # NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = { TRITON_MAX_BLOCK = {
"X": 4096, "X": 8192 if torch.version.hip else 4096,
"Y": 1024, "Y": 1024,
"Z": 1024, "Z": 1024,
"R0_": 4096 * 16, # * 16 is multi-kernel only "R0_": 4096 * 16, # * 16 is multi-kernel only

View File

@ -2244,6 +2244,9 @@ def triton_config(
num_stages=1, num_stages=1,
num_elements_per_warp=256, num_elements_per_warp=256,
min_elem_per_thread=0, min_elem_per_thread=0,
num_warps=None,
matrix_instr=None,
waves_per_eu=None,
) -> Config: ) -> Config:
""" """
Construct a pointwise triton config with some adjustment heuristics Construct a pointwise triton config with some adjustment heuristics
@ -2300,9 +2303,11 @@ def triton_config(
): ):
z *= 2 z *= 2
num_warps = _num_warps( # Calculate num_warps if they are not hard passed to config
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 if num_warps is None:
) num_warps = _num_warps(
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
)
# we are going to arrive at 2 warps only if bs was too small due to # we are going to arrive at 2 warps only if bs was too small due to
# numel being too small. However to workaround some ptx bugs we still # numel being too small. However to workaround some ptx bugs we still
# want at least 4 warps if there's enough elements per thread # want at least 4 warps if there's enough elements per thread
@ -2332,7 +2337,15 @@ def triton_config(
cfg["ZBLOCK"] = z cfg["ZBLOCK"] = z
check_max_block(cfg) check_max_block(cfg)
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
return Config(cfg, num_warps=num_warps, num_stages=num_stages) config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
if torch.version.hip:
if matrix_instr is not None:
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu
return config
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
@ -2578,10 +2591,32 @@ def pointwise(
), ),
*hinted_configs, *hinted_configs,
] ]
# Additional configs appended for ROCm builds
if torch.version.hip:
configs.extend(
[
triton_config_with_settings(
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
),
triton_config_with_settings(
size_hints,
4096, # wrt: better than the max_block for some kernel
),
triton_config_with_settings(
size_hints,
2048,
num_warps=8,
num_stages=2,
waves_per_eu=1, # 20% improvement
),
]
)
if len(size_hints) == 2: if len(size_hints) == 2:
# Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
# ROCm has observed improvement by diverging here
if ( if (
not inductor_meta.get("autotune_pointwise", True) not inductor_meta.get("autotune_pointwise", True)
or tile_hint == TileHint.SQUARE or (torch.version.hip is None and tile_hint == TileHint.SQUARE)
) and not ( ) and not (
inductor_meta.get("max_autotune") inductor_meta.get("max_autotune")
or inductor_meta.get("max_autotune_pointwise") or inductor_meta.get("max_autotune_pointwise")
@ -2597,6 +2632,24 @@ def pointwise(
triton_config_with_settings(size_hints, 1, bs), triton_config_with_settings(size_hints, 1, bs),
*hinted_configs, *hinted_configs,
] ]
# Additional configs appended for ROCm builds
if torch.version.hip:
configs.extend(
[
triton_config_with_settings(
size_hints, 64, 32
), # better for some kernels
triton_config_with_settings(
size_hints, 128, 16
), # +10% for some kernels
triton_config_with_settings(
size_hints, 128, 32
), # additional 10% more
triton_config_with_settings(
size_hints, 32, 512
), # +30% for some kernels
]
)
if len(size_hints) == 3: if len(size_hints) == 3:
if not inductor_meta.get("autotune_pointwise", True): if not inductor_meta.get("autotune_pointwise", True):
configs = [triton_config_with_settings(size_hints, 16, 16, 16)] configs = [triton_config_with_settings(size_hints, 16, 16, 16)]