[cuBLAS][cuBLASLt] Unify cuBLASLt workspaces with cuBLAS workspaces (#145130)

As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925
+ fixes behavior broken behavior with the memtracker; https://github.com/pytorch/pytorch/pull/139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145130
Approved by: https://github.com/ngimel
This commit is contained in:
eqy
2025-03-22 05:50:11 +00:00
committed by PyTorch MergeBot
parent 51fa8fb0ff
commit 8f7fbe3d7d
4 changed files with 65 additions and 26 deletions

View File

@ -3,6 +3,7 @@
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContextLight.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDADataType.h>
@ -221,6 +222,36 @@ static size_t _getWorkspaceSize() {
return workspace_size;
}
void* _getWorkspaceWithoutHandle() {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
return workspace_it->second.mutable_get();
}
void* _getWorkspace(size_t& workspaceSize) {
// #ifdef (defined(USE_ROCM) || defined(FBCODE_CAFFE2))
workspaceSize = _getWorkspaceSize();
auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize();
if (cublasWorkspaceSize < workspaceSize) {
TORCH_WARN_ONCE("Requested CUBLASLT workspace size of ", workspaceSize,
" bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize,
" bytes. Please increase CUBLAS workspace size",
" via CUBLAS_WORKSPACE_CONFIG or decrease requested"
" CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace"
" size will be limited to the CUBLAS workspace size.");
workspaceSize = cublasWorkspaceSize;
}
// #else
// workspaceSize = at::cuda::getChosenWorkspaceSize();
// #endif
auto workspace_ptr = _getWorkspaceWithoutHandle();
return workspace_ptr;
}
} // anonymous namespace
namespace at::cuda::blas {
@ -410,9 +441,8 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
}
CuBlasLtMatmulPreference preference;
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = _getWorkspaceSize();
size_t workspaceSize = 0;
auto workspace_ptr = _getWorkspace(workspaceSize);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
#ifndef USE_ROCM
@ -424,8 +454,6 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
#endif
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
@ -457,7 +485,7 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
c,
Cdesc.descriptor(),
&heuristicResult.algo,
workspace.mutable_data_ptr(),
workspace_ptr,
workspaceSize,
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
@ -1357,9 +1385,8 @@ void gemm_and_bias(
CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
CuBlasLtMatmulPreference preference;
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = _getWorkspaceSize();
size_t workspaceSize = 0;
auto workspace_ptr = _getWorkspace(workspaceSize);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
#ifndef USE_ROCM
@ -1373,8 +1400,7 @@ void gemm_and_bias(
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
#endif
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto stream = c10::cuda::getCurrentCUDAStream();
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
@ -1407,9 +1433,9 @@ void gemm_and_bias(
result_ptr,
Cdesc.descriptor(),
&heuristicResult.algo,
workspace.mutable_data_ptr(),
workspace_ptr,
workspaceSize,
at::cuda::getCurrentCUDAStream());
stream);
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
@ -1586,9 +1612,9 @@ void scaled_gemm(
#endif // CUDA_VERSION >= 12080
}
size_t workspaceSize = _getWorkspaceSize();
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto stream = c10::cuda::getCurrentCUDAStream();
size_t workspaceSize = 0;
auto workspace_ptr = _getWorkspace(workspaceSize);
CuBlasLtMatmulPreference preference;
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
cublasLtMatmulHeuristicResult_t heuristicResult = {};
@ -1671,9 +1697,9 @@ void scaled_gemm(
result_ptr,
Ddesc.descriptor(),
&heuristicResult.algo,
workspace.mutable_data_ptr(),
workspace_ptr,
workspaceSize,
at::cuda::getCurrentCUDAStream());
stream);
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
@ -1749,8 +1775,8 @@ void int8_gemm(
CuBlasLtMatmulPreference preference;
size_t workspaceSize = _getWorkspaceSize();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
@ -1788,7 +1814,7 @@ void int8_gemm(
nullptr, // Heuristics don't seem to work for int8
#endif
#ifdef USE_ROCM
workspace.mutable_data_ptr(),
workspace.mutable_get(),
#else
nullptr, // Non-zero workspace doesn't seem to work.
#endif

View File

@ -2,6 +2,7 @@
// Light-weight version of CUDAContext.h with fewer transitive includes
#include <cstdint>
#include <map>
#include <cuda_runtime_api.h>
#include <cusparse.h>
@ -87,6 +88,8 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
#if defined(CUDART_VERSION) || defined(USE_ROCM)
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();

View File

@ -83,11 +83,6 @@ static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, v
#endif
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
void createCublasHandle(cublasHandle_t *handle) {
TORCH_CUDABLAS_CHECK(cublasCreate(handle));
}
@ -109,6 +104,11 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
void clearCublasWorkspaces() {
cublas_handle_stream_to_workspace().clear();
}

View File

@ -3592,6 +3592,16 @@ def run(runner, args, original_dir=None):
# some of the models do not support use_deterministic_algorithms
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# TODO(eqy): revisit when cuBLASLt workspace size is bumped
# if args.only is not None and args.only in {
# "DebertaForQuestionAnswering",
# "RobertaForQuestionAnswering",
# "nvidia_deeprecommender",
# "volo_d1_224",
# }:
# # These seem unhappy with numerics of larger cuBLASLt workspace
# # sizes following #145130 (due to enabling split-k?)
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.benchmark = False