[Reland][Inductor] Prune configs that require more shared memory than the hardware limit. (#161996)

Summary:
This is a re-land of [PR161040](https://github.com/pytorch/pytorch/pull/161040), which had previously caused test failures on AMD GPUs. The tests are now configured to target only NVIDIA GPUs.

This diff removes configurations that exceed the hardware shared memory limit, which causes the following compilation error:
```
No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
```

Test Plan:
```
pytest test/inductor/test_max_autotune.py
pytest test/inductor/test_triton_heuristics.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161996
Approved by: https://github.com/coconutruben
This commit is contained in:
Wenyuan Chi
2025-09-03 04:23:05 +00:00
committed by PyTorch MergeBot
parent 09d2f1b631
commit 00636e0171
5 changed files with 126 additions and 17 deletions

View File

@ -18,7 +18,7 @@ from torch import multiprocessing as mp, nn
from torch._dynamo import reset
from torch._dynamo.exc import BackendCompilerFailed
from torch._dynamo.testing import rand_strided, reset_rng_state
from torch._dynamo.utils import same
from torch._dynamo.utils import counters, same
from torch._inductor import config
from torch._inductor.autotune_process import (
_TestBenchmarkRequest,
@ -1683,6 +1683,26 @@ class TestMaxAutotune(TestCase):
out, code = run_and_get_code(compiled_f, a, b)
torch.testing.assert_close(out, mm(a, b), atol=1e-2, rtol=1e-2)
@config.patch(
max_autotune_gemm=True,
max_autotune_prune_choices_based_on_shared_mem=True,
)
def test_max_autotune_prune_choices(self):
def mm(x, y):
return x @ y
M, K, N = (3, 3, 3)
x = torch.rand([M, K], device=GPU_TYPE, dtype=torch.float32)
y = torch.rand([K, N], device=GPU_TYPE, dtype=torch.float32)
compiled_f = torch.compile(mm)
compiled_f(x, y)
self.assertEqual(
counters["inductor"]["select_algorithm_num_precompilation_exceptions"], 0
)
class TestMaxAutotunePrecompile(TestCase):
def test_precompilation_threads(self):

View File

@ -3,15 +3,24 @@
import functools
import sys
import unittest
from unittest import skipUnless
from unittest.mock import MagicMock, patch
import torch
from torch._dynamo.testing import rand_strided
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
from torch._inductor.utils import clone_preserve_strides
from torch.testing._internal.common_utils import IS_LINUX, runOnRocm, skipIfXpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_LINUX,
parametrize,
runOnRocm,
skipIfRocm,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_CUDA_AND_TRITON,
HAS_GPU,
requires_cuda_with_enough_memory,
)
@ -67,6 +76,7 @@ def get_autotuned_amd_sqr_kernel():
)(amd_sqr_kernel)
@instantiate_parametrized_tests
class TestTritonHeuristics(TestCase):
device_type = GPU_TYPE
@ -262,6 +272,34 @@ class TestTritonHeuristics(TestCase):
res = torch.compile(fn)(x)
self.assertEqual(ref, res)
@skipIfXpu
@skipIfRocm
@skipUnless(HAS_CUDA_AND_TRITON, "requires CUDA")
@parametrize("do_pruning", [False, True])
def test_prune_configs_over_shared_memory_limit(self, do_pruning):
from torch._inductor.template_heuristics.triton import (
CUDAConfigHeuristic,
GemmConfig,
)
expected_count = 1 if do_pruning else 2
mm_configs = [
GemmConfig(32, 32, 32, 1, 8, 8),
GemmConfig(
128, 128, 128, 100, 8, 4
), # intentionally large to exceed shared memory limit
]
with config.patch(
{"max_autotune_prune_choices_based_on_shared_mem": do_pruning}
):
config_heuristic = CUDAConfigHeuristic()
config_heuristic.should_scale_configs = False
config_heuristic.mm_configs = mm_configs
configs = list(
config_heuristic.get_mm_configs()(3, 3, 3, dtype_size=4, op_name="mm")
)
self.assertEqual(len(configs), expected_count)
class TestArgumentCloneAndRestore(TestCase):
# Our tensor is large enough. If a unexpected copy happens, the

View File

@ -448,6 +448,12 @@ max_autotune_report_choices_stats = (
os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "1") == "1"
)
# Prune configs that require more shared memory than the hardware limit
max_autotune_prune_choices_based_on_shared_mem = (
os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_PRUNE_CHOICES_BASED_ON_SHARED_MEM", "1")
== "1"
)
# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph
graph_partition: bool = (
os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0")

View File

@ -2775,6 +2775,9 @@ class AlgorithmSelectorCache(PersistentCache):
timeout=precompilation_timeout_seconds,
):
if e := future.exception():
counters["inductor"][
"select_algorithm_num_precompilation_exceptions"
] += 1
exceptions.append((futures[future], e))
from torch._inductor.codegen.cuda.cuda_kernel import (
CUDATemplateCaller,

View File

@ -551,34 +551,69 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
return scaled_configs
def _get_exceeding_shared_memory_checker(
self,
) -> Optional[Callable[[BaseConfig, int], bool]]:
"""
Returns a function that checks whether a given configuration exceeds the available shared memory for the device.
If the device does not report available shared memory, returns None.
"""
try:
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
if not hasattr(props, "shared_memory_per_block_optin"): # for NVidia GPUs
return None
sm_available = int(props.shared_memory_per_block_optin)
except Exception:
# If CUDA is not available or properties cannot be queried, return None
return None
# TODO make a BaseDeviceConfigHeuristics to handle different device configuration in its own implementation.
def exceeds(gemm_config: BaseConfig, dtype_size: int) -> bool:
shared_mem_accum = dtype_size * (
gemm_config.block_m * gemm_config.block_k
+ gemm_config.block_n * gemm_config.block_k
)
return shared_mem_accum * gemm_config.num_stages > sm_available
return exceeds
def _prune_exceeding_max_shared_mem_configs(
self,
configs: list[BaseConfig],
dtype_size: int,
) -> list[BaseConfig]:
if dtype_size <= 0:
return configs
is_exceeding_shared_memory = self._get_exceeding_shared_memory_checker()
if is_exceeding_shared_memory is None:
return configs
return [c for c in configs if not is_exceeding_shared_memory(c, dtype_size)]
def _prune_exhaustive_configs(
self,
configs: list[BaseConfig],
dtype_size: int,
) -> list[BaseConfig]:
import torch
is_exceeding_shared_memory = self._get_exceeding_shared_memory_checker()
pruned_configs = []
for gemm_config in configs:
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
sm_available = props.shared_memory_per_block_optin # type: ignore[attr-defined]
NUM_REG = 255
# Will use more shared memory than available
if is_exceeding_shared_memory and is_exceeding_shared_memory(
gemm_config, dtype_size
):
continue
NUM_REG = 255
acc_regs = math.ceil(
gemm_config.block_m * gemm_config.block_n / (gemm_config.num_warps * 32)
)
shared_mem_accum = dtype_size * (
gemm_config.block_m * gemm_config.block_k
+ gemm_config.block_n * gemm_config.block_k
)
# Will use more shared memory than available
if shared_mem_accum * gemm_config.num_stages > sm_available:
continue
# Lower bound for register spillage, if exceeds the kernel will certainly spill
elif acc_regs > NUM_REG:
if acc_regs > NUM_REG:
continue
pruned_configs.append(gemm_config)
@ -610,6 +645,13 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
scaled_configs = self._scale_mm_configs(
m, n, k, configs, scale, has_int8_tensor, exclude
)
# Filter out configs that require more shared memory than is available.
if config.max_autotune_prune_choices_based_on_shared_mem:
scaled_configs = self._prune_exceeding_max_shared_mem_configs(
scaled_configs, dtype_size
)
if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
assert dtype_size > 0, "dtype_size must be provided for exhaustive search"
scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size)