mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland][Inductor] Prune configs that require more shared memory than the hardware limit. (#161996)
Summary: This is a re-land of [PR161040](https://github.com/pytorch/pytorch/pull/161040), which had previously caused test failures on AMD GPUs. The tests are now configured to target only NVIDIA GPUs. This diff removes configurations that exceed the hardware shared memory limit, which causes the following compilation error: ``` No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help. ``` Test Plan: ``` pytest test/inductor/test_max_autotune.py pytest test/inductor/test_triton_heuristics.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161996 Approved by: https://github.com/coconutruben
This commit is contained in:
committed by
PyTorch MergeBot
parent
09d2f1b631
commit
00636e0171
@ -18,7 +18,7 @@ from torch import multiprocessing as mp, nn
|
||||
from torch._dynamo import reset
|
||||
from torch._dynamo.exc import BackendCompilerFailed
|
||||
from torch._dynamo.testing import rand_strided, reset_rng_state
|
||||
from torch._dynamo.utils import same
|
||||
from torch._dynamo.utils import counters, same
|
||||
from torch._inductor import config
|
||||
from torch._inductor.autotune_process import (
|
||||
_TestBenchmarkRequest,
|
||||
@ -1683,6 +1683,26 @@ class TestMaxAutotune(TestCase):
|
||||
out, code = run_and_get_code(compiled_f, a, b)
|
||||
torch.testing.assert_close(out, mm(a, b), atol=1e-2, rtol=1e-2)
|
||||
|
||||
@config.patch(
|
||||
max_autotune_gemm=True,
|
||||
max_autotune_prune_choices_based_on_shared_mem=True,
|
||||
)
|
||||
def test_max_autotune_prune_choices(self):
|
||||
def mm(x, y):
|
||||
return x @ y
|
||||
|
||||
M, K, N = (3, 3, 3)
|
||||
|
||||
x = torch.rand([M, K], device=GPU_TYPE, dtype=torch.float32)
|
||||
y = torch.rand([K, N], device=GPU_TYPE, dtype=torch.float32)
|
||||
|
||||
compiled_f = torch.compile(mm)
|
||||
compiled_f(x, y)
|
||||
|
||||
self.assertEqual(
|
||||
counters["inductor"]["select_algorithm_num_precompilation_exceptions"], 0
|
||||
)
|
||||
|
||||
|
||||
class TestMaxAutotunePrecompile(TestCase):
|
||||
def test_precompilation_threads(self):
|
||||
|
@ -3,15 +3,24 @@
|
||||
import functools
|
||||
import sys
|
||||
import unittest
|
||||
from unittest import skipUnless
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC
|
||||
from torch._inductor.utils import clone_preserve_strides
|
||||
from torch.testing._internal.common_utils import IS_LINUX, runOnRocm, skipIfXpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_LINUX,
|
||||
parametrize,
|
||||
runOnRocm,
|
||||
skipIfRocm,
|
||||
skipIfXpu,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_CUDA_AND_TRITON,
|
||||
HAS_GPU,
|
||||
requires_cuda_with_enough_memory,
|
||||
)
|
||||
@ -67,6 +76,7 @@ def get_autotuned_amd_sqr_kernel():
|
||||
)(amd_sqr_kernel)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestTritonHeuristics(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
@ -262,6 +272,34 @@ class TestTritonHeuristics(TestCase):
|
||||
res = torch.compile(fn)(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@skipIfXpu
|
||||
@skipIfRocm
|
||||
@skipUnless(HAS_CUDA_AND_TRITON, "requires CUDA")
|
||||
@parametrize("do_pruning", [False, True])
|
||||
def test_prune_configs_over_shared_memory_limit(self, do_pruning):
|
||||
from torch._inductor.template_heuristics.triton import (
|
||||
CUDAConfigHeuristic,
|
||||
GemmConfig,
|
||||
)
|
||||
|
||||
expected_count = 1 if do_pruning else 2
|
||||
mm_configs = [
|
||||
GemmConfig(32, 32, 32, 1, 8, 8),
|
||||
GemmConfig(
|
||||
128, 128, 128, 100, 8, 4
|
||||
), # intentionally large to exceed shared memory limit
|
||||
]
|
||||
with config.patch(
|
||||
{"max_autotune_prune_choices_based_on_shared_mem": do_pruning}
|
||||
):
|
||||
config_heuristic = CUDAConfigHeuristic()
|
||||
config_heuristic.should_scale_configs = False
|
||||
config_heuristic.mm_configs = mm_configs
|
||||
configs = list(
|
||||
config_heuristic.get_mm_configs()(3, 3, 3, dtype_size=4, op_name="mm")
|
||||
)
|
||||
self.assertEqual(len(configs), expected_count)
|
||||
|
||||
|
||||
class TestArgumentCloneAndRestore(TestCase):
|
||||
# Our tensor is large enough. If a unexpected copy happens, the
|
||||
|
@ -448,6 +448,12 @@ max_autotune_report_choices_stats = (
|
||||
os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "1") == "1"
|
||||
)
|
||||
|
||||
# Prune configs that require more shared memory than the hardware limit
|
||||
max_autotune_prune_choices_based_on_shared_mem = (
|
||||
os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_PRUNE_CHOICES_BASED_ON_SHARED_MEM", "1")
|
||||
== "1"
|
||||
)
|
||||
|
||||
# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph
|
||||
graph_partition: bool = (
|
||||
os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0")
|
||||
|
@ -2775,6 +2775,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
timeout=precompilation_timeout_seconds,
|
||||
):
|
||||
if e := future.exception():
|
||||
counters["inductor"][
|
||||
"select_algorithm_num_precompilation_exceptions"
|
||||
] += 1
|
||||
exceptions.append((futures[future], e))
|
||||
from torch._inductor.codegen.cuda.cuda_kernel import (
|
||||
CUDATemplateCaller,
|
||||
|
@ -551,34 +551,69 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
|
||||
return scaled_configs
|
||||
|
||||
def _get_exceeding_shared_memory_checker(
|
||||
self,
|
||||
) -> Optional[Callable[[BaseConfig, int], bool]]:
|
||||
"""
|
||||
Returns a function that checks whether a given configuration exceeds the available shared memory for the device.
|
||||
If the device does not report available shared memory, returns None.
|
||||
"""
|
||||
|
||||
try:
|
||||
device = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if not hasattr(props, "shared_memory_per_block_optin"): # for NVidia GPUs
|
||||
return None
|
||||
sm_available = int(props.shared_memory_per_block_optin)
|
||||
except Exception:
|
||||
# If CUDA is not available or properties cannot be queried, return None
|
||||
return None
|
||||
|
||||
# TODO make a BaseDeviceConfigHeuristics to handle different device configuration in its own implementation.
|
||||
def exceeds(gemm_config: BaseConfig, dtype_size: int) -> bool:
|
||||
shared_mem_accum = dtype_size * (
|
||||
gemm_config.block_m * gemm_config.block_k
|
||||
+ gemm_config.block_n * gemm_config.block_k
|
||||
)
|
||||
return shared_mem_accum * gemm_config.num_stages > sm_available
|
||||
|
||||
return exceeds
|
||||
|
||||
def _prune_exceeding_max_shared_mem_configs(
|
||||
self,
|
||||
configs: list[BaseConfig],
|
||||
dtype_size: int,
|
||||
) -> list[BaseConfig]:
|
||||
if dtype_size <= 0:
|
||||
return configs
|
||||
|
||||
is_exceeding_shared_memory = self._get_exceeding_shared_memory_checker()
|
||||
if is_exceeding_shared_memory is None:
|
||||
return configs
|
||||
|
||||
return [c for c in configs if not is_exceeding_shared_memory(c, dtype_size)]
|
||||
|
||||
def _prune_exhaustive_configs(
|
||||
self,
|
||||
configs: list[BaseConfig],
|
||||
dtype_size: int,
|
||||
) -> list[BaseConfig]:
|
||||
import torch
|
||||
is_exceeding_shared_memory = self._get_exceeding_shared_memory_checker()
|
||||
|
||||
pruned_configs = []
|
||||
for gemm_config in configs:
|
||||
device = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
sm_available = props.shared_memory_per_block_optin # type: ignore[attr-defined]
|
||||
NUM_REG = 255
|
||||
# Will use more shared memory than available
|
||||
if is_exceeding_shared_memory and is_exceeding_shared_memory(
|
||||
gemm_config, dtype_size
|
||||
):
|
||||
continue
|
||||
|
||||
NUM_REG = 255
|
||||
acc_regs = math.ceil(
|
||||
gemm_config.block_m * gemm_config.block_n / (gemm_config.num_warps * 32)
|
||||
)
|
||||
|
||||
shared_mem_accum = dtype_size * (
|
||||
gemm_config.block_m * gemm_config.block_k
|
||||
+ gemm_config.block_n * gemm_config.block_k
|
||||
)
|
||||
|
||||
# Will use more shared memory than available
|
||||
if shared_mem_accum * gemm_config.num_stages > sm_available:
|
||||
continue
|
||||
# Lower bound for register spillage, if exceeds the kernel will certainly spill
|
||||
elif acc_regs > NUM_REG:
|
||||
if acc_regs > NUM_REG:
|
||||
continue
|
||||
|
||||
pruned_configs.append(gemm_config)
|
||||
@ -610,6 +645,13 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
scaled_configs = self._scale_mm_configs(
|
||||
m, n, k, configs, scale, has_int8_tensor, exclude
|
||||
)
|
||||
|
||||
# Filter out configs that require more shared memory than is available.
|
||||
if config.max_autotune_prune_choices_based_on_shared_mem:
|
||||
scaled_configs = self._prune_exceeding_max_shared_mem_configs(
|
||||
scaled_configs, dtype_size
|
||||
)
|
||||
|
||||
if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
|
||||
assert dtype_size > 0, "dtype_size must be provided for exhaustive search"
|
||||
scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size)
|
||||
|
Reference in New Issue
Block a user