[ROCm] Fix caffe2 build with hipblasv2 api (#116073)

Summary: we need this change along with D52244365 to make caffe2 build happy

Test Plan: OSS CI

Differential Revision: D52275058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116073
Approved by: https://github.com/jeffdaily, https://github.com/malfet
This commit is contained in:
Xiaodong Wang
2023-12-20 04:02:29 +00:00
committed by PyTorch MergeBot
parent a597a00c87
commit c72bc61bcd

View File

@ -44,8 +44,10 @@
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#ifndef HIPBLAS_V2
#define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_32F HIPBLAS_R_32F
#endif // HIPBLAS_V2
#else // USE_ROCM
#define CUBLAS_HALF_TYPE __half
#endif // USE_ROCM
@ -618,6 +620,11 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
auto compute_type = HIPBLAS_COMPUTE_32F;
#else
auto compute_type = HIPBLAS_R_32F;
#endif
HIPBLAS_ENFORCE(hipblasGemmEx(
context->hipblas_handle(),
cu_trans_B,
@ -636,7 +643,7 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
C,
HIPBLAS_R_16F,
N,
HIPBLAS_R_32F, // compute type
compute_type,
HIPBLAS_GEMM_DEFAULT));
#else
CUBLAS_ENFORCE(cublasSgemmEx(
@ -854,6 +861,11 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
thrust::device_vector<void*> C_device(C, C + batch_size);
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
auto compute_type = HIPBLAS_COMPUTE_32F;
#else
auto compute_type = CUDA_R_32F;
#endif
CUBLAS_ENFORCE(cublasGemmBatchedEx(
context->cublas_handle(),
cu_trans_B,
@ -873,7 +885,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
CUDA_R_16F,
ldc,
batch_size,
CUDA_R_32F,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else if (math_type == TensorProto_DataType_FLOAT16) {
// Convert alpha, beta from float -> __half
@ -945,6 +957,11 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
if (math_type == TensorProto_DataType_FLOAT) {
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
auto compute_type = HIPBLAS_COMPUTE_32F;
#else
auto compute_type = CUDA_R_32F;
#endif
CUBLAS_ENFORCE(cublasGemmStridedBatchedEx(
context->cublas_handle(),
cu_trans_B,
@ -967,7 +984,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
ldc,
C_stride,
batch_size,
CUDA_R_32F,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else if (math_type == TensorProto_DataType_FLOAT16) {
// Convert alpha, beta from float -> __half
@ -1059,6 +1076,11 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
auto compute_type = HIPBLAS_COMPUTE_32F;
#else
auto compute_type = HIPBLAS_R_32F;
#endif
HIPBLAS_ENFORCE(hipblasGemmEx(
context->hipblas_handle(),
cu_trans_A,
@ -1077,7 +1099,7 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
y,
HIPBLAS_R_16F,
ldc,
HIPBLAS_R_32F, // compute type
compute_type,
HIPBLAS_GEMM_DEFAULT));
#else
CUBLAS_ENFORCE(cublasSgemmEx(