mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm][inductor] heuristic improvements for pointwise kernels (#163197)
Heuristic improvements for pointwise kernels for MI350. Contributions from several members of the AMD Inductor and Triton teams: @jataylo @AmdSampsa @iupaikov-amd @@xiaohuguo2023 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163197 Approved by: https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jansel Co-authored-by: AmdSampsa <sampsa.riikonen@amd.com> Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
24520b8386
commit
0bbdd6b8db
@ -6,13 +6,14 @@ import functools
|
|||||||
import typing
|
import typing
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.utils._triton import has_triton_package
|
from torch.utils._triton import has_triton_package
|
||||||
|
|
||||||
|
|
||||||
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
|
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
|
||||||
# NOTE: if these fail asserts submit a PR to increase them
|
# NOTE: if these fail asserts submit a PR to increase them
|
||||||
TRITON_MAX_BLOCK = {
|
TRITON_MAX_BLOCK = {
|
||||||
"X": 4096,
|
"X": 8192 if torch.version.hip else 4096,
|
||||||
"Y": 1024,
|
"Y": 1024,
|
||||||
"Z": 1024,
|
"Z": 1024,
|
||||||
"R0_": 4096 * 16, # * 16 is multi-kernel only
|
"R0_": 4096 * 16, # * 16 is multi-kernel only
|
||||||
|
@ -2244,6 +2244,9 @@ def triton_config(
|
|||||||
num_stages=1,
|
num_stages=1,
|
||||||
num_elements_per_warp=256,
|
num_elements_per_warp=256,
|
||||||
min_elem_per_thread=0,
|
min_elem_per_thread=0,
|
||||||
|
num_warps=None,
|
||||||
|
matrix_instr=None,
|
||||||
|
waves_per_eu=None,
|
||||||
) -> Config:
|
) -> Config:
|
||||||
"""
|
"""
|
||||||
Construct a pointwise triton config with some adjustment heuristics
|
Construct a pointwise triton config with some adjustment heuristics
|
||||||
@ -2300,9 +2303,11 @@ def triton_config(
|
|||||||
):
|
):
|
||||||
z *= 2
|
z *= 2
|
||||||
|
|
||||||
num_warps = _num_warps(
|
# Calculate num_warps if they are not hard passed to config
|
||||||
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
|
if num_warps is None:
|
||||||
)
|
num_warps = _num_warps(
|
||||||
|
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
|
||||||
|
)
|
||||||
# we are going to arrive at 2 warps only if bs was too small due to
|
# we are going to arrive at 2 warps only if bs was too small due to
|
||||||
# numel being too small. However to workaround some ptx bugs we still
|
# numel being too small. However to workaround some ptx bugs we still
|
||||||
# want at least 4 warps if there's enough elements per thread
|
# want at least 4 warps if there's enough elements per thread
|
||||||
@ -2332,7 +2337,15 @@ def triton_config(
|
|||||||
cfg["ZBLOCK"] = z
|
cfg["ZBLOCK"] = z
|
||||||
check_max_block(cfg)
|
check_max_block(cfg)
|
||||||
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
|
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
|
||||||
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||||
|
|
||||||
|
if torch.version.hip:
|
||||||
|
if matrix_instr is not None:
|
||||||
|
config.kwargs["matrix_instr_nonkdim"] = matrix_instr
|
||||||
|
if waves_per_eu is not None:
|
||||||
|
config.kwargs["waves_per_eu"] = waves_per_eu
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
|
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
|
||||||
@ -2578,10 +2591,32 @@ def pointwise(
|
|||||||
),
|
),
|
||||||
*hinted_configs,
|
*hinted_configs,
|
||||||
]
|
]
|
||||||
|
# Additional configs appended for ROCm builds
|
||||||
|
if torch.version.hip:
|
||||||
|
configs.extend(
|
||||||
|
[
|
||||||
|
triton_config_with_settings(
|
||||||
|
size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
|
||||||
|
),
|
||||||
|
triton_config_with_settings(
|
||||||
|
size_hints,
|
||||||
|
4096, # wrt: better than the max_block for some kernel
|
||||||
|
),
|
||||||
|
triton_config_with_settings(
|
||||||
|
size_hints,
|
||||||
|
2048,
|
||||||
|
num_warps=8,
|
||||||
|
num_stages=2,
|
||||||
|
waves_per_eu=1, # 20% improvement
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
if len(size_hints) == 2:
|
if len(size_hints) == 2:
|
||||||
|
# Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
|
||||||
|
# ROCm has observed improvement by diverging here
|
||||||
if (
|
if (
|
||||||
not inductor_meta.get("autotune_pointwise", True)
|
not inductor_meta.get("autotune_pointwise", True)
|
||||||
or tile_hint == TileHint.SQUARE
|
or (torch.version.hip is None and tile_hint == TileHint.SQUARE)
|
||||||
) and not (
|
) and not (
|
||||||
inductor_meta.get("max_autotune")
|
inductor_meta.get("max_autotune")
|
||||||
or inductor_meta.get("max_autotune_pointwise")
|
or inductor_meta.get("max_autotune_pointwise")
|
||||||
@ -2597,6 +2632,24 @@ def pointwise(
|
|||||||
triton_config_with_settings(size_hints, 1, bs),
|
triton_config_with_settings(size_hints, 1, bs),
|
||||||
*hinted_configs,
|
*hinted_configs,
|
||||||
]
|
]
|
||||||
|
# Additional configs appended for ROCm builds
|
||||||
|
if torch.version.hip:
|
||||||
|
configs.extend(
|
||||||
|
[
|
||||||
|
triton_config_with_settings(
|
||||||
|
size_hints, 64, 32
|
||||||
|
), # better for some kernels
|
||||||
|
triton_config_with_settings(
|
||||||
|
size_hints, 128, 16
|
||||||
|
), # +10% for some kernels
|
||||||
|
triton_config_with_settings(
|
||||||
|
size_hints, 128, 32
|
||||||
|
), # additional 10% more
|
||||||
|
triton_config_with_settings(
|
||||||
|
size_hints, 32, 512
|
||||||
|
), # +30% for some kernels
|
||||||
|
]
|
||||||
|
)
|
||||||
if len(size_hints) == 3:
|
if len(size_hints) == 3:
|
||||||
if not inductor_meta.get("autotune_pointwise", True):
|
if not inductor_meta.get("autotune_pointwise", True):
|
||||||
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
|
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
|
||||||
|
Reference in New Issue
Block a user