mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ROCm] add hipblaslt support (#114329)"
This reverts commit b062ea38039234c80404a8f5f4d5a93c4cb9832d. Reverted https://github.com/pytorch/pytorch/pull/114329 on behalf of https://github.com/jeanschmidt due to Reverting due to inconsistencies on internal diff ([comment](https://github.com/pytorch/pytorch/pull/114329#issuecomment-1861933267))
This commit is contained in:
@ -11,9 +11,14 @@
|
|||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
|
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
||||||
|
// added bf16 support
|
||||||
|
#if !defined(USE_ROCM) && !defined(_MSC_VER)
|
||||||
|
#include <cublasLt.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
// until hipblas has an API to accept flags, we must use rocblas here
|
// until hipblas has an API to accept flags, we must use rocblas here
|
||||||
#include <hipblas/hipblas.h>
|
|
||||||
#include <rocblas/rocblas.h>
|
#include <rocblas/rocblas.h>
|
||||||
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
|
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
|
||||||
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
|
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
|
||||||
@ -59,7 +64,6 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
|
|||||||
// until we use hiblas v2
|
// until we use hiblas v2
|
||||||
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
|
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
|
||||||
// however hipblas v1 is still using its custom type
|
// however hipblas v1 is still using its custom type
|
||||||
#ifndef HIPBLAS_V2
|
|
||||||
#define HIP_R_16F HIPBLAS_R_16F
|
#define HIP_R_16F HIPBLAS_R_16F
|
||||||
#define HIP_R_32F HIPBLAS_R_32F
|
#define HIP_R_32F HIPBLAS_R_32F
|
||||||
#define HIP_R_64F HIPBLAS_R_64F
|
#define HIP_R_64F HIPBLAS_R_64F
|
||||||
@ -77,7 +81,6 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
|
|||||||
#define HIP_R_16BF HIPBLAS_R_16B
|
#define HIP_R_16BF HIPBLAS_R_16B
|
||||||
#define HIP_C_16BF HIPBLAS_C_16B
|
#define HIP_C_16BF HIPBLAS_C_16B
|
||||||
#endif
|
#endif
|
||||||
#endif
|
|
||||||
|
|
||||||
#define CUDABLAS_POSINT_CHECK(FD, X) \
|
#define CUDABLAS_POSINT_CHECK(FD, X) \
|
||||||
TORCH_CHECK( \
|
TORCH_CHECK( \
|
||||||
@ -164,7 +167,6 @@ static void _cublasAdjustLdLevel3(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
uint32_t _getAlignment(uintptr_t address) {
|
uint32_t _getAlignment(uintptr_t address) {
|
||||||
// alignment are in bytes
|
// alignment are in bytes
|
||||||
uint32_t alignment = 256;
|
uint32_t alignment = 256;
|
||||||
@ -174,25 +176,18 @@ uint32_t _getAlignment(uintptr_t address) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
static size_t _parseChosenWorkspaceSize() {
|
static size_t _parseChosenWorkspaceSize() {
|
||||||
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
|
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
|
||||||
#ifdef USE_ROCM
|
|
||||||
if (!val) {
|
|
||||||
// accept either env var
|
|
||||||
val = getenv("HIPBLASLT_WORKSPACE_SIZE");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
size_t workspace_size = 1024; /* default size in KiB according to #73328 */
|
size_t workspace_size = 1024; /* default size in KiB according to #73328 */
|
||||||
if (val) {
|
if (val) {
|
||||||
try {
|
try {
|
||||||
workspace_size = std::stoi(val);
|
workspace_size = std::stoi(val);
|
||||||
} catch(std::invalid_argument const& e) {
|
} catch(std::invalid_argument const& e) {
|
||||||
TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,",
|
TORCH_WARN("invalid CUBLAS_LT_WORKSPACE_SIZE,",
|
||||||
" using default workspace size of ", workspace_size, " bytes.");
|
" using default workspace size of ", workspace_size, " bytes.");
|
||||||
} catch(std::out_of_range const& e) {
|
} catch(std::out_of_range const& e) {
|
||||||
TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,",
|
TORCH_WARN("CUBLAS_LT_WORKSPACE_SIZE out of range,",
|
||||||
" using default workspace size of ", workspace_size, " bytes.");
|
" using default workspace size of ", workspace_size, " bytes.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -346,19 +341,12 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
|||||||
const float fbeta = beta;
|
const float fbeta = beta;
|
||||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||||
|
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
|
|
||||||
auto compute_type = CUBLAS_COMPUTE_32F;
|
|
||||||
#else
|
|
||||||
auto compute_type = CUDA_R_32F;
|
|
||||||
#endif
|
|
||||||
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle,
|
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle,
|
||||||
opa, opb, (int)m, (int)n, (int)k,
|
opa, opb, (int)m, (int)n, (int)k,
|
||||||
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
|
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
|
||||||
b, CUDA_R_16BF, (int)ldb, strideb,
|
b, CUDA_R_16BF, (int)ldb, strideb,
|
||||||
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
|
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
|
||||||
(int)num_batches,
|
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
compute_type,
|
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -528,11 +516,6 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
|||||||
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
|
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
|
||||||
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
|
|
||||||
auto compute_type = CUBLAS_COMPUTE_32F;
|
|
||||||
#else
|
|
||||||
auto compute_type = CUDA_R_32F;
|
|
||||||
#endif
|
#endif
|
||||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
|
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
|
||||||
TORCH_CUDABLAS_CHECK(cublasGemmEx(
|
TORCH_CUDABLAS_CHECK(cublasGemmEx(
|
||||||
@ -553,62 +536,12 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
|||||||
c,
|
c,
|
||||||
CUDA_R_16BF,
|
CUDA_R_16BF,
|
||||||
ldc,
|
ldc,
|
||||||
compute_type,
|
CUDA_R_32F,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
#if !defined(USE_ROCM) && !defined(_MSC_VER)
|
||||||
|
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
|
|
||||||
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
|
|
||||||
// to hipify correctly without this change.
|
|
||||||
#define hipDataType hipblasDatatype_t
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// hipblaslt custom types were a temporary work-around
|
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_DATA_TYPE
|
|
||||||
hipblasltDatatype_t hipToLt(hipDataType type) {
|
|
||||||
switch (type) {
|
|
||||||
case HIP_R_32F: return HIPBLASLT_R_32F;
|
|
||||||
case HIP_R_64F: return HIPBLASLT_R_64F;
|
|
||||||
case HIP_R_16F: return HIPBLASLT_R_16F;
|
|
||||||
case HIP_R_8I: return HIPBLASLT_R_8I;
|
|
||||||
case HIP_C_32F: return HIPBLASLT_C_32F;
|
|
||||||
case HIP_C_64F: return HIPBLASLT_C_64F;
|
|
||||||
case HIP_C_16F: return HIPBLASLT_C_16F;
|
|
||||||
case HIP_C_8I: return HIPBLASLT_C_8I;
|
|
||||||
case HIP_R_8U: return HIPBLASLT_R_8U;
|
|
||||||
case HIP_C_8U: return HIPBLASLT_C_8U;
|
|
||||||
case HIP_R_32I: return HIPBLASLT_R_32I;
|
|
||||||
case HIP_C_32I: return HIPBLASLT_C_32I;
|
|
||||||
case HIP_R_32U: return HIPBLASLT_R_32U;
|
|
||||||
case HIP_C_32U: return HIPBLASLT_C_32U;
|
|
||||||
case HIP_R_16BF: return HIPBLASLT_R_16B;
|
|
||||||
case HIP_C_16BF: return HIPBLASLT_C_16B;
|
|
||||||
default: TORCH_CHECK(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#define HIPTOLT(type) hipToLt(type)
|
|
||||||
#else
|
|
||||||
#define HIPTOLT(type) type
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_COMPUTE_TYPE
|
|
||||||
hipblasLtComputeType_t hipblasToLt(hipblasComputeType_t type) {
|
|
||||||
switch (type) {
|
|
||||||
case HIPBLAS_COMPUTE_32F: return HIPBLASLT_COMPUTE_F32;
|
|
||||||
case HIPBLAS_COMPUTE_32F_FAST_16F: return HIPBLASLT_COMPUTE_F32_FAST_F16;
|
|
||||||
case HIPBLAS_COMPUTE_32F_FAST_TF32: return HIPBLASLT_COMPUTE_F32_FAST_XF32;
|
|
||||||
case HIPBLAS_COMPUTE_64F: return HIPBLASLT_COMPUTE_F64;
|
|
||||||
case HIPBLAS_COMPUTE_32I: return HIPBLASLT_COMPUTE_I32;
|
|
||||||
default: TORCH_CHECK(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#define HIPCOMPTOLT(type) hipblasToLt(type)
|
|
||||||
#else
|
|
||||||
#define HIPCOMPTOLT(type) type
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Following the pattern of CuSparseDescriptor
|
// Following the pattern of CuSparseDescriptor
|
||||||
@ -647,7 +580,7 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
|
|||||||
cudaDataType_t scale_type) {
|
cudaDataType_t scale_type) {
|
||||||
cublasLtMatmulDesc_t raw_descriptor = nullptr;
|
cublasLtMatmulDesc_t raw_descriptor = nullptr;
|
||||||
TORCH_CUDABLAS_CHECK(
|
TORCH_CUDABLAS_CHECK(
|
||||||
cublasLtMatmulDescCreate(&raw_descriptor, HIPCOMPTOLT(compute_type), HIPTOLT(scale_type)));
|
cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
|
||||||
descriptor_.reset(raw_descriptor);
|
descriptor_.reset(raw_descriptor);
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -668,7 +601,7 @@ class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
|
|||||||
bool t = false) {
|
bool t = false) {
|
||||||
cublasLtMatrixLayout_t raw_descriptor = nullptr;
|
cublasLtMatrixLayout_t raw_descriptor = nullptr;
|
||||||
TORCH_CUDABLAS_CHECK(
|
TORCH_CUDABLAS_CHECK(
|
||||||
cublasLtMatrixLayoutCreate(&raw_descriptor, HIPTOLT(type), t ? cols : rows, t ? rows : cols, ld));
|
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
|
||||||
descriptor_.reset(raw_descriptor);
|
descriptor_.reset(raw_descriptor);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -712,19 +645,13 @@ void gemm_and_bias(
|
|||||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||||
cudaDataType_t scaleType = CUDA_R_32F;
|
cudaDataType_t scaleType = CUDA_R_32F;
|
||||||
if constexpr (std::is_same_v<Dtype, double>) {
|
if constexpr (std::is_same_v<Dtype, double>) {
|
||||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
|
||||||
abcType = CUDA_R_64F;
|
abcType = CUDA_R_64F;
|
||||||
computeType = CUBLAS_COMPUTE_64F;
|
computeType = CUBLAS_COMPUTE_64F;
|
||||||
scaleType = CUDA_R_64F;
|
scaleType = CUDA_R_64F;
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "gemm_and_bias is only supported for double type on ROCm 6.0 and above");
|
|
||||||
#endif
|
|
||||||
} else if constexpr (std::is_same_v<Dtype, float>) {
|
} else if constexpr (std::is_same_v<Dtype, float>) {
|
||||||
#ifndef USE_ROCM
|
|
||||||
if (at::globalContext().allowTF32CuBLAS()) {
|
if (at::globalContext().allowTF32CuBLAS()) {
|
||||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
abcType = CUDA_R_32F;
|
abcType = CUDA_R_32F;
|
||||||
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
||||||
abcType = CUDA_R_16F;
|
abcType = CUDA_R_16F;
|
||||||
@ -741,7 +668,7 @@ void gemm_and_bias(
|
|||||||
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
||||||
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
|
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||||
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
|
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
|
||||||
#if CUDA_VERSION >= 11040 || defined(USE_ROCM)
|
#if CUDA_VERSION >= 11040
|
||||||
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
|
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -758,7 +685,6 @@ void gemm_and_bias(
|
|||||||
size_t workspaceSize = _getWorkspaceSize();
|
size_t workspaceSize = _getWorkspaceSize();
|
||||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
|
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
|
||||||
uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
|
uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
|
||||||
uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
|
uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
|
||||||
@ -767,14 +693,14 @@ void gemm_and_bias(
|
|||||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
|
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
|
||||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
|
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
|
||||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
|
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
|
||||||
#endif
|
|
||||||
|
|
||||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||||
auto workspace = allocator.allocate(workspaceSize);
|
auto workspace = allocator.allocate(workspaceSize);
|
||||||
|
|
||||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||||
int returnedResult = 0;
|
int returnedResult = 0;
|
||||||
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
|
cublasLtHandle_t ltHandle =
|
||||||
|
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||||
ltHandle,
|
ltHandle,
|
||||||
computeDesc.descriptor(),
|
computeDesc.descriptor(),
|
||||||
@ -950,7 +876,8 @@ void scaled_gemm(
|
|||||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
||||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||||
int returnedResult = 0;
|
int returnedResult = 0;
|
||||||
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
|
cublasLtHandle_t ltHandle =
|
||||||
|
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||||
ltHandle,
|
ltHandle,
|
||||||
computeDesc.descriptor(),
|
computeDesc.descriptor(),
|
||||||
@ -1025,7 +952,6 @@ void int8_gemm(
|
|||||||
int64_t mat2_ld,
|
int64_t mat2_ld,
|
||||||
int32_t* result_ptr,
|
int32_t* result_ptr,
|
||||||
int64_t result_ld) {
|
int64_t result_ld) {
|
||||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
|
||||||
|
|
||||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
|
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
|
||||||
cudaDataType_t scaleType = CUDA_R_32I;
|
cudaDataType_t scaleType = CUDA_R_32I;
|
||||||
@ -1044,7 +970,8 @@ void int8_gemm(
|
|||||||
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
|
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
|
||||||
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
|
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
|
||||||
|
|
||||||
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
|
cublasLtHandle_t ltHandle =
|
||||||
|
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
||||||
|
|
||||||
// cublas team: alpha and beta need to be the same dtype as of scaleType
|
// cublas team: alpha and beta need to be the same dtype as of scaleType
|
||||||
at::opmath_type<int32_t> alpha_val = 1;
|
at::opmath_type<int32_t> alpha_val = 1;
|
||||||
@ -1095,14 +1022,11 @@ void int8_gemm(
|
|||||||
computeType,
|
computeType,
|
||||||
" scaleType ",
|
" scaleType ",
|
||||||
scaleType);
|
scaleType);
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
|
|
||||||
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
|
||||||
}
|
}
|
||||||
#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
#endif // !defined(USE_ROCM) && !defined(_MSC_VER)
|
||||||
|
|
||||||
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
|
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION <= 50600
|
#if defined(USE_ROCM) && ROCM_VERSION <= 56000
|
||||||
#define ROCM_CONST_BUG
|
#define ROCM_CONST_BUG
|
||||||
#else
|
#else
|
||||||
#define ROCM_CONST_BUG const
|
#define ROCM_CONST_BUG const
|
||||||
|
|||||||
@ -62,7 +62,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
|||||||
template <>
|
template <>
|
||||||
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
||||||
|
|
||||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
#if !defined(USE_ROCM) && !defined(_MSC_VER)
|
||||||
enum GEMMAndBiasActivationEpilogue {
|
enum GEMMAndBiasActivationEpilogue {
|
||||||
None,
|
None,
|
||||||
RELU,
|
RELU,
|
||||||
@ -149,7 +149,7 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
|
|||||||
template <>
|
template <>
|
||||||
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
|
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
|
||||||
|
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION <= 50500
|
#if defined(USE_ROCM) && ROCM_VERSION <= 55000
|
||||||
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
|
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
|
||||||
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
|
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
|
||||||
hipblasHandle_t handle, hipblasSideMode_t side, hipblasFillMode_t uplo, \
|
hipblasHandle_t handle, hipblasSideMode_t side, hipblasFillMode_t uplo, \
|
||||||
|
|||||||
@ -7,12 +7,6 @@
|
|||||||
#include <cusparse.h>
|
#include <cusparse.h>
|
||||||
#include <cublas_v2.h>
|
#include <cublas_v2.h>
|
||||||
|
|
||||||
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
|
||||||
// added bf16 support
|
|
||||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
||||||
#include <cublasLt.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef CUDART_VERSION
|
#ifdef CUDART_VERSION
|
||||||
#include <cusolverDn.h>
|
#include <cusolverDn.h>
|
||||||
#endif
|
#endif
|
||||||
@ -82,9 +76,6 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
|||||||
/* Handles */
|
/* Handles */
|
||||||
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
||||||
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
||||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||||
|
|
||||||
|
|||||||
@ -9,46 +9,10 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
/**
|
|
||||||
* Note [hipblaslt handles]
|
|
||||||
* ~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
* The cublas documentation states:
|
|
||||||
* cuBLAS handle (cublasHandle_t) encapsulates a cuBLASLt handle.
|
|
||||||
* Any valid cublasHandle_t can be used in place of cublasLtHandle_t with a simple cast.
|
|
||||||
*
|
|
||||||
* hipblaslt does not behave in this way.
|
|
||||||
* A hipblas handle does not encapsulate a hipblaslt handle.
|
|
||||||
*
|
|
||||||
* To work around this difference in behavior, a separate handle pool is available for ROCm builds.
|
|
||||||
* For CUDA builds, getCurrentCUDABlasLtHandle will alias for getCurrentCUDABlasHandle,
|
|
||||||
* whereas for ROCm builds, it is a distinct function.
|
|
||||||
*/
|
|
||||||
|
|
||||||
namespace at::cuda {
|
namespace at::cuda {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
|
||||||
void createCublasLtHandle(cublasLtHandle_t *handle) {
|
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
|
|
||||||
}
|
|
||||||
|
|
||||||
void destroyCublasLtHandle(cublasLtHandle_t handle) {
|
|
||||||
// this is because of something dumb in the ordering of
|
|
||||||
// destruction. Sometimes atexit, the cuda context (or something)
|
|
||||||
// would already be destroyed by the time this gets destroyed. It
|
|
||||||
// happens in fbcode setting. @colesbury and @soumith decided to not destroy
|
|
||||||
// the handle as a workaround.
|
|
||||||
// - Comments of @soumith copied from cuDNN handle pool implementation
|
|
||||||
#ifdef NO_CUDNN_DESTROY_HANDLE
|
|
||||||
#else
|
|
||||||
cublasLtDestroy(handle);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
using CuBlasLtPoolType = DeviceThreadHandlePool<cublasLtHandle_t, createCublasLtHandle, destroyCublasLtHandle>;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||||
return instance;
|
return instance;
|
||||||
@ -177,33 +141,4 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
|||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
||||||
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
|
||||||
#ifdef USE_ROCM
|
|
||||||
int device;
|
|
||||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
|
||||||
|
|
||||||
// Thread local PoolWindows are lazily-initialized
|
|
||||||
// to avoid initialization issues that caused hangs on Windows.
|
|
||||||
// See: https://github.com/pytorch/pytorch/pull/22405
|
|
||||||
// This thread local unique_ptrs will be destroyed when the thread terminates,
|
|
||||||
// releasing its reserved handles back to the pool.
|
|
||||||
|
|
||||||
// Use a leaky singleton for the pool following standard practice around
|
|
||||||
// singletons: https://isocpp.org/wiki/faq/ctors#construct-on-first-use-v2
|
|
||||||
static auto pool = std::shared_ptr<CuBlasLtPoolType>(
|
|
||||||
new CuBlasLtPoolType(), [](CuBlasLtPoolType* p) {
|
|
||||||
// Leak the memory.
|
|
||||||
});
|
|
||||||
thread_local std::unique_ptr<CuBlasLtPoolType::PoolWindow> myPoolWindow(
|
|
||||||
pool->newPoolWindow());
|
|
||||||
|
|
||||||
auto handle = myPoolWindow->reserve(device);
|
|
||||||
return handle;
|
|
||||||
#else
|
|
||||||
return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace at::cuda
|
} // namespace at::cuda
|
||||||
|
|||||||
@ -153,7 +153,7 @@ enum class Activation {
|
|||||||
GELU,
|
GELU,
|
||||||
};
|
};
|
||||||
|
|
||||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
#if !defined(USE_ROCM) && !defined(_MSC_VER)
|
||||||
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
|
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
|
||||||
switch (a) {
|
switch (a) {
|
||||||
case Activation::None:
|
case Activation::None:
|
||||||
@ -171,40 +171,12 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
|||||||
|
|
||||||
static bool getDisableAddmmCudaLt() {
|
static bool getDisableAddmmCudaLt() {
|
||||||
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
|
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
|
||||||
#ifdef USE_ROCM
|
|
||||||
// allow both CUDA and HIP env var names for ROCm builds
|
|
||||||
// also, current default for ROCm builds is disable by default
|
|
||||||
if (env_value == nullptr) {
|
|
||||||
env_value = std::getenv("DISABLE_ADDMM_HIP_LT");
|
|
||||||
}
|
|
||||||
if (env_value != nullptr && strcmp(env_value, "0") == 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
#else
|
|
||||||
if (env_value != nullptr && strcmp(env_value, "1") == 0) {
|
if (env_value != nullptr && strcmp(env_value, "1") == 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
|
||||||
static bool isSupportedHipLtROCmArch(int index) {
|
|
||||||
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
|
||||||
std::string device_arch = prop->gcnArchName;
|
|
||||||
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
|
|
||||||
for (std::string arch : archs) {
|
|
||||||
size_t substring = device_arch.find(arch);
|
|
||||||
if (substring != std::string::npos) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) {
|
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) {
|
||||||
// Make sure to keep addmm_cuda below in sync with this code; it
|
// Make sure to keep addmm_cuda below in sync with this code; it
|
||||||
// preflights a check to try to avoid actually needing to call
|
// preflights a check to try to avoid actually needing to call
|
||||||
@ -226,7 +198,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||||||
at::ScalarType scalar_type = self.scalar_type();
|
at::ScalarType scalar_type = self.scalar_type();
|
||||||
c10::MaybeOwned<Tensor> self_;
|
c10::MaybeOwned<Tensor> self_;
|
||||||
if (&result != &self) {
|
if (&result != &self) {
|
||||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)
|
||||||
// Strangely, if mat2 has only 1 row or column, we get
|
// Strangely, if mat2 has only 1 row or column, we get
|
||||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||||
@ -239,17 +211,10 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||||||
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
|
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
|
||||||
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
|
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
|
||||||
self.is_contiguous() && result.is_contiguous() &&
|
self.is_contiguous() && result.is_contiguous() &&
|
||||||
#ifdef USE_ROCM
|
|
||||||
isSupportedHipLtROCmArch(self.device().index()) &&
|
|
||||||
(scalar_type == at::ScalarType::Float ||
|
|
||||||
scalar_type == at::ScalarType::Half ||
|
|
||||||
scalar_type == at::ScalarType::BFloat16) &&
|
|
||||||
#else
|
|
||||||
(scalar_type == at::ScalarType::Double ||
|
(scalar_type == at::ScalarType::Double ||
|
||||||
scalar_type == at::ScalarType::Float ||
|
scalar_type == at::ScalarType::Float ||
|
||||||
scalar_type == at::ScalarType::Half ||
|
scalar_type == at::ScalarType::Half ||
|
||||||
scalar_type == at::ScalarType::BFloat16) &&
|
scalar_type == at::ScalarType::BFloat16) &&
|
||||||
#endif
|
|
||||||
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
|
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
|
||||||
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||||
@ -269,14 +234,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||||||
}
|
}
|
||||||
self__sizes = self_->sizes();
|
self__sizes = self_->sizes();
|
||||||
} else {
|
} else {
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
|
||||||
useLtInterface = !disable_addmm_cuda_lt &&
|
|
||||||
result.dim() == 2 && result.is_contiguous() &&
|
|
||||||
isSupportedHipLtROCmArch(self.device().index()) &&
|
|
||||||
(scalar_type == at::ScalarType::Float ||
|
|
||||||
scalar_type == at::ScalarType::Half ||
|
|
||||||
scalar_type == at::ScalarType::BFloat16);
|
|
||||||
#endif
|
|
||||||
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
|
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
|
||||||
self__sizes = self_->sizes();
|
self__sizes = self_->sizes();
|
||||||
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
|
||||||
@ -320,7 +277,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||||||
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||||
|
|
||||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
#if !defined(USE_ROCM) && !defined(_MSC_VER)
|
||||||
if (useLtInterface) {
|
if (useLtInterface) {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||||
at::ScalarType::Half,
|
at::ScalarType::Half,
|
||||||
@ -342,7 +299,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||||||
self.const_data_ptr<scalar_t>(),
|
self.const_data_ptr<scalar_t>(),
|
||||||
args.result->data_ptr<scalar_t>(),
|
args.result->data_ptr<scalar_t>(),
|
||||||
args.result_ld,
|
args.result_ld,
|
||||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || defined(USE_ROCM)
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
|
||||||
activation_to_gemm_and_blas_arg(activation)
|
activation_to_gemm_and_blas_arg(activation)
|
||||||
#else
|
#else
|
||||||
// GELU is not supported (and does not compile!) prior
|
// GELU is not supported (and does not compile!) prior
|
||||||
@ -400,7 +357,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||||||
// gating activation_to_gemm_and_blas_arg above; here we are manually
|
// gating activation_to_gemm_and_blas_arg above; here we are manually
|
||||||
// performing a post-GELU because we weren't able to use the GELU
|
// performing a post-GELU because we weren't able to use the GELU
|
||||||
// epilogue above.
|
// epilogue above.
|
||||||
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 11080) && !defined(USE_ROCM)
|
#if !defined(CUDA_VERSION) || CUDA_VERSION < 11080
|
||||||
if (useLtInterface && activation == Activation::GELU) {
|
if (useLtInterface && activation == Activation::GELU) {
|
||||||
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
|
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1257,15 +1257,6 @@ if(USE_ROCM)
|
|||||||
list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN)
|
list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN)
|
||||||
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
|
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
|
||||||
list(APPEND HIP_CXX_FLAGS -std=c++17)
|
list(APPEND HIP_CXX_FLAGS -std=c++17)
|
||||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
|
|
||||||
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
|
||||||
endif()
|
|
||||||
if(HIPBLASLT_CUSTOM_DATA_TYPE)
|
|
||||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
|
|
||||||
endif()
|
|
||||||
if(HIPBLASLT_CUSTOM_COMPUTE_TYPE)
|
|
||||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_COMPUTE_TYPE)
|
|
||||||
endif()
|
|
||||||
add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT})
|
add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT})
|
||||||
add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
|
add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
|
||||||
message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines")
|
message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines")
|
||||||
@ -1291,9 +1282,6 @@ if(USE_ROCM)
|
|||||||
|
|
||||||
set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||||
${PYTORCH_HIP_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB})
|
${PYTORCH_HIP_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB})
|
||||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
|
||||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${hipblaslt_LIBRARIES})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||||
roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
|
roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
|
||||||
|
|||||||
@ -136,7 +136,6 @@ if(HIP_FOUND)
|
|||||||
set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand)
|
set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand)
|
||||||
set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas)
|
set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas)
|
||||||
set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas)
|
set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas)
|
||||||
set(hipblaslt_DIR ${ROCM_PATH}/lib/cmake/hipblaslt)
|
|
||||||
set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen)
|
set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen)
|
||||||
set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft)
|
set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft)
|
||||||
set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft)
|
set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft)
|
||||||
@ -155,9 +154,6 @@ if(HIP_FOUND)
|
|||||||
find_package_and_print_version(hiprand REQUIRED)
|
find_package_and_print_version(hiprand REQUIRED)
|
||||||
find_package_and_print_version(rocblas REQUIRED)
|
find_package_and_print_version(rocblas REQUIRED)
|
||||||
find_package_and_print_version(hipblas REQUIRED)
|
find_package_and_print_version(hipblas REQUIRED)
|
||||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
|
||||||
find_package_and_print_version(hipblaslt REQUIRED)
|
|
||||||
endif()
|
|
||||||
find_package_and_print_version(miopen REQUIRED)
|
find_package_and_print_version(miopen REQUIRED)
|
||||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
|
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
|
||||||
find_package_and_print_version(hipfft REQUIRED)
|
find_package_and_print_version(hipfft REQUIRED)
|
||||||
@ -191,57 +187,4 @@ if(HIP_FOUND)
|
|||||||
find_library(ROCM_HIPRTC_LIB amdhip64 HINTS ${ROCM_PATH}/lib)
|
find_library(ROCM_HIPRTC_LIB amdhip64 HINTS ${ROCM_PATH}/lib)
|
||||||
# roctx is part of roctracer
|
# roctx is part of roctracer
|
||||||
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
|
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
|
||||||
|
|
||||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
|
||||||
# check whether hipblaslt is using its own datatype
|
|
||||||
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
|
|
||||||
file(WRITE ${file} ""
|
|
||||||
"#include <hipblaslt/hipblaslt.h>\n"
|
|
||||||
"int main() {\n"
|
|
||||||
" hipblasltDatatype_t bar = HIPBLASLT_R_16F;\n"
|
|
||||||
" return 0;\n"
|
|
||||||
"}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
|
||||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
|
||||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
|
||||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
|
||||||
|
|
||||||
if(hipblaslt_compile_result)
|
|
||||||
set(HIPBLASLT_CUSTOM_DATA_TYPE ON)
|
|
||||||
#message("hipblaslt is using custom data type: ${hipblaslt_compile_output}")
|
|
||||||
message("hipblaslt is using custom data type")
|
|
||||||
else()
|
|
||||||
set(HIPBLASLT_CUSTOM_DATA_TYPE OFF)
|
|
||||||
#message("hipblaslt is NOT using custom data type: ${hipblaslt_compile_output}")
|
|
||||||
message("hipblaslt is NOT using custom data type")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# check whether hipblaslt is using its own compute type
|
|
||||||
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_compute_type.cc")
|
|
||||||
file(WRITE ${file} ""
|
|
||||||
"#include <hipblaslt/hipblaslt.h>\n"
|
|
||||||
"int main() {\n"
|
|
||||||
" hipblasLtComputeType_t baz = HIPBLASLT_COMPUTE_F32;\n"
|
|
||||||
" return 0;\n"
|
|
||||||
"}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
|
||||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
|
||||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
|
||||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
|
||||||
|
|
||||||
if(hipblaslt_compile_result)
|
|
||||||
set(HIPBLASLT_CUSTOM_COMPUTE_TYPE ON)
|
|
||||||
#message("hipblaslt is using custom compute type: ${hipblaslt_compile_output}")
|
|
||||||
message("hipblaslt is using custom compute type")
|
|
||||||
else()
|
|
||||||
set(HIPBLASLT_CUSTOM_COMPUTE_TYPE OFF)
|
|
||||||
#message("hipblaslt is NOT using custom compute type: ${hipblaslt_compile_output}")
|
|
||||||
message("hipblaslt is NOT using custom compute type")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@ -237,9 +237,6 @@ COMMON_HIP_FLAGS = [
|
|||||||
'-DUSE_ROCM=1',
|
'-DUSE_ROCM=1',
|
||||||
]
|
]
|
||||||
|
|
||||||
if ROCM_VERSION is not None and ROCM_VERSION >= (6, 0):
|
|
||||||
COMMON_HIP_FLAGS.append('-DHIPBLAS_V2')
|
|
||||||
|
|
||||||
COMMON_HIPCC_FLAGS = [
|
COMMON_HIPCC_FLAGS = [
|
||||||
'-DCUDA_HAS_FP16=1',
|
'-DCUDA_HAS_FP16=1',
|
||||||
'-D__HIP_NO_HALF_OPERATORS__=1',
|
'-D__HIP_NO_HALF_OPERATORS__=1',
|
||||||
|
|||||||
@ -611,7 +611,6 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
|
|||||||
("vector_types.h", ("hip/hip_vector_types.h", CONV_INCLUDE, API_RUNTIME)),
|
("vector_types.h", ("hip/hip_vector_types.h", CONV_INCLUDE, API_RUNTIME)),
|
||||||
("cublas.h", ("hipblas/hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
|
("cublas.h", ("hipblas/hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
|
||||||
("cublas_v2.h", ("hipblas/hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
|
("cublas_v2.h", ("hipblas/hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
|
||||||
("cublasLt.h", ("hipblaslt/hipblaslt.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
|
|
||||||
("curand.h", ("hiprand/hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)),
|
("curand.h", ("hiprand/hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)),
|
||||||
("curand_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
("curand_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||||
("curand_discrete.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
("curand_discrete.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||||
@ -3852,7 +3851,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
|||||||
HIP_UNSUPPORTED,
|
HIP_UNSUPPORTED,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
("cudaDataType_t", ("hipDataType", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
|
("cudaDataType_t", ("hipDataType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
|
||||||
("cudaDataType", ("hipDataType", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
|
("cudaDataType", ("hipDataType", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
|
||||||
("CUDA_R_16BF", ("HIP_R_16BF", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
|
("CUDA_R_16BF", ("HIP_R_16BF", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
|
||||||
("CUDA_C_16BF", ("HIP_C_16BF", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
|
("CUDA_C_16BF", ("HIP_C_16BF", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)),
|
||||||
@ -7272,65 +7271,6 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
|||||||
"cublasDrotmg_v2",
|
"cublasDrotmg_v2",
|
||||||
("hipblasDrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED),
|
("hipblasDrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED),
|
||||||
),
|
),
|
||||||
(
|
|
||||||
"cublasComputeType_t",
|
|
||||||
("hipblasComputeType_t" if rocm_version >= (6, 0, 0) else "hipblasLtComputeType_t",
|
|
||||||
CONV_MATH_FUNC, API_BLAS)
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"CUBLAS_COMPUTE_32I",
|
|
||||||
("HIPBLAS_COMPUTE_32I" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_I32", CONV_MATH_FUNC, API_BLAS)
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"CUBLAS_COMPUTE_32F",
|
|
||||||
("HIPBLAS_COMPUTE_32F" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_F32", CONV_MATH_FUNC, API_BLAS)
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"CUBLAS_COMPUTE_64F",
|
|
||||||
("HIPBLAS_COMPUTE_64F" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_F64", CONV_MATH_FUNC, API_BLAS)
|
|
||||||
),
|
|
||||||
("cublasLtEpilogue_t", ("hipblasLtEpilogue_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_EPILOGUE_DEFAULT", ("HIPBLASLT_EPILOGUE_DEFAULT", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_EPILOGUE_RELU", ("HIPBLASLT_EPILOGUE_RELU", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_EPILOGUE_BIAS", ("HIPBLASLT_EPILOGUE_BIAS", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_EPILOGUE_RELU_BIAS", ("HIPBLASLT_EPILOGUE_RELU_BIAS", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_EPILOGUE_GELU", ("HIPBLASLT_EPILOGUE_GELU", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_EPILOGUE_GELU_BIAS", ("HIPBLASLT_EPILOGUE_GELU_BIAS", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtHandle_t", ("hipblasLtHandle_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulDesc_t", ("hipblasLtMatmulDesc_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulDescOpaque_t", ("hipblasLtMatmulDescOpaque_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulDescAttributes_t", ("hipblasLtMatmulDescAttributes_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_DESC_TRANSA", ("HIPBLASLT_MATMUL_DESC_TRANSA", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_DESC_TRANSB", ("HIPBLASLT_MATMUL_DESC_TRANSB", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_DESC_EPILOGUE", ("HIPBLASLT_MATMUL_DESC_EPILOGUE", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_DESC_BIAS_POINTER", ("HIPBLASLT_MATMUL_DESC_BIAS_POINTER", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_DESC_A_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_DESC_B_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulPreference_t", ("hipblasLtMatmulPreference_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulPreferenceOpaque_t", ("hipblasLtMatmulPreferenceOpaque_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulPreferenceAttributes_t", ("hipblasLtMatmulPreferenceAttributes_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_PREF_SEARCH_MODE", ("HIPBLASLT_MATMUL_PREF_SEARCH_MODE", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", ("HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulAlgo_t", ("hipblasLtMatmulAlgo_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulHeuristicResult_t", ("hipblasLtMatmulHeuristicResult_t", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatrixLayoutCreate", ("hipblasLtMatrixLayoutCreate", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatrixLayoutDestroy", ("hipblasLtMatrixLayoutDestroy", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtCreate", ("hipblasLtCreate", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtDestroy", ("hipblasLtDestroy", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulDescCreate", ("hipblasLtMatmulDescCreate", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulDescDestroy", ("hipblasLtMatmulDescDestroy", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulDescSetAttribute", ("hipblasLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulPreferenceCreate", ("hipblasLtMatmulPreferenceCreate", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulPreferenceDestroy", ("hipblasLtMatmulPreferenceDestroy", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulPreferenceSetAttribute", ("hipblasLtMatmulPreferenceSetAttribute", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmulAlgoGetHeuristic", ("hipblasLtMatmulAlgoGetHeuristic", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
("cublasLtMatmul", ("hipblasLtMatmul", CONV_MATH_FUNC, API_BLAS)),
|
|
||||||
(
|
(
|
||||||
"CURAND_STATUS_SUCCESS",
|
"CURAND_STATUS_SUCCESS",
|
||||||
("HIPRAND_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND),
|
("HIPRAND_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND),
|
||||||
@ -7737,14 +7677,8 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
|||||||
HIP_UNSUPPORTED,
|
HIP_UNSUPPORTED,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
(
|
("cuComplex", ("hipblasComplex", CONV_TYPE, API_BLAS)),
|
||||||
"cuComplex",
|
("cuDoubleComplex", ("hipblasDoubleComplex", CONV_TYPE, API_BLAS)),
|
||||||
("hipComplex" if rocm_version >= (6, 0, 0) else "hipblasComplex", CONV_TYPE, API_BLAS)
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"cuDoubleComplex",
|
|
||||||
("hipDoubleComplex" if rocm_version >= (6, 0, 0) else "hipblasDoubleComplex", CONV_TYPE, API_BLAS),
|
|
||||||
),
|
|
||||||
("cufftResult_t", ("hipfftResult_t", CONV_TYPE, API_FFT)),
|
("cufftResult_t", ("hipfftResult_t", CONV_TYPE, API_FFT)),
|
||||||
("cufftResult", ("hipfftResult", CONV_TYPE, API_FFT)),
|
("cufftResult", ("hipfftResult", CONV_TYPE, API_FFT)),
|
||||||
("CUFFT_SUCCESS", ("HIPFFT_SUCCESS", CONV_NUMERIC_LITERAL, API_FFT)),
|
("CUFFT_SUCCESS", ("HIPFFT_SUCCESS", CONV_NUMERIC_LITERAL, API_FFT)),
|
||||||
|
|||||||
Reference in New Issue
Block a user