[CUDA][cuBLAS][cuBLASLt] Opt-in unified cuBLAS + cuBLASLt workspaces (#151163)

opt-in version of https://github.com/pytorch/pytorch/pull/145130 as there was a lack of repro for the 70% forward issue
`TORCH_CUBLASLT_UNIFIED_WORKSPACE=1`

@izaitsevfb could you comment if it was repeatable per every forward pass, on startup, or something else?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151163
Approved by: https://github.com/ngimel
This commit is contained in:
Eddie Yan
2025-04-23 15:24:22 +00:00
committed by PyTorch MergeBot
parent 7310049c42
commit dcc32ff5bf
4 changed files with 80 additions and 31 deletions

View File

@ -3,6 +3,7 @@
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContextLight.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDADataType.h>
@ -221,6 +222,48 @@ static size_t _getWorkspaceSize() {
return workspace_size;
}
void* _getUnifiedWorkspaceWithoutHandle() {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
return workspace_it->second.mutable_get();
}
struct CublasLtWorkspace {
CublasLtWorkspace() {
size = _getWorkspaceSize();
#ifndef USE_ROCM
static bool unified = c10::utils::check_env("TORCH_CUBLASLT_UNIFIED_WORKSPACE") == true;
if (unified) {
auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize();
if (cublasWorkspaceSize < size) {
TORCH_WARN_ONCE("Requested unified CUBLASLT workspace size of ", size,
" bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize,
" bytes. Please increase CUBLAS workspace size",
" via CUBLAS_WORKSPACE_CONFIG or decrease requested"
" CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace"
" size will be limited to the CUBLAS workspace size.");
size = cublasWorkspaceSize;
}
ptr = _getUnifiedWorkspaceWithoutHandle();
} else {
auto allocator = c10::cuda::CUDACachingAllocator::get();
stashed_ptr_ = allocator->allocate(size);
ptr = stashed_ptr_.mutable_get();
}
#else
auto allocator = c10::cuda::CUDACachingAllocator::get();
stashed_ptr_ = allocator->allocate(size);
ptr = stashed_ptr_.mutable_get();
#endif
}
at::DataPtr stashed_ptr_;
void * ptr;
size_t size;
};
} // anonymous namespace
namespace at::cuda::blas {
@ -415,10 +458,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
}
CuBlasLtMatmulPreference preference;
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = _getWorkspaceSize();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
#ifndef USE_ROCM
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(a));
@ -429,7 +468,9 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
#endif
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto ltworkspace = CublasLtWorkspace();
TORCH_CHECK(ltworkspace.ptr != nullptr, "OOM trying to allocate workspace for cublaslt");
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
@ -463,8 +504,8 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
c,
Cdesc.descriptor(),
&heuristicResult.algo,
workspace.mutable_data_ptr(),
workspaceSize,
ltworkspace.ptr,
ltworkspace.size,
at::cuda::getCurrentCUDAStream());
}
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
@ -1577,10 +1618,8 @@ bool gemm_and_bias(
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
CuBlasLtMatmulPreference preference;
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = _getWorkspaceSize();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
auto ltworkspace = CublasLtWorkspace();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
#ifndef USE_ROCM
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
@ -1593,8 +1632,6 @@ bool gemm_and_bias(
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
#endif
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
@ -1628,8 +1665,8 @@ bool gemm_and_bias(
result_ptr,
Cdesc.descriptor(),
&heuristicResult.algo,
workspace.mutable_data_ptr(),
workspaceSize,
ltworkspace.ptr,
ltworkspace.size,
at::cuda::getCurrentCUDAStream());
}
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
@ -1855,11 +1892,10 @@ void scaled_gemm(
#endif // if CUDA_VERSION >= 12080
}
size_t workspaceSize = _getWorkspaceSize();
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto stream = c10::cuda::getCurrentCUDAStream();
CuBlasLtMatmulPreference preference;
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
auto ltworkspace = CublasLtWorkspace();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
@ -1912,7 +1948,7 @@ void scaled_gemm(
all_algos[i].algo,
ret_workspace_size);
if (is_valid_status == HIPBLAS_STATUS_SUCCESS) {
if (ret_workspace_size <= workspaceSize) {
if (ret_workspace_size <= ltworkspace.size) {
heuristicResult = all_algos[i];
found = true;
break;
@ -1940,9 +1976,9 @@ void scaled_gemm(
result_ptr,
Ddesc.descriptor(),
&heuristicResult.algo,
workspace.mutable_data_ptr(),
workspaceSize,
at::cuda::getCurrentCUDAStream());
ltworkspace.ptr,
ltworkspace.size,
stream);
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
@ -2018,8 +2054,8 @@ void int8_gemm(
CuBlasLtMatmulPreference preference;
size_t workspaceSize = _getWorkspaceSize();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
@ -2057,7 +2093,7 @@ void int8_gemm(
nullptr, // Heuristics don't seem to work for int8
#endif
#ifdef USE_ROCM
workspace.mutable_data_ptr(),
workspace.mutable_get(),
#else
nullptr, // Non-zero workspace doesn't seem to work.
#endif

View File

@ -2,6 +2,7 @@
// Light-weight version of CUDAContext.h with fewer transitive includes
#include <cstdint>
#include <map>
#include <cuda_runtime_api.h>
#include <cusparse.h>
@ -87,6 +88,8 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
#if defined(CUDART_VERSION) || defined(USE_ROCM)
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();

View File

@ -83,11 +83,6 @@ static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, v
#endif
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
void createCublasHandle(cublasHandle_t *handle) {
TORCH_CUDABLAS_CHECK(cublasCreate(handle));
}
@ -109,6 +104,11 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
void clearCublasWorkspaces() {
cublas_handle_stream_to_workspace().clear();
}

View File

@ -3556,6 +3556,16 @@ def run(runner, args, original_dir=None):
if args.devices == ["xpu"]:
torch.use_deterministic_algorithms(True, warn_only=True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# TODO(eqy): revisit when cuBLASLt workspace size is bumped
# if args.only is not None and args.only in {
# "DebertaForQuestionAnswering",
# "RobertaForQuestionAnswering",
# "nvidia_deeprecommender",
# "volo_d1_224",
# }:
# # These seem unhappy with numerics of larger cuBLASLt workspace
# # sizes following #145130 (due to enabling split-k?)
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.benchmark = False