mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] use hipblas instead of rocblas (#105881)
- BatchLinearAlgebraLib.cpp is now split into one additional file - BatchLinearAlgebraLib.cpp uses only cusolver APIs - BatchLinearAlgebraLibBlas.cpp uses only cublas APIs - hipify operates at the file level and cannot mix cusolver and cublas APIs within the same file - cmake changes to link against hipblas instead of rocblas - hipify mappings changes to map cublas -> hipblas instead of rocblas Pull Request resolved: https://github.com/pytorch/pytorch/pull/105881 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c9c66819a1
commit
5379b5f927
@ -499,11 +499,6 @@ endif()
|
||||
endif(MSVC)
|
||||
endif(USE_MAGMA)
|
||||
|
||||
# NB: We're relying on cmake/Dependencies.cmake to appropriately setup HIP dependencies.
|
||||
# In principle we could duplicate them, but handling the rocblas
|
||||
# dependency is nontrivial. So better not to copy-paste.
|
||||
# Look for Note [rocblas cmake bug]
|
||||
|
||||
# Include CPU paths for CUDA/HIP as well
|
||||
list(APPEND ATen_CUDA_INCLUDE ${ATen_CPU_INCLUDE})
|
||||
list(APPEND ATen_HIP_INCLUDE ${ATen_CPU_INCLUDE})
|
||||
|
@ -16,8 +16,68 @@
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// until hipblas has an API to accept flags, we must use rocblas here
|
||||
#include <rocblas/rocblas.h>
|
||||
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
|
||||
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
|
||||
// needed to work around calling rocblas API instead of hipblas API
|
||||
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
|
||||
{
|
||||
switch(op)
|
||||
{
|
||||
case HIPBLAS_OP_N:
|
||||
return rocblas_operation_none;
|
||||
case HIPBLAS_OP_T:
|
||||
return rocblas_operation_transpose;
|
||||
case HIPBLAS_OP_C:
|
||||
return rocblas_operation_conjugate_transpose;
|
||||
}
|
||||
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
|
||||
}
|
||||
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
|
||||
{
|
||||
switch(error)
|
||||
{
|
||||
case rocblas_status_size_unchanged:
|
||||
case rocblas_status_size_increased:
|
||||
case rocblas_status_success:
|
||||
return HIPBLAS_STATUS_SUCCESS;
|
||||
case rocblas_status_invalid_handle:
|
||||
return HIPBLAS_STATUS_NOT_INITIALIZED;
|
||||
case rocblas_status_not_implemented:
|
||||
return HIPBLAS_STATUS_NOT_SUPPORTED;
|
||||
case rocblas_status_invalid_pointer:
|
||||
case rocblas_status_invalid_size:
|
||||
case rocblas_status_invalid_value:
|
||||
return HIPBLAS_STATUS_INVALID_VALUE;
|
||||
case rocblas_status_memory_error:
|
||||
return HIPBLAS_STATUS_ALLOC_FAILED;
|
||||
case rocblas_status_internal_error:
|
||||
return HIPBLAS_STATUS_INTERNAL_ERROR;
|
||||
}
|
||||
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
|
||||
}
|
||||
// hipblas does not have hipblasSetMathMode
|
||||
#define hipblasSetMathMode(handle, flags) HIPBLAS_STATUS_SUCCESS
|
||||
// until we use hiblas v2
|
||||
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
|
||||
// however hipblas v1 is still using its custom type
|
||||
#define HIP_R_16F HIPBLAS_R_16F
|
||||
#define HIP_R_32F HIPBLAS_R_32F
|
||||
#define HIP_R_64F HIPBLAS_R_64F
|
||||
#define HIP_C_16F HIPBLAS_C_16F
|
||||
#define HIP_C_32F HIPBLAS_C_32F
|
||||
#define HIP_C_64F HIPBLAS_C_64F
|
||||
#define HIP_R_8I HIPBLAS_R_8I
|
||||
#define HIP_R_8U HIPBLAS_R_8U
|
||||
#define HIP_R_32I HIPBLAS_R_32I
|
||||
#define HIP_R_32U HIPBLAS_R_32U
|
||||
#define HIP_C_8I HIPBLAS_C_8I
|
||||
#define HIP_C_8U HIPBLAS_C_8U
|
||||
#define HIP_C_32I HIPBLAS_C_32I
|
||||
#define HIP_C_32U HIPBLAS_C_32U
|
||||
#define HIP_R_16BF HIPBLAS_R_16B
|
||||
#define HIP_C_16BF HIPBLAS_C_16B
|
||||
#endif
|
||||
|
||||
#define CUDABLAS_POSINT_CHECK(FD, X) \
|
||||
@ -193,6 +253,8 @@ cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
#else // USE_ROCM
|
||||
#define cublasGemmStridedBatchedExFix hipblasGemmStridedBatchedEx
|
||||
#endif
|
||||
|
||||
#define GEMM_CHECK_ARGVALUES(Dtype) \
|
||||
@ -288,13 +350,15 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
|
||||
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
|
||||
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
|
||||
#endif
|
||||
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
|
||||
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
|
||||
hipOperationToRocOperation(opa),
|
||||
hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
|
||||
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
|
||||
b, rocblas_datatype_f16_r, (int)ldb, strideb,
|
||||
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
|
||||
c, rocblas_datatype_f16_r, (int)ldc, stridec,
|
||||
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
|
||||
0, flag));
|
||||
0, flag)));
|
||||
#else
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 5){
|
||||
@ -329,22 +393,12 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
const float fbeta = beta;
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
|
||||
opa, opb, (int)m, (int)n, (int)k,
|
||||
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
|
||||
b, CUDA_R_16BF, (int)ldb, strideb,
|
||||
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
|
||||
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
#else
|
||||
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
|
||||
(void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea,
|
||||
b, rocblas_datatype_bf16_r, (int)ldb, strideb,
|
||||
(void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec,
|
||||
c, rocblas_datatype_bf16_r, (int)ldc, stridec,
|
||||
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
|
||||
0, 0, NULL, NULL));
|
||||
#endif // !defined(USE_ROCM)
|
||||
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
|
||||
opa, opb, (int)m, (int)n, (int)k,
|
||||
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
|
||||
b, CUDA_R_16BF, (int)ldb, strideb,
|
||||
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
|
||||
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -373,39 +427,35 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
|
||||
handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
|
||||
template <>
|
||||
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
cublasOperation_t opa = _cublasOpFromChar(transa);
|
||||
cublasOperation_t opb = _cublasOpFromChar(transb);
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
GEMM_CHECK_ARGVALUES(c10::complex<double>);
|
||||
TORCH_CUDABLAS_CHECK(cublasZgemm(
|
||||
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
|
||||
lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta),
|
||||
reinterpret_cast<cuDoubleComplex*>(c), ldc));
|
||||
}
|
||||
#endif
|
||||
template <>
|
||||
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
cublasOperation_t opa = _cublasOpFromChar(transa);
|
||||
cublasOperation_t opb = _cublasOpFromChar(transb);
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
GEMM_CHECK_ARGVALUES(c10::complex<double>);
|
||||
TORCH_CUDABLAS_CHECK(cublasZgemm(
|
||||
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
|
||||
lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta),
|
||||
reinterpret_cast<cuDoubleComplex*>(c), ldc));
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
|
||||
template <>
|
||||
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
cublasOperation_t opa = _cublasOpFromChar(transa);
|
||||
cublasOperation_t opb = _cublasOpFromChar(transb);
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
GEMM_CHECK_ARGVALUES(c10::complex<float>);
|
||||
TORCH_CUDABLAS_CHECK(cublasCgemm(
|
||||
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
|
||||
lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta),
|
||||
reinterpret_cast<cuComplex*>(c), ldc));
|
||||
}
|
||||
#endif
|
||||
template <>
|
||||
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
cublasOperation_t opa = _cublasOpFromChar(transa);
|
||||
cublasOperation_t opb = _cublasOpFromChar(transb);
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
GEMM_CHECK_ARGVALUES(c10::complex<float>);
|
||||
TORCH_CUDABLAS_CHECK(cublasCgemm(
|
||||
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
|
||||
lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta),
|
||||
reinterpret_cast<cuComplex*>(c), ldc));
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
@ -423,10 +473,10 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
|
||||
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
|
||||
#endif
|
||||
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
|
||||
handle,
|
||||
opa,
|
||||
opb,
|
||||
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(
|
||||
(rocblas_handle)handle,
|
||||
hipOperationToRocOperation(opa),
|
||||
hipOperationToRocOperation(opb),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
@ -447,14 +497,16 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard,
|
||||
0,
|
||||
flag));
|
||||
flag)));
|
||||
#else
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 5) {
|
||||
#ifndef USE_ROCM
|
||||
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
|
||||
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||
}
|
||||
#endif
|
||||
// Disallow fp16 reductions that could lead to unexpected overflow issues.
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
|
||||
TORCH_CUDABLAS_CHECK(cublasGemmEx(
|
||||
@ -501,45 +553,6 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
template <>
|
||||
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
cublasOperation_t opa = _cublasOpFromChar(transa);
|
||||
cublasOperation_t opb = _cublasOpFromChar(transb);
|
||||
float falpha = alpha;
|
||||
float fbeta = beta;
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
GEMM_CHECK_ARGVALUES(at::BFloat16);
|
||||
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
|
||||
handle,
|
||||
opa,
|
||||
opb,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
&falpha,
|
||||
a,
|
||||
rocblas_datatype_bf16_r,
|
||||
lda,
|
||||
b,
|
||||
rocblas_datatype_bf16_r,
|
||||
ldb,
|
||||
&fbeta,
|
||||
c,
|
||||
rocblas_datatype_bf16_r,
|
||||
ldc,
|
||||
c,
|
||||
rocblas_datatype_bf16_r,
|
||||
ldc,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard,
|
||||
0,
|
||||
0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
template <>
|
||||
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
@ -550,10 +563,12 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
float fbeta = beta;
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
GEMM_CHECK_ARGVALUES(at::BFloat16);
|
||||
#ifndef USE_ROCM
|
||||
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
|
||||
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
|
||||
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||
}
|
||||
#endif
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
|
||||
TORCH_CUDABLAS_CHECK(cublasGemmEx(
|
||||
handle,
|
||||
@ -577,7 +592,6 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
}
|
||||
#endif // !defined(USE_ROCM)
|
||||
|
||||
#if !defined(USE_ROCM) && !defined(_MSC_VER)
|
||||
|
||||
@ -1113,23 +1127,20 @@ void trsmBatched<c10::complex<double>>(
|
||||
CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
|
||||
} while (0)
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
|
||||
template <>
|
||||
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
cublasOperation_t op = _cublasOpFromChar(trans);
|
||||
_cublasAdjustLdLevel2(m, n, &lda);
|
||||
GEMV_CHECK_ARGVALUES(c10::complex<double>);
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
|
||||
lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta),
|
||||
reinterpret_cast<cuDoubleComplex*>(y), incy));
|
||||
}
|
||||
#endif
|
||||
template <>
|
||||
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
cublasOperation_t op = _cublasOpFromChar(trans);
|
||||
_cublasAdjustLdLevel2(m, n, &lda);
|
||||
GEMV_CHECK_ARGVALUES(c10::complex<double>);
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
|
||||
lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta),
|
||||
reinterpret_cast<cuDoubleComplex*>(y), incy));
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
|
||||
template <>
|
||||
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
|
||||
// gemv is bw bound, and does not benefit from TF32. But the precision
|
||||
@ -1146,7 +1157,6 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
|
||||
lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
|
||||
reinterpret_cast<cuComplex*>(y), incy));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
|
||||
@ -1247,7 +1257,6 @@ void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
|
||||
|
||||
template <>
|
||||
void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
|
||||
#if !defined(USE_ROCM)
|
||||
TORCH_CUDABLAS_CHECK(cublasDotEx(
|
||||
handle,
|
||||
n,
|
||||
@ -1260,23 +1269,10 @@ void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
|
||||
result,
|
||||
CUDA_R_16F,
|
||||
CUDA_R_32F));
|
||||
#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
|
||||
TORCH_CUDABLAS_CHECK(rocblas_hdot(
|
||||
handle,
|
||||
n,
|
||||
reinterpret_cast<const rocblas_half*>(x),
|
||||
incx,
|
||||
reinterpret_cast<const rocblas_half*>(y),
|
||||
incy,
|
||||
reinterpret_cast<rocblas_half*>(result)));
|
||||
#else
|
||||
AT_ERROR("Cublas_Hdot requires CUDA 8.0+");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
|
||||
#if !defined(USE_ROCM)
|
||||
TORCH_CUDABLAS_CHECK(cublasDotEx(
|
||||
handle,
|
||||
n,
|
||||
@ -1289,18 +1285,6 @@ void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
|
||||
result,
|
||||
CUDA_R_16BF,
|
||||
CUDA_R_32F));
|
||||
#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
|
||||
TORCH_CUDABLAS_CHECK(rocblas_bfdot(
|
||||
handle,
|
||||
n,
|
||||
reinterpret_cast<const rocblas_bfloat16*>(x),
|
||||
incx,
|
||||
reinterpret_cast<const rocblas_bfloat16*>(y),
|
||||
incy,
|
||||
reinterpret_cast<rocblas_bfloat16*>(result)));
|
||||
#else
|
||||
AT_ERROR("Cublas_bfdot requires CUDA 11.0+");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -1317,9 +1301,6 @@ void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
|
||||
reinterpret_cast<cuDoubleComplex*>(result)));
|
||||
}
|
||||
|
||||
// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
|
||||
#ifdef CUDART_VERSION
|
||||
|
||||
template <>
|
||||
void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float)) {
|
||||
TORCH_CUDABLAS_CHECK(cublasSgetrsBatched(
|
||||
@ -1519,8 +1500,6 @@ void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::comple
|
||||
batchSize));
|
||||
}
|
||||
|
||||
#endif // CUDART_VERSION
|
||||
|
||||
} // namespace blas
|
||||
} // namespace cuda
|
||||
} // namespace at
|
||||
|
@ -55,14 +55,10 @@ template <>
|
||||
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
|
||||
template <>
|
||||
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
|
||||
template <>
|
||||
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
|
||||
#endif
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
|
||||
template <>
|
||||
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
|
||||
#endif
|
||||
template <>
|
||||
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
|
||||
template <>
|
||||
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
|
||||
template <>
|
||||
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
@ -189,12 +185,10 @@ template <>
|
||||
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
|
||||
template <>
|
||||
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
|
||||
template <>
|
||||
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
|
||||
template <>
|
||||
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
|
||||
#endif
|
||||
template <>
|
||||
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
|
||||
template <>
|
||||
@ -234,9 +228,6 @@ void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
|
||||
template <>
|
||||
void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
|
||||
|
||||
// This guards blocks use of getrsBatched, geqrfBatched, getrfBatched on platforms other than cuda
|
||||
#ifdef CUDART_VERSION
|
||||
|
||||
#define CUDABLAS_GETRS_ARGTYPES(Dtype) \
|
||||
cublasHandle_t handle, cublasOperation_t trans, \
|
||||
int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
|
||||
@ -311,8 +302,6 @@ TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_A
|
||||
template<>
|
||||
TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
|
||||
|
||||
#endif // CUDART_VERSION
|
||||
|
||||
} // namespace blas
|
||||
} // namespace cuda
|
||||
} // namespace at
|
||||
|
@ -125,14 +125,14 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
}
|
||||
#endif
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 30800
|
||||
rocblas_atomics_mode rocblas_mode;
|
||||
#if defined(USE_ROCM)
|
||||
hipblasAtomicsMode_t hipblas_mode;
|
||||
if (at::globalContext().deterministicAlgorithms()) {
|
||||
rocblas_mode = rocblas_atomics_not_allowed;
|
||||
hipblas_mode = HIPBLAS_ATOMICS_NOT_ALLOWED;
|
||||
} else {
|
||||
rocblas_mode = rocblas_atomics_allowed;
|
||||
hipblas_mode = HIPBLAS_ATOMICS_ALLOWED;
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(rocblas_set_atomics_mode(handle, rocblas_mode));
|
||||
TORCH_CUDABLAS_CHECK(hipblasSetAtomicsMode(handle, hipblas_mode));
|
||||
#endif
|
||||
return handle;
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
// See Note [BatchLinearAlgebraLib split implementation files]
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
@ -29,7 +30,7 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
cublasOperation_t to_cublas(TransposeType trans) {
|
||||
static cublasOperation_t to_cublas(TransposeType trans) {
|
||||
switch (trans) {
|
||||
case TransposeType::NoTranspose: return CUBLAS_OP_N;
|
||||
case TransposeType::Transpose: return CUBLAS_OP_T;
|
||||
@ -57,102 +58,6 @@ static Tensor get_device_pointers(const Tensor& input) {
|
||||
input.options().dtype(at::kLong));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void apply_geqrf_batched(const Tensor& input, const Tensor& tau) {
|
||||
// AMD ROCm backend is implemented via rewriting all CUDA calls to HIP
|
||||
// rocBLAS does not implement BLAS-like extensions of cuBLAS, they're in rocSOLVER
|
||||
// rocSOLVER is currently not used in ATen, therefore we raise an error in this case
|
||||
#ifndef CUDART_VERSION
|
||||
TORCH_CHECK(false, "geqrf: Batched version is supported only with cuBLAS backend.")
|
||||
#else
|
||||
auto batch_size = cuda_int_cast(batchCount(input), "batch_size");
|
||||
auto m = cuda_int_cast(input.size(-2), "m");
|
||||
auto n = cuda_int_cast(input.size(-1), "n");
|
||||
auto lda = std::max<int>(1, m);
|
||||
|
||||
// cuBLAS batched geqrf requires input to be the device array of pointers to device single matrices
|
||||
Tensor input_ptr_array = get_device_pointers<scalar_t>(input);
|
||||
Tensor tau_ptr_array = get_device_pointers<scalar_t>(tau.unsqueeze(-1));
|
||||
auto input_ptr_array_data = reinterpret_cast<scalar_t**>(input_ptr_array.data_ptr());
|
||||
auto tau_ptr_array_data = reinterpret_cast<scalar_t**>(tau_ptr_array.data_ptr());
|
||||
|
||||
int info;
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
at::cuda::blas::geqrfBatched(handle, m, n, input_ptr_array_data, lda, tau_ptr_array_data, &info, batch_size);
|
||||
|
||||
// info only indicates wrong arguments to geqrfBatched call
|
||||
// info is a host variable, we can check it without device synchronization
|
||||
TORCH_INTERNAL_ASSERT(info == 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
void geqrf_batched_cublas(const Tensor& input, const Tensor& tau) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_batched_cuda", [&]{
|
||||
apply_geqrf_batched<scalar_t>(input, tau);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
|
||||
#ifndef CUDART_VERSION
|
||||
TORCH_CHECK(false, "linalg.lu_factor: cuBLAS backend for linalg.lu_factor is not available.")
|
||||
#else
|
||||
// This function just works with square matrices
|
||||
TORCH_INTERNAL_ASSERT(A.size(-2) == A.size(-1));
|
||||
|
||||
auto batch_size = cuda_int_cast(batchCount(A), "batch_size");;
|
||||
auto n = cuda_int_cast(A.size(-2), "n");
|
||||
auto lda = cuda_int_cast(std::max<int>(1, n), "lda");
|
||||
|
||||
auto pivots_data = get_pivots ? pivots.data_ptr<int>() : nullptr;
|
||||
auto infos_data = infos.data_ptr<int>();
|
||||
Tensor a_ptr_array = get_device_pointers<scalar_t>(A);
|
||||
auto a_ptr_array_data = reinterpret_cast<scalar_t**>(a_ptr_array.data_ptr());
|
||||
|
||||
at::cuda::blas::getrfBatched(n, a_ptr_array_data, lda, pivots_data, infos_data, batch_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "lu_factor_cublas", [&]{
|
||||
apply_lu_factor_batched_cublas<scalar_t>(A, pivots, infos, get_pivots);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
|
||||
#ifndef CUDART_VERSION
|
||||
TORCH_CHECK(false, "linalg.lu_solve: cuBLAS backend for linalg.lu_solve is not available.")
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(B), "batch_size of LU and B must be the same");
|
||||
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same");
|
||||
const auto trans = to_cublas(transpose);
|
||||
|
||||
auto pivots_data = pivots.data_ptr<int>();
|
||||
auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");;
|
||||
auto m = cuda_int_cast(LU.size(-2), "m");
|
||||
auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
|
||||
auto lda = cuda_int_cast(std::max<int>(1, m), "lda");
|
||||
int info = 0;
|
||||
|
||||
Tensor lu_ptr_array = get_device_pointers<scalar_t>(LU);
|
||||
Tensor b_ptr_array = get_device_pointers<scalar_t>(B);
|
||||
auto lu_ptr_array_data = reinterpret_cast<scalar_t**>(lu_ptr_array.data_ptr());
|
||||
auto b_ptr_array_data = reinterpret_cast<scalar_t**>(b_ptr_array.data_ptr());
|
||||
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
at::cuda::blas::getrsBatched(handle, trans, m, nrhs, lu_ptr_array_data,
|
||||
lda, pivots_data, b_ptr_array_data, lda, &info, batch_size);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
void lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_cublas", [&]{
|
||||
apply_lu_solve_batched_cublas<scalar_t>(LU, pivots, B, trans);
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -331,162 +236,6 @@ void ldl_solve_cusolver(
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_triangular_solve(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
|
||||
#ifdef ROCM_VERSION
|
||||
// Cannot auto-hipifiy this piece of code, because in other functions the uplo
|
||||
// and other variables need to be hipSOLVER's type.
|
||||
auto uplo = upper ? rocblas_fill_upper : rocblas_fill_lower;
|
||||
const auto trans = (rocblas_operation)to_cublas(transpose);
|
||||
rocblas_side side = left ? rocblas_side_left : rocblas_side_right;
|
||||
#else
|
||||
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
|
||||
const auto trans = to_cublas(transpose);
|
||||
cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
||||
#endif
|
||||
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
|
||||
|
||||
auto A_data = A.data_ptr<scalar_t>();
|
||||
auto B_data = B.data_ptr<scalar_t>();
|
||||
auto A_mat_stride = matrixStride(A);
|
||||
auto B_mat_stride = matrixStride(B);
|
||||
auto batch_size = batchCount(A);
|
||||
// This allows to pass rectangular A and B when left = True
|
||||
auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
|
||||
auto n = cuda_int_cast(B.size(-1), "n");
|
||||
auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
|
||||
auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
|
||||
|
||||
auto alpha = scalar_t{1};
|
||||
|
||||
for (decltype(batch_size) i = 0; i < batch_size; i++) {
|
||||
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
|
||||
scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
at::cuda::blas::trsm(handle, side, uplo, trans, diag, m, n, &alpha, A_working_ptr, lda, B_working_ptr, ldb);
|
||||
}
|
||||
}
|
||||
|
||||
void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
|
||||
apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_triangular_solve_batched(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
|
||||
#ifdef ROCM_VERSION
|
||||
// Cannot auto-hipifiy this piece of code, because in other functions the uplo
|
||||
// and other variables need to be hipSOLVER's type.
|
||||
auto uplo = upper ? rocblas_fill_upper : rocblas_fill_lower;
|
||||
const auto trans = (rocblas_operation)to_cublas(transpose);
|
||||
rocblas_side side = left ? rocblas_side_left : rocblas_side_right;
|
||||
#else
|
||||
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
|
||||
const auto trans = to_cublas(transpose);
|
||||
cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
||||
#endif
|
||||
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
|
||||
|
||||
auto batch_size = cuda_int_cast(batchCount(A), "batch_size");
|
||||
// This allows to pass rectangular A and B when left = True
|
||||
auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
|
||||
auto n = cuda_int_cast(B.size(-1), "n");
|
||||
auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
|
||||
auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
|
||||
|
||||
auto alpha = scalar_t{1};
|
||||
|
||||
// cuBLAS batched trsm requires input to be the device array of pointers to device single matrices
|
||||
Tensor A_ptr_array = get_device_pointers<scalar_t>(A);
|
||||
Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
|
||||
auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
|
||||
auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
|
||||
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, m, n, &alpha, A_ptr_array_data, lda, B_ptr_array_data, ldb, batch_size);
|
||||
}
|
||||
|
||||
void triangular_solve_batched_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
|
||||
apply_triangular_solve_batched<scalar_t>(A, B, left, upper, transpose, unitriangular);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void apply_gels_batched(const Tensor& A, Tensor& B, Tensor& infos) {
|
||||
// AMD ROCm backend is implemented via rewriting all CUDA calls to HIP
|
||||
// rocBLAS does not implement BLAS-like extensions of cuBLAS, they're in rocSOLVER
|
||||
// rocSOLVER is currently not used in ATen, therefore we raise an error in this case
|
||||
#ifndef CUDART_VERSION
|
||||
TORCH_CHECK(false, "torch.linalg.lstsq: Batched version is supported only with cuBLAS backend.")
|
||||
#else
|
||||
#ifdef ROCM_VERSION
|
||||
// Cannot auto-hipifiy this piece of code, because in other functions
|
||||
// CUBLAS_OP_N must be translated to HIPSOLVER_OP_N
|
||||
auto trans = rocblas_operation_none;
|
||||
#else
|
||||
auto trans = CUBLAS_OP_N;
|
||||
#endif
|
||||
auto m = cuda_int_cast(A.size(-2), "m");
|
||||
auto n = cuda_int_cast(A.size(-1), "n");
|
||||
|
||||
auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
|
||||
// cuBLAS from cuda10 and older doesn't work with nrhs == 0 (cuda11 works)
|
||||
// so we need to put this early return
|
||||
if (nrhs == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto batch_size = cuda_int_cast(batchCount(B), "batch_size");
|
||||
auto lda = std::max<int>(1, m);
|
||||
auto ldb = std::max<int>(1, m);
|
||||
|
||||
// cuBLAS's requirement
|
||||
TORCH_CHECK(
|
||||
m >= n,
|
||||
"torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA with cuBLAS backend.");
|
||||
|
||||
// cuBLAS documentation says:
|
||||
// Matrices Aarray[i] should not overlap; otherwise, undefined behavior is expected.
|
||||
// explicitly broadcast the batch dimensions of A
|
||||
IntArrayRef A_batch_sizes(A.sizes().data(), A.dim() - 2);
|
||||
IntArrayRef B_batch_sizes(B.sizes().data(), B.dim() - 2);
|
||||
std::vector<int64_t> expand_batch_portion = at::infer_size(A_batch_sizes, B_batch_sizes);
|
||||
expand_batch_portion.insert(expand_batch_portion.end(), {A.size(-2), A.size(-1)});
|
||||
Tensor A_expanded = A.expand({expand_batch_portion});
|
||||
Tensor A_broadcasted = cloneBatchedColumnMajor(A_expanded);
|
||||
|
||||
// cuBLAS batched gels requires input to be the device array of pointers to device single matrices
|
||||
Tensor A_ptr_array = get_device_pointers<scalar_t>(A_broadcasted);
|
||||
Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
|
||||
auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
|
||||
auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
|
||||
|
||||
auto infos_data = infos.data_ptr<int>();
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
int info;
|
||||
|
||||
at::cuda::blas::gelsBatched<scalar_t>(
|
||||
handle, trans, m, n, nrhs,
|
||||
A_ptr_array_data, lda,
|
||||
B_ptr_array_data, ldb,
|
||||
&info,
|
||||
infos_data,
|
||||
batch_size);
|
||||
|
||||
// negative info indicates that an argument to gelsBatched call is invalid
|
||||
TORCH_INTERNAL_ASSERT(info == 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
// This is a type dispatching helper function for 'apply_gels_batched'
|
||||
void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(a.scalar_type(), "gels_batched_cublas", [&]{
|
||||
apply_gels_batched<scalar_t>(a, b, infos);
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(USE_LINALG_SOLVER)
|
||||
|
||||
inline static Tensor column_major_identity_matrix_like(const Tensor& self) {
|
||||
|
287
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp
Normal file
287
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp
Normal file
@ -0,0 +1,287 @@
|
||||
// Note [BatchLinearAlgebraLib split implementation files]
|
||||
//
|
||||
// There are two files that implement the interfaces found in
|
||||
// BatchLinearAlgebraLib.h
|
||||
// - BatchLinearAlgebraLib.cpp
|
||||
// - BatchLinearAlgebraLibBlas.cpp (this file)
|
||||
//
|
||||
// In order to support the ROCm build target, the use of cublas and
|
||||
// cusolver APIs needed to be split into separate source files to
|
||||
// accomodate the hipify step of the ROCm build process.
|
||||
//
|
||||
// To create this current file, the original file
|
||||
// BatchLinearAlgebraLib.cpp was copied to
|
||||
// BatchLinearAlgebraLibBlas.cpp, then any functions that used cusolver
|
||||
// APIs were removed. Similarly, in the original file
|
||||
// BatchLinearAlgebraLib.cpp, any use of cublas APIs was removed.
|
||||
// The net result is a split of the BatchLinearAlgebraLib
|
||||
// implementation files. The original file BatchLinearAlgebraLib.cpp
|
||||
// contains the full, original git history for both files.
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/cuda/PinnedMemoryAllocator.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/TransposeType.h>
|
||||
#include <ATen/native/cuda/MiscUtils.h>
|
||||
#include <ATen/native/cuda/linalg/CUDASolver.h>
|
||||
#include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/arange.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/nan_to_num.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor.h>
|
||||
#include <ATen/ops/where.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
static cublasOperation_t to_cublas(TransposeType trans) {
|
||||
switch (trans) {
|
||||
case TransposeType::NoTranspose: return CUBLAS_OP_N;
|
||||
case TransposeType::Transpose: return CUBLAS_OP_T;
|
||||
case TransposeType::ConjTranspose: return CUBLAS_OP_C;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
|
||||
}
|
||||
|
||||
// Some cuBLAS and cuSOLVER batched routines require input to be a device array of pointers to device individual matrices
|
||||
// 'input' must be a contiguous tensor
|
||||
template <typename scalar_t>
|
||||
static Tensor get_device_pointers(const Tensor& input) {
|
||||
auto input_data = input.const_data_ptr<scalar_t>();
|
||||
int64_t input_mat_stride = matrixStride(input);
|
||||
|
||||
// cublas/cusolver interface requires 'int'
|
||||
int batch_size = cuda_int_cast(batchCount(input), "batch_size");
|
||||
|
||||
// if batch_size==0, then start=0 and end=0
|
||||
// if input_mat_stride==0, then step=sizeof(scalar_t)
|
||||
return at::arange(
|
||||
/*start=*/reinterpret_cast<int64_t>(input_data),
|
||||
/*end=*/reinterpret_cast<int64_t>(input_data + batch_size * input_mat_stride),
|
||||
/*step=*/static_cast<int64_t>(std::max<int64_t>(input_mat_stride, 1) * sizeof(scalar_t)),
|
||||
input.options().dtype(at::kLong));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void apply_geqrf_batched(const Tensor& input, const Tensor& tau) {
|
||||
auto batch_size = cuda_int_cast(batchCount(input), "batch_size");
|
||||
auto m = cuda_int_cast(input.size(-2), "m");
|
||||
auto n = cuda_int_cast(input.size(-1), "n");
|
||||
auto lda = std::max<int>(1, m);
|
||||
|
||||
// cuBLAS batched geqrf requires input to be the device array of pointers to device single matrices
|
||||
Tensor input_ptr_array = get_device_pointers<scalar_t>(input);
|
||||
Tensor tau_ptr_array = get_device_pointers<scalar_t>(tau.unsqueeze(-1));
|
||||
auto input_ptr_array_data = reinterpret_cast<scalar_t**>(input_ptr_array.data_ptr());
|
||||
auto tau_ptr_array_data = reinterpret_cast<scalar_t**>(tau_ptr_array.data_ptr());
|
||||
|
||||
int info;
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
at::cuda::blas::geqrfBatched(handle, m, n, input_ptr_array_data, lda, tau_ptr_array_data, &info, batch_size);
|
||||
|
||||
// info only indicates wrong arguments to geqrfBatched call
|
||||
// info is a host variable, we can check it without device synchronization
|
||||
TORCH_INTERNAL_ASSERT(info == 0);
|
||||
}
|
||||
|
||||
void geqrf_batched_cublas(const Tensor& input, const Tensor& tau) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_batched_cuda", [&]{
|
||||
apply_geqrf_batched<scalar_t>(input, tau);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
|
||||
// This function just works with square matrices
|
||||
TORCH_INTERNAL_ASSERT(A.size(-2) == A.size(-1));
|
||||
|
||||
auto batch_size = cuda_int_cast(batchCount(A), "batch_size");;
|
||||
auto n = cuda_int_cast(A.size(-2), "n");
|
||||
auto lda = cuda_int_cast(std::max<int>(1, n), "lda");
|
||||
|
||||
auto pivots_data = get_pivots ? pivots.data_ptr<int>() : nullptr;
|
||||
auto infos_data = infos.data_ptr<int>();
|
||||
Tensor a_ptr_array = get_device_pointers<scalar_t>(A);
|
||||
auto a_ptr_array_data = reinterpret_cast<scalar_t**>(a_ptr_array.data_ptr());
|
||||
|
||||
at::cuda::blas::getrfBatched(n, a_ptr_array_data, lda, pivots_data, infos_data, batch_size);
|
||||
}
|
||||
|
||||
void lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "lu_factor_cublas", [&]{
|
||||
apply_lu_factor_batched_cublas<scalar_t>(A, pivots, infos, get_pivots);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
|
||||
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(B), "batch_size of LU and B must be the same");
|
||||
TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same");
|
||||
const auto trans = to_cublas(transpose);
|
||||
|
||||
auto pivots_data = pivots.data_ptr<int>();
|
||||
auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");;
|
||||
auto m = cuda_int_cast(LU.size(-2), "m");
|
||||
auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
|
||||
auto lda = cuda_int_cast(std::max<int>(1, m), "lda");
|
||||
int info = 0;
|
||||
|
||||
Tensor lu_ptr_array = get_device_pointers<scalar_t>(LU);
|
||||
Tensor b_ptr_array = get_device_pointers<scalar_t>(B);
|
||||
auto lu_ptr_array_data = reinterpret_cast<scalar_t**>(lu_ptr_array.data_ptr());
|
||||
auto b_ptr_array_data = reinterpret_cast<scalar_t**>(b_ptr_array.data_ptr());
|
||||
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
at::cuda::blas::getrsBatched(handle, trans, m, nrhs, lu_ptr_array_data,
|
||||
lda, pivots_data, b_ptr_array_data, lda, &info, batch_size);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
|
||||
}
|
||||
|
||||
void lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_cublas", [&]{
|
||||
apply_lu_solve_batched_cublas<scalar_t>(LU, pivots, B, trans);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_triangular_solve(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
|
||||
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
|
||||
const auto trans = to_cublas(transpose);
|
||||
cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
||||
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
|
||||
|
||||
auto A_data = A.data_ptr<scalar_t>();
|
||||
auto B_data = B.data_ptr<scalar_t>();
|
||||
auto A_mat_stride = matrixStride(A);
|
||||
auto B_mat_stride = matrixStride(B);
|
||||
auto batch_size = batchCount(A);
|
||||
// This allows to pass rectangular A and B when left = True
|
||||
auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
|
||||
auto n = cuda_int_cast(B.size(-1), "n");
|
||||
auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
|
||||
auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
|
||||
|
||||
auto alpha = scalar_t{1};
|
||||
|
||||
for (decltype(batch_size) i = 0; i < batch_size; i++) {
|
||||
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
|
||||
scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
at::cuda::blas::trsm(handle, side, uplo, trans, diag, m, n, &alpha, A_working_ptr, lda, B_working_ptr, ldb);
|
||||
}
|
||||
}
|
||||
|
||||
void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
|
||||
apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_triangular_solve_batched(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
|
||||
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
|
||||
const auto trans = to_cublas(transpose);
|
||||
cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
||||
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
|
||||
|
||||
auto batch_size = cuda_int_cast(batchCount(A), "batch_size");
|
||||
// This allows to pass rectangular A and B when left = True
|
||||
auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
|
||||
auto n = cuda_int_cast(B.size(-1), "n");
|
||||
auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
|
||||
auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
|
||||
|
||||
auto alpha = scalar_t{1};
|
||||
|
||||
// cuBLAS batched trsm requires input to be the device array of pointers to device single matrices
|
||||
Tensor A_ptr_array = get_device_pointers<scalar_t>(A);
|
||||
Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
|
||||
auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
|
||||
auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
|
||||
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, m, n, &alpha, A_ptr_array_data, lda, B_ptr_array_data, ldb, batch_size);
|
||||
}
|
||||
|
||||
void triangular_solve_batched_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
|
||||
apply_triangular_solve_batched<scalar_t>(A, B, left, upper, transpose, unitriangular);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void apply_gels_batched(const Tensor& A, Tensor& B, Tensor& infos) {
|
||||
auto trans = CUBLAS_OP_N;
|
||||
auto m = cuda_int_cast(A.size(-2), "m");
|
||||
auto n = cuda_int_cast(A.size(-1), "n");
|
||||
|
||||
auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
|
||||
// cuBLAS from cuda10 and older doesn't work with nrhs == 0 (cuda11 works)
|
||||
// so we need to put this early return
|
||||
if (nrhs == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto batch_size = cuda_int_cast(batchCount(B), "batch_size");
|
||||
auto lda = std::max<int>(1, m);
|
||||
auto ldb = std::max<int>(1, m);
|
||||
|
||||
// cuBLAS's requirement
|
||||
TORCH_CHECK(
|
||||
m >= n,
|
||||
"torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA with cuBLAS backend.");
|
||||
|
||||
// cuBLAS documentation says:
|
||||
// Matrices Aarray[i] should not overlap; otherwise, undefined behavior is expected.
|
||||
// explicitly broadcast the batch dimensions of A
|
||||
IntArrayRef A_batch_sizes(A.sizes().data(), A.dim() - 2);
|
||||
IntArrayRef B_batch_sizes(B.sizes().data(), B.dim() - 2);
|
||||
std::vector<int64_t> expand_batch_portion = at::infer_size(A_batch_sizes, B_batch_sizes);
|
||||
expand_batch_portion.insert(expand_batch_portion.end(), {A.size(-2), A.size(-1)});
|
||||
Tensor A_expanded = A.expand({expand_batch_portion});
|
||||
Tensor A_broadcasted = cloneBatchedColumnMajor(A_expanded);
|
||||
|
||||
// cuBLAS batched gels requires input to be the device array of pointers to device single matrices
|
||||
Tensor A_ptr_array = get_device_pointers<scalar_t>(A_broadcasted);
|
||||
Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
|
||||
auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
|
||||
auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
|
||||
|
||||
auto infos_data = infos.data_ptr<int>();
|
||||
auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
int info;
|
||||
|
||||
at::cuda::blas::gelsBatched<scalar_t>(
|
||||
handle, trans, m, n, nrhs,
|
||||
A_ptr_array_data, lda,
|
||||
B_ptr_array_data, ldb,
|
||||
&info,
|
||||
infos_data,
|
||||
batch_size);
|
||||
|
||||
// negative info indicates that an argument to gelsBatched call is invalid
|
||||
TORCH_INTERNAL_ASSERT(info == 0);
|
||||
}
|
||||
|
||||
// This is a type dispatching helper function for 'apply_gels_batched'
|
||||
void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(a.scalar_type(), "gels_batched_cublas", [&]{
|
||||
apply_gels_batched<scalar_t>(a, b, infos);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace at::native
|
@ -186,28 +186,15 @@ const char* cublasGetErrorString(cublasStatus_t error) {
|
||||
return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR:
|
||||
return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
#if !defined(USE_ROCM)
|
||||
case CUBLAS_STATUS_MAPPING_ERROR:
|
||||
return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED:
|
||||
return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
#if !defined(USE_ROCM)
|
||||
case CUBLAS_STATUS_LICENSE_ERROR:
|
||||
return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
#else
|
||||
case rocblas_status_invalid_size:
|
||||
return "rocblas_status_invalid_size";
|
||||
case rocblas_status_perf_degraded:
|
||||
return "rocblas_status_perf_degraded";
|
||||
case rocblas_status_size_query_mismatch:
|
||||
return "rocblas_status_size_query_mismatch";
|
||||
case rocblas_status_size_increased:
|
||||
return "rocblas_status_size_increased";
|
||||
case rocblas_status_size_unchanged:
|
||||
return "rocblas_status_size_unchanged";
|
||||
default:
|
||||
return "unrecognized_rocblas_error";
|
||||
#endif
|
||||
}
|
||||
// To suppress compiler warning.
|
||||
|
@ -9,12 +9,13 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/utils/conversions.h"
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#if TORCH_HIP_VERSION < 210
|
||||
// rocblas doesn't fully support fp16 yet
|
||||
#define ROCBLAS_FP16 0
|
||||
#endif
|
||||
#endif
|
||||
#if defined(USE_ROCM)
|
||||
// until we use hipblas v2
|
||||
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
|
||||
// however hipblas v1 is still using its custom type
|
||||
#define HIP_R_16F HIPBLAS_R_16F
|
||||
#define HIP_R_32F HIPBLAS_R_32F
|
||||
#endif // USE_ROCM
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -122,30 +123,6 @@ void device_reduce<at::Half>(
|
||||
(void)N; // Suppress unused variable warning
|
||||
(void)buffer; // Suppress unused variable warning
|
||||
(void)context; // Suppress unused variable warning
|
||||
#if TORCH_HIP_VERSION >= 210
|
||||
auto buffer_size = 1;
|
||||
|
||||
if (buffer->numel() != buffer_size) {
|
||||
buffer->Resize(buffer_size);
|
||||
|
||||
math::Set<at::Half, CUDAContext>(
|
||||
N,
|
||||
convert::To<float, at::Half>(1.),
|
||||
buffer->template mutable_data<at::Half>(),
|
||||
context);
|
||||
}
|
||||
|
||||
CUBLAS_ENFORCE(rocblas_hdot(
|
||||
context->cublas_handle(),
|
||||
N,
|
||||
reinterpret_cast<const rocblas_half*>(in),
|
||||
1,
|
||||
reinterpret_cast<const rocblas_half*>(buffer->data<at::Half>()),
|
||||
0,
|
||||
reinterpret_cast<rocblas_half*>(out)));
|
||||
#elif TORCH_HIP_VERSION < 210
|
||||
CAFFE_THROW("HIP rocblas doesn't fully support fp16 device_reduce yet.");
|
||||
#else
|
||||
auto buffer_size = 1;
|
||||
|
||||
if (buffer->numel() != buffer_size) {
|
||||
@ -170,7 +147,6 @@ void device_reduce<at::Half>(
|
||||
out,
|
||||
CUDA_R_16F,
|
||||
CUDA_R_32F));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_THREADS>
|
||||
|
@ -39,9 +39,15 @@
|
||||
#endif // USE_ROCM
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
using CUBLAS_HALF_TYPE = rocblas_half;
|
||||
#define CUBLAS_HALF_TYPE hipblasHalf
|
||||
#define HIPBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||
// until we use hipblas v2
|
||||
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
|
||||
// however hipblas v1 is still using its custom type
|
||||
#define HIP_R_16F HIPBLAS_R_16F
|
||||
#define HIP_R_32F HIPBLAS_R_32F
|
||||
#else // __HIP_PLATFORM_HCC
|
||||
using CUBLAS_HALF_TYPE = __half;
|
||||
#define CUBLAS_HALF_TYPE __half
|
||||
#endif // __HIP_PLATFORM_HCC
|
||||
|
||||
#include "caffe2/utils/math/utils.h"
|
||||
@ -608,12 +614,12 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
|
||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
||||
#if defined(USE_ROCM)
|
||||
// rocblas doesn't support cublasSgemmEx type API yet.
|
||||
// It has more general rocblas_gemm_ex API which is more close to
|
||||
// cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C,
|
||||
// whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||
ROCBLAS_ENFORCE(rocblas_gemm_ex(
|
||||
context->rocblashandle(),
|
||||
// hipblas doesn't support hipblasSgemmEx type API.
|
||||
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
||||
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
||||
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||
HIPBLAS_ENFORCE(hipblasGemmEx(
|
||||
context->hipblas_handle(),
|
||||
cu_trans_B,
|
||||
cu_trans_A,
|
||||
N,
|
||||
@ -621,22 +627,17 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
|
||||
K,
|
||||
&alpha,
|
||||
B,
|
||||
rocblas_datatype_f16_r,
|
||||
HIPBLAS_R_16F,
|
||||
ldb,
|
||||
A,
|
||||
rocblas_datatype_f16_r,
|
||||
HIPBLAS_R_16F,
|
||||
lda,
|
||||
&beta,
|
||||
C,
|
||||
rocblas_datatype_f16_r,
|
||||
HIPBLAS_R_16F,
|
||||
N,
|
||||
C, // D
|
||||
rocblas_datatype_f16_r, // D type
|
||||
N, // ldd
|
||||
rocblas_datatype_f32_r, // compute type
|
||||
rocblas_gemm_algo_standard, // rocblas_gemm_algo
|
||||
0, // solution index, reserved for future use
|
||||
0)); // flags, reserved for future use
|
||||
HIPBLAS_R_32F, // compute type
|
||||
HIPBLAS_GEMM_DEFAULT));
|
||||
#else
|
||||
CUBLAS_ENFORCE(cublasSgemmEx(
|
||||
context->cublas_handle(),
|
||||
@ -900,13 +901,13 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
|
||||
N,
|
||||
M,
|
||||
K,
|
||||
&alpha_fp16,
|
||||
B_device.data().get(),
|
||||
reinterpret_cast<const CUBLAS_HALF_TYPE*>(&alpha_fp16),
|
||||
reinterpret_cast<const CUBLAS_HALF_TYPE* const*>(B_device.data().get()),
|
||||
ldb,
|
||||
A_device.data().get(),
|
||||
reinterpret_cast<const CUBLAS_HALF_TYPE* const*>(A_device.data().get()),
|
||||
lda,
|
||||
&beta_fp16,
|
||||
C_device.data().get(),
|
||||
reinterpret_cast<const CUBLAS_HALF_TYPE*>(&beta_fp16),
|
||||
reinterpret_cast<CUBLAS_HALF_TYPE* const*>(C_device.data().get()),
|
||||
ldc,
|
||||
batch_size));
|
||||
} else {
|
||||
@ -944,40 +945,6 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
|
||||
if (math_type == TensorProto_DataType_FLOAT) {
|
||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
||||
#if defined(USE_ROCM)
|
||||
// D[i*stride_d] = alpha*op(A[i*stride_a])*op(B[i*stride_b]) +
|
||||
// beta*C[i*stride_c], for i in [0,batch_count-1]
|
||||
ROCBLAS_ENFORCE(rocblas_gemm_strided_batched_ex(
|
||||
context->rocblashandle(),
|
||||
cu_trans_B,
|
||||
cu_trans_A,
|
||||
N,
|
||||
M,
|
||||
K,
|
||||
&alpha,
|
||||
B,
|
||||
rocblas_datatype_f16_r,
|
||||
ldb,
|
||||
B_stride,
|
||||
A,
|
||||
rocblas_datatype_f16_r,
|
||||
lda,
|
||||
A_stride,
|
||||
&beta,
|
||||
C,
|
||||
rocblas_datatype_f16_r,
|
||||
ldc,
|
||||
C_stride,
|
||||
C, // D
|
||||
rocblas_datatype_f16_r, // D type
|
||||
ldc, // ldd
|
||||
C_stride, // D stride
|
||||
batch_size,
|
||||
rocblas_datatype_f32_r, // compute type
|
||||
rocblas_gemm_algo_standard, // rocblas_gemm_algo
|
||||
0, // solution index, reserved for future use
|
||||
0)); // flags, reserved for future use
|
||||
#else
|
||||
CUBLAS_ENFORCE(cublasGemmStridedBatchedEx(
|
||||
context->cublas_handle(),
|
||||
cu_trans_B,
|
||||
@ -1002,7 +969,6 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
|
||||
batch_size,
|
||||
CUDA_R_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
#endif // USE_ROCM
|
||||
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
||||
// Convert alpha, beta from float -> __half
|
||||
const __half alpha_fp16 = at::Half(alpha);
|
||||
@ -1089,35 +1055,30 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
|
||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
||||
#if defined(USE_ROCM)
|
||||
// rocblas doesn't support cublasSgemmEx type API yet.
|
||||
// It has more general rocblas_gemm_ex API which is more close to
|
||||
// cublasGemmEx rocblas_gemm_ex does D = alpha*op( A )*op( B ) + beta*C,
|
||||
// whereas cublasgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||
ROCBLAS_ENFORCE(rocblas_gemm_ex(
|
||||
context->rocblashandle(),
|
||||
// hipblas doesn't support hipblasSgemmEx type API.
|
||||
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
||||
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
||||
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||
HIPBLAS_ENFORCE(hipblasGemmEx(
|
||||
context->hipblas_handle(),
|
||||
cu_trans_A,
|
||||
rocblas_operation_none,
|
||||
HIPBLAS_OP_N,
|
||||
m,
|
||||
1,
|
||||
k,
|
||||
&alpha,
|
||||
A,
|
||||
rocblas_datatype_f16_r,
|
||||
HIPBLAS_R_16F,
|
||||
lda,
|
||||
x,
|
||||
rocblas_datatype_f16_r,
|
||||
HIPBLAS_R_16F,
|
||||
k,
|
||||
&beta,
|
||||
y,
|
||||
rocblas_datatype_f16_r,
|
||||
HIPBLAS_R_16F,
|
||||
ldc,
|
||||
y, // D
|
||||
rocblas_datatype_f16_r, // D type
|
||||
ldc, // ldd
|
||||
rocblas_datatype_f32_r, // compute type
|
||||
rocblas_gemm_algo_standard, // rocblas_gemm_algo
|
||||
0, // solution index, reserved for future use
|
||||
0)); // flags, reserved for future use
|
||||
HIPBLAS_R_32F, // compute type
|
||||
HIPBLAS_GEMM_DEFAULT));
|
||||
#else
|
||||
CUBLAS_ENFORCE(cublasSgemmEx(
|
||||
context->cublas_handle(),
|
||||
@ -1635,20 +1596,6 @@ CAFFE2_CUDA_EXPORT void Dot<at::Half, CUDAContext>(
|
||||
const at::Half* b,
|
||||
at::Half* y,
|
||||
CUDAContext* context) {
|
||||
#if defined(USE_ROCM) && (TORCH_HIP_VERSION < 210)
|
||||
CAFFE_THROW("HIP currently does not support FP16 completely yet.");
|
||||
#elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 210)
|
||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||
context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE));
|
||||
CUBLAS_ENFORCE(rocblas_hdot(
|
||||
context->cublas_handle(),
|
||||
n,
|
||||
reinterpret_cast<const rocblas_half*>(a),
|
||||
1,
|
||||
reinterpret_cast<const rocblas_half*>(b),
|
||||
1,
|
||||
reinterpret_cast<rocblas_half*>(y)));
|
||||
#else
|
||||
// execute with 32-bit math
|
||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||
context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE));
|
||||
@ -1664,7 +1611,6 @@ CAFFE2_CUDA_EXPORT void Dot<at::Half, CUDAContext>(
|
||||
y,
|
||||
CUDA_R_16F,
|
||||
CUDA_R_32F));
|
||||
#endif
|
||||
}
|
||||
|
||||
// A previous version of caffe2 used Thrust but it turns out that thrust
|
||||
|
@ -1298,17 +1298,9 @@ if(USE_ROCM)
|
||||
set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB})
|
||||
|
||||
# Note [rocblas & rocfft cmake bug]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# TODO: There is a bug in rocblas's & rocfft's cmake files that exports the wrong targets name in ${rocblas_LIBRARIES}
|
||||
# If you get this wrong, you'll get a complaint like 'ld: cannot find -lrocblas-targets'
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
|
||||
else()
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
roc::rocblas roc::rocfft hip::hiprand roc::hipsparse)
|
||||
endif()
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
|
||||
|
||||
else()
|
||||
caffe2_update_option(USE_ROCM OFF)
|
||||
endif()
|
||||
@ -1319,15 +1311,10 @@ if(USE_ROCM AND ROCM_VERSION_DEV VERSION_LESS "5.2.0")
|
||||
# We check again for USE_ROCM because it might have been set to OFF
|
||||
# in the if above
|
||||
include_directories(SYSTEM ${HIP_PATH}/include)
|
||||
include_directories(SYSTEM ${ROCBLAS_PATH}/include)
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
|
||||
include_directories(SYSTEM ${HIPFFT_PATH}/include)
|
||||
else()
|
||||
include_directories(SYSTEM ${ROCFFT_PATH}/include)
|
||||
endif()
|
||||
include_directories(SYSTEM ${HIPBLAS_PATH}/include)
|
||||
include_directories(SYSTEM ${HIPFFT_PATH}/include)
|
||||
include_directories(SYSTEM ${HIPSPARSE_PATH}/include)
|
||||
include_directories(SYSTEM ${HIPRAND_PATH}/include)
|
||||
include_directories(SYSTEM ${ROCRAND_PATH}/include)
|
||||
include_directories(SYSTEM ${THRUST_PATH})
|
||||
endif()
|
||||
|
||||
|
@ -42,6 +42,13 @@ else()
|
||||
set(ROCBLAS_PATH $ENV{ROCBLAS_PATH})
|
||||
endif()
|
||||
|
||||
# HIPBLAS_PATH
|
||||
if(NOT DEFINED ENV{HIPBLAS_PATH})
|
||||
set(HIPBLAS_PATH ${ROCM_PATH}/hipblas)
|
||||
else()
|
||||
set(HIPBLAS_PATH $ENV{HIPBLAS_PATH})
|
||||
endif()
|
||||
|
||||
# ROCFFT_PATH
|
||||
if(NOT DEFINED ENV{ROCFFT_PATH})
|
||||
set(ROCFFT_PATH ${ROCM_PATH}/rocfft)
|
||||
@ -246,6 +253,7 @@ if(HIP_FOUND)
|
||||
set(rocrand_DIR ${ROCM_PATH}/lib/cmake/rocrand)
|
||||
set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand)
|
||||
set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas)
|
||||
set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas)
|
||||
set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen)
|
||||
set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft)
|
||||
set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft)
|
||||
@ -263,6 +271,7 @@ if(HIP_FOUND)
|
||||
set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand)
|
||||
set(hiprand_DIR ${HIPRAND_PATH}/lib/cmake/hiprand)
|
||||
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
|
||||
set(hipblas_DIR ${HIPBLAS_PATH}/lib/cmake/hipblas)
|
||||
set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen)
|
||||
set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft)
|
||||
set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft)
|
||||
@ -280,6 +289,7 @@ if(HIP_FOUND)
|
||||
find_package_and_print_version(rocrand REQUIRED)
|
||||
find_package_and_print_version(hiprand REQUIRED)
|
||||
find_package_and_print_version(rocblas REQUIRED)
|
||||
find_package_and_print_version(hipblas REQUIRED)
|
||||
find_package_and_print_version(miopen REQUIRED)
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
|
||||
find_package_and_print_version(hipfft REQUIRED)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -643,7 +643,12 @@ def is_cusparse_file(rel_filepath):
|
||||
|
||||
def is_special_file(rel_filepath):
|
||||
if is_pytorch_file(rel_filepath):
|
||||
return ("sparse" in rel_filepath.lower()) or ("linalg" in rel_filepath.lower())
|
||||
if "sparse" in rel_filepath.lower():
|
||||
return True
|
||||
elif "linalg" in rel_filepath.lower():
|
||||
if "batchlinearalgebralibblas" in rel_filepath.lower():
|
||||
return False # don't use "special" mappings for this specific linalg cublas file
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_caffe2_gpu_file(rel_filepath):
|
||||
@ -745,7 +750,7 @@ for mapping in CUDA_TO_HIP_MAPPINGS:
|
||||
PYTORCH_SPECIAL_MAP[src] = dst
|
||||
else:
|
||||
PYTORCH_MAP[src] = dst
|
||||
if constants.API_PYTORCH not in meta_data:
|
||||
if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data:
|
||||
CAFFE2_TRIE.add(src)
|
||||
CAFFE2_MAP[src] = dst
|
||||
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
|
||||
|
Reference in New Issue
Block a user