diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 12dc07fe3b1f..b49b9ac54228 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -3222,6 +3222,15 @@ def _persistent_reduction_configs( else: raise NotImplementedError("native matmul only supports mm/bmm pattern") + max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ) + + if torch.version.hip: + xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256] + else: + xblock_vals = [1, 8, 32, 128] + if "y" not in size_hints: configs = [ triton_config_reduction( @@ -3231,7 +3240,7 @@ def _persistent_reduction_configs( register_intensive=True, reduction_hint=reduction_hint, ) - for xblock in (1, 8, 32, 128) + for xblock in xblock_vals if xblock == 1 or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) ] @@ -3239,7 +3248,7 @@ def _persistent_reduction_configs( configs = [] assert "tiling_scores" in inductor_meta x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} - for target_block_size in (1, 8, 32, 64, 128): + for target_block_size in xblock_vals: if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: continue @@ -3252,39 +3261,47 @@ def _persistent_reduction_configs( ) ) + tiny_configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + # defer to more autotuning, initially if "y" in size_hints: pass # TODO(jansel): we should be able to improve these heuristics - elif reduction_hint == ReductionHint.INNER and rnumel >= 256: - if rnumel > 1024: - configs = configs[:1] - else: - x_block = 8 - if xnumel // x_block < 128 or loads_and_stores >= 5: - x_block = 1 + elif not max_autotune_enabled: # Do not filter configs when tuning + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + if rnumel > 1024: + configs = configs[:1] + else: + x_block = 8 + if xnumel // x_block < 128 or loads_and_stores >= 5: + x_block = 1 - configs = [ - triton_config_reduction( - size_hints, - x_block, - rnumel, - register_intensive=True, - reduction_hint=reduction_hint, - ) - ] + configs = [ + triton_config_reduction( + size_hints, + x_block, + rnumel, + register_intensive=True, + ) + ] + + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = tiny_configs + else: + if torch.version.hip: + # If autotune is enabled append tiny configs + for conf in tiny_configs: + if conf not in configs: + configs.append(conf) - elif reduction_hint == ReductionHint.OUTER: - configs = configs[-1:] - elif reduction_hint == ReductionHint.OUTER_TINY: - configs = [ - triton_config_reduction( - size_hints, - 2 * (256 // rnumel) if rnumel <= 256 else 1, - rnumel, - reduction_hint=reduction_hint, - ) - ] for c in configs: # we don't need Rn_BLOCK for persistent reduction for prefix in size_hints: