[ROCm] use hipblas instead of rocblas (#105881)

- BatchLinearAlgebraLib.cpp is now split into one additional file
  - BatchLinearAlgebraLib.cpp uses only cusolver APIs
  - BatchLinearAlgebraLibBlas.cpp uses only cublas APIs
  - hipify operates at the file level and cannot mix cusolver and cublas APIs within the same file
- cmake changes to link against hipblas instead of rocblas
- hipify mappings changes to map cublas -> hipblas instead of rocblas

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105881
Approved by: https://github.com/albanD
This commit is contained in:
Jeff Daily
2023-07-31 20:42:52 +00:00
committed by PyTorch MergeBot
parent c9c66819a1
commit 5379b5f927
13 changed files with 975 additions and 1057 deletions

View File

@ -499,11 +499,6 @@ endif()
endif(MSVC)
endif(USE_MAGMA)
# NB: We're relying on cmake/Dependencies.cmake to appropriately setup HIP dependencies.
# In principle we could duplicate them, but handling the rocblas
# dependency is nontrivial. So better not to copy-paste.
# Look for Note [rocblas cmake bug]
# Include CPU paths for CUDA/HIP as well
list(APPEND ATen_CUDA_INCLUDE ${ATen_CPU_INCLUDE})
list(APPEND ATen_HIP_INCLUDE ${ATen_CPU_INCLUDE})

View File

@ -16,8 +16,68 @@
#endif
#ifdef USE_ROCM
// until hipblas has an API to accept flags, we must use rocblas here
#include <rocblas/rocblas.h>
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
return HIPBLAS_STATUS_INTERNAL_ERROR;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
// hipblas does not have hipblasSetMathMode
#define hipblasSetMathMode(handle, flags) HIPBLAS_STATUS_SUCCESS
// until we use hiblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_C_16F HIPBLAS_C_16F
#define HIP_C_32F HIPBLAS_C_32F
#define HIP_C_64F HIPBLAS_C_64F
#define HIP_R_8I HIPBLAS_R_8I
#define HIP_R_8U HIPBLAS_R_8U
#define HIP_R_32I HIPBLAS_R_32I
#define HIP_R_32U HIPBLAS_R_32U
#define HIP_C_8I HIPBLAS_C_8I
#define HIP_C_8U HIPBLAS_C_8U
#define HIP_C_32I HIPBLAS_C_32I
#define HIP_C_32U HIPBLAS_C_32U
#define HIP_R_16BF HIPBLAS_R_16B
#define HIP_C_16BF HIPBLAS_C_16B
#endif
#define CUDABLAS_POSINT_CHECK(FD, X) \
@ -193,6 +253,8 @@ cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
return result;
}
#endif
#else // USE_ROCM
#define cublasGemmStridedBatchedExFix hipblasGemmStridedBatchedEx
#endif
#define GEMM_CHECK_ARGVALUES(Dtype) \
@ -288,13 +350,15 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
hipOperationToRocOperation(opa),
hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
c, rocblas_datatype_f16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, flag));
0, flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5){
@ -329,22 +393,12 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
const float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
#if !defined(USE_ROCM)
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
#else
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea,
b, rocblas_datatype_bf16_r, (int)ldb, strideb,
(void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec,
c, rocblas_datatype_bf16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0, NULL, NULL));
#endif // !defined(USE_ROCM)
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
template <>
@ -373,39 +427,35 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
}
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(cublasZgemm(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc));
}
#endif
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(cublasZgemm(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc));
}
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(cublasCgemm(
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc));
}
#endif
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(cublasCgemm(
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc));
}
template <>
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
@ -423,10 +473,10 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
handle,
opa,
opb,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle)handle,
hipOperationToRocOperation(opa),
hipOperationToRocOperation(opb),
m,
n,
k,
@ -447,14 +497,16 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag));
flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5) {
#ifndef USE_ROCM
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
#endif
// Disallow fp16 reductions that could lead to unexpected overflow issues.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
@ -501,45 +553,6 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
#endif
}
#ifdef USE_ROCM
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::BFloat16);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
rocblas_datatype_bf16_r,
lda,
b,
rocblas_datatype_bf16_r,
ldb,
&fbeta,
c,
rocblas_datatype_bf16_r,
ldc,
c,
rocblas_datatype_bf16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0));
}
#endif
#if !defined(USE_ROCM)
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
globalContext().alertCuBLASConfigNotDeterministic();
@ -550,10 +563,12 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::BFloat16);
#ifndef USE_ROCM
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
#endif
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
@ -577,7 +592,6 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#endif // !defined(USE_ROCM)
#if !defined(USE_ROCM) && !defined(_MSC_VER)
@ -1113,23 +1127,20 @@ void trsmBatched<c10::complex<double>>(
CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
} while (0)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(
cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(y), incy));
}
#endif
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
GEMV_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(
cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(y), incy));
}
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
// gemv is bw bound, and does not benefit from TF32. But the precision
@ -1146,7 +1157,6 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(y), incy));
}
#endif
template <>
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
@ -1247,7 +1257,6 @@ void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
template <>
void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
#if !defined(USE_ROCM)
TORCH_CUDABLAS_CHECK(cublasDotEx(
handle,
n,
@ -1260,23 +1269,10 @@ void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
result,
CUDA_R_16F,
CUDA_R_32F));
#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
TORCH_CUDABLAS_CHECK(rocblas_hdot(
handle,
n,
reinterpret_cast<const rocblas_half*>(x),
incx,
reinterpret_cast<const rocblas_half*>(y),
incy,
reinterpret_cast<rocblas_half*>(result)));
#else
AT_ERROR("Cublas_Hdot requires CUDA 8.0+");
#endif
}
template <>
void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
#if !defined(USE_ROCM)
TORCH_CUDABLAS_CHECK(cublasDotEx(
handle,
n,
@ -1289,18 +1285,6 @@ void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
result,
CUDA_R_16BF,
CUDA_R_32F));
#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
TORCH_CUDABLAS_CHECK(rocblas_bfdot(
handle,
n,
reinterpret_cast<const rocblas_bfloat16*>(x),
incx,
reinterpret_cast<const rocblas_bfloat16*>(y),
incy,
reinterpret_cast<rocblas_bfloat16*>(result)));
#else
AT_ERROR("Cublas_bfdot requires CUDA 11.0+");
#endif
}
template <>
@ -1317,9 +1301,6 @@ void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
reinterpret_cast<cuDoubleComplex*>(result)));
}
// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
#ifdef CUDART_VERSION
template <>
void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float)) {
TORCH_CUDABLAS_CHECK(cublasSgetrsBatched(
@ -1519,8 +1500,6 @@ void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::comple
batchSize));
}
#endif // CUDART_VERSION
} // namespace blas
} // namespace cuda
} // namespace at

View File

@ -55,14 +55,10 @@ template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
#endif
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
#endif
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
template <>
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
@ -189,12 +185,10 @@ template <>
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
template <>
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
#endif
template <>
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
template <>
@ -234,9 +228,6 @@ void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
template <>
void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
#ifdef CUDART_VERSION
#define CUDABLAS_GETRS_ARGTYPES(Dtype) \
cublasHandle_t handle, cublasOperation_t trans, \
int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
@ -311,8 +302,6 @@ TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_A
template<>
TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
#endif // CUDART_VERSION
} // namespace blas
} // namespace cuda
} // namespace at

View File

@ -125,14 +125,14 @@ cublasHandle_t getCurrentCUDABlasHandle() {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#endif
#if defined(USE_ROCM) && ROCM_VERSION >= 30800
rocblas_atomics_mode rocblas_mode;
#if defined(USE_ROCM)
hipblasAtomicsMode_t hipblas_mode;
if (at::globalContext().deterministicAlgorithms()) {
rocblas_mode = rocblas_atomics_not_allowed;
hipblas_mode = HIPBLAS_ATOMICS_NOT_ALLOWED;
} else {
rocblas_mode = rocblas_atomics_allowed;
hipblas_mode = HIPBLAS_ATOMICS_ALLOWED;
}
TORCH_CUDABLAS_CHECK(rocblas_set_atomics_mode(handle, rocblas_mode));
TORCH_CUDABLAS_CHECK(hipblasSetAtomicsMode(handle, hipblas_mode));
#endif
return handle;
}

View File

@ -1,3 +1,4 @@
// See Note [BatchLinearAlgebraLib split implementation files]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
@ -29,7 +30,7 @@
namespace at::native {
cublasOperation_t to_cublas(TransposeType trans) {
static cublasOperation_t to_cublas(TransposeType trans) {
switch (trans) {
case TransposeType::NoTranspose: return CUBLAS_OP_N;
case TransposeType::Transpose: return CUBLAS_OP_T;
@ -57,102 +58,6 @@ static Tensor get_device_pointers(const Tensor& input) {
input.options().dtype(at::kLong));
}
template <typename scalar_t>
void apply_geqrf_batched(const Tensor& input, const Tensor& tau) {
// AMD ROCm backend is implemented via rewriting all CUDA calls to HIP
// rocBLAS does not implement BLAS-like extensions of cuBLAS, they're in rocSOLVER
// rocSOLVER is currently not used in ATen, therefore we raise an error in this case
#ifndef CUDART_VERSION
TORCH_CHECK(false, "geqrf: Batched version is supported only with cuBLAS backend.")
#else
auto batch_size = cuda_int_cast(batchCount(input), "batch_size");
auto m = cuda_int_cast(input.size(-2), "m");
auto n = cuda_int_cast(input.size(-1), "n");
auto lda = std::max<int>(1, m);
// cuBLAS batched geqrf requires input to be the device array of pointers to device single matrices
Tensor input_ptr_array = get_device_pointers<scalar_t>(input);
Tensor tau_ptr_array = get_device_pointers<scalar_t>(tau.unsqueeze(-1));
auto input_ptr_array_data = reinterpret_cast<scalar_t**>(input_ptr_array.data_ptr());
auto tau_ptr_array_data = reinterpret_cast<scalar_t**>(tau_ptr_array.data_ptr());
int info;
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::geqrfBatched(handle, m, n, input_ptr_array_data, lda, tau_ptr_array_data, &info, batch_size);
// info only indicates wrong arguments to geqrfBatched call
// info is a host variable, we can check it without device synchronization
TORCH_INTERNAL_ASSERT(info == 0);
#endif
}
void geqrf_batched_cublas(const Tensor& input, const Tensor& tau) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_batched_cuda", [&]{
apply_geqrf_batched<scalar_t>(input, tau);
});
}
template <typename scalar_t>
static void apply_lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
#ifndef CUDART_VERSION
TORCH_CHECK(false, "linalg.lu_factor: cuBLAS backend for linalg.lu_factor is not available.")
#else
// This function just works with square matrices
TORCH_INTERNAL_ASSERT(A.size(-2) == A.size(-1));
auto batch_size = cuda_int_cast(batchCount(A), "batch_size");;
auto n = cuda_int_cast(A.size(-2), "n");
auto lda = cuda_int_cast(std::max<int>(1, n), "lda");
auto pivots_data = get_pivots ? pivots.data_ptr<int>() : nullptr;
auto infos_data = infos.data_ptr<int>();
Tensor a_ptr_array = get_device_pointers<scalar_t>(A);
auto a_ptr_array_data = reinterpret_cast<scalar_t**>(a_ptr_array.data_ptr());
at::cuda::blas::getrfBatched(n, a_ptr_array_data, lda, pivots_data, infos_data, batch_size);
#endif
}
void lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "lu_factor_cublas", [&]{
apply_lu_factor_batched_cublas<scalar_t>(A, pivots, infos, get_pivots);
});
}
template <typename scalar_t>
static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
#ifndef CUDART_VERSION
TORCH_CHECK(false, "linalg.lu_solve: cuBLAS backend for linalg.lu_solve is not available.")
#else
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(B), "batch_size of LU and B must be the same");
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same");
const auto trans = to_cublas(transpose);
auto pivots_data = pivots.data_ptr<int>();
auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");;
auto m = cuda_int_cast(LU.size(-2), "m");
auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
auto lda = cuda_int_cast(std::max<int>(1, m), "lda");
int info = 0;
Tensor lu_ptr_array = get_device_pointers<scalar_t>(LU);
Tensor b_ptr_array = get_device_pointers<scalar_t>(B);
auto lu_ptr_array_data = reinterpret_cast<scalar_t**>(lu_ptr_array.data_ptr());
auto b_ptr_array_data = reinterpret_cast<scalar_t**>(b_ptr_array.data_ptr());
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::getrsBatched(handle, trans, m, nrhs, lu_ptr_array_data,
lda, pivots_data, b_ptr_array_data, lda, &info, batch_size);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
#endif
}
void lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_cublas", [&]{
apply_lu_solve_batched_cublas<scalar_t>(LU, pivots, B, trans);
});
}
namespace {
template <typename scalar_t>
@ -331,162 +236,6 @@ void ldl_solve_cusolver(
});
}
template <typename scalar_t>
static void apply_triangular_solve(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
#ifdef ROCM_VERSION
// Cannot auto-hipifiy this piece of code, because in other functions the uplo
// and other variables need to be hipSOLVER's type.
auto uplo = upper ? rocblas_fill_upper : rocblas_fill_lower;
const auto trans = (rocblas_operation)to_cublas(transpose);
rocblas_side side = left ? rocblas_side_left : rocblas_side_right;
#else
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
const auto trans = to_cublas(transpose);
cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
#endif
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
auto A_data = A.data_ptr<scalar_t>();
auto B_data = B.data_ptr<scalar_t>();
auto A_mat_stride = matrixStride(A);
auto B_mat_stride = matrixStride(B);
auto batch_size = batchCount(A);
// This allows to pass rectangular A and B when left = True
auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
auto n = cuda_int_cast(B.size(-1), "n");
auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
auto alpha = scalar_t{1};
for (decltype(batch_size) i = 0; i < batch_size; i++) {
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::trsm(handle, side, uplo, trans, diag, m, n, &alpha, A_working_ptr, lda, B_working_ptr, ldb);
}
}
void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
});
}
template <typename scalar_t>
static void apply_triangular_solve_batched(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
#ifdef ROCM_VERSION
// Cannot auto-hipifiy this piece of code, because in other functions the uplo
// and other variables need to be hipSOLVER's type.
auto uplo = upper ? rocblas_fill_upper : rocblas_fill_lower;
const auto trans = (rocblas_operation)to_cublas(transpose);
rocblas_side side = left ? rocblas_side_left : rocblas_side_right;
#else
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
const auto trans = to_cublas(transpose);
cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
#endif
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
auto batch_size = cuda_int_cast(batchCount(A), "batch_size");
// This allows to pass rectangular A and B when left = True
auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
auto n = cuda_int_cast(B.size(-1), "n");
auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
auto alpha = scalar_t{1};
// cuBLAS batched trsm requires input to be the device array of pointers to device single matrices
Tensor A_ptr_array = get_device_pointers<scalar_t>(A);
Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, m, n, &alpha, A_ptr_array_data, lda, B_ptr_array_data, ldb, batch_size);
}
void triangular_solve_batched_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
apply_triangular_solve_batched<scalar_t>(A, B, left, upper, transpose, unitriangular);
});
}
template <typename scalar_t>
inline void apply_gels_batched(const Tensor& A, Tensor& B, Tensor& infos) {
// AMD ROCm backend is implemented via rewriting all CUDA calls to HIP
// rocBLAS does not implement BLAS-like extensions of cuBLAS, they're in rocSOLVER
// rocSOLVER is currently not used in ATen, therefore we raise an error in this case
#ifndef CUDART_VERSION
TORCH_CHECK(false, "torch.linalg.lstsq: Batched version is supported only with cuBLAS backend.")
#else
#ifdef ROCM_VERSION
// Cannot auto-hipifiy this piece of code, because in other functions
// CUBLAS_OP_N must be translated to HIPSOLVER_OP_N
auto trans = rocblas_operation_none;
#else
auto trans = CUBLAS_OP_N;
#endif
auto m = cuda_int_cast(A.size(-2), "m");
auto n = cuda_int_cast(A.size(-1), "n");
auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
// cuBLAS from cuda10 and older doesn't work with nrhs == 0 (cuda11 works)
// so we need to put this early return
if (nrhs == 0) {
return;
}
auto batch_size = cuda_int_cast(batchCount(B), "batch_size");
auto lda = std::max<int>(1, m);
auto ldb = std::max<int>(1, m);
// cuBLAS's requirement
TORCH_CHECK(
m >= n,
"torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA with cuBLAS backend.");
// cuBLAS documentation says:
// Matrices Aarray[i] should not overlap; otherwise, undefined behavior is expected.
// explicitly broadcast the batch dimensions of A
IntArrayRef A_batch_sizes(A.sizes().data(), A.dim() - 2);
IntArrayRef B_batch_sizes(B.sizes().data(), B.dim() - 2);
std::vector<int64_t> expand_batch_portion = at::infer_size(A_batch_sizes, B_batch_sizes);
expand_batch_portion.insert(expand_batch_portion.end(), {A.size(-2), A.size(-1)});
Tensor A_expanded = A.expand({expand_batch_portion});
Tensor A_broadcasted = cloneBatchedColumnMajor(A_expanded);
// cuBLAS batched gels requires input to be the device array of pointers to device single matrices
Tensor A_ptr_array = get_device_pointers<scalar_t>(A_broadcasted);
Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
auto infos_data = infos.data_ptr<int>();
auto handle = at::cuda::getCurrentCUDABlasHandle();
int info;
at::cuda::blas::gelsBatched<scalar_t>(
handle, trans, m, n, nrhs,
A_ptr_array_data, lda,
B_ptr_array_data, ldb,
&info,
infos_data,
batch_size);
// negative info indicates that an argument to gelsBatched call is invalid
TORCH_INTERNAL_ASSERT(info == 0);
#endif
}
// This is a type dispatching helper function for 'apply_gels_batched'
void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(a.scalar_type(), "gels_batched_cublas", [&]{
apply_gels_batched<scalar_t>(a, b, infos);
});
}
#if defined(USE_LINALG_SOLVER)
inline static Tensor column_major_identity_matrix_like(const Tensor& self) {

View File

@ -0,0 +1,287 @@
// Note [BatchLinearAlgebraLib split implementation files]
//
// There are two files that implement the interfaces found in
// BatchLinearAlgebraLib.h
// - BatchLinearAlgebraLib.cpp
// - BatchLinearAlgebraLibBlas.cpp (this file)
//
// In order to support the ROCm build target, the use of cublas and
// cusolver APIs needed to be split into separate source files to
// accomodate the hipify step of the ROCm build process.
//
// To create this current file, the original file
// BatchLinearAlgebraLib.cpp was copied to
// BatchLinearAlgebraLibBlas.cpp, then any functions that used cusolver
// APIs were removed. Similarly, in the original file
// BatchLinearAlgebraLib.cpp, any use of cublas APIs was removed.
// The net result is a split of the BatchLinearAlgebraLib
// implementation files. The original file BatchLinearAlgebraLib.cpp
// contains the full, original git history for both files.
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/cuda/PinnedMemoryAllocator.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/irange.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/TransposeType.h>
#include <ATen/native/cuda/MiscUtils.h>
#include <ATen/native/cuda/linalg/CUDASolver.h>
#include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/arange.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/nan_to_num.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/where.h>
#include <ATen/ops/zeros.h>
#endif
namespace at::native {
static cublasOperation_t to_cublas(TransposeType trans) {
switch (trans) {
case TransposeType::NoTranspose: return CUBLAS_OP_N;
case TransposeType::Transpose: return CUBLAS_OP_T;
case TransposeType::ConjTranspose: return CUBLAS_OP_C;
}
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
}
// Some cuBLAS and cuSOLVER batched routines require input to be a device array of pointers to device individual matrices
// 'input' must be a contiguous tensor
template <typename scalar_t>
static Tensor get_device_pointers(const Tensor& input) {
auto input_data = input.const_data_ptr<scalar_t>();
int64_t input_mat_stride = matrixStride(input);
// cublas/cusolver interface requires 'int'
int batch_size = cuda_int_cast(batchCount(input), "batch_size");
// if batch_size==0, then start=0 and end=0
// if input_mat_stride==0, then step=sizeof(scalar_t)
return at::arange(
/*start=*/reinterpret_cast<int64_t>(input_data),
/*end=*/reinterpret_cast<int64_t>(input_data + batch_size * input_mat_stride),
/*step=*/static_cast<int64_t>(std::max<int64_t>(input_mat_stride, 1) * sizeof(scalar_t)),
input.options().dtype(at::kLong));
}
template <typename scalar_t>
void apply_geqrf_batched(const Tensor& input, const Tensor& tau) {
auto batch_size = cuda_int_cast(batchCount(input), "batch_size");
auto m = cuda_int_cast(input.size(-2), "m");
auto n = cuda_int_cast(input.size(-1), "n");
auto lda = std::max<int>(1, m);
// cuBLAS batched geqrf requires input to be the device array of pointers to device single matrices
Tensor input_ptr_array = get_device_pointers<scalar_t>(input);
Tensor tau_ptr_array = get_device_pointers<scalar_t>(tau.unsqueeze(-1));
auto input_ptr_array_data = reinterpret_cast<scalar_t**>(input_ptr_array.data_ptr());
auto tau_ptr_array_data = reinterpret_cast<scalar_t**>(tau_ptr_array.data_ptr());
int info;
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::geqrfBatched(handle, m, n, input_ptr_array_data, lda, tau_ptr_array_data, &info, batch_size);
// info only indicates wrong arguments to geqrfBatched call
// info is a host variable, we can check it without device synchronization
TORCH_INTERNAL_ASSERT(info == 0);
}
void geqrf_batched_cublas(const Tensor& input, const Tensor& tau) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_batched_cuda", [&]{
apply_geqrf_batched<scalar_t>(input, tau);
});
}
template <typename scalar_t>
static void apply_lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
// This function just works with square matrices
TORCH_INTERNAL_ASSERT(A.size(-2) == A.size(-1));
auto batch_size = cuda_int_cast(batchCount(A), "batch_size");;
auto n = cuda_int_cast(A.size(-2), "n");
auto lda = cuda_int_cast(std::max<int>(1, n), "lda");
auto pivots_data = get_pivots ? pivots.data_ptr<int>() : nullptr;
auto infos_data = infos.data_ptr<int>();
Tensor a_ptr_array = get_device_pointers<scalar_t>(A);
auto a_ptr_array_data = reinterpret_cast<scalar_t**>(a_ptr_array.data_ptr());
at::cuda::blas::getrfBatched(n, a_ptr_array_data, lda, pivots_data, infos_data, batch_size);
}
void lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "lu_factor_cublas", [&]{
apply_lu_factor_batched_cublas<scalar_t>(A, pivots, infos, get_pivots);
});
}
template <typename scalar_t>
static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(B), "batch_size of LU and B must be the same");
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same");
const auto trans = to_cublas(transpose);
auto pivots_data = pivots.data_ptr<int>();
auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");;
auto m = cuda_int_cast(LU.size(-2), "m");
auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
auto lda = cuda_int_cast(std::max<int>(1, m), "lda");
int info = 0;
Tensor lu_ptr_array = get_device_pointers<scalar_t>(LU);
Tensor b_ptr_array = get_device_pointers<scalar_t>(B);
auto lu_ptr_array_data = reinterpret_cast<scalar_t**>(lu_ptr_array.data_ptr());
auto b_ptr_array_data = reinterpret_cast<scalar_t**>(b_ptr_array.data_ptr());
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::getrsBatched(handle, trans, m, nrhs, lu_ptr_array_data,
lda, pivots_data, b_ptr_array_data, lda, &info, batch_size);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
}
void lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_cublas", [&]{
apply_lu_solve_batched_cublas<scalar_t>(LU, pivots, B, trans);
});
}
template <typename scalar_t>
static void apply_triangular_solve(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
const auto trans = to_cublas(transpose);
cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
auto A_data = A.data_ptr<scalar_t>();
auto B_data = B.data_ptr<scalar_t>();
auto A_mat_stride = matrixStride(A);
auto B_mat_stride = matrixStride(B);
auto batch_size = batchCount(A);
// This allows to pass rectangular A and B when left = True
auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
auto n = cuda_int_cast(B.size(-1), "n");
auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
auto alpha = scalar_t{1};
for (decltype(batch_size) i = 0; i < batch_size; i++) {
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::trsm(handle, side, uplo, trans, diag, m, n, &alpha, A_working_ptr, lda, B_working_ptr, ldb);
}
}
void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
});
}
template <typename scalar_t>
static void apply_triangular_solve_batched(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
const auto trans = to_cublas(transpose);
cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
auto batch_size = cuda_int_cast(batchCount(A), "batch_size");
// This allows to pass rectangular A and B when left = True
auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
auto n = cuda_int_cast(B.size(-1), "n");
auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
auto alpha = scalar_t{1};
// cuBLAS batched trsm requires input to be the device array of pointers to device single matrices
Tensor A_ptr_array = get_device_pointers<scalar_t>(A);
Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, m, n, &alpha, A_ptr_array_data, lda, B_ptr_array_data, ldb, batch_size);
}
void triangular_solve_batched_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
apply_triangular_solve_batched<scalar_t>(A, B, left, upper, transpose, unitriangular);
});
}
template <typename scalar_t>
inline void apply_gels_batched(const Tensor& A, Tensor& B, Tensor& infos) {
auto trans = CUBLAS_OP_N;
auto m = cuda_int_cast(A.size(-2), "m");
auto n = cuda_int_cast(A.size(-1), "n");
auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
// cuBLAS from cuda10 and older doesn't work with nrhs == 0 (cuda11 works)
// so we need to put this early return
if (nrhs == 0) {
return;
}
auto batch_size = cuda_int_cast(batchCount(B), "batch_size");
auto lda = std::max<int>(1, m);
auto ldb = std::max<int>(1, m);
// cuBLAS's requirement
TORCH_CHECK(
m >= n,
"torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA with cuBLAS backend.");
// cuBLAS documentation says:
// Matrices Aarray[i] should not overlap; otherwise, undefined behavior is expected.
// explicitly broadcast the batch dimensions of A
IntArrayRef A_batch_sizes(A.sizes().data(), A.dim() - 2);
IntArrayRef B_batch_sizes(B.sizes().data(), B.dim() - 2);
std::vector<int64_t> expand_batch_portion = at::infer_size(A_batch_sizes, B_batch_sizes);
expand_batch_portion.insert(expand_batch_portion.end(), {A.size(-2), A.size(-1)});
Tensor A_expanded = A.expand({expand_batch_portion});
Tensor A_broadcasted = cloneBatchedColumnMajor(A_expanded);
// cuBLAS batched gels requires input to be the device array of pointers to device single matrices
Tensor A_ptr_array = get_device_pointers<scalar_t>(A_broadcasted);
Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
auto infos_data = infos.data_ptr<int>();
auto handle = at::cuda::getCurrentCUDABlasHandle();
int info;
at::cuda::blas::gelsBatched<scalar_t>(
handle, trans, m, n, nrhs,
A_ptr_array_data, lda,
B_ptr_array_data, ldb,
&info,
infos_data,
batch_size);
// negative info indicates that an argument to gelsBatched call is invalid
TORCH_INTERNAL_ASSERT(info == 0);
}
// This is a type dispatching helper function for 'apply_gels_batched'
void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(a.scalar_type(), "gels_batched_cublas", [&]{
apply_gels_batched<scalar_t>(a, b, infos);
});
}
} // namespace at::native

View File

@ -186,28 +186,15 @@ const char* cublasGetErrorString(cublasStatus_t error) {
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
#if !defined(USE_ROCM)
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
#if !defined(USE_ROCM)
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
#else
case rocblas_status_invalid_size:
return "rocblas_status_invalid_size";
case rocblas_status_perf_degraded:
return "rocblas_status_perf_degraded";
case rocblas_status_size_query_mismatch:
return "rocblas_status_size_query_mismatch";
case rocblas_status_size_increased:
return "rocblas_status_size_increased";
case rocblas_status_size_unchanged:
return "rocblas_status_size_unchanged";
default:
return "unrecognized_rocblas_error";
#endif
}
// To suppress compiler warning.

View File

@ -9,12 +9,13 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/conversions.h"
#ifdef __HIPCC__
#if TORCH_HIP_VERSION < 210
// rocblas doesn't fully support fp16 yet
#define ROCBLAS_FP16 0
#endif
#endif
#if defined(USE_ROCM)
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_32F HIPBLAS_R_32F
#endif // USE_ROCM
namespace caffe2 {
@ -122,30 +123,6 @@ void device_reduce<at::Half>(
(void)N; // Suppress unused variable warning
(void)buffer; // Suppress unused variable warning
(void)context; // Suppress unused variable warning
#if TORCH_HIP_VERSION >= 210
auto buffer_size = 1;
if (buffer->numel() != buffer_size) {
buffer->Resize(buffer_size);
math::Set<at::Half, CUDAContext>(
N,
convert::To<float, at::Half>(1.),
buffer->template mutable_data<at::Half>(),
context);
}
CUBLAS_ENFORCE(rocblas_hdot(
context->cublas_handle(),
N,
reinterpret_cast<const rocblas_half*>(in),
1,
reinterpret_cast<const rocblas_half*>(buffer->data<at::Half>()),
0,
reinterpret_cast<rocblas_half*>(out)));
#elif TORCH_HIP_VERSION < 210
CAFFE_THROW("HIP rocblas doesn't fully support fp16 device_reduce yet.");
#else
auto buffer_size = 1;
if (buffer->numel() != buffer_size) {
@ -170,7 +147,6 @@ void device_reduce<at::Half>(
out,
CUDA_R_16F,
CUDA_R_32F));
#endif
}
template <typename T, int BLOCK_THREADS>

View File

@ -39,9 +39,15 @@
#endif // USE_ROCM
#if defined(USE_ROCM)
using CUBLAS_HALF_TYPE = rocblas_half;
#define CUBLAS_HALF_TYPE hipblasHalf
#define HIPBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_32F HIPBLAS_R_32F
#else // __HIP_PLATFORM_HCC
using CUBLAS_HALF_TYPE = __half;
#define CUBLAS_HALF_TYPE __half
#endif // __HIP_PLATFORM_HCC
#include "caffe2/utils/math/utils.h"
@ -608,12 +614,12 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#if defined(USE_ROCM)
// rocblas doesn't support cublasSgemmEx type API yet.
// It has more general rocblas_gemm_ex API which is more close to
// cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C,
// whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
ROCBLAS_ENFORCE(rocblas_gemm_ex(
context->rocblashandle(),
// hipblas doesn't support hipblasSgemmEx type API.
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
HIPBLAS_ENFORCE(hipblasGemmEx(
context->hipblas_handle(),
cu_trans_B,
cu_trans_A,
N,
@ -621,22 +627,17 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
K,
&alpha,
B,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
ldb,
A,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
lda,
&beta,
C,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
N,
C, // D
rocblas_datatype_f16_r, // D type
N, // ldd
rocblas_datatype_f32_r, // compute type
rocblas_gemm_algo_standard, // rocblas_gemm_algo
0, // solution index, reserved for future use
0)); // flags, reserved for future use
HIPBLAS_R_32F, // compute type
HIPBLAS_GEMM_DEFAULT));
#else
CUBLAS_ENFORCE(cublasSgemmEx(
context->cublas_handle(),
@ -900,13 +901,13 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
N,
M,
K,
&alpha_fp16,
B_device.data().get(),
reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
reinterpret_cast<const CUBLAS_HALF_TYPE* const*>(B_device.data().get()),
ldb,
A_device.data().get(),
reinterpret_cast<const CUBLAS_HALF_TYPE* const*>(A_device.data().get()),
lda,
&beta_fp16,
C_device.data().get(),
reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
reinterpret_cast<CUBLAS_HALF_TYPE* const*>(C_device.data().get()),
ldc,
batch_size));
} else {
@ -944,40 +945,6 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
if (math_type == TensorProto_DataType_FLOAT) {
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#if defined(USE_ROCM)
// D[i*stride_d] = alpha*op(A[i*stride_a])*op(B[i*stride_b]) +
// beta*C[i*stride_c], for i in [0,batch_count-1]
ROCBLAS_ENFORCE(rocblas_gemm_strided_batched_ex(
context->rocblashandle(),
cu_trans_B,
cu_trans_A,
N,
M,
K,
&alpha,
B,
rocblas_datatype_f16_r,
ldb,
B_stride,
A,
rocblas_datatype_f16_r,
lda,
A_stride,
&beta,
C,
rocblas_datatype_f16_r,
ldc,
C_stride,
C, // D
rocblas_datatype_f16_r, // D type
ldc, // ldd
C_stride, // D stride
batch_size,
rocblas_datatype_f32_r, // compute type
rocblas_gemm_algo_standard, // rocblas_gemm_algo
0, // solution index, reserved for future use
0)); // flags, reserved for future use
#else
CUBLAS_ENFORCE(cublasGemmStridedBatchedEx(
context->cublas_handle(),
cu_trans_B,
@ -1002,7 +969,6 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
batch_size,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
#endif // USE_ROCM
} else if (math_type == TensorProto_DataType_FLOAT16) {
// Convert alpha, beta from float -> __half
const __half alpha_fp16 = at::Half(alpha);
@ -1089,35 +1055,30 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#if defined(USE_ROCM)
// rocblas doesn't support cublasSgemmEx type API yet.
// It has more general rocblas_gemm_ex API which is more close to
// cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C,
// whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
ROCBLAS_ENFORCE(rocblas_gemm_ex(
context->rocblashandle(),
// hipblas doesn't support hipblasSgemmEx type API.
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
HIPBLAS_ENFORCE(hipblasGemmEx(
context->hipblas_handle(),
cu_trans_A,
rocblas_operation_none,
HIPBLAS_OP_N,
m,
1,
k,
&alpha,
A,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
lda,
x,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
k,
&beta,
y,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
ldc,
y, // D
rocblas_datatype_f16_r, // D type
ldc, // ldd
rocblas_datatype_f32_r, // compute type
rocblas_gemm_algo_standard, // rocblas_gemm_algo
0, // solution index, reserved for future use
0)); // flags, reserved for future use
HIPBLAS_R_32F, // compute type
HIPBLAS_GEMM_DEFAULT));
#else
CUBLAS_ENFORCE(cublasSgemmEx(
context->cublas_handle(),
@ -1635,20 +1596,6 @@ CAFFE2_CUDA_EXPORT void Dot<at::Half, CUDAContext>(
const at::Half* b,
at::Half* y,
CUDAContext* context) {
#if defined(USE_ROCM) && (TORCH_HIP_VERSION < 210)
CAFFE_THROW("HIP currently does not support FP16 completely yet.");
#elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 210)
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE));
CUBLAS_ENFORCE(rocblas_hdot(
context->cublas_handle(),
n,
reinterpret_cast<const rocblas_half*>(a),
1,
reinterpret_cast<const rocblas_half*>(b),
1,
reinterpret_cast<rocblas_half*>(y)));
#else
// execute with 32-bit math
CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE));
@ -1664,7 +1611,6 @@ CAFFE2_CUDA_EXPORT void Dot<at::Half, CUDAContext>(
y,
CUDA_R_16F,
CUDA_R_32F));
#endif
}
// A previous version of caffe2 used Thrust but it turns out that thrust

View File

@ -1298,17 +1298,9 @@ if(USE_ROCM)
set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB})
# Note [rocblas & rocfft cmake bug]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO: There is a bug in rocblas's & rocfft's cmake files that exports the wrong targets name in ${rocblas_LIBRARIES}
# If you get this wrong, you'll get a complaint like 'ld: cannot find -lrocblas-targets'
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
else()
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
roc::rocblas roc::rocfft hip::hiprand roc::hipsparse)
endif()
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
else()
caffe2_update_option(USE_ROCM OFF)
endif()
@ -1319,15 +1311,10 @@ if(USE_ROCM AND ROCM_VERSION_DEV VERSION_LESS "5.2.0")
# We check again for USE_ROCM because it might have been set to OFF
# in the if above
include_directories(SYSTEM ${HIP_PATH}/include)
include_directories(SYSTEM ${ROCBLAS_PATH}/include)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
include_directories(SYSTEM ${HIPFFT_PATH}/include)
else()
include_directories(SYSTEM ${ROCFFT_PATH}/include)
endif()
include_directories(SYSTEM ${HIPBLAS_PATH}/include)
include_directories(SYSTEM ${HIPFFT_PATH}/include)
include_directories(SYSTEM ${HIPSPARSE_PATH}/include)
include_directories(SYSTEM ${HIPRAND_PATH}/include)
include_directories(SYSTEM ${ROCRAND_PATH}/include)
include_directories(SYSTEM ${THRUST_PATH})
endif()

View File

@ -42,6 +42,13 @@ else()
set(ROCBLAS_PATH $ENV{ROCBLAS_PATH})
endif()
# HIPBLAS_PATH
if(NOT DEFINED ENV{HIPBLAS_PATH})
set(HIPBLAS_PATH ${ROCM_PATH}/hipblas)
else()
set(HIPBLAS_PATH $ENV{HIPBLAS_PATH})
endif()
# ROCFFT_PATH
if(NOT DEFINED ENV{ROCFFT_PATH})
set(ROCFFT_PATH ${ROCM_PATH}/rocfft)
@ -246,6 +253,7 @@ if(HIP_FOUND)
set(rocrand_DIR ${ROCM_PATH}/lib/cmake/rocrand)
set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand)
set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas)
set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas)
set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen)
set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft)
set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft)
@ -263,6 +271,7 @@ if(HIP_FOUND)
set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand)
set(hiprand_DIR ${HIPRAND_PATH}/lib/cmake/hiprand)
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
set(hipblas_DIR ${HIPBLAS_PATH}/lib/cmake/hipblas)
set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen)
set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft)
set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft)
@ -280,6 +289,7 @@ if(HIP_FOUND)
find_package_and_print_version(rocrand REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
find_package_and_print_version(rocblas REQUIRED)
find_package_and_print_version(hipblas REQUIRED)
find_package_and_print_version(miopen REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
find_package_and_print_version(hipfft REQUIRED)

File diff suppressed because it is too large Load Diff

View File

@ -643,7 +643,12 @@ def is_cusparse_file(rel_filepath):
def is_special_file(rel_filepath):
if is_pytorch_file(rel_filepath):
return ("sparse" in rel_filepath.lower()) or ("linalg" in rel_filepath.lower())
if "sparse" in rel_filepath.lower():
return True
elif "linalg" in rel_filepath.lower():
if "batchlinearalgebralibblas" in rel_filepath.lower():
return False # don't use "special" mappings for this specific linalg cublas file
return True
return False
def is_caffe2_gpu_file(rel_filepath):
@ -745,7 +750,7 @@ for mapping in CUDA_TO_HIP_MAPPINGS:
PYTORCH_SPECIAL_MAP[src] = dst
else:
PYTORCH_MAP[src] = dst
if constants.API_PYTORCH not in meta_data:
if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data:
CAFFE2_TRIE.add(src)
CAFFE2_MAP[src] = dst
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())