[ROCm] Changes not to rely on CUDA_VERSION or HIP_VERSION (#65610)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65610

- Replace HIP_PLATFORM_HCC with USE_ROCM
- Dont rely on CUDA_VERSION or HIP_VERSION and use USE_ROCM and ROCM_VERSION.

- In the next PR
   - Will be removing the mapping from CUDA_VERSION to HIP_VERSION and CUDA to HIP in hipify.
   - HIP_PLATFORM_HCC is deprecated, so will add HIP_PLATFORM_AMD to support HIP host code compilation on gcc.

cc jeffdaily sunway513 jithunnair-amd ROCmSupport amathews-amd

Reviewed By: jbschlosser

Differential Revision: D30909053

Pulled By: ezyang

fbshipit-source-id: 224a966ebf1aaec79beccbbd686fdf3d49267e06
This commit is contained in:
Pruthvi Madugundu
2021-09-29 09:53:51 -07:00
committed by Facebook GitHub Bot
parent 9b40eaaaab
commit 085e2f7bdd
131 changed files with 415 additions and 398 deletions

View File

@ -76,11 +76,11 @@ TORCH_API void record_kernel_function_dtype(std::string name);
// Workaround for C10_UNUSED because CUDA 10.1 and below fails to handle unused
// attribute in the type aliasing context. Keep name long and verbose to avoid
// macro collisions.
#if defined(__CUDACC__) && CUDA_VERSION <= 10100
#if defined(__CUDACC__) && defined(CUDA_VERSION) && CUDA_VERSION <= 10100
#define C10_UNUSED_DISPATCH_CUDA_WORKAROUND
#else
#define C10_UNUSED_DISPATCH_CUDA_WORKAROUND C10_UNUSED
#endif // defined(__CUDACC__) && CUDA_VERSION <= 10100
#endif // defined(__CUDACC__) && defined(CUDA_VERSION) && CUDA_VERSION <= 10100
#if defined __cpp_if_constexpr
#define AT_QINT_PRIVATE_CASE_TYPE( \

View File

@ -17,7 +17,7 @@ struct Array {
C10_HOST_DEVICE T& operator[](int i) {
return data[i];
}
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_HOST_DEVICE Array() = default;
C10_HOST_DEVICE Array(const Array&) = default;
C10_HOST_DEVICE Array& operator=(const Array&) = default;

View File

@ -167,7 +167,7 @@ static inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
}
static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
__atomic_fetch_add(address, val, __ATOMIC_RELAXED);
#else
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
@ -179,7 +179,7 @@ static inline __device__ void gpuAtomicAdd(bool *address, bool val) {
}
static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
#if ((CUDA_VERSION < 10000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 10000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
return AtomicFPOp<at::Half>()(address, val,
[](at::Half hsum, at::Half val) {
return hsum + val;
@ -196,7 +196,7 @@ static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BF
});
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
#if defined(CUDA_VERSION) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
// from CUDA C Programmic Guide
static inline __device__ double atomicAdd(double* address, double val)
#if defined(__clang__) && defined(__CUDA__)
@ -212,7 +212,7 @@ static inline __device__ double atomicAdd(double* address, double val)
return __double_as_longlong(val + __longlong_as_double(assumed));
});
}
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000) || defined(__HIP_PLATFORM_HCC__)
#elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__) && (defined(CUDA_VERSION) && CUDA_VERSION < 8000))
/* Note [hip-clang differences to hcc]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -298,7 +298,7 @@ static inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BF
static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
/* Special case fp32 atomic. */
#if defined(__HIP_PLATFORM_HCC__) && defined(__gfx908__)
#if defined(USE_ROCM) && defined(__gfx908__)
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); }
#else
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }

View File

@ -274,7 +274,7 @@ template <typename Op,
typename IndexType,
int ADims,
int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
@ -360,7 +360,7 @@ template <typename Op,
int step,
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)
#endif
__global__ void

View File

@ -133,7 +133,7 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) {
/* LEVEL 3 BLAS FUNCTIONS */
#ifndef __HIP_PLATFORM_HCC__
#ifndef USE_ROCM
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx
#else
@ -271,7 +271,7 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
BGEMM_CHECK_ARGVALUES(at::Half);
float falpha = alpha;
float fbeta = beta;
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
@ -284,7 +284,7 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5){
@ -308,11 +308,11 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
#endif // __HIP_PLATFORM_HCC__
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
#endif // USE_ROCM
}
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
// See Note [Writing Nondeterministic Operations]
@ -332,7 +332,7 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
#elif defined(__HIP_PLATFORM_HCC__)
#elif defined(USE_ROCM)
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea,
b, rocblas_datatype_bf16_r, (int)ldb, strideb,
@ -344,7 +344,7 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
TORCH_CHECK(false, "CUDA BFloat16 bgemm requires CUDA 11 or later");
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
}
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
@ -372,7 +372,7 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
}
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
@ -389,7 +389,7 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
}
#endif
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
@ -417,7 +417,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
handle,
opa,
@ -450,7 +450,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
opa,
@ -475,7 +475,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
} else {
TORCH_CUDABLAS_CHECK(cublasSgemmEx(
handle,
@ -499,7 +499,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
#endif
}
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
@ -569,7 +569,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP));
}
#endif
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
@ -702,7 +702,7 @@ void trsmBatched<c10::complex<double>>(
CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
} while (0)
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
@ -718,7 +718,7 @@ void trsmBatched<c10::complex<double>>(
}
#endif
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
// gemv is bw bound, and does not benefit from TF32. But the precision
@ -797,7 +797,7 @@ void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)) {
'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
}
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)) {
bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
@ -838,7 +838,7 @@ void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
template <>
void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
#if CUDA_VERSION >= 8000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 8000
TORCH_CUDABLAS_CHECK(cublasDotEx(
handle,
n,
@ -851,7 +851,7 @@ void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
result,
CUDA_R_16F,
CUDA_R_32F));
#elif TORCH_HIP_VERSION >= 210
#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
TORCH_CUDABLAS_CHECK(rocblas_hdot(
handle,
n,
@ -867,7 +867,7 @@ void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
template <>
void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
TORCH_CUDABLAS_CHECK(cublasDotEx(
handle,
n,
@ -880,7 +880,7 @@ void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
result,
CUDA_R_16BF,
CUDA_R_32F));
#elif TORCH_HIP_VERSION >= 210
#elif defined(ROCM_VERSION) && ROCM_VERSION >= 21000
TORCH_CUDABLAS_CHECK(rocblas_bfdot(
handle,
n,

View File

@ -54,17 +54,17 @@ template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
#endif
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
#endif
template <>
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif
@ -90,7 +90,7 @@ template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
#endif
@ -152,7 +152,7 @@ template <>
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
template <>
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 210)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 21000)
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
template <>
@ -160,7 +160,7 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
#endif
template <>
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#if defined(USE_ROCM) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
#endif

View File

@ -32,7 +32,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
CUDAEvent(
DeviceIndex device_index, const cudaIpcEventHandle_t* handle) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
device_index_ = device_index;
CUDAGuard guard(device_index_);
@ -148,7 +148,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
// Note: cudaIpcGetEventHandle must be called on the same device as the event
void ipc_handle(cudaIpcEventHandle_t * handle) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
if (!is_created_) {
// this CUDAEvent object was initially constructed from flags but event_
// is not created yet.

View File

@ -9,14 +9,14 @@ namespace at {
namespace cuda {
MempoolId_t graph_pool_handle() {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// uuid count starts at 1. 0 is reserved to mean "wasn't set by graph_pool_handle".
static std::atomic<CaptureId_t> uuid{1};
// Sets just the second value, to distinguish it from MempoolId_ts created from
// cudaStreamGetCaptureInfo id_s in capture_begin.
return {0, uuid++};
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
return {0, 0};
#endif
}
@ -45,13 +45,13 @@ MempoolId_t graph_pool_handle() {
CUDAGraph::CUDAGraph()
// CUDAStreams may not be default-constructed.
: capture_stream_(at::cuda::getCurrentCUDAStream()) {
#if CUDA_VERSION < 11000
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11000) || defined(USE_ROCM)
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
#endif
}
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
TORCH_CHECK(!has_graph_exec_,
"This CUDAGraph instance already owns a captured graph. "
"To capture a new graph, create a new instance.");
@ -120,12 +120,12 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
// kernel will end up as part of the capture or not.
c10::cuda::CUDACachingAllocator::notifyCaptureBegin(capture_dev_, id_, mempool_id_);
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
#endif
}
void CUDAGraph::capture_end() {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
auto stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(stream == capture_stream_,
@ -156,12 +156,12 @@ void CUDAGraph::capture_end() {
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
has_graph_ = false;
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
#endif
}
void CUDAGraph::replay() {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
TORCH_CHECK(has_graph_exec_,
"Called CUDAGraph::replay without a preceding successful capture.");
@ -190,12 +190,12 @@ void CUDAGraph::replay() {
cudaDeviceSynchronize();
}
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
#endif
}
void CUDAGraph::reset() {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// I'd prefer these checks throw exceptions, not print warnings,
// but the destructor calls reset(), and at least one CI build
// refuses to compile with a throwing destructor.
@ -226,17 +226,17 @@ void CUDAGraph::reset() {
C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_));
}
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
#endif
}
// Returns an id another graph's capture_begin can use to share the same memory pool as this graph.
MempoolId_t CUDAGraph::pool() {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
TORCH_CHECK(has_graph_exec_,
"Called CUDAGraph::pool() without a preceding successful capture.");
#else
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0");
TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM");
#endif
return mempool_id_;
}

View File

@ -26,7 +26,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
MempoolId_t pool();
protected:
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
cudaGraph_t graph_ = NULL;
cudaGraphExec_t graph_exec_ = NULL;
#endif

View File

@ -47,7 +47,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
auto handle = myPoolWindow->reserve(device);
auto stream = c10::cuda::getCurrentCUDAStream();
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
@ -57,7 +57,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#endif
#if defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 308
#if defined(USE_ROCM) && ROCM_VERSION >= 30800
rocblas_atomics_mode rocblas_mode;
if (at::globalContext().deterministicAlgorithms()) {
rocblas_mode = rocblas_atomics_not_allowed;

View File

@ -6,7 +6,7 @@
__device__ __forceinline__ unsigned int ACTIVE_MASK()
{
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
return __activemask();
#else
// will be ignored anyway
@ -14,7 +14,7 @@ __device__ __forceinline__ unsigned int ACTIVE_MASK()
#endif
}
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
{
return __ballot(predicate);
@ -22,7 +22,7 @@ return __ballot(predicate);
#else
__device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
{
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
return __ballot_sync(mask, predicate);
#else
return __ballot(predicate);
@ -33,7 +33,7 @@ __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
@ -43,7 +43,7 @@ __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = wa
template <typename T>
__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff)
{
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
return __shfl_sync(mask, value, srcLane, width);
#else
return __shfl(value, srcLane, width);
@ -53,7 +53,7 @@ __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSiz
template <typename T>
__device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
return __shfl_up_sync(mask, value, delta, width);
#else
return __shfl_up(value, delta, width);
@ -63,14 +63,14 @@ __device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width
template <typename T>
__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
return __shfl_down_sync(mask, value, delta, width);
#else
return __shfl_down(value, delta, width);
#endif
}
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
template<>
__device__ __forceinline__ int64_t WARP_SHFL_DOWN<int64_t>(int64_t value, unsigned int delta, int width , unsigned int mask)
{
@ -91,7 +91,7 @@ __device__ __forceinline__ c10::Half WARP_SHFL_DOWN<c10::Half>(c10::Half value,
template <typename T>
__device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
return c10::complex<T>(
__shfl_down_sync(mask, value.real_, delta, width),
__shfl_down_sync(mask, value.imag_, delta, width));
@ -107,7 +107,7 @@ __device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value,
*/
template <typename T>
__device__ __forceinline__ T doLdg(const T* p) {
#if __CUDA_ARCH__ >= 350 && !defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 && !defined(USE_ROCM)
return __ldg(p);
#else
return *p;

View File

@ -89,7 +89,7 @@ const char* cusolverGetErrorMessage(cusolverStatus_t status);
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
// in ATen, and we need to use its nvrtcGetErrorString.
// See NOTE [ USE OF NVRTC AND DRIVER API ].
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#define AT_CUDA_DRIVER_CHECK(EXPR) \
do { \

View File

@ -29,7 +29,7 @@
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
#ifdef __HIP_PLATFORM_HCC__
#ifdef USE_ROCM
#define NO_ROCM(x)
#else
#define NO_ROCM(x) x
@ -67,7 +67,7 @@ struct cuda_type<c10::BFloat16> {
using type = __nv_bfloat16;
};
#elif !defined(__HIP_PLATFORM_HCC__)
#elif !defined(USE_ROCM)
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16

View File

@ -25,7 +25,7 @@
#include <magma_v2.h>
#endif
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#include <miopen/version.h>
#endif
@ -93,7 +93,7 @@ bool CUDAHooks::isPinnedPtr(void* data) const {
}
cudaPointerAttributes attr;
cudaError_t err = cudaPointerGetAttributes(&attr, data);
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
if (err == cudaErrorInvalidValue) {
cudaGetLastError();
return false;
@ -106,7 +106,7 @@ bool CUDAHooks::isPinnedPtr(void* data) const {
return false;
}
#endif
#if CUDA_VERSION >= 10000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000
return attr.type == cudaMemoryTypeHost;
#else
return attr.memoryType == cudaMemoryTypeHost;
@ -287,7 +287,7 @@ std::string CUDAHooks::showConfig() const {
}
};
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
oss << " - CUDA Runtime ";
#else
oss << " - HIP Runtime ";
@ -296,7 +296,7 @@ std::string CUDAHooks::showConfig() const {
oss << "\n";
// TODO: Make HIPIFY understand CUDART_VERSION macro
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
if (runtimeVersion != CUDART_VERSION) {
oss << " - Built with CUDA Runtime ";
printCudaStyleVersion(CUDART_VERSION);
@ -305,7 +305,7 @@ std::string CUDAHooks::showConfig() const {
oss << " - NVCC architecture flags: " << NVCC_FLAGS_EXTRA << "\n";
#endif
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#if AT_CUDNN_ENABLED()

View File

@ -147,7 +147,7 @@ nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
NVRTC_STUB1(nvrtcDestroyProgram, nvrtcProgram *);
NVRTC_STUB2(nvrtcGetPTXSize, nvrtcProgram, size_t *);
NVRTC_STUB2(nvrtcGetPTX, nvrtcProgram, char *);
#if CUDA_VERSION >= 11010
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
NVRTC_STUB2(nvrtcGetCUBINSize, nvrtcProgram, size_t *);
NVRTC_STUB2(nvrtcGetCUBIN, nvrtcProgram, char *);
#endif

View File

@ -13,7 +13,7 @@
// Operands that share the same shape, but may have different strides.
// OffsetCalculator iterates the tensor in a column-major order
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
constexpr int MAX_DIMS = 16;
#else
constexpr int MAX_DIMS = 25;

View File

@ -29,7 +29,7 @@ namespace at { namespace cuda {
// and edit ATen/cuda/detail/LazyNVRTC.cpp accordingly (e.g., via one of the stub
// macros).
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#define AT_FORALL_NVRTC_BASE(_) \
_(nvrtcVersion) \
@ -56,7 +56,7 @@ namespace at { namespace cuda {
_(cuLinkAddData) \
_(cuLinkComplete)
#if CUDA_VERSION >= 11010
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
#define AT_FORALL_NVRTC(_) \
AT_FORALL_NVRTC_BASE(_) \
_(nvrtcGetCUBINSize) \

View File

@ -120,7 +120,7 @@ bool check_fast_path_restrictions(
bool can_use_fast_route(ArrayRef<TensorList> tensorLists,
ArrayRef<Scalar> scalarList = {},
bool does_op_promote_integer_inputs_to_float = false) {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
return false;
#else
return check_fast_path_restrictions(tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
@ -128,7 +128,7 @@ bool can_use_fast_route(ArrayRef<TensorList> tensorLists,
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, bool does_op_promote_integer_inputs_to_float = false) {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
return false;
#else
return can_use_fast_route({tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);

View File

@ -432,7 +432,7 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
// rrelu
// -----------------------------------
template <typename scalar_t, int unroll_factor, typename F>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(256, 4)
#endif
__global__ void rrelu_with_noise_cuda_kernel(

View File

@ -112,7 +112,7 @@ public:
~CuFFTHandle() {
// Not using fftDestroy() for rocFFT to work around double freeing of handles
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
cufftDestroy(handle_);
#endif
}
@ -123,7 +123,7 @@ static bool is_pow_of_two(int64_t x) {
return (x & (x - 1)) == 0;
}
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
using cufft_size_type = int;
#else
using cufft_size_type = long long int;
@ -258,7 +258,7 @@ public:
// use a flag to keep track throughout this function to see if we need to
// input = input.clone();
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
// clone input to avoid issues with hipfft clobering the input and failing tests
clone_input = true;
#else
@ -300,7 +300,7 @@ public:
const bool simple_layout = in_layout.simple && out_layout.simple;
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
hipfftType exec_type = [&]{
if (dtype == kFloat) {
switch (fft_type) {
@ -350,7 +350,7 @@ public:
// by assuming istride = ostride = 1.
//
// See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1,
@ -362,7 +362,7 @@ public:
batch, &ws_size_t, exec_type));
#endif
} else {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
in_layout.embed.data(), in_layout.stride, in_layout.dist,
out_layout.embed.data(), out_layout.stride, out_layout.dist,
@ -392,7 +392,7 @@ private:
ScalarType value_type_;
};
#if CUDA_VERSION < 10000
#if (defined(CUDA_VERSION) && CUDA_VERSION < 10000) || defined(USE_ROCM)
// Note that the max plan number for CUDA version < 10 has to be 1023
// due to a bug that fails on the 1024th plan
constexpr int64_t CUFFT_MAX_PLAN_NUM = 1023;

View File

@ -49,7 +49,7 @@ static inline std::string _cudaGetErrorEnum(cufftResult error)
return "CUFFT_NO_WORKSPACE";
case CUFFT_NOT_IMPLEMENTED:
return "CUFFT_NOT_IMPLEMENTED";
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
case CUFFT_LICENSE_ERROR:
return "CUFFT_LICENSE_ERROR";
#endif

View File

@ -69,11 +69,11 @@ __global__ void conv_depthwise2d_forward_kernel(
acc_t value = biasEnabled ? static_cast<acc_t>(bias.data()[c]) : acc_t(0);
const index_t offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int kH = 0; kH < KH_LIMIT; ++kH) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int kW = 0; kW < KW_LIMIT; ++kW) {
@ -125,17 +125,17 @@ __global__ void conv_depthwise2d_backward_kernel(
acc_t value(0);
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int multiplier = 0; multiplier < depthwiseMultiplier; ++multiplier) {
int och = (c * depthwiseMultiplier) + multiplier;
int weightOffset = och * kernelHeight * kernelWidth;
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int kh = 0; kh < KH_LIMIT; ++kh) {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#pragma unroll
#endif
for (int kw = 0; kw < KW_LIMIT; ++kw) {

View File

@ -1,7 +1,7 @@
#pragma once
namespace at { namespace native {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
// take these out when ROCm implements std:: math functions
#include <math.h>
template <typename scalar_t>

View File

@ -160,7 +160,7 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
static const int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t>
#if defined (__HIP_PLATFORM_HCC__)
#if defined (USE_ROCM)
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
#else
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)

View File

@ -87,7 +87,7 @@ void binomial_cuda_kernel(
at::native::distribution_binary_kernel(iter, philox_args,
[philox_args] GPU_LAMBDA (curandStatePhilox4_32_10_t& state, scalar_t count, scalar_t prob) {
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
auto uniform_lambda = curand_uniform_wrapper(state);
BaseSampler<accscalar_t, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
auto sample = sample_binomial<scalar_t, accscalar_t, decltype(uniform_lambda)>(count, prob, standard_uniform);

View File

@ -29,7 +29,7 @@ template <
typename IndexType,
int ADims,
int VEC>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(256, 4)
#endif
__global__ void fused_dropout_kernel_vec(
@ -118,7 +118,7 @@ template <
typename IndexType,
int ADims,
int BDims = ADims>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(256, 4)
#endif
__global__ void fused_dropout_kernel(

View File

@ -16,7 +16,7 @@ namespace at { namespace native {
namespace {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
static const int BLOCKDIMY = 16;
#else
static const int BLOCKDIMY = 32;
@ -83,7 +83,7 @@ __global__ void embedding_backward_feature_kernel
(dst_row == indices_batch[chunk_start - batch_start + threadIdx.x]);
if(threadIdx.x >= n_this_chunk)
match_found_this_thread = 0;
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
unsigned long long int matchmask = WARP_BALLOT(match_found_this_thread);
int first_remaining_peer = __ffsll(matchmask) - 1;
#else
@ -96,7 +96,7 @@ __global__ void embedding_backward_feature_kernel
matchmask ^= (1 << first_remaining_peer);
while(matchmask)
{
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
first_remaining_peer = __ffsll(matchmask) - 1;
#else
first_remaining_peer = __ffs(matchmask) - 1;

View File

@ -237,7 +237,7 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
dim3 block = dim3(64, 4);
#else
dim3 block = dim3(32, 8);
@ -335,7 +335,7 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices_,
max_indices = at::empty({0}, indices.options());
}
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
dim3 block = dim3(64, 4);
#else
dim3 block = dim3(32, 8);

View File

@ -229,7 +229,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
// cub on CUDA <= 11.2 have a bug that for small sizes
// cub's sort can be much slower than thrust's merge sort
// this bug is fixed in CUDA 11.3
#if defined(CUDA_VERSION) && CUDA_VERSION < 11030
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) || defined(USE_ROCM)
if (num_indices < 50000) {
index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
} else

View File

@ -33,8 +33,9 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
index_t index,
const index_t numel,
scalar_t value) {
#if ( \
(CUDA_VERSION < 10000) || \
#if ( \
(defined(USE_ROCM)) || \
(defined(CUDA_VERSION) && (CUDA_VERSION < 10000)) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
gpuAtomicAddNoReturn(
reinterpret_cast<at::Half*>(tensor) + index,

View File

@ -81,7 +81,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
// Because for some reason trying to enable vectorized
// memory access introduce regression on ROCm.
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#include <ATen/native/cuda/CUDALoops.cuh>
#else
#include <ATen/native/cuda/ROCmLoops.cuh>

View File

@ -57,7 +57,7 @@ __device__ static inline int64_t get_target_prime(
// computed when we start a new block_s. This is why we have our own for loop here.
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
#if defined (USE_ROCM)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
@ -413,7 +413,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
// alphabets the inplace nature is a considerable advantage.
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
#if defined (USE_ROCM)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
@ -465,7 +465,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da
// It appears to be faster than the above method for small batch sizes.
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
#if defined (USE_ROCM)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
@ -537,7 +537,7 @@ ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
// elements are padded
template<typename scalar_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
#if defined (USE_ROCM)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_zero_padded_gradients(

View File

@ -25,7 +25,7 @@ struct MAGMAQueue {
// Constructor
explicit MAGMAQueue(int64_t device_id) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// Magma operations is numerically sensitive, so TF32 should be off
// regardless of the global flag.
TORCH_CUDABLAS_CHECK(cublasGetMathMode(handle, &original_math_mode));
@ -44,7 +44,7 @@ struct MAGMAQueue {
// Destructor
~MAGMAQueue() {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// We've manually set the math mode to CUBLAS_DEFAULT_MATH, now we
// should restore the original math mode back
cublasHandle_t handle = magma_queue_get_cublas_handle(magma_queue_);
@ -55,7 +55,7 @@ struct MAGMAQueue {
private:
magma_queue_t magma_queue_;
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
cublasMath_t original_math_mode;
#endif
};

View File

@ -206,7 +206,7 @@ void slow_conv_dilated_all_cuda_template(
output.zero_();
}
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
/* When using ROCm, the sum evaluation is inaccurate for double
tensors. The reason is currently unknown. Hence, we use gemv for
computing `grad_output_n.sum(dims)` until the ROCm-sum issue is

View File

@ -12,7 +12,7 @@
namespace at { namespace native {
// The maximum number of threads in a block
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 256;
#else
constexpr int MAX_BLOCK_SIZE = 512;
@ -22,7 +22,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };

View File

@ -81,7 +81,7 @@ T sigmoid(T in) {
namespace kernel {
template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void lstm_cell_forward(
@ -168,7 +168,7 @@ __global__ void lstm_cell_forward(
}
template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void lstm_cell_backward(
@ -233,7 +233,7 @@ __global__ void lstm_cell_backward(
}
template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void gru_cell_forward(
@ -303,7 +303,7 @@ __global__ void gru_cell_forward(
}
template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void gru_cell_backward(

View File

@ -14,7 +14,7 @@
namespace at {
namespace native {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
constexpr int CAT_ARRAY_BATCH_SIZE = 1024;
#else
constexpr int CAT_ARRAY_BATCH_SIZE = 128;
@ -546,7 +546,7 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) {
});
allSameType = allSameType && (out.scalar_type() == firstType);
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
if (inputs.size() > 1 &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&

View File

@ -125,7 +125,7 @@ void SpatialSoftMax_getLaunchSizes(
uint32_t block_threads = block.x * block.y;
smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t);
int max_active_blocks;
#if defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION < 305
#if defined(USE_ROCM) && TORCH_HIP_VERSION < 305
// HIP function signature is not compatible yet.
uint32_t max_blocks;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks,
@ -358,7 +358,7 @@ blockReduce(AccumT* smem, AccumT val,
for (int i = 0; i < C10_WARP_SIZE; ++i) {
warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]);
}
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
__syncwarp(mask);
#endif
smem[lane] = warpVal;

View File

@ -366,7 +366,7 @@ void sort_cuda_kernel(
int64_t numel_or_intmax = std::min(numel, static_cast<int64_t>(std::numeric_limits<int>::max()));
int64_t nbatch = (numel_or_intmax / nsort) * nsort;
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
constexpr bool is_rocm = true;
#else
constexpr bool is_rocm = false;

View File

@ -37,13 +37,13 @@ __device__ inline void bitonicSort(K keys[Power2SortSize],
V values[Power2SortSize],
bool valid[Power2SortSize],
const Comparator& comp) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
bool flag = ((threadIdx.x & (size / 2)) != 0);
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
@ -58,7 +58,7 @@ __device__ inline void bitonicSort(K keys[Power2SortSize],
}
}
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {

View File

@ -13,7 +13,7 @@ namespace at {
namespace native {
// Is this questionable namespace pollution?
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 256;
#else

View File

@ -127,7 +127,7 @@ struct TopKTypeConfig<at::Half> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(at::Half v) {
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
RadixType x = __half_as_ushort(v);
RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
return (v == v) ? (x ^ mask) : 0xffff;
@ -138,7 +138,7 @@ struct TopKTypeConfig<at::Half> {
}
static inline __device__ at::Half deconvert(RadixType v) {
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
return __ushort_as_half(v ^ mask);
#else
@ -211,7 +211,7 @@ __device__ void countRadixUsingMask(
#pragma unroll
for (uint32_t j = 0; j < RadixSize; ++j) {
bool vote = hasVal && (digitInRadix == j);
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
counts[j] += __popcll(WARP_BALLOT(vote));
#else
counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK()));

View File

@ -28,7 +28,7 @@ using namespace at::native::detail;
static void exec_cufft_plan(
const CuFFTConfig &config, void* in_data, void* out_data, bool forward) {
auto& plan = config.plan();
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
auto value_type = config.data_type();
if (value_type == kFloat) {
switch (config.transform_type()) {

View File

@ -353,7 +353,7 @@ Tensor _histc_cuda_template(
maxvalue = maxvalue + 1;
}
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
TORCH_CHECK(
!(THCNumerics<input_t>::isinf(minvalue) ||
THCNumerics<input_t>::isinf(maxvalue) ||

View File

@ -224,7 +224,7 @@ inline void get_coordinate_in_triu_trapezoid(
template <typename scalar_t>
__global__
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_1(512)
#endif
void tril_indices_kernel(scalar_t * tensor,

View File

@ -143,13 +143,13 @@ __device__ inline void bitonicSortKeys(
K keys[Power2SortSize],
bool valid[Power2SortSize],
const Comparator& comp) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
bool flag = ((threadIdx.x & (size / 2)) != 0);
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
@ -166,7 +166,7 @@ __device__ inline void bitonicSortKeys(
}
}
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {

View File

@ -33,7 +33,7 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo<T, IndexType> input,
IndexType indicesWithinSliceStride) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of IndexType
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
__shared__ int smem[64];
#else
__shared__ int smem[32]; // one per each warp, up to warp limit

View File

@ -13,7 +13,7 @@ namespace at {
namespace native {
template <typename scalar_t, typename IndexType>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
#endif
__global__ void kernel_pointwise_flip_apply2(

View File

@ -155,7 +155,7 @@ void nan_to_num_kernel_cuda(
}
void frexp_kernel_cuda(TensorIteratorBase& iter) {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
// Reference: https://rocmdocs.amd.com/en/latest/ROCm_API_References/HIP-MATH.html
// https://github.com/ROCm-Developer-Tools/HIP/issues/2169
// ROCm does not support frexp function yet

View File

@ -45,7 +45,7 @@ __device__ __forceinline__ void reduce_block_into_lanes
__syncthreads();
}
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
@ -64,7 +64,7 @@ __device__ __forceinline__ void reduce_block_into_lanes
final = val;
// __SYNCWARP();
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for(int i = 16; i >= lanes; i >>= 1)

View File

@ -1406,7 +1406,7 @@ struct DropoutState {
at::Tensor buffer;
c10::optional<cuda::CUDAEvent> event;
std::mutex mutex;
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// cudaStreamGetCaptureInfo will never give back a capture id of 0, so 0 can serve
// as a sentinel value that capture was not underway.
cuda::CaptureId_t capture_id_last_lock = 0;
@ -1424,7 +1424,7 @@ struct DropoutState {
// could then define it before we get to unlock().
mutex.lock();
if (event) {
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// See Note [DropoutState and CUDA graph capture]
cudaStreamCaptureStatus status;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(cuda::getCurrentCUDAStream(),
@ -1445,7 +1445,7 @@ struct DropoutState {
void unlock() {
if (event) {
event->record();
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// See Note [DropoutState and CUDA graph capture]
cudaStreamCaptureStatus status;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(cuda::getCurrentCUDAStream(),

View File

@ -51,7 +51,7 @@ namespace {
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
int threadSizes[5] = {16, 32, 64, 128, 256};
#else
int threadSizes[5] = {32, 64, 128, 256, 512};

View File

@ -41,7 +41,7 @@ __device__ void applyOp3(
// Assume both dense and values are contiguous.
// Currently only used in add_out_dense_sparse_cuda: add(dense, sparse, scalar).
template <typename Op, typename IndexType, typename Real>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
#endif
__global__ void sparseElementwiseKernel(
@ -71,7 +71,7 @@ __global__ void sparseElementwiseKernel(
// Assume dense is contiguous.
// Currently only used in add_out_dense_sparse_cuda: add(dense, sparse, scalar).
template <typename Op, typename IndexType, typename Real>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
#endif
__global__ void sparseElementwiseKernelScalar(
@ -95,7 +95,7 @@ __global__ void sparseElementwiseKernelScalar(
}
template <typename OpBoth, typename OpLeft, typename OpRight, typename IndexType, typename Real>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
#endif
__global__ void valueSparseUnionKernel(
@ -142,7 +142,7 @@ __global__ void valueSparseUnionKernel(
// TODO find a way to parallelize this...
template <typename IndexType, typename Real>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
#endif
__global__ void indexSparseUnionKernel(
@ -192,7 +192,7 @@ __global__ void indexSparseUnionKernel(
}
template <typename Op, typename IndexType, typename Real>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
#endif
__global__ void valueSparseIntersectionKernel(
@ -231,7 +231,7 @@ __global__ void valueSparseIntersectionKernel(
// TODO find a way to parallelize this...
template <typename IndexType, typename Real>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
#endif
__global__ void indexSparseIntersectionKernel(

View File

@ -517,7 +517,7 @@ SparseTensor& mul_out_sparse_cuda(const SparseTensor& t_, const SparseTensor& sr
// see NOTE [ sparse.sum() backward ]
// --------------------------------------------------------------------
template <typename scalar_t>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
#endif
__global__ void _sparse_sum_backward_cuda_kernel(
@ -683,7 +683,7 @@ Tensor bmm_sparse_cuda(const SparseTensor& self, const Tensor& mat2) {
return bmm_out_sparse_cuda(self, mat2, result);
}
#if !(defined(__HIP_PLATFORM_HCC__) || (defined(_MSC_VER) && CUSPARSE_VERSION < 11000))
#if !(defined(USE_ROCM) || (defined(_MSC_VER) && CUSPARSE_VERSION < 11000))
__global__ void search_end_matrix_indices_cuda_kernel(
int64_t* mat_el_end_indices,
int64_t num_matrices,
@ -764,7 +764,7 @@ cudaDataType getTensorCudaDataType(Tensor self) {
#endif
Tensor& bmm_out_sparse_cuda(const SparseTensor& self, const Tensor& mat2, Tensor& result) {
#if defined __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
TORCH_CHECK(false, "bmm sparse-dense is not supported on HIP");
#elif defined(_MSC_VER) && (CUSPARSE_VERSION < 11000)
TORCH_CHECK(false, "bmm sparse-dense CUDA is not supported on Windows with cuda before 11.0");

View File

@ -25,7 +25,7 @@ void reset_buffers() {
}
}
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
TEST(TestLoops, HasSameArgTypes) {
// This is a compile-time unit test. If this file compiles without error,
// then the test passes and during runtime, we just need to return.

View File

@ -10,7 +10,7 @@ template <>
struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
pos &= 0xff;
len &= 0xff;
@ -25,7 +25,7 @@ struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
pos &= 0xff;
len &= 0xff;
@ -48,7 +48,7 @@ template <>
struct Bitfield<uint64_t> {
static __device__ __forceinline__
uint64_t getBitfield(uint64_t val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
pos &= 0xff;
len &= 0xff;
@ -63,7 +63,7 @@ struct Bitfield<uint64_t> {
static __device__ __forceinline__
uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
pos &= 0xff;
len &= 0xff;
@ -83,7 +83,7 @@ struct Bitfield<uint64_t> {
};
__device__ __forceinline__ int getLaneId() {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
return __lane_id();
#else
int laneId;
@ -92,7 +92,7 @@ __device__ __forceinline__ int getLaneId() {
#endif
}
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskLt() {
const std::uint64_t m = (1ull << getLaneId()) - 1ull;
return m;
@ -105,7 +105,7 @@ __device__ __forceinline__ unsigned getLaneMaskLt() {
}
#endif
#if defined (__HIP_PLATFORM_HCC__)
#if defined (USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
return m;
@ -118,7 +118,7 @@ __device__ __forceinline__ unsigned getLaneMaskLe() {
}
#endif
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskGt() {
const std::uint64_t m = getLaneMaskLe();
return m ? ~m : m;
@ -131,7 +131,7 @@ __device__ __forceinline__ unsigned getLaneMaskGt() {
}
#endif
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskGe() {
const std::uint64_t m = getLaneMaskLt();
return ~m;

View File

@ -182,7 +182,7 @@ template <typename T, int Dim,
__host__ __device__ THCDeviceTensor<T, Dim, IndexT, PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>::transpose(int dim1,
int dim2) const {
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
// Device code
assert(dim1 >= 0 && dim1 < Dim);
assert(dim1 >= 0 && dim2 < Dim);
@ -285,7 +285,7 @@ THCDeviceTensor<T, Dim, IndexT, PtrTraits>::downcastOuter() {
// in all of the dimensions we are collapsing (no padding in
// them).
bool cont = isContiguousRange(0, Dim - NewDim);
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
// Device code
assert(cont);
#else
@ -336,7 +336,7 @@ THCDeviceTensor<T, Dim, IndexT, PtrTraits>::downcastInner() {
// in all of the dimensions we are collapsing (no padding in
// them).
bool cont = isContiguousRange(NewDim, Dim);
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
// Device code
assert(cont);
#else
@ -404,7 +404,7 @@ template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
void
THCDeviceTensor<T, Dim, IndexT, PtrTraits>::zero(cudaStream_t stream) {
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
assert(isContiguous());
#else
if (!isContiguous()) {

View File

@ -219,7 +219,7 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
errmsg = "an absent device architectural feature is required";
break;
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
case CUBLAS_STATUS_MAPPING_ERROR:
errmsg = "an access to GPU memory space failed";
break;

View File

@ -14,7 +14,7 @@
#cmakedefine USE_MAGMA
/* Needed for hipMAGMA to correctly identify implementation */
#if defined(USE_MAGMA) && defined(__HIP_PLATFORM_HCC__)
#if defined(USE_MAGMA) && defined(USE_ROCM)
#define HAVE_HIP 1
#endif

View File

@ -96,7 +96,7 @@ __device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunct
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
// Within-warp, we use warp voting.
#if defined (__HIP_PLATFORM_HCC__)
#if defined (USE_ROCM)
unsigned long long int vote = WARP_BALLOT(in);
T index = __popcll(getLaneMaskLe() & vote);
T carry = __popcll(vote);

View File

@ -3,7 +3,7 @@
#include <ATen/cuda/ThrustAllocator.h>
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#if CUDA_VERSION >= 7000 || defined(__HIP_PLATFORM_HCC__)
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 7000) || defined(USE_ROCM)
#include <thrust/system/cuda/execution_policy.h>
#endif

View File

@ -7,7 +7,7 @@ void THCStorage_(fill)(THCState *state, THCStorage *self, scalar_t value)
at::cuda::ThrustAllocator thrustAlloc;
thrust::device_ptr<scalar_t> self_data(THCStorage_(data)(state, self));
thrust::fill(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 7000) || defined(USE_ROCM)
thrust::cuda::par(thrustAlloc).on(c10::cuda::getCurrentCUDAStream()),
#endif
self_data,

View File

@ -17,7 +17,7 @@ using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
// RAII guard for "cudaStreamCaptureMode", a thread-local value
// that controls the error-checking strictness of a capture.
#if CUDA_VERSION >= 11000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
struct C10_CUDA_API CUDAStreamCaptureModeGuard {
CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) {
strictness_ = desired;

View File

@ -295,13 +295,13 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
#define C10_DEVICE
#endif
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#define C10_HIP_HOST_DEVICE __host__ __device__
#else
#define C10_HIP_HOST_DEVICE
#endif
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h)
#else
#define C10_WARP_SIZE 32
@ -315,7 +315,7 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
// even when NDEBUG is defined. This is useful for important assertions in CUDA
// code that would otherwise be suppressed when building Release.
#if defined(__ANDROID__) || defined(__APPLE__) || \
(defined(__HIP_PLATFORM_HCC__) && ROCM_VERSION < 40100)
(defined(USE_ROCM) && ROCM_VERSION < 40100)
// Those platforms do not support assert()
#define CUDA_KERNEL_ASSERT(cond)
#elif defined(_MSC_VER)

View File

@ -433,7 +433,7 @@ C10_HOST_DEVICE void test_arithmetic_assign_complex() {
// this test is skipped due to a bug in constexpr evaluation
// in nvcc. This bug has already been fixed since CUDA 11.2
#if !defined(__CUDACC__) || CUDA_VERSION >= 11020
#if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020)
static_assert(x3.imag() == scalar_t(3), "");
#endif
@ -445,7 +445,7 @@ C10_HOST_DEVICE void test_arithmetic_assign_complex() {
// this test is skipped due to a bug in constexpr evaluation
// in nvcc. This bug has already been fixed since CUDA 11.2
#if !defined(__CUDACC__) || CUDA_VERSION >= 11020
#if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020)
static_assert(y3.imag() == scalar_t(1), "");
#endif

View File

@ -37,7 +37,7 @@ TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) {
// On these platforms there's no assert
#if !defined(__ANDROID__) && !defined(__APPLE__) && \
!(defined(__HIP_PLATFORM_HCC__) && ROCM_VERSION < 40100)
!(defined(USE_ROCM) && ROCM_VERSION < 40100)
TEST(ExceptionTest, CUDA_KERNEL_ASSERT) {
// This function always throws even in NDEBUG mode
ASSERT_DEATH_IF_SUPPORTED({ CUDA_KERNEL_ASSERT(false); }, "Assert");

View File

@ -19,7 +19,7 @@ inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
uint32_t tmp = src;
tmp <<= 16;
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
float* tempRes;
// We should be using memcpy in order to respect the strict aliasing rule
@ -36,7 +36,7 @@ inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
uint32_t res = 0;
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
@ -49,7 +49,7 @@ inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
}
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
if (src != src) {
#elif defined(_MSC_VER)
if (isnan(src)) {
@ -74,7 +74,7 @@ struct alignas(2) BFloat16 {
uint16_t x;
// HIP wants __host__ __device__ tag, CUDA does not
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_HOST_DEVICE BFloat16() = default;
#else
BFloat16() = default;

View File

@ -107,7 +107,7 @@ using void_t = typename make_void<Ts...>::type;
#endif
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
// rocm doesn't like the C10_HOST_DEVICE
#define CUDA_HOST_DEVICE
#else

View File

@ -372,7 +372,7 @@ struct alignas(2) Half {
}
// HIP wants __host__ __device__ tag, CUDA does not
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_HOST_DEVICE Half() = default;
#else
Half() = default;

View File

@ -541,7 +541,7 @@ C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
#endif
}
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#define ROCm_Bug(x)
#else
#define ROCm_Bug(x) x

View File

@ -63,7 +63,7 @@ private:
at::TensorOptions optionsFor(const Tensor& ten) {
at::Device device = ten.GetDevice();
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
if (backend() == at::Backend::HIP) {
device = at::Device(kCUDA, device.index());
}
@ -107,7 +107,7 @@ private:
auto at_sizes = src.sizes();
caffe2::TypeMeta type_meta = typeMetaFor(src);
at::Device device = src.device();
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
if (device.is_cuda()) {
device = at::Device(at::DeviceType::HIP, device.index());
}

View File

@ -117,7 +117,7 @@ void DeviceQuery(const int device) {
<< std::endl;
ss << "Total registers per block: " << prop.regsPerBlock << std::endl;
ss << "Warp size: " << prop.warpSize << std::endl;
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
ss << "Maximum memory pitch: " << prop.memPitch << std::endl;
#endif
ss << "Maximum threads per block: " << prop.maxThreadsPerBlock
@ -130,14 +130,14 @@ void DeviceQuery(const int device) {
<< prop.maxGridSize[2] << std::endl;
ss << "Clock rate: " << prop.clockRate << std::endl;
ss << "Total constant memory: " << prop.totalConstMem << std::endl;
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
ss << "Texture alignment: " << prop.textureAlignment << std::endl;
ss << "Concurrent copy and execution: "
<< (prop.deviceOverlap ? "Yes" : "No") << std::endl;
#endif
ss << "Number of multiprocessors: " << prop.multiProcessorCount
<< std::endl;
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
ss << "Kernel execution timeout: "
<< (prop.kernelExecTimeoutEnabled ? "Yes" : "No") << std::endl;
#endif
@ -186,7 +186,7 @@ const char* cublasGetErrorString(cublasStatus_t error) {
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
@ -240,7 +240,7 @@ const char* curandGetErrorString(curandStatus_t error) {
return "CURAND_STATUS_ARCH_MISMATCH";
case CURAND_STATUS_INTERNAL_ERROR:
return "CURAND_STATUS_INTERNAL_ERROR";
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
case HIPRAND_STATUS_NOT_IMPLEMENTED:
return "HIPRAND_STATUS_NOT_IMPLEMENTED";
#endif

View File

@ -5,14 +5,14 @@
#include <cuda.h>
#include <cuda_runtime.h>
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#ifdef __GNUC__
#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
#pragma GCC diagnostic push
#endif
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // __GNUC__
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
#include <cublas_v2.h>
#include <curand.h>
@ -30,10 +30,11 @@
// CAFFE2_CUDA_API gets translated to CAFFE2_HIP_API in hipify script, which
// causes a marco redefinition issue with the later definition of
// CAFFE2_HIP_API, so we exclude this definition when HIP is specified
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#define CAFFE2_CUDA_API TORCH_CUDA_CPP_API
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
//TODO: [ROCm] Need to remove this after CUDA->HIP mapping is updated.
#define CAFFE2_HIP_EXPORT C10_EXPORT
#define CAFFE2_HIP_API TORCH_HIP_API
@ -52,20 +53,20 @@
#endif
// cuda major revision number below which fp16 compute is not supoorted
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
constexpr int kFp16CUDADevicePropMajor = 6;
#else
constexpr int kFp16CUDADevicePropMajor = 3;
#endif
// Re-enable strict aliasing diagnostic if it was disabled.
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#ifdef __GNUC__
#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
#pragma GCC diagnostic pop
#endif
#endif // __GNUC__
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
/**
* The maximum number of peers that each gpu can have when doing p2p setup.
@ -78,14 +79,14 @@ constexpr int kFp16CUDADevicePropMajor = 3;
namespace caffe2 {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
/**
* Empty class to identify TensorCore-based math
*/
class TensorCoreEngine {};
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
#if CUDA_VERSION >= 10000
#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000
#define CAFFE2_CUDA_PTRATTR_MEMTYPE type
#else
#define CAFFE2_CUDA_PTRATTR_MEMTYPE memoryType
@ -95,7 +96,11 @@ class TensorCoreEngine {};
* A runtime function to report the cuda version that Caffe2 is built with.
*/
inline int CudaVersion() {
#if defined(USE_ROCM)
return ROCM_VERSION;
#else
return CUDA_VERSION;
#endif
}
/**

View File

@ -57,7 +57,9 @@ static_assert(
{"BLAS_INFO", "${BLAS_INFO}"}, \
{"LAPACK_INFO", "${LAPACK_INFO}"}, \
{"USE_CUDA", "${USE_CUDA}"}, \
{"USE_ROCM", "${USE_ROCM}"}, \
{"CUDA_VERSION", "${CUDA_VERSION}"}, \
{"ROCM_VERSION", "${ROCM_VERSION}"}, \
{"USE_CUDNN", "${USE_CUDNN}"}, \
{"CUDNN_VERSION", "${CUDNN_VERSION}"}, \
{"USE_NCCL", "${USE_NCCL}"}, \

View File

@ -1,6 +1,6 @@
#include "caffe2/distributed/file_store_handler_op.h"
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#include <caffe2/core/context_gpu.h>
#else
#include <caffe2/core/hip/context_gpu.h>
@ -8,7 +8,7 @@
namespace caffe2 {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
REGISTER_CUDA_OPERATOR(
FileStoreHandlerCreate,
FileStoreHandlerCreateOp<CUDAContext>);

View File

@ -1,6 +1,6 @@
#include "caffe2/distributed/redis_store_handler_op.h"
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#include <caffe2/core/context_gpu.h>
#else
#include <caffe2/core/hip/context_gpu.h>
@ -8,7 +8,7 @@
namespace caffe2 {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
REGISTER_CUDA_OPERATOR(
RedisStoreHandlerCreate,
RedisStoreHandlerCreateOp<CUDAContext>);

View File

@ -12,7 +12,7 @@ bool BatchMatMulOp<CUDAContext, DefaultEngine>::RunOnDevice() {
REGISTER_CUDA_OPERATOR(BatchMatMul, BatchMatMulOp<CUDAContext>);
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
template <>
bool BatchMatMulOp<CUDAContext, TensorCoreEngine>::RunOnDevice() {

View File

@ -25,7 +25,7 @@ __global__ void ChannelStatsNCHWCUDAKernel(
for (int n = threadIdx.x; n < N; n += blockDim.x) {
for (int hw = threadIdx.y; hw < HxW; hw += blockDim.y) {
const int index = (n * C + c) * HxW + hw;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
m_val += __ldg(X + index);
v_val += __ldg(X + index) * __ldg(X + index);
#else
@ -58,7 +58,7 @@ __global__ void ChannelStatsNHWCCUDAKernel(
T v_val = 0;
for (int i = threadIdx.x; i < inner_size; i += blockDim.x) {
const int index = i * C + c;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
m_val += __ldg(X + index);
v_val += __ldg(X + index) * __ldg(X + index);
#else

View File

@ -139,7 +139,7 @@ bool FullyConnectedGradientOp<
}
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
// Require these to be defined otherwise TensorCore FC ops will end
// up calling the default FC implementation which doesn't have
@ -191,7 +191,7 @@ REGISTER_CUDA_OPERATOR(
DefaultEngine,
false /* don't transpose weight */>);
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
REGISTER_CUDA_OPERATOR_WITH_ENGINE(
FC,

View File

@ -6,7 +6,7 @@
#include "caffe2/operators/generate_proposals_op_util_nms.h"
#include "caffe2/operators/generate_proposals_op_util_nms_gpu.h"
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#include <cfloat>
#endif

View File

@ -7,7 +7,7 @@ namespace utils {
namespace {
// Helper data structure used locally
struct
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
__align__(16)
#endif
Box {

View File

@ -44,7 +44,7 @@ __global__ void ComputeFusedParamsCUDAKernel<float>(
if (index < N * C) {
const int ng = index / K;
const int c = index % C;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
const float scale_val = __ldg(gamma + c) * __ldg(rsig + ng);
scale[index] = scale_val;
bias[index] = fmaf(-scale_val, __ldg(mu + ng), __ldg(beta + c));
@ -78,7 +78,7 @@ __global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NCHW>(
const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (index < N * C * HxW) {
const int nc = index / HxW;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
#else
Y[index] = fmaf(X[index], scale[nc], bias[nc]);
@ -98,7 +98,7 @@ __global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NHWC>(
const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (index < N * C * HxW) {
const int nc = index / (HxW * C) * C + index % C;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
#else
Y[index] = fmaf(X[index], scale[nc], bias[nc]);
@ -120,7 +120,7 @@ __global__ void ComputeInternalGradientsNCHWCUDAKernel(
T db_sum = 0;
for (int i = threadIdx.x; i < HxW; i += blockDim.x) {
const int index = nc * HxW + i;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
ds_sum += __ldg(dY + index) * __ldg(X + index);
db_sum += __ldg(dY + index);
#else
@ -160,7 +160,7 @@ __global__ void ComputeYGradientScaleCUDAKernel(
if (index < N * C) {
const int ng = index / K;
const int c = index % C;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
dY_scale[index] = __ldg(gamma + c) * __ldg(rsig + ng);
#else
dY_scale[index] = gamma[c] * rsig[ng];
@ -203,7 +203,7 @@ __global__ void ComputeXScaleAndBiasCUDAKernel<float>(
for (int i = threadIdx.x; i < K; i += blockDim.x) {
const int index = ng * K + i;
const int c = g * K + i;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
ds_sum += __ldg(ds + index) * __ldg(gamma + c);
db_sum += __ldg(db + index) * __ldg(gamma + c);
#else
@ -214,7 +214,7 @@ __global__ void ComputeXScaleAndBiasCUDAKernel<float>(
ds_sum = BlockReduce<float>(ds_storage).Sum(ds_sum);
db_sum = BlockReduce<float>(db_storage).Sum(db_sum);
if (threadIdx.x == 0) {
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
const float x = fmaf(db_sum, __ldg(mu + ng), -ds_sum) *
math::utils::Cube<float>(__ldg(rsig + ng)) * alpha;
X_scale[ng] = x;
@ -258,7 +258,7 @@ __global__ void GroupNormBackwardCUDAKernel<float, StorageOrder::NCHW>(
if (index < N * C * HxW) {
const int nc = index / HxW;
const int ng = nc / K;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
dX[index] = fmaf(
__ldg(dY_scale + nc),
__ldg(dY + index),
@ -287,7 +287,7 @@ __global__ void GroupNormBackwardCUDAKernel<float, StorageOrder::NHWC>(
if (index < N * C * HxW) {
const int nc = index / (HxW * C) * C + index % C;
const int ng = nc / K;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
dX[index] = fmaf(
__ldg(dY_scale + nc),
__ldg(dY + index),
@ -333,7 +333,7 @@ __global__ void GammaBetaBackwardCUDAKernel<float>(
for (int i = threadIdx.x; i < N; i += blockDim.x) {
const int nc = i * C + c;
const int ng = i * G + g;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
dg_sum += fmaf(-__ldg(db + nc), __ldg(mu + ng), __ldg(ds + nc)) *
__ldg(rsig + ng);
db_sum += __ldg(db + nc);

View File

@ -21,7 +21,7 @@ __global__ void ComputeFusedParamsCUDAKernel(
const int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N * C) {
const int64_t c = index % C;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
const T scale_val = __ldg(gamma + c) * __ldg(rstd + index);
scale[index] = scale_val;
bias[index] = __ldg(beta + c) - scale_val * __ldg(mean + index);
@ -47,7 +47,7 @@ __global__ void InstanceNormForwardCUDAKernel(
const int64_t nc = kOrder == StorageOrder::NCHW
? (index / HxW)
: (index / (HxW * C) * C + index % C);
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
Y[index] = __ldg(scale + nc) * __ldg(X + index) + __ldg(bias + nc);
#else
Y[index] = scale[nc] * X[index] + bias[nc];
@ -69,7 +69,7 @@ __global__ void ComputeInternalGradientsNCHWCUDAKernel(
T db_sum = 0;
for (int64_t j = threadIdx.x; j < HxW; j += blockDim.x) {
const int64_t index = i * HxW + j;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
ds_sum += __ldg(dY + index) * __ldg(X + index);
db_sum += __ldg(dY + index);
#else
@ -101,7 +101,7 @@ __global__ void ComputeFusedParams(
const int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N * C) {
const int64_t c = index % C;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
T x = __ldg(ds + index) * __ldg(gamma + c);
T y = __ldg(db + index) * __ldg(gamma + c);
x = (y * __ldg(mean + index) - x) *
@ -136,7 +136,7 @@ __global__ void InstanceNormBackwardCUDAKernel(
const int64_t c = kOrder == StorageOrder::NCHW
? (index / HxW)
: (index / (HxW * C) * C + index % C);
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
dX[index] = __ldg(c1 + c) * __ldg(dY + index) +
__ldg(c2 + c) * __ldg(X + index) + __ldg(c3 + c);
#else
@ -162,7 +162,7 @@ __global__ void GammaBetaBackwardCUDAKernel(
T sum2 = 0;
for (int64_t i = threadIdx.x; i < N; i += blockDim.x) {
const int64_t index = i * C + c;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
sum1 += (__ldg(ds + index) - __ldg(db + index) * __ldg(mean + index)) *
__ldg(rstd + index);
sum2 += __ldg(db + index);

View File

@ -16,7 +16,7 @@ __global__ void SelectGradientCUDAKernel(
T* dX) {
const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (i < N) {
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
dX[i] = __ldg(X + i) == __ldg(Y + i) ? __ldg(dY + i) : T(0);
#else
dX[i] = X[i] == Y[i] ? dY[i] : T(0);

View File

@ -2,7 +2,7 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/reduce_front_back_max_ops.h"
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#include <cfloat>
#endif

View File

@ -3,11 +3,11 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/rmac_regions_op.h"
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#include <cfloat>
#endif
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
namespace rocprim {
#else
namespace cub {

View File

@ -41,7 +41,7 @@ void inclusive_scan_wrapper(
}
template <typename T, bool ExactBlock = false, bool Average = false>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void length_sum_kernel(
@ -85,7 +85,7 @@ __global__ void length_sum_kernel(
}
template <typename T, bool ExactBlock = false, bool Average = false>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void length_sum_gradient_kernel(
@ -126,7 +126,7 @@ __global__ void length_sum_gradient_kernel(
}
template <typename T, bool ExactBlock = false>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void length_max_kernel(
@ -172,7 +172,7 @@ __global__ void length_max_kernel(
}
template <typename T, bool ExactBlock = false>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void length_weighted_sum_gradient_kernel(
@ -209,7 +209,7 @@ __global__ void length_weighted_sum_gradient_kernel(
}
template <typename T, typename IndexType, int NumThreads>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void length_weighted_sum_with_main_input_gradient_kernel(
@ -252,7 +252,7 @@ __global__ void length_weighted_sum_with_main_input_gradient_kernel(
}
template <typename T, typename IndexType, bool ExactBlock = false>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void sparse_length_max_kernel(
@ -313,7 +313,7 @@ __global__ void sparse_length_max_kernel(
}
template <typename T, typename IndexType, bool ExactBlock = false>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void sparse_length_weighted_sum_kernel(

View File

@ -4,7 +4,7 @@
#include "caffe2/core/context_gpu.h"
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#define SEGREDUCE_MINBLOCKS 8
#else
#define SEGREDUCE_MINBLOCKS 16
@ -56,7 +56,7 @@ template <
typename IndexType,
bool ExactBlock = false,
bool Average = false>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024,SEGREDUCE_MINBLOCKS)
#endif
__global__ void sparse_length_sum_kernel(

View File

@ -20,7 +20,7 @@ __global__ void TileCopyCUDAKernel(
if (x < total_size) {
const int r = x / inner_size / tiles;
const int c = x % inner_size;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
Y[x] = __ldg(X + r * inner_size + c);
#else
Y[x] = X[r * inner_size + c];

View File

@ -71,7 +71,7 @@ __device__ inline void warpHeapInsert(K k, V v, K* keyHeap, V* valueHeap) {
// (0 12 3456)
// log2(8 / 2) = 2 levels of interior nodes for heap size 8 (0 and 12)
int i = 0;
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int levels = 0; levels < math::IntegerLog2(HeapSize / 2); ++levels) {
@ -114,7 +114,7 @@ warpHeap(K k, V v, K& keyHeapHead, K* keyHeap, V* valueHeap) {
bool wantInsert = Dir ? (k > keyHeapHead) : (k < keyHeapHead);
// Find out all the lanes that have elements to add to the heap
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
unsigned long long int vote = __ballot(wantInsert);
if (!vote) {
@ -138,7 +138,7 @@ warpHeap(K k, V v, K& keyHeapHead, K* keyHeap, V* valueHeap) {
// that have elements
int index = __popc(getLaneMaskLt() & vote);
int total = __popc(vote);
#endif // __HIP_PLATFORM_HCC__
#endif // _USE_ROCM
// FIXME: try switch statement and explicitly handle cases
// FIXME: how do cases work?
@ -261,14 +261,14 @@ __global__ void selectRowsViaHeap(
V vals[Unroll];
for (int i = threadIdx.x; i < n; i += blockDim.x * Unroll) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int j = 0; j < Unroll; ++j) {
vals[j] = inputStart[i + j * blockDim.x];
}
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (int j = 0; j < Unroll; ++j) {

View File

@ -170,11 +170,11 @@ __device__ void countRadixUsingMask(CountType counts[RadixSize],
#pragma unroll
for (unsigned int j = 0; j < RadixSize; ++j) {
bool vote = hasVal && (digitInRadix == j);
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
counts[j] += __popcll(__ballot(vote));
#else
counts[j] += __popc(__ballot_sync(__activemask(), vote));
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
}
}

View File

@ -1064,15 +1064,15 @@ void addGlobalMethods(py::module& m) {
#endif // CAFFE2_USE_MKLDNN
);
// if the binary is built with __HIP_PLATFORM_HCC__, this is a ROCm build
// if the binary is built with USE_ROCM, this is a ROCm build
// and therefore we need to ignore dyndep failures (because the the module
// may not have a ROCm equivalent yet e.g. nccl)
m.attr("use_rocm") = py::bool_(
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
true
#else // __HIP_PLATFORM_HCC__
#else // USE_ROCM
false
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
);
m.attr("use_trt") = py::bool_(

View File

@ -86,7 +86,7 @@ void sort_pairs_wrapper(
}
template <typename T>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void gradient_mean_kernel(
@ -104,7 +104,7 @@ __global__ void gradient_mean_kernel(
}
template <typename SIndex, typename TParam, typename T, bool ExactBlock = false>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void sparse_adagrad_fused_length_sum_gradient_kernel(
@ -171,7 +171,7 @@ __global__ void sparse_adagrad_fused_length_sum_gradient_kernel(
}
template <typename SIndex, typename TParam, typename T, int NumThreads>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void sparse_adagrad_fused_length_weighted_sum_gradient_kernel(
@ -252,7 +252,7 @@ __global__ void sparse_adagrad_fused_length_weighted_sum_gradient_kernel(
// Construct a reverse map of offset_of_idx -> segment_id.
template <typename SIndex>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void linear_index_weight_offsets_dedup_kernel(
@ -279,7 +279,7 @@ template <
typename T,
bool ExactBlock = false,
roundOption roundOpt = NEAREST>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel(
@ -343,7 +343,7 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel(
sorted_linear_ind_data[sorted_linear_indice_id + num_dup + threadIdx.x] ==
index;
}
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
int32_t num_dup_incr = __popc(__ballot_sync(0xFFFFFFFF, segment_continue));
#else
int32_t num_dup_incr = __popc(__ballot(segment_continue));
@ -438,7 +438,7 @@ __global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_dedup_kernel(
}
template <typename SIndex, typename TParam, typename T, int NumThreads>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__

View File

@ -10,7 +10,7 @@
#include "caffe2/core/operator.h"
#include "caffe2/utils/GpuAtomics.cuh"
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
#define SEGREDUCE_MINBLOCKS 8
#else
#define SEGREDUCE_MINBLOCKS 16
@ -31,7 +31,7 @@ constexpr int kWarpSize = 32;
template <typename T>
inline __device__ T shfl_xor(const T val, int laneMask, int width = kWarpSize) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
return __shfl_xor_sync(0xffffffff, val, laneMask, width);
#else
return __shfl_xor(val, laneMask, width);
@ -108,8 +108,9 @@ static inline __device__ void gpuAtomicAdd(float* address, float val) {
}
static inline __device__ void gpuAtomicAdd(c10::Half* address, c10::Half val) {
#if ( \
(CUDA_VERSION < 10000) || \
#if ( \
(defined(USE_ROCM)) || \
(defined(CUDA_VERSION) && (CUDA_VERSION < 10000)) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
unsigned int* address_as_ui =
(unsigned int*)((char*)address - ((size_t)address & 2));
@ -136,7 +137,7 @@ template <
typename T,
bool ExactBlock = false,
roundOption roundOpt = NEAREST>
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024, SEGREDUCE_MINBLOCKS)
#endif
__global__ void rowwise_sparse_adagrad_fused_length_sum_gradient_kernel(

View File

@ -22,7 +22,7 @@ __global__ void FP16MomentumSGDKernel(
bool nesterov,
const float wd,
half2* param) {
#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 530 || defined(USE_ROCM)
const float lr2 = lr[0];
const half2 LR = __float2half2_rn(lr2);
const half2 momentum = __float2half2_rn(mom);
@ -109,7 +109,7 @@ __global__ void FP16MomentumSGDFP32Kernel(
bool nesterov,
const float wd,
half2* param) {
#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__)
#if __CUDA_ARCH__ >= 530 || defined(USE_ROCM)
const float lr2 = lr[0];
const float LR = lr2;
const float momentum = mom;

View File

@ -14,7 +14,7 @@ inline __device__ void gpu_atomic_add(T* address, const T val) {
template <>
inline __device__ void gpu_atomic_add(float* address, const float val) {
#if defined(__HIP_PLATFORM_HCC__) && defined(__gfx908__)
#if defined(USE_ROCM) && defined(__gfx908__)
atomicAddNoRet(address, val);
#else
atomicAdd(address, val);

View File

@ -7,7 +7,7 @@ namespace caffe2 {
// Static definition of GPU warp size for unrolling and code generation
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
constexpr int kWarpSize = warpSize; // = 64 (Defined in hip_runtime.h)
#else
constexpr int kWarpSize = 32;
@ -25,7 +25,7 @@ template <>
struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
pos &= 0xff;
len &= 0xff;
@ -35,12 +35,12 @@ struct Bitfield<unsigned int> {
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
}
static __device__ __forceinline__
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
pos &= 0xff;
len &= 0xff;
@ -55,7 +55,7 @@ struct Bitfield<unsigned int> {
asm("bfi.b32 %0, %1, %2, %3, %4;" :
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
return ret;
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
}
};
@ -63,7 +63,7 @@ template <>
struct Bitfield<unsigned long long int> {
static __device__ __forceinline__
unsigned long long int getBitfield(unsigned long long int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
pos &= 0xff;
len &= 0xff;
@ -73,12 +73,12 @@ struct Bitfield<unsigned long long int> {
unsigned long long int ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
return ret;
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
}
static __device__ __forceinline__
unsigned long long int setBitfield(unsigned long long int val, unsigned long long int toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
pos &= 0xff;
len &= 0xff;
@ -93,21 +93,21 @@ struct Bitfield<unsigned long long int> {
asm("bfi.b64 %0, %1, %2, %3, %4;" :
"=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
return ret;
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
}
};
__device__ __forceinline__ int getLaneId() {
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
return __lane_id();
#else
int laneId;
asm("mov.s32 %0, %%laneid;" : "=r"(laneId) );
return laneId;
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
}
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskLt() {
unsigned long long int m = (1ull << getLaneId()) - 1ull;
return m;
@ -151,7 +151,7 @@ __device__ __forceinline__ unsigned getLaneMaskGe() {
asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
return mask;
}
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
} // namespace caffe2

View File

@ -62,7 +62,7 @@ __device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunct
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
// Within-warp, we use warp voting.
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
unsigned long long int vote = __ballot(in);
T index = __popcll(getLaneMaskLe() & vote);
@ -71,7 +71,7 @@ __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFuncti
T vote = __ballot_sync(__activemask(), in);
T index = __popc(getLaneMaskLe() & vote);
T carry = __popc(vote);
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
int warp = threadIdx.x / kWarpSize;
@ -117,11 +117,11 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
*out -= (T) in;
// The outgoing carry for all threads is the last warp's sum
#if defined(__HIP_PLATFORM_HCC__)
#if defined(USE_ROCM)
*carry = smem[math::DivUp<int>(blockDim.x, kWarpSize) - 1];
#else
*carry = smem[(blockDim.x / kWarpSize) - 1];
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
if (KillWARDependency) {
__syncthreads();

View File

@ -30,16 +30,16 @@ class FixedDivisor<std::int32_t> {
FixedDivisor() = default;
explicit FixedDivisor(const std::int32_t d) : d_(d) {
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
CalcSignedMagic();
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
}
FIXED_DIVISOR_DECL std::int32_t d() const {
return d_;
}
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
FIXED_DIVISOR_DECL std::uint64_t magic() const {
return magic_;
}
@ -47,17 +47,17 @@ class FixedDivisor<std::int32_t> {
FIXED_DIVISOR_DECL int shift() const {
return shift_;
}
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
/// Calculates `q = n / d`.
FIXED_DIVISOR_DECL std::int32_t Div(const std::int32_t n) const {
#ifdef __HIP_PLATFORM_HCC__
#if defined(USE_ROCM)
return n / d_;
#else // __HIP_PLATFORM_HCC__
#else // USE_ROCM
// In lieu of a mulhi instruction being available, perform the
// work in uint64
return (int32_t)((magic_ * (uint64_t)n) >> shift_);
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
}
/// Calculates `r = n % d`.
@ -73,7 +73,7 @@ class FixedDivisor<std::int32_t> {
}
private:
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
// Calculates magic multiplicative value and shift amount for calculating `q =
// n / d` for signed 32-bit integers.
// Implementation taken from Hacker's Delight section 10.
@ -117,14 +117,14 @@ class FixedDivisor<std::int32_t> {
shift_ = p;
magic_ = (std::uint64_t)(std::uint32_t)magic;
}
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
std::int32_t d_ = 1;
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
std::uint64_t magic_;
int shift_;
#endif // __HIP_PLATFORM_HCC__
#endif // USE_ROCM
};
} // namespace caffe2

View File

@ -17,7 +17,7 @@ void CompareDivMod(int32_t v, int32_t divisor) {
int fixed_q = fixed.Div(v);
int fixed_r = fixed.Mod(v);
#ifndef __HIP_PLATFORM_HCC__
#if !defined(USE_ROCM)
EXPECT_EQ(native_q, fixed_q)
<< v << " / " << divisor << " magic " << fixed.magic() << " shift "
<< fixed.shift() << " quot " << fixed_q << " " << native_q;

Some files were not shown because too many files have changed in this diff Show More