From 0bbdd6b8dbda2d63820ae46d05536bd1e9a111b9 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Sat, 18 Oct 2025 07:23:37 +0000 Subject: [PATCH] [ROCm][inductor] heuristic improvements for pointwise kernels (#163197) Heuristic improvements for pointwise kernels for MI350. Contributions from several members of the AMD Inductor and Triton teams: @jataylo @AmdSampsa @iupaikov-amd @@xiaohuguo2023 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163197 Approved by: https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jansel Co-authored-by: AmdSampsa Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com> --- torch/_inductor/runtime/hints.py | 3 +- torch/_inductor/runtime/triton_heuristics.py | 63 ++++++++++++++++++-- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 1cff04d04079..10a5a9749a51 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -6,13 +6,14 @@ import functools import typing from enum import auto, Enum +import torch from torch.utils._triton import has_triton_package # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 4096, + "X": 8192 if torch.version.hip else 4096, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 2ae2880fb018..12dc07fe3b1f 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2244,6 +2244,9 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None, ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -2300,9 +2303,11 @@ def triton_config( ): z *= 2 - num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 - ) + # Calculate num_warps if they are not hard passed to config + if num_warps is None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -2332,7 +2337,15 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: @@ -2578,10 +2591,32 @@ def pointwise( ), *hinted_configs, ] + # Additional configs appended for ROCm builds + if torch.version.hip: + configs.extend( + [ + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), + triton_config_with_settings( + size_hints, + 4096, # wrt: better than the max_block for some kernel + ), + triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1, # 20% improvement + ), + ] + ) if len(size_hints) == 2: + # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds + # ROCm has observed improvement by diverging here if ( not inductor_meta.get("autotune_pointwise", True) - or tile_hint == TileHint.SQUARE + or (torch.version.hip is None and tile_hint == TileHint.SQUARE) ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") @@ -2597,6 +2632,24 @@ def pointwise( triton_config_with_settings(size_hints, 1, bs), *hinted_configs, ] + # Additional configs appended for ROCm builds + if torch.version.hip: + configs.extend( + [ + triton_config_with_settings( + size_hints, 64, 32 + ), # better for some kernels + triton_config_with_settings( + size_hints, 128, 16 + ), # +10% for some kernels + triton_config_with_settings( + size_hints, 128, 32 + ), # additional 10% more + triton_config_with_settings( + size_hints, 32, 512 + ), # +30% for some kernels + ] + ) if len(size_hints) == 3: if not inductor_meta.get("autotune_pointwise", True): configs = [triton_config_with_settings(size_hints, 16, 16, 16)]