[ROCm][inductor] autotune support for persistent reduction kernels (#163908)

After the removal of want_no_x_dim for persistent reduction kernels, we can improve the autotuning setup for persistent reduction kernels.

Currently even with tuning enable, filtering will only try a single config in many cases. Avoid filtering with autotune mode, and override MAX_BLOCK limit. Also we always include tiny_config when autotuning is enabled.

Contributions from several members of the AMD Inductor and Triton teams: @jataylo @iupaikov-amd @AmdSampsa @xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163908
Approved by: https://github.com/jansel, https://github.com/PaulZhang12
This commit is contained in:
Nichols A. Romero
2025-10-18 07:33:21 +00:00
committed by PyTorch MergeBot
parent 0bbdd6b8db
commit a0948d4d23

View File

@ -3222,6 +3222,15 @@ def _persistent_reduction_configs(
else:
raise NotImplementedError("native matmul only supports mm/bmm pattern")
max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get(
"max_autotune_pointwise"
)
if torch.version.hip:
xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256]
else:
xblock_vals = [1, 8, 32, 128]
if "y" not in size_hints:
configs = [
triton_config_reduction(
@ -3231,7 +3240,7 @@ def _persistent_reduction_configs(
register_intensive=True,
reduction_hint=reduction_hint,
)
for xblock in (1, 8, 32, 128)
for xblock in xblock_vals
if xblock == 1
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
]
@ -3239,7 +3248,7 @@ def _persistent_reduction_configs(
configs = []
assert "tiling_scores" in inductor_meta
x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
for target_block_size in (1, 8, 32, 64, 128):
for target_block_size in xblock_vals:
if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL:
continue
@ -3252,11 +3261,20 @@ def _persistent_reduction_configs(
)
)
tiny_configs = [
triton_config_reduction(
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
rnumel,
)
]
# defer to more autotuning, initially
if "y" in size_hints:
pass
# TODO(jansel): we should be able to improve these heuristics
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
elif not max_autotune_enabled: # Do not filter configs when tuning
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
if rnumel > 1024:
configs = configs[:1]
else:
@ -3270,21 +3288,20 @@ def _persistent_reduction_configs(
x_block,
rnumel,
register_intensive=True,
reduction_hint=reduction_hint,
)
]
elif reduction_hint == ReductionHint.OUTER:
configs = configs[-1:]
elif reduction_hint == ReductionHint.OUTER_TINY:
configs = [
triton_config_reduction(
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
rnumel,
reduction_hint=reduction_hint,
)
]
configs = tiny_configs
else:
if torch.version.hip:
# If autotune is enabled append tiny configs
for conf in tiny_configs:
if conf not in configs:
configs.append(conf)
for c in configs:
# we don't need Rn_BLOCK for persistent reduction
for prefix in size_hints: