mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Fix caffe2 build with hipblasv2 api (#116073)
Summary: we need this change along with D52244365 to make caffe2 build happy Test Plan: OSS CI Differential Revision: D52275058 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116073 Approved by: https://github.com/jeffdaily, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
a597a00c87
commit
c72bc61bcd
@ -44,8 +44,10 @@
|
||||
// until we use hipblas v2
|
||||
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
|
||||
// however hipblas v1 is still using its custom type
|
||||
#ifndef HIPBLAS_V2
|
||||
#define HIP_R_16F HIPBLAS_R_16F
|
||||
#define HIP_R_32F HIPBLAS_R_32F
|
||||
#endif // HIPBLAS_V2
|
||||
#else // USE_ROCM
|
||||
#define CUBLAS_HALF_TYPE __half
|
||||
#endif // USE_ROCM
|
||||
@ -618,6 +620,11 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
|
||||
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
||||
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
||||
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = HIPBLAS_R_32F;
|
||||
#endif
|
||||
HIPBLAS_ENFORCE(hipblasGemmEx(
|
||||
context->hipblas_handle(),
|
||||
cu_trans_B,
|
||||
@ -636,7 +643,7 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
|
||||
C,
|
||||
HIPBLAS_R_16F,
|
||||
N,
|
||||
HIPBLAS_R_32F, // compute type
|
||||
compute_type,
|
||||
HIPBLAS_GEMM_DEFAULT));
|
||||
#else
|
||||
CUBLAS_ENFORCE(cublasSgemmEx(
|
||||
@ -854,6 +861,11 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
|
||||
thrust::device_vector<void*> C_device(C, C + batch_size);
|
||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = CUDA_R_32F;
|
||||
#endif
|
||||
CUBLAS_ENFORCE(cublasGemmBatchedEx(
|
||||
context->cublas_handle(),
|
||||
cu_trans_B,
|
||||
@ -873,7 +885,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
|
||||
CUDA_R_16F,
|
||||
ldc,
|
||||
batch_size,
|
||||
CUDA_R_32F,
|
||||
compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
||||
// Convert alpha, beta from float -> __half
|
||||
@ -945,6 +957,11 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
|
||||
if (math_type == TensorProto_DataType_FLOAT) {
|
||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = CUDA_R_32F;
|
||||
#endif
|
||||
CUBLAS_ENFORCE(cublasGemmStridedBatchedEx(
|
||||
context->cublas_handle(),
|
||||
cu_trans_B,
|
||||
@ -967,7 +984,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
|
||||
ldc,
|
||||
C_stride,
|
||||
batch_size,
|
||||
CUDA_R_32F,
|
||||
compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
||||
// Convert alpha, beta from float -> __half
|
||||
@ -1059,6 +1076,11 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
|
||||
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
||||
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
||||
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = HIPBLAS_R_32F;
|
||||
#endif
|
||||
HIPBLAS_ENFORCE(hipblasGemmEx(
|
||||
context->hipblas_handle(),
|
||||
cu_trans_A,
|
||||
@ -1077,7 +1099,7 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
|
||||
y,
|
||||
HIPBLAS_R_16F,
|
||||
ldc,
|
||||
HIPBLAS_R_32F, // compute type
|
||||
compute_type,
|
||||
HIPBLAS_GEMM_DEFAULT));
|
||||
#else
|
||||
CUBLAS_ENFORCE(cublasSgemmEx(
|
||||
|
Reference in New Issue
Block a user