diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 13716736c577..6933099bb1f3 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -16,6 +16,8 @@ #include #include +#include + #ifdef USE_ROCM #include #include @@ -1954,13 +1956,15 @@ void scaled_gemm( const void *result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum) { + bool use_fast_accum, + const std::optional& alpha) { // Note: see `cublasCommonArgs` for various non-intuitive manupulations // of input arguments to this function. const auto computeType = CUBLAS_COMPUTE_32F; const auto scaleType = CUDA_R_32F; - const float alpha_val = 1.0; - const float beta_val = 0.0; + // Note: alpha_val may change later depending on user-passed argument + float alpha_val = 1.0; + float beta_val = 0.0; CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); @@ -2031,6 +2035,33 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); } + + // Handle user-passed alpha + float *alpha_ptr = &alpha_val; + float *beta_ptr = &beta_val; + + if (alpha.has_value()) { + auto& a = alpha.value(); + + // if device-tensor + if (a.is_cuda()) { + // NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be + // valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus + // we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically + // managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory + // for beta respectively. + float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr(); + at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat)); + user_alpha.copy_(a); + // Tell cublasLt we're using device-side pointers for alpha/beta + auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode); + alpha_ptr = user_alpha.data_ptr(); + beta_ptr = at::cuda::detail::get_cublas_device_zero(); + } else { + alpha_val = a.item(); + } + } // For other data types, use the get_scale_mode function based on scaling type // The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt, // but we must invoke get_scale_mode anyways to trigger the version checks. @@ -2048,6 +2079,7 @@ void scaled_gemm( cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( ltHandle, computeDesc.descriptor(), @@ -2088,10 +2120,10 @@ void scaled_gemm( auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported( ltHandle, computeDesc.descriptor(), - &alpha_val, + alpha_ptr, Adesc.descriptor(), Bdesc.descriptor(), - &beta_val, + beta_ptr, Cdesc.descriptor(), Ddesc.descriptor(), all_algos[i].algo, @@ -2110,17 +2142,14 @@ void scaled_gemm( cublasStatus_t cublasStatus = cublasLtMatmul( ltHandle, computeDesc.descriptor(), - &alpha_val, + alpha_ptr, mat1_ptr, Adesc.descriptor(), mat2_ptr, Bdesc.descriptor(), - &beta_val, -#ifdef USE_ROCM + beta_ptr, + // NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr -#else - nullptr, -#endif // ifdef USE_ROCM Cdesc.descriptor(), result_ptr, Ddesc.descriptor(), diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 6618658704a7..0295948311a5 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -161,7 +161,8 @@ void scaled_gemm( const void* result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum); + bool use_fast_accum, + const std::optional& alpha); #define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) diff --git a/aten/src/ATen/cuda/detail/BLASConstants.cu b/aten/src/ATen/cuda/detail/BLASConstants.cu new file mode 100644 index 000000000000..967388044705 --- /dev/null +++ b/aten/src/ATen/cuda/detail/BLASConstants.cu @@ -0,0 +1,54 @@ +#include +#include +#include + +#include + +namespace at { +namespace cuda { +namespace detail { + +__device__ __constant__ float cublas_one_device; +__device__ __constant__ float cublas_zero_device; + +float *get_cublas_device_one() { + static c10::once_flag init_flag; + + c10::call_once(init_flag, []() { + const float one = 1.f; + AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float))); + }); + + float *ptr; + AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&ptr), cublas_one_device)); + return ptr; +} + +float *get_cublas_device_zero() { + static c10::once_flag init_flag; + + c10::call_once(init_flag, []() { + const float zero = 0.f; + AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float))); + }); + + float *ptr; + AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&ptr), cublas_zero_device)); + return ptr; +} + +float *get_user_alpha_ptr() { + static float *alpha_ptr; + + static c10::once_flag init_flag; + + c10::call_once(init_flag, []() { + AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float))); + }); + + return alpha_ptr; +} + +} // namespace detail +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/detail/BLASConstants.h b/aten/src/ATen/cuda/detail/BLASConstants.h new file mode 100644 index 000000000000..d62aaf1330ee --- /dev/null +++ b/aten/src/ATen/cuda/detail/BLASConstants.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace at::cuda::detail { + +float *get_cublas_device_one(); +float *get_cublas_device_zero(); +float *get_user_alpha_ptr(); + +} // namespace at::cuda::detail diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index d941c230630c..c014d1ea569c 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -109,7 +109,8 @@ class DefaultScaledGemmOp : public Callable> { params->c_scale_ptr, params->ldc, params->c_dtype, - params->use_fast_accum); + params->use_fast_accum, + std::nullopt /* alpha */); return OK; } }; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 1e7c4600efc5..4ee35013ab77 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1359,7 +1359,8 @@ _scaled_gemm( const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, const std::optional& bias, const bool use_fast_accum, - Tensor& out) { + Tensor& out, + const std::optional& alpha = std::nullopt) { cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); @@ -1410,7 +1411,8 @@ _scaled_gemm( args.scale_result_ptr, args.result_ld, out_dtype_, - use_fast_accum); + use_fast_accum, + alpha); return out; } } diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 54442fe403e9..d1d9a08c71c5 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -7702,8 +7702,11 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict( ("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_POINTER_MODE", ("HIPBLASLT_MATMUL_DESC_POINTER_MODE", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_POINTER_MODE_DEVICE", ("HIPBLASLT_POINTER_MODE_DEVICE", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLASLT_POINTER_MODE_HOST", ("HIPBLASLT_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_BLAS)), ("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)), ("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)), ("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)),