diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index 3a587638d337..98f090ab416b 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -44,8 +44,10 @@ // until we use hipblas v2 // hipify correctly maps things like CUDA_R_16F to HIP_R_16F, // however hipblas v1 is still using its custom type +#ifndef HIPBLAS_V2 #define HIP_R_16F HIPBLAS_R_16F #define HIP_R_32F HIPBLAS_R_32F +#endif // HIPBLAS_V2 #else // USE_ROCM #define CUBLAS_HALF_TYPE __half #endif // USE_ROCM @@ -618,6 +620,11 @@ CAFFE2_CUDA_EXPORT void Gemm( // It has more general hipblasGemmEx API which is more close to cublasGemmEx. // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C +#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2) + auto compute_type = HIPBLAS_COMPUTE_32F; +#else + auto compute_type = HIPBLAS_R_32F; +#endif HIPBLAS_ENFORCE(hipblasGemmEx( context->hipblas_handle(), cu_trans_B, @@ -636,7 +643,7 @@ CAFFE2_CUDA_EXPORT void Gemm( C, HIPBLAS_R_16F, N, - HIPBLAS_R_32F, // compute type + compute_type, HIPBLAS_GEMM_DEFAULT)); #else CUBLAS_ENFORCE(cublasSgemmEx( @@ -854,6 +861,11 @@ CAFFE2_CUDA_EXPORT void GemmBatched( thrust::device_vector C_device(C, C + batch_size); CUBLAS_ENFORCE(cublasSetPointerMode( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2) + auto compute_type = HIPBLAS_COMPUTE_32F; +#else + auto compute_type = CUDA_R_32F; +#endif CUBLAS_ENFORCE(cublasGemmBatchedEx( context->cublas_handle(), cu_trans_B, @@ -873,7 +885,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched( CUDA_R_16F, ldc, batch_size, - CUDA_R_32F, + compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else if (math_type == TensorProto_DataType_FLOAT16) { // Convert alpha, beta from float -> __half @@ -945,6 +957,11 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( if (math_type == TensorProto_DataType_FLOAT) { CUBLAS_ENFORCE(cublasSetPointerMode( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2) + auto compute_type = HIPBLAS_COMPUTE_32F; +#else + auto compute_type = CUDA_R_32F; +#endif CUBLAS_ENFORCE(cublasGemmStridedBatchedEx( context->cublas_handle(), cu_trans_B, @@ -967,7 +984,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( ldc, C_stride, batch_size, - CUDA_R_32F, + compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else if (math_type == TensorProto_DataType_FLOAT16) { // Convert alpha, beta from float -> __half @@ -1059,6 +1076,11 @@ CAFFE2_CUDA_EXPORT void Gemv( // It has more general hipblasGemmEx API which is more close to cublasGemmEx. // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C +#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2) + auto compute_type = HIPBLAS_COMPUTE_32F; +#else + auto compute_type = HIPBLAS_R_32F; +#endif HIPBLAS_ENFORCE(hipblasGemmEx( context->hipblas_handle(), cu_trans_A, @@ -1077,7 +1099,7 @@ CAFFE2_CUDA_EXPORT void Gemv( y, HIPBLAS_R_16F, ldc, - HIPBLAS_R_32F, // compute type + compute_type, HIPBLAS_GEMM_DEFAULT)); #else CUBLAS_ENFORCE(cublasSgemmEx(