mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add option to limit number of SMs used by matmul kernels (#147966)
Resubmission of #144974 which was reverted for unrelated reasons. Newer matmul kernels, e.g. those targeting Hopper GPUs, sometime use a "persistent" schedule which consists in launching as many CUDA blocks as there are SMs on the GPU, with each such block then working on multiple output tiles in a row. This allows to eliminate the overhead of starting and finishing each tile, effectively doing cross-tile pipelining. In previous generations these latencies could be hidden by having multiple CUDA blocks per SM but, with blocks becoming larger, only one can run at a time per SM and thus this needs to be taken care of in software. Persistent kernels become an issue when other kernels are running concurrently. The classical example is a NCCL communication kernel running in the background. In such cases the matmul expects to be able to use all the SMs but is prevented from doing so because some of the are busy. This can lead to its blocks being scheduled as two separate waves on the available SMs. This "wave quantization" can double the latency of the matmul kernels. While we wait for smarter solutions, such as automatic load balancing among the blocks, an easy way to unblock ourselves is to tell the matmuls to only use a subset of the GPU's SMs. For this, I am introducing a global `sm_carveout` flag which can be used to specify how many SMs should be left available for other kernels. For now I only change the cuBLAS kernels and the scaled-mm CUTLASS kernel. More kernels can be opted-in later. I tested this change manually, by using the Kineto profiler to look up the grid size of a scaled-mm kernel with different values of `sm_carveout`, and making sure it changed. Suggestions are welcome for a more automated test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147966 Approved by: https://github.com/danthe3rd
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ffae2c028
commit
60d94ea22b
@ -433,6 +433,18 @@ void Context::setAllowFP16AccumulationCuBLAS(bool b) {
|
||||
allow_fp16_accumulation_cublas = b;
|
||||
}
|
||||
|
||||
std::optional<int32_t> Context::_SMCarveout_EXPERIMENTAL() const {
|
||||
return sm_carveout;
|
||||
}
|
||||
|
||||
void Context::_setSMCarveout_EXPERIMENTAL(std::optional<int32_t> c) {
|
||||
if (c.has_value()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Setting the SM carveout for matmuls is a temporary experimental mitigation for performance issues, "
|
||||
"while more robust solutions are developed. It may be removed at any moment without notice.");
|
||||
}
|
||||
sm_carveout = c;
|
||||
}
|
||||
|
||||
bool Context::hasMKL() {
|
||||
#if AT_MKL_ENABLED()
|
||||
|
@ -345,6 +345,19 @@ class TORCH_API Context {
|
||||
void setAllowBF16ReductionCuBLAS(bool);
|
||||
bool allowFP16AccumulationCuBLAS() const;
|
||||
void setAllowFP16AccumulationCuBLAS(bool);
|
||||
|
||||
// Matmuls can use a so-called "persistent" kernel which launches one CUDA
|
||||
// block for each SM on the GPU, and each block then iterates over multiple
|
||||
// output tiles. This allows to use software pipelining to hide the begin/end
|
||||
// latencies (e.g., epilogue), especially when only one tile fits per SM.
|
||||
// However, if some SMs are busy (e.g., with a background NCCL kernel), the
|
||||
// matmul's blocks will be scheduled in two waves and, in the absence of some
|
||||
// smart load balancing, the kernel will take twice as long. This flag allows
|
||||
// to make matmuls target only a subset of the SMs, so they can fully schedule
|
||||
// even next to a comms kernel, and only be a few percent slower.
|
||||
std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
|
||||
void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t>);
|
||||
|
||||
at::QEngine qEngine() const;
|
||||
void setQEngine(at::QEngine e);
|
||||
static const std::vector<at::QEngine>& supportedQEngines();
|
||||
@ -423,6 +436,7 @@ class TORCH_API Context {
|
||||
bool allow_fp16_reduction_cublas = true;
|
||||
bool allow_bf16_reduction_cublas = true;
|
||||
bool allow_fp16_accumulation_cublas = false;
|
||||
std::optional<int32_t> sm_carveout = std::nullopt;
|
||||
bool enabled_mkldnn = true;
|
||||
bool allow_tf32_onednn = false;
|
||||
bool enabled_nnpack = true;
|
||||
|
@ -405,6 +405,14 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#endif
|
||||
CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout Cdesc(abcType, m, n, ldc);
|
||||
@ -1331,6 +1339,14 @@ void gemm_and_bias(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
||||
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#endif
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
||||
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||
@ -1541,6 +1557,14 @@ void scaled_gemm(
|
||||
if (result_scale_ptr != nullptr) {
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#endif
|
||||
#ifndef USE_ROCM
|
||||
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
|
||||
@ -1701,7 +1725,14 @@ void int8_gemm(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
||||
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#endif
|
||||
|
||||
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
|
||||
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
|
||||
|
@ -46,8 +46,6 @@ C10_DIAGNOSTIC_POP()
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int kNumSMsForH100 = 132;
|
||||
|
||||
using DtypeScale = float;
|
||||
using DtypeAccum = float;
|
||||
using DtypeEpilogue = float;
|
||||
@ -263,6 +261,13 @@ void f8f8bf16_rowwise_impl(
|
||||
// multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
arguments.hw_info.sm_count =
|
||||
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
|
||||
}
|
||||
|
||||
// Set the swizzle size
|
||||
arguments.scheduler.max_swizzle_size = swizzle;
|
||||
|
||||
@ -521,12 +526,17 @@ void dispatch_fp8_rowwise_kernel_on_tile_size(
|
||||
int M = XQ.size(0);
|
||||
int N = WQ.size(1);
|
||||
|
||||
int smTarget = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount;
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
smTarget -= at::globalContext()._SMCarveout_EXPERIMENTAL().value();
|
||||
}
|
||||
|
||||
// We prefer to use smaller tiles (less wasted compute in case of padding),
|
||||
// but if this causes us to have more CUDA blocks than there are SMs on the
|
||||
// GPU then we'll hit wave quantization, hence we'll switch to larger tiles.
|
||||
if (ceildiv(M, 64 * cute::get<0>(ClusterShape{})) *
|
||||
ceildiv(N, 128 * cute::get<1>(ClusterShape{})) <=
|
||||
kNumSMsForH100 / cute::size(ClusterShape{})) {
|
||||
smTarget / cute::size(ClusterShape{})) {
|
||||
return f8f8bf16_rowwise_impl<
|
||||
/*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
|
||||
ClusterShape,
|
||||
|
@ -1,11 +1,14 @@
|
||||
# Owner(s): ["module: linear algebra"]
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
from itertools import product
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
@ -18,6 +21,7 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import (
|
||||
SM53OrLater,
|
||||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8
|
||||
)
|
||||
@ -768,6 +772,45 @@ class TestFP8MatmulCuda(TestCase):
|
||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support row-wise scaling")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
|
||||
def test_honor_sm_carveout(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
x = torch.randn(8192, 2048, device="cuda", dtype=torch.float32)
|
||||
y = torch.randn(8192, 2048, device="cuda", dtype=torch.float32).t()
|
||||
x_scales = tensor_to_scale(x, e4m3_type, dim=1).reciprocal()
|
||||
y_scales = tensor_to_scale(y, e4m3_type, dim=0).reciprocal()
|
||||
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
|
||||
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
|
||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
torch._C._set_sm_carveout_experimental(0)
|
||||
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
torch._C._set_sm_carveout_experimental(66)
|
||||
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
torch._C._set_sm_carveout_experimental(None)
|
||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
|
||||
prof.export_chrome_trace(f.name)
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
||||
math.prod(evt.get("args", {}).get("grid", []))
|
||||
for evt in json.load(open(f.name))["traceEvents"]
|
||||
if evt.get("cat", "") == "kernel"
|
||||
]
|
||||
|
||||
self.assertEqual(no_carveout, no_carveout_again)
|
||||
self.assertNotEqual(no_carveout, carveout_66)
|
||||
self.assertNotEqual(carveout_66, carveout_0)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||
|
@ -1216,6 +1216,8 @@ def _get_cublas_allow_fp16_accumulation() -> _bool: ... # THPModule_allowFP16Acc
|
||||
def _set_cublas_allow_fp16_accumulation(
|
||||
arg: _bool,
|
||||
) -> None: ... # THPModule_setAllowFP16AccumulationCuBLAS
|
||||
def _get_sm_carveout_experimental() -> Optional[_int]: ...
|
||||
def _set_sm_carveout_experimental(arg: Optional[_int]) -> None: ...
|
||||
def _set_conj(x: Tensor, conj: _bool) -> None: ...
|
||||
def _set_neg(x: Tensor, neg: _bool) -> None: ...
|
||||
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
|
||||
|
@ -639,6 +639,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._get_privateuse1_backend_name",
|
||||
"torch._C._get_qengine",
|
||||
"torch._C._get_schema",
|
||||
"torch._C._get_sm_carveout_experimental",
|
||||
"torch._C._get_nested_int",
|
||||
"torch._C._get_tensor_metadata",
|
||||
"torch._C._get_tracing_state",
|
||||
@ -1157,6 +1158,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._set_math_sdp_allow_fp16_bf16_reduction",
|
||||
"torch._C._set_sdp_use_mem_efficient",
|
||||
"torch._C._set_should_use_format_with_string_table",
|
||||
"torch._C._set_sm_carveout_experimental",
|
||||
"torch._C._set_storage_access_error_msg",
|
||||
"torch._C._set_tensor_metadata",
|
||||
"torch._C._set_tracing_state",
|
||||
|
@ -2262,6 +2262,14 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
return at::globalContext().getROCmFAPreferredBackend();
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_set_sm_carveout_experimental", [](std::optional<int32_t> val) {
|
||||
at::globalContext()._setSMCarveout_EXPERIMENTAL(val);
|
||||
});
|
||||
py_module.def("_get_sm_carveout_experimental", []() {
|
||||
return at::globalContext()._SMCarveout_EXPERIMENTAL();
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_construct_storage_from_data_pointer",
|
||||
[](int64_t data_ptr, c10::Device device, size_t size_bytes) {
|
||||
|
Reference in New Issue
Block a user