Add option to limit number of SMs used by matmul kernels (#147966)

Resubmission of #144974 which was reverted for unrelated reasons.

Newer matmul kernels, e.g. those targeting Hopper GPUs, sometime use a "persistent" schedule which consists in launching as many CUDA blocks as there are SMs on the GPU, with each such block then working on multiple output tiles in a row. This allows to eliminate the overhead of starting and finishing each tile, effectively doing cross-tile pipelining. In previous generations these latencies could be hidden by having multiple CUDA blocks per SM but, with blocks becoming larger, only one can run at a time per SM and thus this needs to be taken care of in software.

Persistent kernels become an issue when other kernels are running concurrently. The classical example is a NCCL communication kernel running in the background. In such cases the matmul expects to be able to use all the SMs but is prevented from doing so because some of the are busy. This can lead to its blocks being scheduled as two separate waves on the available SMs. This "wave quantization" can double the latency of the matmul kernels.

While we wait for smarter solutions, such as automatic load balancing among the blocks, an easy way to unblock ourselves is to tell the matmuls to only use a subset of the GPU's SMs. For this, I am introducing a global `sm_carveout` flag which can be used to specify how many SMs should be left available for other kernels.

For now I only change the cuBLAS kernels and the scaled-mm CUTLASS kernel. More kernels can be opted-in later.

I tested this change manually, by using the Kineto profiler to look up the grid size of a scaled-mm kernel with different values of `sm_carveout`, and making sure it changed. Suggestions are welcome for a more automated test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147966
Approved by: https://github.com/danthe3rd
This commit is contained in:
Luca Wehrstedt
2025-02-26 10:38:07 +00:00
committed by PyTorch MergeBot
parent 7ffae2c028
commit 60d94ea22b
8 changed files with 127 additions and 5 deletions

View File

@ -433,6 +433,18 @@ void Context::setAllowFP16AccumulationCuBLAS(bool b) {
allow_fp16_accumulation_cublas = b;
}
std::optional<int32_t> Context::_SMCarveout_EXPERIMENTAL() const {
return sm_carveout;
}
void Context::_setSMCarveout_EXPERIMENTAL(std::optional<int32_t> c) {
if (c.has_value()) {
TORCH_WARN_ONCE(
"Setting the SM carveout for matmuls is a temporary experimental mitigation for performance issues, "
"while more robust solutions are developed. It may be removed at any moment without notice.");
}
sm_carveout = c;
}
bool Context::hasMKL() {
#if AT_MKL_ENABLED()

View File

@ -345,6 +345,19 @@ class TORCH_API Context {
void setAllowBF16ReductionCuBLAS(bool);
bool allowFP16AccumulationCuBLAS() const;
void setAllowFP16AccumulationCuBLAS(bool);
// Matmuls can use a so-called "persistent" kernel which launches one CUDA
// block for each SM on the GPU, and each block then iterates over multiple
// output tiles. This allows to use software pipelining to hide the begin/end
// latencies (e.g., epilogue), especially when only one tile fits per SM.
// However, if some SMs are busy (e.g., with a background NCCL kernel), the
// matmul's blocks will be scheduled in two waves and, in the absence of some
// smart load balancing, the kernel will take twice as long. This flag allows
// to make matmuls target only a subset of the SMs, so they can fully schedule
// even next to a comms kernel, and only be a few percent slower.
std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t>);
at::QEngine qEngine() const;
void setQEngine(at::QEngine e);
static const std::vector<at::QEngine>& supportedQEngines();
@ -423,6 +436,7 @@ class TORCH_API Context {
bool allow_fp16_reduction_cublas = true;
bool allow_bf16_reduction_cublas = true;
bool allow_fp16_accumulation_cublas = false;
std::optional<int32_t> sm_carveout = std::nullopt;
bool enabled_mkldnn = true;
bool allow_tf32_onednn = false;
bool enabled_nnpack = true;

View File

@ -405,6 +405,14 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#endif
CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T);
CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T);
CuBlasLtMatrixLayout Cdesc(abcType, m, n, ldc);
@ -1331,6 +1339,14 @@ void gemm_and_bias(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#endif
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
@ -1541,6 +1557,14 @@ void scaled_gemm(
if (result_scale_ptr != nullptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
}
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#endif
#ifndef USE_ROCM
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
@ -1701,7 +1725,14 @@ void int8_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#endif
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);

View File

@ -46,8 +46,6 @@ C10_DIAGNOSTIC_POP()
namespace {
constexpr int kNumSMsForH100 = 132;
using DtypeScale = float;
using DtypeAccum = float;
using DtypeEpilogue = float;
@ -263,6 +261,13 @@ void f8f8bf16_rowwise_impl(
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
arguments.hw_info.sm_count =
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
// Set the swizzle size
arguments.scheduler.max_swizzle_size = swizzle;
@ -521,12 +526,17 @@ void dispatch_fp8_rowwise_kernel_on_tile_size(
int M = XQ.size(0);
int N = WQ.size(1);
int smTarget = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount;
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
smTarget -= at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
// We prefer to use smaller tiles (less wasted compute in case of padding),
// but if this causes us to have more CUDA blocks than there are SMs on the
// GPU then we'll hit wave quantization, hence we'll switch to larger tiles.
if (ceildiv(M, 64 * cute::get<0>(ClusterShape{})) *
ceildiv(N, 128 * cute::get<1>(ClusterShape{})) <=
kNumSMsForH100 / cute::size(ClusterShape{})) {
smTarget / cute::size(ClusterShape{})) {
return f8f8bf16_rowwise_impl<
/*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
ClusterShape,

View File

@ -1,11 +1,14 @@
# Owner(s): ["module: linear algebra"]
import contextlib
import json
import math
import re
import tempfile
import unittest
from itertools import product
from functools import partial
from typing import Optional
import re
import torch
@ -18,6 +21,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_cuda import (
SM53OrLater,
SM89OrLater,
SM90OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8
)
@ -768,6 +772,45 @@ class TestFP8MatmulCuda(TestCase):
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support row-wise scaling")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
def test_honor_sm_carveout(self) -> None:
torch.manual_seed(42)
x = torch.randn(8192, 2048, device="cuda", dtype=torch.float32)
y = torch.randn(8192, 2048, device="cuda", dtype=torch.float32).t()
x_scales = tensor_to_scale(x, e4m3_type, dim=1).reciprocal()
y_scales = tensor_to_scale(y, e4m3_type, dim=0).reciprocal()
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
with tempfile.NamedTemporaryFile() as f:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
self.assertIsNone(torch._C._get_sm_carveout_experimental())
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
torch._C._set_sm_carveout_experimental(0)
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
torch._C._set_sm_carveout_experimental(66)
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
torch._C._set_sm_carveout_experimental(None)
self.assertIsNone(torch._C._get_sm_carveout_experimental())
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
prof.export_chrome_trace(f.name)
no_carveout, carveout_0, carveout_66, no_carveout_again = [
math.prod(evt.get("args", {}).get("grid", []))
for evt in json.load(open(f.name))["traceEvents"]
if evt.get("cat", "") == "kernel"
]
self.assertEqual(no_carveout, no_carveout_again)
self.assertNotEqual(no_carveout, carveout_66)
self.assertNotEqual(carveout_66, carveout_0)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")

View File

@ -1216,6 +1216,8 @@ def _get_cublas_allow_fp16_accumulation() -> _bool: ... # THPModule_allowFP16Acc
def _set_cublas_allow_fp16_accumulation(
arg: _bool,
) -> None: ... # THPModule_setAllowFP16AccumulationCuBLAS
def _get_sm_carveout_experimental() -> Optional[_int]: ...
def _set_sm_carveout_experimental(arg: Optional[_int]) -> None: ...
def _set_conj(x: Tensor, conj: _bool) -> None: ...
def _set_neg(x: Tensor, neg: _bool) -> None: ...
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...

View File

@ -639,6 +639,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._get_privateuse1_backend_name",
"torch._C._get_qengine",
"torch._C._get_schema",
"torch._C._get_sm_carveout_experimental",
"torch._C._get_nested_int",
"torch._C._get_tensor_metadata",
"torch._C._get_tracing_state",
@ -1157,6 +1158,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._set_math_sdp_allow_fp16_bf16_reduction",
"torch._C._set_sdp_use_mem_efficient",
"torch._C._set_should_use_format_with_string_table",
"torch._C._set_sm_carveout_experimental",
"torch._C._set_storage_access_error_msg",
"torch._C._set_tensor_metadata",
"torch._C._set_tracing_state",

View File

@ -2262,6 +2262,14 @@ Call this whenever a new thread is created in order to propagate values from
return at::globalContext().getROCmFAPreferredBackend();
});
py_module.def(
"_set_sm_carveout_experimental", [](std::optional<int32_t> val) {
at::globalContext()._setSMCarveout_EXPERIMENTAL(val);
});
py_module.def("_get_sm_carveout_experimental", []() {
return at::globalContext()._SMCarveout_EXPERIMENTAL();
});
py_module.def(
"_construct_storage_from_data_pointer",
[](int64_t data_ptr, c10::Device device, size_t size_bytes) {