mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] enforce ROCM_VERSION >= 6.0 (#125646)
Remove any code relying on ROCM_VERSION < 6.0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125646 Approved by: https://github.com/albanD, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
0116ffae7f
commit
ae9a4fa63c
@ -6,9 +6,6 @@ ver() {
|
||||
printf "%3d%03d%03d%03d" $(echo "$1" | tr '.' ' ');
|
||||
}
|
||||
|
||||
# Map ROCm version to AMDGPU version
|
||||
declare -A AMDGPU_VERSIONS=( ["5.0"]="21.50" ["5.1.1"]="22.10.1" ["5.2"]="22.20" )
|
||||
|
||||
install_ubuntu() {
|
||||
apt-get update
|
||||
if [[ $UBUNTU_VERSION == 18.04 ]]; then
|
||||
@ -26,31 +23,14 @@ install_ubuntu() {
|
||||
apt-get install -y libc++1
|
||||
apt-get install -y libc++abi1
|
||||
|
||||
if [[ $(ver $ROCM_VERSION) -ge $(ver 4.5) ]]; then
|
||||
# Add amdgpu repository
|
||||
UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'`
|
||||
local amdgpu_baseurl
|
||||
if [[ $(ver $ROCM_VERSION) -ge $(ver 5.3) ]]; then
|
||||
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu"
|
||||
else
|
||||
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/ubuntu"
|
||||
fi
|
||||
echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list
|
||||
fi
|
||||
|
||||
ROCM_REPO="ubuntu"
|
||||
if [[ $(ver $ROCM_VERSION) -lt $(ver 4.2) ]]; then
|
||||
ROCM_REPO="xenial"
|
||||
fi
|
||||
|
||||
if [[ $(ver $ROCM_VERSION) -ge $(ver 5.3) ]]; then
|
||||
ROCM_REPO="${UBUNTU_VERSION_NAME}"
|
||||
fi
|
||||
# Add amdgpu repository
|
||||
UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'`
|
||||
echo "deb [arch=amd64] https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list
|
||||
|
||||
# Add rocm repository
|
||||
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
|
||||
local rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}"
|
||||
echo "deb [arch=amd64] ${rocm_baseurl} ${ROCM_REPO} main" > /etc/apt/sources.list.d/rocm.list
|
||||
echo "deb [arch=amd64] ${rocm_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/rocm.list
|
||||
apt-get update --allow-insecure-repositories
|
||||
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
|
||||
@ -68,29 +48,18 @@ install_ubuntu() {
|
||||
# precompiled miopen kernels added in ROCm 3.5, renamed in ROCm 5.5
|
||||
# search for all unversioned packages
|
||||
# if search fails it will abort this script; use true to avoid case where search fails
|
||||
if [[ $(ver $ROCM_VERSION) -ge $(ver 5.5) ]]; then
|
||||
MIOPENHIPGFX=$(apt-cache search --names-only miopen-hip-gfx | awk '{print $1}' | grep -F -v . || true)
|
||||
if [[ "x${MIOPENHIPGFX}" = x ]]; then
|
||||
echo "miopen-hip-gfx package not available" && exit 1
|
||||
else
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENHIPGFX}
|
||||
fi
|
||||
MIOPENHIPGFX=$(apt-cache search --names-only miopen-hip-gfx | awk '{print $1}' | grep -F -v . || true)
|
||||
if [[ "x${MIOPENHIPGFX}" = x ]]; then
|
||||
echo "miopen-hip-gfx package not available" && exit 1
|
||||
else
|
||||
MIOPENKERNELS=$(apt-cache search --names-only miopenkernels | awk '{print $1}' | grep -F -v . || true)
|
||||
if [[ "x${MIOPENKERNELS}" = x ]]; then
|
||||
echo "miopenkernels package not available" && exit 1
|
||||
else
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENKERNELS}
|
||||
fi
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENHIPGFX}
|
||||
fi
|
||||
|
||||
# ROCm 6.0 had a regression where journal_mode was enabled on the kdb files resulting in permission errors at runtime
|
||||
if [[ $(ver $ROCM_VERSION) -ge $(ver 6.0) ]]; then
|
||||
for kdb in /opt/rocm/share/miopen/db/*.kdb
|
||||
do
|
||||
sqlite3 $kdb "PRAGMA journal_mode=off; PRAGMA VACUUM;"
|
||||
done
|
||||
fi
|
||||
for kdb in /opt/rocm/share/miopen/db/*.kdb
|
||||
do
|
||||
sqlite3 $kdb "PRAGMA journal_mode=off; PRAGMA VACUUM;"
|
||||
done
|
||||
|
||||
# Cleanup
|
||||
apt-get autoclean && apt-get clean
|
||||
@ -107,25 +76,19 @@ install_centos() {
|
||||
yum install -y epel-release
|
||||
yum install -y dkms kernel-headers-`uname -r` kernel-devel-`uname -r`
|
||||
|
||||
if [[ $(ver $ROCM_VERSION) -ge $(ver 4.5) ]]; then
|
||||
# Add amdgpu repository
|
||||
local amdgpu_baseurl
|
||||
if [[ $OS_VERSION == 9 ]]; then
|
||||
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/rhel/9.0/main/x86_64"
|
||||
else
|
||||
if [[ $(ver $ROCM_VERSION) -ge $(ver 5.3) ]]; then
|
||||
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/rhel/7.9/main/x86_64"
|
||||
else
|
||||
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/rhel/7.9/main/x86_64"
|
||||
fi
|
||||
fi
|
||||
echo "[AMDGPU]" > /etc/yum.repos.d/amdgpu.repo
|
||||
echo "name=AMDGPU" >> /etc/yum.repos.d/amdgpu.repo
|
||||
echo "baseurl=${amdgpu_baseurl}" >> /etc/yum.repos.d/amdgpu.repo
|
||||
echo "enabled=1" >> /etc/yum.repos.d/amdgpu.repo
|
||||
echo "gpgcheck=1" >> /etc/yum.repos.d/amdgpu.repo
|
||||
echo "gpgkey=http://repo.radeon.com/rocm/rocm.gpg.key" >> /etc/yum.repos.d/amdgpu.repo
|
||||
# Add amdgpu repository
|
||||
local amdgpu_baseurl
|
||||
if [[ $OS_VERSION == 9 ]]; then
|
||||
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/rhel/9.0/main/x86_64"
|
||||
else
|
||||
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/rhel/7.9/main/x86_64"
|
||||
fi
|
||||
echo "[AMDGPU]" > /etc/yum.repos.d/amdgpu.repo
|
||||
echo "name=AMDGPU" >> /etc/yum.repos.d/amdgpu.repo
|
||||
echo "baseurl=${amdgpu_baseurl}" >> /etc/yum.repos.d/amdgpu.repo
|
||||
echo "enabled=1" >> /etc/yum.repos.d/amdgpu.repo
|
||||
echo "gpgcheck=1" >> /etc/yum.repos.d/amdgpu.repo
|
||||
echo "gpgkey=http://repo.radeon.com/rocm/rocm.gpg.key" >> /etc/yum.repos.d/amdgpu.repo
|
||||
|
||||
local rocm_baseurl="http://repo.radeon.com/rocm/yum/${ROCM_VERSION}"
|
||||
echo "[ROCm]" > /etc/yum.repos.d/rocm.repo
|
||||
@ -147,29 +110,18 @@ install_centos() {
|
||||
|
||||
# precompiled miopen kernels; search for all unversioned packages
|
||||
# if search fails it will abort this script; use true to avoid case where search fails
|
||||
if [[ $(ver $ROCM_VERSION) -ge $(ver 5.5) ]]; then
|
||||
MIOPENHIPGFX=$(yum -q search miopen-hip-gfx | grep miopen-hip-gfx | awk '{print $1}'| grep -F kdb. || true)
|
||||
if [[ "x${MIOPENHIPGFX}" = x ]]; then
|
||||
echo "miopen-hip-gfx package not available" && exit 1
|
||||
else
|
||||
yum install -y ${MIOPENHIPGFX}
|
||||
fi
|
||||
MIOPENHIPGFX=$(yum -q search miopen-hip-gfx | grep miopen-hip-gfx | awk '{print $1}'| grep -F kdb. || true)
|
||||
if [[ "x${MIOPENHIPGFX}" = x ]]; then
|
||||
echo "miopen-hip-gfx package not available" && exit 1
|
||||
else
|
||||
MIOPENKERNELS=$(yum -q search miopenkernels | grep miopenkernels- | awk '{print $1}'| grep -F kdb. || true)
|
||||
if [[ "x${MIOPENKERNELS}" = x ]]; then
|
||||
echo "miopenkernels package not available" && exit 1
|
||||
else
|
||||
yum install -y ${MIOPENKERNELS}
|
||||
fi
|
||||
yum install -y ${MIOPENHIPGFX}
|
||||
fi
|
||||
|
||||
# ROCm 6.0 had a regression where journal_mode was enabled on the kdb files resulting in permission errors at runtime
|
||||
if [[ $(ver $ROCM_VERSION) -ge $(ver 6.0) ]]; then
|
||||
for kdb in /opt/rocm/share/miopen/db/*.kdb
|
||||
do
|
||||
sqlite3 $kdb "PRAGMA journal_mode=off; PRAGMA VACUUM;"
|
||||
done
|
||||
fi
|
||||
for kdb in /opt/rocm/share/miopen/db/*.kdb
|
||||
do
|
||||
sqlite3 $kdb "PRAGMA journal_mode=off; PRAGMA VACUUM;"
|
||||
done
|
||||
|
||||
# Cleanup
|
||||
yum clean all
|
||||
|
16
RELEASE.md
16
RELEASE.md
@ -48,14 +48,14 @@
|
||||
|
||||
Following is the Release Compatibility Matrix for PyTorch releases:
|
||||
|
||||
| PyTorch version | Python | Stable CUDA | Experimental CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.3 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 |
|
||||
| 2.2 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 |
|
||||
| 2.1 | >=3.8, <=3.11 | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 |
|
||||
| 2.0 | >=3.8, <=3.11 | CUDA 11.7, CUDNN 8.5.0.96 | CUDA 11.8, CUDNN 8.7.0.84 |
|
||||
| 1.13 | >=3.7, <=3.10 | CUDA 11.6, CUDNN 8.3.2.44 | CUDA 11.7, CUDNN 8.5.0.96 |
|
||||
| 1.12 | >=3.7, <=3.10 | CUDA 11.3, CUDNN 8.3.2.44 | CUDA 11.6, CUDNN 8.3.2.44 |
|
||||
| PyTorch version | Python | Stable CUDA | Experimental CUDA | Stable ROCm |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| 2.3 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 | ROCm 6.0 |
|
||||
| 2.2 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 | ROCm 5.7 |
|
||||
| 2.1 | >=3.8, <=3.11 | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 | ROCm 5.6 |
|
||||
| 2.0 | >=3.8, <=3.11 | CUDA 11.7, CUDNN 8.5.0.96 | CUDA 11.8, CUDNN 8.7.0.84 | ROCm 5.4 |
|
||||
| 1.13 | >=3.7, <=3.10 | CUDA 11.6, CUDNN 8.3.2.44 | CUDA 11.7, CUDNN 8.5.0.96 | ROCm 5.2 |
|
||||
| 1.12 | >=3.7, <=3.10 | CUDA 11.3, CUDNN 8.3.2.44 | CUDA 11.6, CUDNN 8.3.2.44 | ROCm 5.0 |
|
||||
|
||||
## Release Cadence
|
||||
|
||||
|
@ -14,9 +14,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#if ROCM_VERSION >= 60000
|
||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||
#endif
|
||||
// until hipblas has an API to accept flags, we must use rocblas here
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <rocblas/rocblas.h>
|
||||
@ -236,57 +234,6 @@ namespace at::cuda::blas {
|
||||
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
|
||||
} while (0)
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
|
||||
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
|
||||
// to hipify correctly without this change.
|
||||
#define hipDataType hipblasDatatype_t
|
||||
#endif
|
||||
|
||||
// hipblaslt custom types were a temporary work-around
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLASLT_CUSTOM_DATA_TYPE)
|
||||
hipblasltDatatype_t hipToLt(hipDataType type) {
|
||||
switch (type) {
|
||||
case HIP_R_32F: return HIPBLASLT_R_32F;
|
||||
case HIP_R_64F: return HIPBLASLT_R_64F;
|
||||
case HIP_R_16F: return HIPBLASLT_R_16F;
|
||||
case HIP_R_8I: return HIPBLASLT_R_8I;
|
||||
case HIP_C_32F: return HIPBLASLT_C_32F;
|
||||
case HIP_C_64F: return HIPBLASLT_C_64F;
|
||||
case HIP_C_16F: return HIPBLASLT_C_16F;
|
||||
case HIP_C_8I: return HIPBLASLT_C_8I;
|
||||
case HIP_R_8U: return HIPBLASLT_R_8U;
|
||||
case HIP_C_8U: return HIPBLASLT_C_8U;
|
||||
case HIP_R_32I: return HIPBLASLT_R_32I;
|
||||
case HIP_C_32I: return HIPBLASLT_C_32I;
|
||||
case HIP_R_32U: return HIPBLASLT_R_32U;
|
||||
case HIP_C_32U: return HIPBLASLT_C_32U;
|
||||
case HIP_R_16BF: return HIPBLASLT_R_16B;
|
||||
case HIP_C_16BF: return HIPBLASLT_C_16B;
|
||||
default: TORCH_CHECK(false, "unknown hipDataType");
|
||||
}
|
||||
}
|
||||
#define HIPTOLT(type) hipToLt(type)
|
||||
#else
|
||||
#define HIPTOLT(type) type
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLASLT_CUSTOM_COMPUTE_TYPE)
|
||||
hipblasLtComputeType_t hipblasToLt(hipblasComputeType_t type) {
|
||||
switch (type) {
|
||||
case HIPBLAS_COMPUTE_32F: return HIPBLASLT_COMPUTE_F32;
|
||||
case HIPBLAS_COMPUTE_32F_FAST_16F: return HIPBLASLT_COMPUTE_F32_FAST_F16;
|
||||
case HIPBLAS_COMPUTE_32F_FAST_TF32: return HIPBLASLT_COMPUTE_F32_FAST_XF32;
|
||||
case HIPBLAS_COMPUTE_64F: return HIPBLASLT_COMPUTE_F64;
|
||||
case HIPBLAS_COMPUTE_32I: return HIPBLASLT_COMPUTE_I32;
|
||||
default: TORCH_CHECK(false, "unknown hipblasComputeType_t");
|
||||
}
|
||||
}
|
||||
#define HIPCOMPTOLT(type) hipblasToLt(type)
|
||||
#else
|
||||
#define HIPCOMPTOLT(type) type
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
// Following the pattern of CuSparseDescriptor
|
||||
@ -325,7 +272,7 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
|
||||
cudaDataType_t scale_type) {
|
||||
cublasLtMatmulDesc_t raw_descriptor = nullptr;
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasLtMatmulDescCreate(&raw_descriptor, HIPCOMPTOLT(compute_type), HIPTOLT(scale_type)));
|
||||
cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
|
||||
descriptor_.reset(raw_descriptor);
|
||||
}
|
||||
template <typename T>
|
||||
@ -346,7 +293,7 @@ class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
|
||||
bool t = false) {
|
||||
cublasLtMatrixLayout_t raw_descriptor = nullptr;
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasLtMatrixLayoutCreate(&raw_descriptor, HIPTOLT(type), t ? cols : rows, t ? rows : cols, ld));
|
||||
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
|
||||
descriptor_.reset(raw_descriptor);
|
||||
}
|
||||
template <typename T>
|
||||
@ -371,11 +318,9 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
||||
};
|
||||
} // namespace
|
||||
|
||||
#endif
|
||||
|
||||
template <typename Dtype>
|
||||
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
cudaDataType_t abcType = CUDA_R_32F;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
@ -506,9 +451,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
computeType,
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
#else
|
||||
AT_ERROR("at::cuda::blas::bgemm_internal_cublaslt: not implemented for ", typeid(Dtype).name());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -632,7 +574,7 @@ void bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
|
||||
const float fbeta = beta;
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
|
||||
#if defined(USE_ROCM)
|
||||
auto compute_type = CUBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = CUDA_R_32F;
|
||||
@ -1018,7 +960,7 @@ void gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||
}
|
||||
#endif
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
|
||||
#if defined(USE_ROCM)
|
||||
auto compute_type = CUBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = CUDA_R_32F;
|
||||
@ -1235,7 +1177,6 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
|
||||
template <typename Dtype>
|
||||
void gemm_and_bias(
|
||||
@ -1260,13 +1201,9 @@ void gemm_and_bias(
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
if constexpr (std::is_same_v<Dtype, double>) {
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
abcType = CUDA_R_64F;
|
||||
computeType = CUBLAS_COMPUTE_64F;
|
||||
scaleType = CUDA_R_64F;
|
||||
#else
|
||||
TORCH_CHECK(false, "gemm_and_bias is only supported for double type on ROCm 6.0 and above");
|
||||
#endif
|
||||
} else if constexpr (std::is_same_v<Dtype, float>) {
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext().allowTF32CuBLAS()) {
|
||||
@ -1473,7 +1410,7 @@ void scaled_gemm(
|
||||
ScalarType result_dtype,
|
||||
void* amax_ptr,
|
||||
bool use_fast_accum) {
|
||||
#if CUDA_VERSION >= 11080 || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
|
||||
const auto computeType = CUBLAS_COMPUTE_32F;
|
||||
const auto scaleType = CUDA_R_32F;
|
||||
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
|
||||
@ -1537,13 +1474,13 @@ if (isFloat8Type(result_dtype)) {
|
||||
hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
|
||||
_cublasOpFromChar(transa),
|
||||
_cublasOpFromChar(transb),
|
||||
HIPTOLT(ScalarTypeToCudaDataType(mat1_dtype)),
|
||||
HIPTOLT(ScalarTypeToCudaDataType(mat2_dtype)),
|
||||
ScalarTypeToCudaDataType(mat1_dtype),
|
||||
ScalarTypeToCudaDataType(mat2_dtype),
|
||||
// C is nullptr and beta=0, so set to something reasonable. See above.
|
||||
//HIPTOLT(ScalarTypeToCudaDataType(bias_dtype)),
|
||||
HIPTOLT(ScalarTypeToCudaDataType(result_dtype)),
|
||||
HIPTOLT(ScalarTypeToCudaDataType(result_dtype)),
|
||||
HIPCOMPTOLT(CUBLAS_COMPUTE_32F),
|
||||
//ScalarTypeToCudaDataType(bias_dtype),
|
||||
ScalarTypeToCudaDataType(result_dtype),
|
||||
ScalarTypeToCudaDataType(result_dtype),
|
||||
CUBLAS_COMPUTE_32F,
|
||||
all_algos));
|
||||
if (all_algos.size() == 0) {
|
||||
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
|
||||
@ -1620,7 +1557,7 @@ if (isFloat8Type(result_dtype)) {
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
return;
|
||||
#endif // CUDA_VERSION >= 11080 || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
#endif // CUDA_VERSION >= 11080 || defined(USE_ROCM)
|
||||
TORCH_CHECK(false, "scaled_gemm is only supported for CUDA 11.8 and above");
|
||||
}
|
||||
|
||||
@ -1636,7 +1573,6 @@ void int8_gemm(
|
||||
int64_t mat2_ld,
|
||||
int32_t* result_ptr,
|
||||
int64_t result_ld) {
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
|
||||
cudaDataType_t scaleType = CUDA_R_32I;
|
||||
@ -1741,18 +1677,7 @@ void int8_gemm(
|
||||
computeType,
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
#else
|
||||
TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
|
||||
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
}
|
||||
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
|
||||
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
|
||||
#if defined(USE_ROCM) && ROCM_VERSION < 50600
|
||||
#define ROCM_CONST_BUG_CAST(Type, Input) const_cast<Type>(reinterpret_cast<const Type>(Input))
|
||||
#else
|
||||
#define ROCM_CONST_BUG_CAST(Type, Input) reinterpret_cast<const Type>(Input)
|
||||
#endif
|
||||
|
||||
template <>
|
||||
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
|
||||
@ -1777,7 +1702,7 @@ void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>)) {
|
||||
m,
|
||||
n,
|
||||
reinterpret_cast<const cuComplex*>(alpha),
|
||||
ROCM_CONST_BUG_CAST(cuComplex*, A),
|
||||
reinterpret_cast<const cuComplex*>(A),
|
||||
lda,
|
||||
reinterpret_cast<cuComplex*>(B),
|
||||
ldb));
|
||||
@ -1794,7 +1719,7 @@ void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>)) {
|
||||
m,
|
||||
n,
|
||||
reinterpret_cast<const cuDoubleComplex*>(alpha),
|
||||
ROCM_CONST_BUG_CAST(cuDoubleComplex*, A),
|
||||
reinterpret_cast<const cuDoubleComplex*>(A),
|
||||
lda,
|
||||
reinterpret_cast<cuDoubleComplex*>(B),
|
||||
ldb));
|
||||
|
@ -82,7 +82,6 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
enum GEMMAndBiasActivationEpilogue {
|
||||
None,
|
||||
RELU,
|
||||
@ -143,7 +142,6 @@ void scaled_gemm(
|
||||
ScalarType result_dtype,
|
||||
void* amax_ptr,
|
||||
bool use_fast_accum);
|
||||
#endif
|
||||
|
||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
|
||||
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
|
||||
@ -190,18 +188,10 @@ void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION <= 50500
|
||||
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
|
||||
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
|
||||
hipblasHandle_t handle, hipblasSideMode_t side, hipblasFillMode_t uplo, \
|
||||
hipblasOperation_t trans, hipblasDiagType_t diag, int m, int n, \
|
||||
const Dtype *alpha, Dtype *A, int lda, Dtype *B, int ldb
|
||||
#else
|
||||
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
|
||||
cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
|
||||
cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
|
||||
const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
|
||||
#endif
|
||||
|
||||
template <typename Dtype>
|
||||
inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
|
||||
|
@ -9,15 +9,13 @@
|
||||
|
||||
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
||||
// added bf16 support
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#include <cublasLt.h>
|
||||
#endif
|
||||
|
||||
#ifdef CUDART_VERSION
|
||||
#include <cusolverDn.h>
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50300
|
||||
#if defined(USE_ROCM)
|
||||
#include <hipsolver/hipsolver.h>
|
||||
#endif
|
||||
|
||||
@ -82,13 +80,11 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
||||
/* Handles */
|
||||
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
||||
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
#endif
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
|
||||
#if defined(CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION >= 50300
|
||||
#if defined(CUDART_VERSION) || defined(USE_ROCM)
|
||||
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
|
||||
#endif
|
||||
|
||||
|
@ -17,16 +17,11 @@ static bool _cuda_graphs_debug = false;
|
||||
constexpr int kSynchronizeBusyWaitMillis = 10;
|
||||
|
||||
MempoolId_t graph_pool_handle() {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
// uuid count starts at 1. 0 is reserved to mean "wasn't set by graph_pool_handle".
|
||||
static std::atomic<CaptureId_t> uid{1};
|
||||
// Sets just the second value, to distinguish it from MempoolId_ts created from
|
||||
// cudaStreamGetCaptureInfo id_s in capture_begin.
|
||||
return {0, uid++};
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 or ROCM >= 5.3")
|
||||
return {0, 0};
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -84,9 +79,6 @@ int CUDAGraph::num_pending_event_queries() {
|
||||
CUDAGraph::CUDAGraph()
|
||||
// CUDAStreams may not be default-constructed.
|
||||
: capture_stream_(at::cuda::getCurrentCUDAStream()) {
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION < 50300)
|
||||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 or ROCM >= 5.3");
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::register_generator_state(
|
||||
@ -102,7 +94,6 @@ void CUDAGraph::register_generator_state(const at::Generator& generator) {
|
||||
}
|
||||
|
||||
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
TORCH_CHECK(!has_graph_exec_,
|
||||
"This CUDAGraph instance already owns a captured graph. "
|
||||
"To capture a new graph, create a new instance.");
|
||||
@ -171,13 +162,9 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(id_ > 0);
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 or ROCM >= 5.3")
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::capture_end() {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
TORCH_CHECK(stream == capture_stream_,
|
||||
@ -245,13 +232,9 @@ void CUDAGraph::capture_end() {
|
||||
} else {
|
||||
TORCH_WARN("DEBUG: TORCH_CUDAGRAPHS_DEBUG_PATH detected. graph_ will not be freed until debug_dump is called.");
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 or ROCM >= 5.3")
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::replay() {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
TORCH_CHECK(has_graph_exec_,
|
||||
"Called CUDAGraph::replay without a preceding successful capture.");
|
||||
|
||||
@ -273,22 +256,14 @@ void CUDAGraph::replay() {
|
||||
// The bug is fixed in CUDA 11.4+.
|
||||
AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA graphs is not yet supported on ROCM");
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::enable_debug_mode() {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
_cuda_graphs_debug = true;
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA graphs is not yet supported on ROCM");
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void CUDAGraph::debug_dump(const std::string& debug_path) {
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11030)|| (defined(USE_ROCM) && ROCM_VERSION >= 50600)
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11030)|| defined(USE_ROCM)
|
||||
if (_cuda_graphs_debug) {
|
||||
TORCH_WARN("DEBUG: calling debug_dump()");
|
||||
if (has_graph_) {
|
||||
@ -305,7 +280,6 @@ void CUDAGraph::debug_dump(const std::string& debug_path) {
|
||||
}
|
||||
|
||||
void CUDAGraph::reset() {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
// I'd prefer these checks throw exceptions, not print warnings,
|
||||
// but the destructor calls reset(), and at least one CI build
|
||||
// refuses to compile with a throwing destructor.
|
||||
@ -337,19 +311,12 @@ void CUDAGraph::reset() {
|
||||
C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_));
|
||||
has_graph_exec_ = false;
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 or ROCM >= 5.3")
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns an id another graph's capture_begin can use to share the same memory pool as this graph.
|
||||
MempoolId_t CUDAGraph::pool() {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
TORCH_CHECK(has_graph_exec_,
|
||||
"Called CUDAGraph::pool() without a preceding successful capture.");
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 or ROCM >= 5.3")
|
||||
#endif
|
||||
return mempool_id_;
|
||||
}
|
||||
|
||||
|
@ -39,10 +39,8 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
void debug_dump(const std::string& debug_path);
|
||||
|
||||
protected:
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
cudaGraph_t graph_ = NULL;
|
||||
cudaGraphExec_t graph_exec_ = NULL;
|
||||
#endif
|
||||
|
||||
static std::atomic<int> pending_event_queries;
|
||||
|
||||
|
@ -19,16 +19,12 @@ using CaptureStatus = c10::cuda::CaptureStatus;
|
||||
|
||||
// Use this version where you don't want to create a CUDA context if none exists.
|
||||
inline CaptureStatus currentStreamCaptureStatus() {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
// don't create a context if we don't have to
|
||||
if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
|
||||
return c10::cuda::currentStreamCaptureStatusMayInitCtx();
|
||||
} else {
|
||||
return CaptureStatus::None;
|
||||
}
|
||||
#else
|
||||
return CaptureStatus::None;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void assertNotCapturing(std::string attempt) {
|
||||
|
@ -68,8 +68,7 @@
|
||||
#endif
|
||||
|
||||
// BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0)
|
||||
#if defined(CUDART_VERSION) || \
|
||||
(defined(USE_ROCM) && ROCM_VERSION >= 40500 )
|
||||
#if defined(CUDART_VERSION) || defined(USE_ROCM)
|
||||
#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1
|
||||
#else
|
||||
#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0
|
||||
|
@ -273,7 +273,6 @@ class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor
|
||||
};
|
||||
#endif
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50200) || !defined(USE_ROCM)
|
||||
class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
|
||||
: public CuSparseDescriptor<cusparseSpGEMMDescr, &cusparseSpGEMM_destroyDescr> {
|
||||
public:
|
||||
@ -283,7 +282,6 @@ class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
|
||||
descriptor_.reset(raw_descriptor);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
||||
|
||||
|
@ -29,7 +29,7 @@ namespace at::cuda {
|
||||
|
||||
namespace {
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
#if defined(USE_ROCM)
|
||||
void createCublasLtHandle(cublasLtHandle_t *handle) {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
|
||||
}
|
||||
@ -191,7 +191,6 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
return handle;
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
#ifdef USE_ROCM
|
||||
c10::DeviceIndex device = 0;
|
||||
@ -218,6 +217,5 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace at::cuda
|
||||
|
@ -80,10 +80,6 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)
|
||||
AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8)
|
||||
AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8)
|
||||
AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8)
|
||||
|
||||
// BFloat16 Radix sort is supported from ROCm 4.5 onwards
|
||||
#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500)
|
||||
AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8)
|
||||
#endif
|
||||
|
||||
} // namespace at::cuda::cub::detail
|
||||
|
@ -51,8 +51,7 @@
|
||||
#define ROCM_HIPCUB(x) x
|
||||
#endif
|
||||
|
||||
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || \
|
||||
(defined(USE_ROCM) && ROCM_VERSION >= 40500)
|
||||
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
namespace at_cuda_detail {
|
||||
@ -110,7 +109,7 @@ struct cuda_type<c10::BFloat16> {
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
#elif (defined(USE_ROCM) && ROCM_VERSION >= 40500)
|
||||
#elif defined(USE_ROCM)
|
||||
|
||||
template<>
|
||||
struct cuda_type<c10::BFloat16> {
|
||||
@ -234,7 +233,7 @@ constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
|
||||
// so split at int_max/2
|
||||
template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, int max_cub_size=impl::max_cub_size>
|
||||
inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
|
||||
#if defined(USE_ROCM) && (ROCM_VERSION >= 50000)
|
||||
#if defined(USE_ROCM)
|
||||
//For ROCm, use hipCUB chained iterators
|
||||
CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan,
|
||||
input,
|
||||
@ -301,7 +300,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
|
||||
template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
|
||||
inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
|
||||
#if defined(USE_ROCM) && (ROCM_VERSION >= 50000)
|
||||
#if defined(USE_ROCM)
|
||||
//For ROCm, use hipCUB chained iterators
|
||||
CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan,
|
||||
input,
|
||||
|
@ -151,11 +151,7 @@ bool CUDAHooks::isPinnedPtr(const void* data) const {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
return attr.type == cudaMemoryTypeHost;
|
||||
#else
|
||||
return attr.memoryType == cudaMemoryTypeHost;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDAHooks::hasCUDA() const {
|
||||
@ -177,7 +173,7 @@ bool CUDAHooks::hasCuDNN() const {
|
||||
bool CUDAHooks::hasCuSOLVER() const {
|
||||
#if defined(CUDART_VERSION) && defined(CUSOLVER_VERSION)
|
||||
return true;
|
||||
#elif AT_ROCM_ENABLED() && defined(ROCM_VERSION) && ROCM_VERSION >= 50300
|
||||
#elif AT_ROCM_ENABLED()
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
@ -187,7 +183,7 @@ bool CUDAHooks::hasCuSOLVER() const {
|
||||
bool CUDAHooks::hasCuBLASLt() const {
|
||||
#if defined(CUDART_VERSION)
|
||||
return true;
|
||||
#elif AT_ROCM_ENABLED() && defined(ROCM_VERSION) && ROCM_VERSION >= 50700
|
||||
#elif AT_ROCM_ENABLED()
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
|
@ -24,64 +24,6 @@
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
#ifdef HIPBLASLT_HAS_GETINDEXFROMALGO
|
||||
#define GETINDEXFROMALGO(algo) hipblaslt_ext::getIndexFromAlgo(algo)
|
||||
#else
|
||||
static int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo) {
|
||||
int* algo_ptr = (int*)algo.data;
|
||||
if(*algo_ptr < 0) {
|
||||
return -1;
|
||||
}
|
||||
return *algo_ptr;
|
||||
}
|
||||
#define GETINDEXFROMALGO(algo) getIndexFromAlgo(algo)
|
||||
#endif
|
||||
|
||||
#ifdef HIPBLASLT_CUSTOM_COMPUTE_TYPE
|
||||
#define COMPUTE_TYPE_32 HIPBLASLT_COMPUTE_F32
|
||||
#else
|
||||
#define COMPUTE_TYPE_32 HIPBLAS_COMPUTE_32F
|
||||
#endif
|
||||
|
||||
#ifdef HIPBLASLT_CUSTOM_DATA_TYPE
|
||||
|
||||
template <typename T>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor();
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
|
||||
return HIPBLASLT_R_32F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<Half>() {
|
||||
return HIPBLASLT_R_16F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
|
||||
return HIPBLASLT_R_16B;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
|
||||
return HIPBLASLT_R_64F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<c10::Float8_e4m3fnuz>() {
|
||||
return HIPBLASLT_R_8F_E4M3;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<c10::Float8_e5m2fnuz>() {
|
||||
return HIPBLASLT_R_8F_E5M3;
|
||||
}
|
||||
|
||||
#define DATA_TYPE_R_32 HIPBLASLT_R_32F
|
||||
|
||||
#else
|
||||
|
||||
template <typename T>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor();
|
||||
|
||||
@ -115,14 +57,6 @@ constexpr hipblasDatatype_t HipBlasDataTypeFor<c10::Float8_e5m2fnuz>() {
|
||||
return HIP_R_8F_E5M2_FNUZ;
|
||||
}
|
||||
|
||||
#ifdef HIPBLAS_V2
|
||||
#define DATA_TYPE_R_32 HIP_R_32F
|
||||
#else
|
||||
#define DATA_TYPE_R_32 HIPBLAS_R_32F
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
int GetBatchFromParams(const GemmParams<T>* params) {
|
||||
return 1;
|
||||
@ -439,7 +373,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
|
||||
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
|
||||
}
|
||||
|
||||
HipBlasLtMatmulDescriptor matmul(COMPUTE_TYPE_32, DATA_TYPE_R_32);
|
||||
HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F);
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
|
||||
|
||||
@ -543,7 +477,7 @@ auto GetHipBlasLtTypeStringAndOps() {
|
||||
b_datatype,
|
||||
in_out_datatype,
|
||||
in_out_datatype,
|
||||
COMPUTE_TYPE_32,
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
heuristic_result));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
|
||||
|
||||
@ -551,14 +485,14 @@ auto GetHipBlasLtTypeStringAndOps() {
|
||||
std::sort(heuristic_result.begin(),
|
||||
heuristic_result.end(),
|
||||
[](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
|
||||
return GETINDEXFROMALGO(a.algo) < GETINDEXFROMALGO(b.algo);
|
||||
return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo);
|
||||
});
|
||||
|
||||
int returned_algo_count = heuristic_result.size();
|
||||
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
|
||||
for (int i = 0; i < returned_algo_count; i++) {
|
||||
auto algo = heuristic_result[i].algo;
|
||||
int algo_index = GETINDEXFROMALGO(algo);
|
||||
int algo_index = hipblaslt_ext::getIndexFromAlgo(algo);
|
||||
auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
|
||||
std::string type_string = c10::str(
|
||||
"Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
|
||||
@ -584,8 +518,5 @@ auto GetHipBlasLtScaledGemmTypeStringAndOps() {
|
||||
}
|
||||
|
||||
#undef TORCH_HIPBLASLT_CHECK
|
||||
#undef GETINDEXFROMALGO
|
||||
#undef COMPUTE_TYPE_32
|
||||
#undef DATA_TYPE_R_32
|
||||
|
||||
} // namespace at::cuda::tunable
|
||||
|
@ -11,9 +11,7 @@
|
||||
|
||||
#include <ATen/cuda/tunable/GemmCommon.h>
|
||||
#ifdef USE_ROCM
|
||||
#if ROCM_VERSION >= 50700
|
||||
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
||||
#endif
|
||||
#include <ATen/cuda/tunable/GemmRocblas.h>
|
||||
#endif
|
||||
#include <ATen/cuda/tunable/StreamTimer.h>
|
||||
@ -220,7 +218,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
#if defined(USE_ROCM)
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env == nullptr || strcmp(env, "1") == 0) {
|
||||
// disallow tuning of hipblaslt with c10::complex
|
||||
@ -294,7 +292,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
#if defined(USE_ROCM)
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env == nullptr || strcmp(env, "1") == 0) {
|
||||
// disallow tuning of hipblaslt with c10::complex
|
||||
@ -334,7 +332,7 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
|
||||
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
#if defined(USE_ROCM)
|
||||
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
|
@ -104,10 +104,6 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
|
||||
int deviceCnt;
|
||||
hipError_t _err;
|
||||
_err = hipGetDeviceCount(&deviceCnt);
|
||||
#if defined(USE_ROCM) && (ROCM_VERSION < 50201)
|
||||
if(_err == hipErrorInvalidDevice)
|
||||
return 0;
|
||||
#endif
|
||||
if(_err != hipErrorNoDevice && _err != hipSuccess)
|
||||
C10_HIP_CHECK(_err);
|
||||
return deviceCnt;
|
||||
|
@ -157,7 +157,6 @@ enum class Activation {
|
||||
GELU,
|
||||
};
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
|
||||
switch (a) {
|
||||
case Activation::None:
|
||||
@ -171,7 +170,6 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
||||
return cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
|
||||
@ -236,7 +234,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
at::ScalarType scalar_type = self.scalar_type();
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700))
|
||||
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM)
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
@ -283,7 +281,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
}
|
||||
self__sizes = self_->sizes();
|
||||
} else {
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
#if defined(USE_ROCM)
|
||||
useLtInterface = !disable_addmm_cuda_lt &&
|
||||
result.dim() == 2 && result.is_contiguous() &&
|
||||
isSupportedHipLtROCmArch(self.device().index()) &&
|
||||
@ -334,7 +332,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700))
|
||||
if (useLtInterface) {
|
||||
#if defined(USE_ROCM)
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
@ -398,7 +395,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
});
|
||||
#endif
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
@ -770,7 +766,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result)
|
||||
|
||||
TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
|
||||
|
||||
#if (!defined(USE_ROCM) && defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000))
|
||||
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || defined(USE_ROCM)
|
||||
cublasCommonArgs args(self, mat2, result);
|
||||
|
||||
at::cuda::blas::int8_gemm(
|
||||
@ -910,7 +906,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
at::native::resize_output(amax, {});
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000))
|
||||
cublasCommonArgs args(mat1, mat2, out);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
@ -1016,11 +1011,8 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
amax.data_ptr(),
|
||||
use_fast_accum);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
|
||||
#if defined(USE_ROCM)
|
||||
// rocm's hipblaslt does not yet support amax, so calculate separately
|
||||
amax = at::max(at::abs(out.to(kFloat)));
|
||||
#endif
|
||||
|
@ -233,18 +233,6 @@ void launch_stable_sort_kernel(
|
||||
TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort);
|
||||
int64_t* indices_ptr = indices.mutable_data_ptr<int64_t>();
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION < 40500)
|
||||
constexpr bool is_rocm_bf16_sort_unsupported = true;
|
||||
#else
|
||||
constexpr bool is_rocm_bf16_sort_unsupported = false;
|
||||
#endif
|
||||
|
||||
if constexpr (is_rocm_bf16_sort_unsupported) {
|
||||
if (self.scalar_type() == kBFloat16) {
|
||||
TORCH_CHECK(false, "BFloat16 is not supported on ROCm < 4.5");
|
||||
}
|
||||
}
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBool, kHalf, kBFloat16, self.scalar_type(), "sort", [&] {
|
||||
const scalar_t* self_ptr = self.const_data_ptr<scalar_t>();
|
||||
|
@ -328,7 +328,7 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
|
||||
// gesvd just knows how to handle m >= n, so in the other case we need to transpose A
|
||||
const auto not_A_H = A.size(-2) >= A.size(-1);
|
||||
Tensor Vcopy = V; // Shallow copy
|
||||
#ifdef ROCM_VERSION
|
||||
#ifdef USE_ROCM
|
||||
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
|
||||
// not guaranteed to be column major on ROCM, we have to create a copy to
|
||||
// deal with this
|
||||
@ -347,7 +347,7 @@ inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Te
|
||||
infos,
|
||||
full_matrices, compute_uv, calculate_all_batches, batches);
|
||||
});
|
||||
#ifdef ROCM_VERSION
|
||||
#ifdef USE_ROCM
|
||||
if (!not_A_H) {
|
||||
V.copy_(Vcopy);
|
||||
}
|
||||
@ -661,7 +661,7 @@ void svd_cusolver(const Tensor& A,
|
||||
static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
|
||||
|
||||
// The default heuristic is to use gesvdj driver
|
||||
#ifdef ROCM_VERSION
|
||||
#ifdef USE_ROCM
|
||||
const auto driver_v = c10::string_view("gesvdj");
|
||||
#else
|
||||
const auto driver_v = driver.value_or("gesvdj");
|
||||
|
@ -8,7 +8,7 @@
|
||||
#include <ATen/native/TransposeType.h>
|
||||
#include <ATen/native/cuda/MiscUtils.h>
|
||||
|
||||
#if (defined(CUDART_VERSION) && defined(CUSOLVER_VERSION)) || (defined(USE_ROCM) && ROCM_VERSION >= 50300)
|
||||
#if (defined(CUDART_VERSION) && defined(CUSOLVER_VERSION)) || defined(USE_ROCM)
|
||||
#define USE_LINALG_SOLVER
|
||||
#endif
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#if defined(CUDART_VERSION) || defined(ROCM_VERSION) && ROCM_VERSION >= 50300
|
||||
#if defined(CUDART_VERSION) || defined(USE_ROCM)
|
||||
|
||||
namespace at::cuda::solver {
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
#define USE_CUSOLVER_64_BIT
|
||||
#endif
|
||||
|
||||
#if defined(CUDART_VERSION) || defined(ROCM_VERSION) && ROCM_VERSION >= 50300
|
||||
#if defined(CUDART_VERSION) || defined(USE_ROCM)
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/DeviceThreadHandles.h>
|
||||
|
||||
#if defined(CUDART_VERSION) || defined(ROCM_VERSION) && ROCM_VERSION >= 50300
|
||||
#if defined(CUDART_VERSION) || defined(USE_ROCM)
|
||||
|
||||
namespace at::cuda {
|
||||
namespace {
|
||||
|
@ -692,13 +692,6 @@ void spgemm(
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const at::sparse_csr::SparseCsrTensor& C) {
|
||||
#if defined(USE_ROCM) && ROCM_VERSION < 50200
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Calling addmm with sparse GPU tensors requires compiling ",
|
||||
"PyTorch with ROCm 5.2+. ",
|
||||
"Please use PyTorch built with newer ROCm version.");
|
||||
#else
|
||||
// older versions of cusparse on Windows segfault for complex128 dtype
|
||||
#if defined(_WIN32) && defined(CUSPARSE_VERSION) && CUSPARSE_VERSION < 11400
|
||||
TORCH_CHECK(
|
||||
@ -834,7 +827,6 @@ void spgemm(
|
||||
CUSPARSE_SPGEMM_DEFAULT,
|
||||
spgemm_desc.descriptor()));
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
@ -19,7 +19,7 @@
|
||||
#define IS_SPMM_AVAILABLE() 0
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50200
|
||||
#if defined(USE_ROCM)
|
||||
#define IS_SPMM_HIP_AVAILABLE() 1
|
||||
#else
|
||||
#define IS_SPMM_HIP_AVAILABLE() 0
|
||||
|
@ -778,12 +778,9 @@ struct MempoolIdHash {
|
||||
};
|
||||
|
||||
cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
if (at::cuda::currentStreamCaptureStatusMayInitCtx() ==
|
||||
at::cuda::CaptureStatus::None) {
|
||||
#endif
|
||||
return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size));
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
} else {
|
||||
// It's ok to capture cudaMallocs, as long as we never cudaFree those
|
||||
// addresses before replay.
|
||||
@ -792,7 +789,6 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
|
||||
at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed};
|
||||
return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
@ -2183,7 +2179,6 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
BlockPool& get_pool(size_t size, cudaStream_t stream) {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
// captures_underway is a conservative guess that the current stream may be
|
||||
// capturing. It's only non-empty if some thread has begun and not yet ended
|
||||
// a capture, so it's usually 0, and we can short-circuit
|
||||
@ -2201,7 +2196,6 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (size <= kSmallSize) {
|
||||
return small_blocks;
|
||||
} else {
|
||||
|
@ -17,7 +17,6 @@ using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
|
||||
|
||||
// RAII guard for "cudaStreamCaptureMode", a thread-local value
|
||||
// that controls the error-checking strictness of a capture.
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
struct C10_CUDA_API CUDAStreamCaptureModeGuard {
|
||||
CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired)
|
||||
: strictness_(desired) {
|
||||
@ -30,9 +29,7 @@ struct C10_CUDA_API CUDAStreamCaptureModeGuard {
|
||||
private:
|
||||
cudaStreamCaptureMode strictness_;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
// Protects against enum cudaStreamCaptureStatus implementation changes.
|
||||
// Some compilers seem not to like static_assert without the messages.
|
||||
static_assert(
|
||||
@ -44,16 +41,11 @@ static_assert(
|
||||
static_assert(
|
||||
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
|
||||
"unexpected int(cudaStreamCaptureStatusInvalidated) value");
|
||||
#endif
|
||||
|
||||
enum class CaptureStatus : int {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
|
||||
Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
|
||||
Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
|
||||
#else
|
||||
None = 0
|
||||
#endif
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
|
||||
@ -61,14 +53,12 @@ inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
|
||||
case CaptureStatus::None:
|
||||
os << "cudaStreamCaptureStatusNone";
|
||||
break;
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
case CaptureStatus::Active:
|
||||
os << "cudaStreamCaptureStatusActive";
|
||||
break;
|
||||
case CaptureStatus::Invalidated:
|
||||
os << "cudaStreamCaptureStatusInvalidated";
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Unknown CUDA graph CaptureStatus", int(status));
|
||||
@ -78,14 +68,10 @@ inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
|
||||
|
||||
// Use this version where you're sure a CUDA context exists already.
|
||||
inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone};
|
||||
C10_CUDA_CHECK(
|
||||
cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
|
||||
return CaptureStatus(is_capturing);
|
||||
#else
|
||||
return CaptureStatus::None;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
@ -323,8 +323,7 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
|
||||
// CUDA_KERNEL_ASSERT checks the assertion
|
||||
// even when NDEBUG is defined. This is useful for important assertions in CUDA
|
||||
// code that would otherwise be suppressed when building Release.
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__) || \
|
||||
(defined(USE_ROCM) && ROCM_VERSION < 40100)
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__)
|
||||
// Those platforms do not support assert()
|
||||
#define CUDA_KERNEL_ASSERT(cond)
|
||||
#define SYCL_KERNEL_ASSERT(cond)
|
||||
|
@ -34,8 +34,7 @@ TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) {
|
||||
}
|
||||
|
||||
// On these platforms there's no assert
|
||||
#if !defined(__ANDROID__) && !defined(__APPLE__) && \
|
||||
!(defined(USE_ROCM) && ROCM_VERSION < 40100)
|
||||
#if !defined(__ANDROID__) && !defined(__APPLE__)
|
||||
TEST(ExceptionTest, CUDA_KERNEL_ASSERT) {
|
||||
// This function always throws even in NDEBUG mode
|
||||
ASSERT_DEATH_IF_SUPPORTED({ CUDA_KERNEL_ASSERT(false); }, "Assert");
|
||||
|
@ -70,7 +70,7 @@ int GetGPUIDForPointer(const void* ptr) {
|
||||
// Otherwise, there must be no error
|
||||
CUDA_ENFORCE(err);
|
||||
|
||||
if (attr.CAFFE2_CUDA_PTRATTR_MEMTYPE == cudaMemoryTypeHost) {
|
||||
if (attr.type == cudaMemoryTypeHost) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
@ -86,12 +86,6 @@ namespace caffe2 {
|
||||
class TensorCoreEngine {};
|
||||
#endif // USE_ROCM
|
||||
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
#define CAFFE2_CUDA_PTRATTR_MEMTYPE type
|
||||
#else
|
||||
#define CAFFE2_CUDA_PTRATTR_MEMTYPE memoryType
|
||||
#endif
|
||||
|
||||
/**
|
||||
* A runtime function to report the cuda version that Caffe2 is built with.
|
||||
*/
|
||||
|
@ -48,7 +48,7 @@ TEST(CUDAContextTest, MemoryPoolAllocateDealloc) {
|
||||
EXPECT_NE(allocated, nullptr);
|
||||
cudaPointerAttributes attr;
|
||||
CUDA_ENFORCE(cudaPointerGetAttributes(&attr, allocated.get()));
|
||||
EXPECT_EQ(attr.CAFFE2_CUDA_PTRATTR_MEMTYPE, cudaMemoryTypeDevice);
|
||||
EXPECT_EQ(attr.type, cudaMemoryTypeDevice);
|
||||
EXPECT_EQ(attr.device, i);
|
||||
void* prev_allocated = allocated.get();
|
||||
allocated.clear();
|
||||
|
@ -620,11 +620,6 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
|
||||
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
||||
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
||||
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = HIPBLAS_R_32F;
|
||||
#endif
|
||||
HIPBLAS_ENFORCE(hipblasGemmEx(
|
||||
context->hipblas_handle(),
|
||||
cu_trans_B,
|
||||
@ -643,7 +638,7 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
|
||||
C,
|
||||
HIPBLAS_R_16F,
|
||||
N,
|
||||
compute_type,
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
HIPBLAS_GEMM_DEFAULT));
|
||||
#else
|
||||
CUBLAS_ENFORCE(cublasSgemmEx(
|
||||
@ -861,7 +856,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
|
||||
thrust::device_vector<void*> C_device(C, C + batch_size);
|
||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||
#if defined(USE_ROCM)
|
||||
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = CUDA_R_32F;
|
||||
@ -957,7 +952,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
|
||||
if (math_type == TensorProto_DataType_FLOAT) {
|
||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||
#if defined(USE_ROCM)
|
||||
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = CUDA_R_32F;
|
||||
@ -1076,11 +1071,6 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
|
||||
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
||||
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
||||
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||
#else
|
||||
auto compute_type = HIPBLAS_R_32F;
|
||||
#endif
|
||||
HIPBLAS_ENFORCE(hipblasGemmEx(
|
||||
context->hipblas_handle(),
|
||||
cu_trans_A,
|
||||
@ -1099,7 +1089,7 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
|
||||
y,
|
||||
HIPBLAS_R_16F,
|
||||
ldc,
|
||||
compute_type,
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
HIPBLAS_GEMM_DEFAULT));
|
||||
#else
|
||||
CUBLAS_ENFORCE(cublasSgemmEx(
|
||||
|
@ -1213,18 +1213,7 @@ if(USE_ROCM)
|
||||
list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN)
|
||||
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
|
||||
list(APPEND HIP_CXX_FLAGS -std=c++17)
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
||||
endif()
|
||||
if(HIPBLASLT_CUSTOM_DATA_TYPE)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
|
||||
endif()
|
||||
if(HIPBLASLT_CUSTOM_COMPUTE_TYPE)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_COMPUTE_TYPE)
|
||||
endif()
|
||||
if(HIPBLASLT_HAS_GETINDEXFROMALGO)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_HAS_GETINDEXFROMALGO)
|
||||
endif()
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
||||
if(HIP_NEW_TYPE_ENUMS)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
|
||||
endif()
|
||||
@ -1256,9 +1245,7 @@ if(USE_ROCM)
|
||||
|
||||
set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
${PYTORCH_HIP_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB})
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${hipblaslt_LIBRARIES})
|
||||
endif()
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${hipblaslt_LIBRARIES})
|
||||
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
|
||||
@ -1281,18 +1268,6 @@ if(USE_ROCM)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ---[ ROCm
|
||||
if(USE_ROCM AND ROCM_VERSION_DEV VERSION_LESS "5.2.0")
|
||||
# We check again for USE_ROCM because it might have been set to OFF
|
||||
# in the if above
|
||||
include_directories(SYSTEM ${HIP_PATH}/include)
|
||||
include_directories(SYSTEM ${HIPBLAS_PATH}/include)
|
||||
include_directories(SYSTEM ${HIPFFT_PATH}/include)
|
||||
include_directories(SYSTEM ${HIPSPARSE_PATH}/include)
|
||||
include_directories(SYSTEM ${HIPRAND_PATH}/include)
|
||||
include_directories(SYSTEM ${THRUST_PATH})
|
||||
endif()
|
||||
|
||||
# ---[ NCCL
|
||||
if(USE_NCCL)
|
||||
if(NOT (USE_CUDA OR USE_ROCM))
|
||||
|
@ -155,15 +155,9 @@ if(HIP_FOUND)
|
||||
find_package_and_print_version(hiprand REQUIRED)
|
||||
find_package_and_print_version(rocblas REQUIRED)
|
||||
find_package_and_print_version(hipblas REQUIRED)
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
||||
find_package_and_print_version(hipblaslt REQUIRED)
|
||||
endif()
|
||||
find_package_and_print_version(hipblaslt REQUIRED)
|
||||
find_package_and_print_version(miopen REQUIRED)
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
|
||||
find_package_and_print_version(hipfft REQUIRED)
|
||||
else()
|
||||
find_package_and_print_version(rocfft REQUIRED)
|
||||
endif()
|
||||
find_package_and_print_version(hipfft REQUIRED)
|
||||
find_package_and_print_version(hipsparse REQUIRED)
|
||||
find_package_and_print_version(rccl)
|
||||
find_package_and_print_version(rocprim REQUIRED)
|
||||
@ -191,88 +185,6 @@ if(HIP_FOUND)
|
||||
# roctx is part of roctracer
|
||||
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
|
||||
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
||||
# check whether hipblaslt is using its own datatype
|
||||
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
|
||||
file(WRITE ${file} ""
|
||||
"#include <hipblaslt/hipblaslt.h>\n"
|
||||
"int main() {\n"
|
||||
" hipblasltDatatype_t bar = HIPBLASLT_R_16F;\n"
|
||||
" return 0;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
try_compile(hipblaslt_compile_result_custom_datatype ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
||||
|
||||
if(hipblaslt_compile_result_custom_datatype)
|
||||
set(HIPBLASLT_CUSTOM_DATA_TYPE ON)
|
||||
#message("hipblaslt is using custom data type: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt is using custom data type")
|
||||
else()
|
||||
set(HIPBLASLT_CUSTOM_DATA_TYPE OFF)
|
||||
#message("hipblaslt is NOT using custom data type: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt is NOT using custom data type")
|
||||
endif()
|
||||
|
||||
# check whether hipblaslt is using its own compute type
|
||||
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_compute_type.cc")
|
||||
file(WRITE ${file} ""
|
||||
"#include <hipblaslt/hipblaslt.h>\n"
|
||||
"int main() {\n"
|
||||
" hipblasLtComputeType_t baz = HIPBLASLT_COMPUTE_F32;\n"
|
||||
" return 0;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
try_compile(hipblaslt_compile_result_custom_compute_type ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
||||
|
||||
if(hipblaslt_compile_result_custom_compute_type)
|
||||
set(HIPBLASLT_CUSTOM_COMPUTE_TYPE ON)
|
||||
#message("hipblaslt is using custom compute type: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt is using custom compute type")
|
||||
else()
|
||||
set(HIPBLASLT_CUSTOM_COMPUTE_TYPE OFF)
|
||||
#message("hipblaslt is NOT using custom compute type: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt is NOT using custom compute type")
|
||||
endif()
|
||||
|
||||
# check whether hipblaslt provides getIndexFromAlgo
|
||||
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_getIndexFromAlgo.cc")
|
||||
file(WRITE ${file} ""
|
||||
"#include <hipblaslt/hipblaslt.h>\n"
|
||||
"#include <hipblaslt/hipblaslt-ext.hpp>\n"
|
||||
"int main() {\n"
|
||||
" hipblasLtMatmulAlgo_t algo;\n"
|
||||
" return hipblaslt_ext::getIndexFromAlgo(algo);\n"
|
||||
" return 0;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
try_compile(hipblaslt_compile_result_getindexfromalgo ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS
|
||||
"-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
"-DLINK_DIRECTORIES=${ROCM_PATH}/lib"
|
||||
LINK_LIBRARIES ${hipblaslt_LIBRARIES}
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
||||
|
||||
if(hipblaslt_compile_result_getindexfromalgo)
|
||||
set(HIPBLASLT_HAS_GETINDEXFROMALGO ON)
|
||||
#message("hipblaslt provides getIndexFromAlgo: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt provides getIndexFromAlgo")
|
||||
else()
|
||||
set(HAS_GETINDEXFROMALGO OFF)
|
||||
#message("hipblaslt does not provide getIndexFromAlgo: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt does not provide getIndexFromAlgo")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# check whether HIP declares new types
|
||||
set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc")
|
||||
file(WRITE ${file} ""
|
||||
@ -283,18 +195,18 @@ if(HIP_FOUND)
|
||||
"}\n"
|
||||
)
|
||||
|
||||
try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
try_compile(hip_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
||||
OUTPUT_VARIABLE hip_compile_output)
|
||||
|
||||
if(hipblaslt_compile_result)
|
||||
if(hip_compile_result)
|
||||
set(HIP_NEW_TYPE_ENUMS ON)
|
||||
#message("HIP is using new type enums: ${hipblaslt_compile_output}")
|
||||
#message("HIP is using new type enums: ${hip_compile_output}")
|
||||
message("HIP is using new type enums")
|
||||
else()
|
||||
set(HIP_NEW_TYPE_ENUMS OFF)
|
||||
#message("HIP is NOT using new type enums: ${hipblaslt_compile_output}")
|
||||
#message("HIP is NOT using new type enums: ${hip_compile_output}")
|
||||
message("HIP is NOT using new type enums")
|
||||
endif()
|
||||
|
||||
|
@ -1550,17 +1550,8 @@ void GraphTask::stash_current_streams() {
|
||||
caller_current_streams_.resize(num_devices);
|
||||
if (num_devices > 0) {
|
||||
for (c10::DeviceIndex idx = 0; idx < num_devices; idx++) {
|
||||
#if defined(USE_ROCM) && (ROCM_VERSION < 50000)
|
||||
// If the build targets ROCM, stash streams for all visible devices
|
||||
// unconditionally, to work around
|
||||
// https://github.com/pytorch/pytorch/issues/59750.
|
||||
// TODO: Remove ROCM-specific behavior when
|
||||
// https://github.com/pytorch/pytorch/issues/59750 is fixed.
|
||||
if (true) {
|
||||
#else
|
||||
if (at::globalContext().getAcceleratorHooksInterface().hasPrimaryContext(
|
||||
idx)) {
|
||||
#endif
|
||||
caller_current_streams_[idx] = guard.getStream({accelerator, idx});
|
||||
} else {
|
||||
caller_current_streams_[idx] = c10::nullopt;
|
||||
|
@ -821,7 +821,7 @@ void all2all_single_equal_split(
|
||||
const auto* sendbuff = reinterpret_cast<const char*>(input.const_data_ptr());
|
||||
auto* recvbuff = reinterpret_cast<char*>(output.data_ptr());
|
||||
auto comm = to_nccl_comm(_comm);
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50000
|
||||
#if defined(USE_ROCM)
|
||||
NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream));
|
||||
#else
|
||||
NCCL_CHECK(ncclCommCount(comm, &numranks));
|
||||
|
@ -659,25 +659,6 @@ std::string generateKernel(
|
||||
env.s("RandInit", "");
|
||||
}
|
||||
|
||||
// HIP headers must be included until precompiled header feature is available
|
||||
// clang-format off
|
||||
#if defined(USE_ROCM)
|
||||
#if ROCM_VERSION < 40200
|
||||
if (use_cuda && has_half_tensor) {
|
||||
env.s("RuntimeHeader", R"(
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
)");
|
||||
} else if (use_cuda) {
|
||||
env.s("RuntimeHeader", R"(
|
||||
#include <hip/hip_runtime.h>
|
||||
)");
|
||||
}
|
||||
#else
|
||||
// Still need the key defined, but empty.
|
||||
env.s("RuntimeHeader", R"()");
|
||||
#endif
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
// Instantiates the CUDA or CPU-specific templates
|
||||
|
@ -128,9 +128,7 @@ FusedKernelCUDA::FusedKernelCUDA(
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
std::vector<const char*> args = {"--std=c++17"};
|
||||
#if ROCM_VERSION >= 40200
|
||||
args.push_back("-hip-pch");
|
||||
#endif
|
||||
#else
|
||||
const std::string compute = std::string("--gpu-architecture=") +
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
|
||||
@ -190,16 +188,8 @@ FusedKernelCUDA::FusedKernelCUDA(
|
||||
nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str()));
|
||||
|
||||
// Computes max blocks
|
||||
#if defined(USE_ROCM) && ROCM_VERSION < 30500
|
||||
// HIP function signature is not compatible yet
|
||||
uint32_t max_blocks;
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_blocks, function_, 128, 0));
|
||||
maxBlocks_ = max_blocks;
|
||||
#else
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxBlocks_, function_, 128, 0));
|
||||
#endif
|
||||
maxBlocks_ *= prop_->multiProcessorCount;
|
||||
|
||||
// Resets device (end of hacked at::DeviceGuard)
|
||||
|
@ -15,7 +15,6 @@ cases*/
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
static auto type_declarations_template = at::jit::CodeTemplate(R"(
|
||||
${RuntimeHeader}
|
||||
${HalfHeader}
|
||||
${BFloat16Header}
|
||||
${RandHeader}
|
||||
|
@ -897,14 +897,6 @@ void CudaCodeGen::Initialize() {
|
||||
HalfChecker halfChecker(buffer_args());
|
||||
stmt_v->accept(&halfChecker);
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#if ROCM_VERSION < 40200
|
||||
os() << "#include <hip/hip_runtime.h>" << std::endl;
|
||||
if (halfChecker.hasHalf()) {
|
||||
os() << "#include <hip/hip_fp16.h>" << std::endl;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
os() << device_resource_string << shared_resource_string;
|
||||
|
||||
if (has_random_) {
|
||||
@ -1319,9 +1311,7 @@ void CudaCodeGen::CompileToNVRTC(
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
std::vector<const char*> args = {"--std=c++17"};
|
||||
#if ROCM_VERSION >= 40200
|
||||
args.push_back("-hip-pch");
|
||||
#endif
|
||||
#else
|
||||
const std::string compute = std::string("--gpu-architecture=") +
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
|
||||
|
@ -235,11 +235,9 @@ COMMON_HIP_FLAGS = [
|
||||
'-fPIC',
|
||||
'-D__HIP_PLATFORM_AMD__=1',
|
||||
'-DUSE_ROCM=1',
|
||||
'-DHIPBLAS_V2',
|
||||
]
|
||||
|
||||
if ROCM_VERSION is not None and ROCM_VERSION >= (6, 0):
|
||||
COMMON_HIP_FLAGS.append('-DHIPBLAS_V2')
|
||||
|
||||
COMMON_HIPCC_FLAGS = [
|
||||
'-DCUDA_HAS_FP16=1',
|
||||
'-D__HIP_NO_HALF_OPERATORS__=1',
|
||||
@ -1083,8 +1081,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
||||
libraries.append('torch_cpu')
|
||||
libraries.append('torch_python')
|
||||
if IS_HIP_EXTENSION:
|
||||
assert ROCM_VERSION is not None
|
||||
libraries.append('amdhip64' if ROCM_VERSION >= (3, 5) else 'hip_hcc')
|
||||
libraries.append('amdhip64')
|
||||
libraries.append('c10_hip')
|
||||
libraries.append('torch_hip')
|
||||
else:
|
||||
@ -1907,9 +1904,8 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone):
|
||||
if CUDNN_HOME is not None:
|
||||
extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}')
|
||||
elif IS_HIP_EXTENSION:
|
||||
assert ROCM_VERSION is not None
|
||||
extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
|
||||
extra_ldflags.append('-lamdhip64' if ROCM_VERSION >= (3, 5) else '-lhip_hcc')
|
||||
extra_ldflags.append('-lamdhip64')
|
||||
return extra_ldflags
|
||||
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from .constants import (API_BLAS, API_C10, API_CAFFE2, API_DRIVER, API_FFT,
|
||||
API_PYTORCH, API_RAND, API_ROCTX, API_RTC, API_RUNTIME,
|
||||
@ -27,40 +24,6 @@ ROCm/HIP string, a type and API annotation and - optionally - an annotation if i
|
||||
supported in ROCm/HIP yet.
|
||||
"""
|
||||
|
||||
# We need to know the ROCm version so we can conditionalize some of the mappings later.
|
||||
# As of ROCm 5.0, the version is found in rocm_version.h header file under /opt/rocm/include.
|
||||
rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
|
||||
try:
|
||||
rocm_path = subprocess.check_output(["hipconfig", "--rocmpath"]).decode("utf-8")
|
||||
except subprocess.CalledProcessError:
|
||||
print(f"Warning: hipconfig --rocmpath failed, assuming {rocm_path}")
|
||||
except (FileNotFoundError, PermissionError, NotADirectoryError):
|
||||
# Do not print warning. This is okay. This file can also be imported for non-ROCm builds.
|
||||
pass
|
||||
|
||||
rocm_version = (0, 0, 0)
|
||||
rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
|
||||
if not os.path.isfile(rocm_version_h):
|
||||
rocm_version_h = f"{rocm_path}/include/rocm_version.h"
|
||||
|
||||
# The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
|
||||
if os.path.isfile(rocm_version_h):
|
||||
RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
|
||||
RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
|
||||
RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
|
||||
major, minor, patch = 0, 0, 0
|
||||
for line in open(rocm_version_h):
|
||||
match = RE_MAJOR.search(line)
|
||||
if match:
|
||||
major = int(match.group(1))
|
||||
match = RE_MINOR.search(line)
|
||||
if match:
|
||||
minor = int(match.group(1))
|
||||
match = RE_PATCH.search(line)
|
||||
if match:
|
||||
patch = int(match.group(1))
|
||||
rocm_version = (major, minor, patch)
|
||||
|
||||
# List of math functions that should be replaced inside device code only.
|
||||
MATH_TRANSPILATIONS = collections.OrderedDict(
|
||||
[
|
||||
@ -7304,20 +7267,19 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||
),
|
||||
(
|
||||
"cublasComputeType_t",
|
||||
("hipblasComputeType_t" if rocm_version >= (6, 0, 0) else "hipblasLtComputeType_t",
|
||||
CONV_MATH_FUNC, API_BLAS)
|
||||
("hipblasComputeType_t", CONV_MATH_FUNC, API_BLAS)
|
||||
),
|
||||
(
|
||||
"CUBLAS_COMPUTE_32I",
|
||||
("HIPBLAS_COMPUTE_32I" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_I32", CONV_MATH_FUNC, API_BLAS)
|
||||
("HIPBLAS_COMPUTE_32I", CONV_MATH_FUNC, API_BLAS)
|
||||
),
|
||||
(
|
||||
"CUBLAS_COMPUTE_32F",
|
||||
("HIPBLAS_COMPUTE_32F" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_F32", CONV_MATH_FUNC, API_BLAS)
|
||||
("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS)
|
||||
),
|
||||
(
|
||||
"CUBLAS_COMPUTE_64F",
|
||||
("HIPBLAS_COMPUTE_64F" if rocm_version >= (6, 0, 0) else "HIPBLASLT_COMPUTE_F64", CONV_MATH_FUNC, API_BLAS)
|
||||
("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS)
|
||||
),
|
||||
("cublasLtEpilogue_t", ("hipblasLtEpilogue_t", CONV_MATH_FUNC, API_BLAS)),
|
||||
("CUBLASLT_EPILOGUE_DEFAULT", ("HIPBLASLT_EPILOGUE_DEFAULT", CONV_MATH_FUNC, API_BLAS)),
|
||||
@ -7770,14 +7732,8 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||
HIP_UNSUPPORTED,
|
||||
),
|
||||
),
|
||||
(
|
||||
"cuComplex",
|
||||
("hipComplex" if rocm_version >= (6, 0, 0) else "hipblasComplex", CONV_TYPE, API_BLAS)
|
||||
),
|
||||
(
|
||||
"cuDoubleComplex",
|
||||
("hipDoubleComplex" if rocm_version >= (6, 0, 0) else "hipblasDoubleComplex", CONV_TYPE, API_BLAS),
|
||||
),
|
||||
("cuComplex", ("hipComplex", CONV_TYPE, API_BLAS)),
|
||||
("cuDoubleComplex", ("hipDoubleComplex", CONV_TYPE, API_BLAS)),
|
||||
("cufftResult_t", ("hipfftResult_t", CONV_TYPE, API_FFT)),
|
||||
("cufftResult", ("hipfftResult", CONV_TYPE, API_FFT)),
|
||||
("CUFFT_SUCCESS", ("HIPFFT_SUCCESS", CONV_NUMERIC_LITERAL, API_FFT)),
|
||||
|
Reference in New Issue
Block a user