mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm][inductor] heuristic improvements for pointwise kernels (#163197)
Heuristic improvements for pointwise kernels for MI350. Contributions from several members of the AMD Inductor and Triton teams: @jataylo @AmdSampsa @iupaikov-amd @@xiaohuguo2023 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163197 Approved by: https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jansel Co-authored-by: AmdSampsa <sampsa.riikonen@amd.com> Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
24520b8386
commit
0bbdd6b8db
@ -6,13 +6,14 @@ import functools
|
||||
import typing
|
||||
from enum import auto, Enum
|
||||
|
||||
import torch
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
|
||||
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
|
||||
# NOTE: if these fail asserts submit a PR to increase them
|
||||
TRITON_MAX_BLOCK = {
|
||||
"X": 4096,
|
||||
"X": 8192 if torch.version.hip else 4096,
|
||||
"Y": 1024,
|
||||
"Z": 1024,
|
||||
"R0_": 4096 * 16, # * 16 is multi-kernel only
|
||||
|
@ -2244,6 +2244,9 @@ def triton_config(
|
||||
num_stages=1,
|
||||
num_elements_per_warp=256,
|
||||
min_elem_per_thread=0,
|
||||
num_warps=None,
|
||||
matrix_instr=None,
|
||||
waves_per_eu=None,
|
||||
) -> Config:
|
||||
"""
|
||||
Construct a pointwise triton config with some adjustment heuristics
|
||||
@ -2300,6 +2303,8 @@ def triton_config(
|
||||
):
|
||||
z *= 2
|
||||
|
||||
# Calculate num_warps if they are not hard passed to config
|
||||
if num_warps is None:
|
||||
num_warps = _num_warps(
|
||||
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
|
||||
)
|
||||
@ -2332,7 +2337,15 @@ def triton_config(
|
||||
cfg["ZBLOCK"] = z
|
||||
check_max_block(cfg)
|
||||
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
|
||||
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||
|
||||
if torch.version.hip:
|
||||
if matrix_instr is not None:
|
||||
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
|
||||
if waves_per_eu is not None:
|
||||
config.kwargs["waves_per_eu"] = waves_per_eu
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
|
||||
@ -2578,10 +2591,32 @@ def pointwise(
|
||||
),
|
||||
*hinted_configs,
|
||||
]
|
||||
# Additional configs appended for ROCm builds
|
||||
if torch.version.hip:
|
||||
configs.extend(
|
||||
[
|
||||
triton_config_with_settings(
|
||||
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
|
||||
),
|
||||
triton_config_with_settings(
|
||||
size_hints,
|
||||
4096, # wrt: better than the max_block for some kernel
|
||||
),
|
||||
triton_config_with_settings(
|
||||
size_hints,
|
||||
2048,
|
||||
num_warps=8,
|
||||
num_stages=2,
|
||||
waves_per_eu=1, # 20% improvement
|
||||
),
|
||||
]
|
||||
)
|
||||
if len(size_hints) == 2:
|
||||
# Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
|
||||
# ROCm has observed improvement by diverging here
|
||||
if (
|
||||
not inductor_meta.get("autotune_pointwise", True)
|
||||
or tile_hint == TileHint.SQUARE
|
||||
or (torch.version.hip is None and tile_hint == TileHint.SQUARE)
|
||||
) and not (
|
||||
inductor_meta.get("max_autotune")
|
||||
or inductor_meta.get("max_autotune_pointwise")
|
||||
@ -2597,6 +2632,24 @@ def pointwise(
|
||||
triton_config_with_settings(size_hints, 1, bs),
|
||||
*hinted_configs,
|
||||
]
|
||||
# Additional configs appended for ROCm builds
|
||||
if torch.version.hip:
|
||||
configs.extend(
|
||||
[
|
||||
triton_config_with_settings(
|
||||
size_hints, 64, 32
|
||||
), # better for some kernels
|
||||
triton_config_with_settings(
|
||||
size_hints, 128, 16
|
||||
), # +10% for some kernels
|
||||
triton_config_with_settings(
|
||||
size_hints, 128, 32
|
||||
), # additional 10% more
|
||||
triton_config_with_settings(
|
||||
size_hints, 32, 512
|
||||
), # +30% for some kernels
|
||||
]
|
||||
)
|
||||
if len(size_hints) == 3:
|
||||
if not inductor_meta.get("autotune_pointwise", True):
|
||||
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
|
||||
|
Reference in New Issue
Block a user