mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm][inductor] autotune support for persistent reduction kernels (#163908)
After the removal of want_no_x_dim for persistent reduction kernels, we can improve the autotuning setup for persistent reduction kernels. Currently even with tuning enable, filtering will only try a single config in many cases. Avoid filtering with autotune mode, and override MAX_BLOCK limit. Also we always include tiny_config when autotuning is enabled. Contributions from several members of the AMD Inductor and Triton teams: @jataylo @iupaikov-amd @AmdSampsa @xiaohuguo2023 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163908 Approved by: https://github.com/jansel, https://github.com/PaulZhang12
This commit is contained in:
committed by
PyTorch MergeBot
parent
0bbdd6b8db
commit
a0948d4d23
@ -3222,6 +3222,15 @@ def _persistent_reduction_configs(
|
||||
else:
|
||||
raise NotImplementedError("native matmul only supports mm/bmm pattern")
|
||||
|
||||
max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get(
|
||||
"max_autotune_pointwise"
|
||||
)
|
||||
|
||||
if torch.version.hip:
|
||||
xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
else:
|
||||
xblock_vals = [1, 8, 32, 128]
|
||||
|
||||
if "y" not in size_hints:
|
||||
configs = [
|
||||
triton_config_reduction(
|
||||
@ -3231,7 +3240,7 @@ def _persistent_reduction_configs(
|
||||
register_intensive=True,
|
||||
reduction_hint=reduction_hint,
|
||||
)
|
||||
for xblock in (1, 8, 32, 128)
|
||||
for xblock in xblock_vals
|
||||
if xblock == 1
|
||||
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
|
||||
]
|
||||
@ -3239,7 +3248,7 @@ def _persistent_reduction_configs(
|
||||
configs = []
|
||||
assert "tiling_scores" in inductor_meta
|
||||
x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
|
||||
for target_block_size in (1, 8, 32, 64, 128):
|
||||
for target_block_size in xblock_vals:
|
||||
if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL:
|
||||
continue
|
||||
|
||||
@ -3252,39 +3261,47 @@ def _persistent_reduction_configs(
|
||||
)
|
||||
)
|
||||
|
||||
tiny_configs = [
|
||||
triton_config_reduction(
|
||||
size_hints,
|
||||
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
||||
rnumel,
|
||||
)
|
||||
]
|
||||
|
||||
# defer to more autotuning, initially
|
||||
if "y" in size_hints:
|
||||
pass
|
||||
# TODO(jansel): we should be able to improve these heuristics
|
||||
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
|
||||
if rnumel > 1024:
|
||||
configs = configs[:1]
|
||||
else:
|
||||
x_block = 8
|
||||
if xnumel // x_block < 128 or loads_and_stores >= 5:
|
||||
x_block = 1
|
||||
elif not max_autotune_enabled: # Do not filter configs when tuning
|
||||
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
|
||||
if rnumel > 1024:
|
||||
configs = configs[:1]
|
||||
else:
|
||||
x_block = 8
|
||||
if xnumel // x_block < 128 or loads_and_stores >= 5:
|
||||
x_block = 1
|
||||
|
||||
configs = [
|
||||
triton_config_reduction(
|
||||
size_hints,
|
||||
x_block,
|
||||
rnumel,
|
||||
register_intensive=True,
|
||||
reduction_hint=reduction_hint,
|
||||
)
|
||||
]
|
||||
configs = [
|
||||
triton_config_reduction(
|
||||
size_hints,
|
||||
x_block,
|
||||
rnumel,
|
||||
register_intensive=True,
|
||||
)
|
||||
]
|
||||
|
||||
elif reduction_hint == ReductionHint.OUTER:
|
||||
configs = configs[-1:]
|
||||
elif reduction_hint == ReductionHint.OUTER_TINY:
|
||||
configs = tiny_configs
|
||||
else:
|
||||
if torch.version.hip:
|
||||
# If autotune is enabled append tiny configs
|
||||
for conf in tiny_configs:
|
||||
if conf not in configs:
|
||||
configs.append(conf)
|
||||
|
||||
elif reduction_hint == ReductionHint.OUTER:
|
||||
configs = configs[-1:]
|
||||
elif reduction_hint == ReductionHint.OUTER_TINY:
|
||||
configs = [
|
||||
triton_config_reduction(
|
||||
size_hints,
|
||||
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
||||
rnumel,
|
||||
reduction_hint=reduction_hint,
|
||||
)
|
||||
]
|
||||
for c in configs:
|
||||
# we don't need Rn_BLOCK for persistent reduction
|
||||
for prefix in size_hints:
|
||||
|
Reference in New Issue
Block a user