mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA][cuBLAS][cuBLASLt] Opt-in unified cuBLAS + cuBLASLt workspaces (#151163)
opt-in version of https://github.com/pytorch/pytorch/pull/145130 as there was a lack of repro for the 70% forward issue `TORCH_CUBLASLT_UNIFIED_WORKSPACE=1` @izaitsevfb could you comment if it was repeatable per every forward pass, on startup, or something else? Pull Request resolved: https://github.com/pytorch/pytorch/pull/151163 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
7310049c42
commit
dcc32ff5bf
@ -3,6 +3,7 @@
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContextLight.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
@ -221,6 +222,48 @@ static size_t _getWorkspaceSize() {
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
void* _getUnifiedWorkspaceWithoutHandle() {
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
|
||||
struct CublasLtWorkspace {
|
||||
CublasLtWorkspace() {
|
||||
size = _getWorkspaceSize();
|
||||
#ifndef USE_ROCM
|
||||
static bool unified = c10::utils::check_env("TORCH_CUBLASLT_UNIFIED_WORKSPACE") == true;
|
||||
if (unified) {
|
||||
auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize();
|
||||
if (cublasWorkspaceSize < size) {
|
||||
TORCH_WARN_ONCE("Requested unified CUBLASLT workspace size of ", size,
|
||||
" bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize,
|
||||
" bytes. Please increase CUBLAS workspace size",
|
||||
" via CUBLAS_WORKSPACE_CONFIG or decrease requested"
|
||||
" CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace"
|
||||
" size will be limited to the CUBLAS workspace size.");
|
||||
size = cublasWorkspaceSize;
|
||||
}
|
||||
ptr = _getUnifiedWorkspaceWithoutHandle();
|
||||
} else {
|
||||
auto allocator = c10::cuda::CUDACachingAllocator::get();
|
||||
stashed_ptr_ = allocator->allocate(size);
|
||||
ptr = stashed_ptr_.mutable_get();
|
||||
}
|
||||
#else
|
||||
auto allocator = c10::cuda::CUDACachingAllocator::get();
|
||||
stashed_ptr_ = allocator->allocate(size);
|
||||
ptr = stashed_ptr_.mutable_get();
|
||||
#endif
|
||||
}
|
||||
at::DataPtr stashed_ptr_;
|
||||
void * ptr;
|
||||
size_t size;
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
namespace at::cuda::blas {
|
||||
@ -415,10 +458,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
}
|
||||
|
||||
CuBlasLtMatmulPreference preference;
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
||||
// setting this to 1M.
|
||||
size_t workspaceSize = _getWorkspaceSize();
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(a));
|
||||
@ -429,7 +468,9 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
|
||||
#endif
|
||||
|
||||
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
auto ltworkspace = CublasLtWorkspace();
|
||||
TORCH_CHECK(ltworkspace.ptr != nullptr, "OOM trying to allocate workspace for cublaslt");
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
|
||||
|
||||
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
@ -463,8 +504,8 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
c,
|
||||
Cdesc.descriptor(),
|
||||
&heuristicResult.algo,
|
||||
workspace.mutable_data_ptr(),
|
||||
workspaceSize,
|
||||
ltworkspace.ptr,
|
||||
ltworkspace.size,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
|
||||
@ -1577,10 +1618,8 @@ bool gemm_and_bias(
|
||||
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
|
||||
|
||||
CuBlasLtMatmulPreference preference;
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
||||
// setting this to 1M.
|
||||
size_t workspaceSize = _getWorkspaceSize();
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
||||
auto ltworkspace = CublasLtWorkspace();
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
|
||||
@ -1593,8 +1632,6 @@ bool gemm_and_bias(
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
|
||||
#endif
|
||||
|
||||
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
|
||||
@ -1628,8 +1665,8 @@ bool gemm_and_bias(
|
||||
result_ptr,
|
||||
Cdesc.descriptor(),
|
||||
&heuristicResult.algo,
|
||||
workspace.mutable_data_ptr(),
|
||||
workspaceSize,
|
||||
ltworkspace.ptr,
|
||||
ltworkspace.size,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
|
||||
@ -1855,11 +1892,10 @@ void scaled_gemm(
|
||||
#endif // if CUDA_VERSION >= 12080
|
||||
}
|
||||
|
||||
size_t workspaceSize = _getWorkspaceSize();
|
||||
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
CuBlasLtMatmulPreference preference;
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
||||
auto ltworkspace = CublasLtWorkspace();
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
|
||||
@ -1912,7 +1948,7 @@ void scaled_gemm(
|
||||
all_algos[i].algo,
|
||||
ret_workspace_size);
|
||||
if (is_valid_status == HIPBLAS_STATUS_SUCCESS) {
|
||||
if (ret_workspace_size <= workspaceSize) {
|
||||
if (ret_workspace_size <= ltworkspace.size) {
|
||||
heuristicResult = all_algos[i];
|
||||
found = true;
|
||||
break;
|
||||
@ -1940,9 +1976,9 @@ void scaled_gemm(
|
||||
result_ptr,
|
||||
Ddesc.descriptor(),
|
||||
&heuristicResult.algo,
|
||||
workspace.mutable_data_ptr(),
|
||||
workspaceSize,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
ltworkspace.ptr,
|
||||
ltworkspace.size,
|
||||
stream);
|
||||
TORCH_CHECK(
|
||||
cublasStatus == CUBLAS_STATUS_SUCCESS,
|
||||
"CUDA error: ",
|
||||
@ -2018,8 +2054,8 @@ void int8_gemm(
|
||||
CuBlasLtMatmulPreference preference;
|
||||
size_t workspaceSize = _getWorkspaceSize();
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
||||
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
|
||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||
auto workspace = allocator.allocate(workspaceSize);
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
@ -2057,7 +2093,7 @@ void int8_gemm(
|
||||
nullptr, // Heuristics don't seem to work for int8
|
||||
#endif
|
||||
#ifdef USE_ROCM
|
||||
workspace.mutable_data_ptr(),
|
||||
workspace.mutable_get(),
|
||||
#else
|
||||
nullptr, // Non-zero workspace doesn't seem to work.
|
||||
#endif
|
||||
|
@ -2,6 +2,7 @@
|
||||
// Light-weight version of CUDAContext.h with fewer transitive includes
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -87,6 +88,8 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
|
||||
|
||||
#if defined(CUDART_VERSION) || defined(USE_ROCM)
|
||||
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
|
||||
|
@ -83,11 +83,6 @@ static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, v
|
||||
|
||||
#endif
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void createCublasHandle(cublasHandle_t *handle) {
|
||||
TORCH_CUDABLAS_CHECK(cublasCreate(handle));
|
||||
}
|
||||
@ -109,6 +104,11 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
|
||||
|
||||
} // namespace
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void clearCublasWorkspaces() {
|
||||
cublas_handle_stream_to_workspace().clear();
|
||||
}
|
||||
|
@ -3556,6 +3556,16 @@ def run(runner, args, original_dir=None):
|
||||
if args.devices == ["xpu"]:
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
# TODO(eqy): revisit when cuBLASLt workspace size is bumped
|
||||
# if args.only is not None and args.only in {
|
||||
# "DebertaForQuestionAnswering",
|
||||
# "RobertaForQuestionAnswering",
|
||||
# "nvidia_deeprecommender",
|
||||
# "volo_d1_224",
|
||||
# }:
|
||||
# # These seem unhappy with numerics of larger cuBLASLt workspace
|
||||
# # sizes following #145130 (due to enabling split-k?)
|
||||
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
Reference in New Issue
Block a user