mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
CUDA 9 support
Summary: Adds support for the CUDA 9 toolkit. Includes new fp16 data type fixes, and changes to warp-synchronous programming. Also updates CUB third-party repo for CUDA 9 support. Closes https://github.com/caffe2/caffe2/pull/853 Differential Revision: D5548507 Pulled By: Yangqing fbshipit-source-id: c7fd2edb623f2aa8c67b9a1000efc8f71e6832ab
This commit is contained in:
committed by
Facebook Github Bot
parent
4d8a8c2e1e
commit
e97c04118e
@ -112,7 +112,11 @@ warpHeap(K k, V v, K& keyHeapHead, K* keyHeap, V* valueHeap) {
|
||||
bool wantInsert = Dir ? (k > keyHeapHead) : (k < keyHeapHead);
|
||||
|
||||
// Find out all the lanes that have elements to add to the heap
|
||||
#if CUDA_VERSION >= 9000
|
||||
unsigned int vote = __ballot_sync(__activemask(), wantInsert);
|
||||
#else
|
||||
unsigned int vote = __ballot(wantInsert);
|
||||
#endif
|
||||
|
||||
if (!vote) {
|
||||
// Everything the warp has is smaller than our heap
|
||||
|
@ -167,7 +167,11 @@ __device__ void countRadixUsingMask(CountType counts[RadixSize],
|
||||
#pragma unroll
|
||||
for (unsigned int j = 0; j < RadixSize; ++j) {
|
||||
bool vote = hasVal && (digitInRadix == j);
|
||||
#if CUDA_VERSION >= 9000
|
||||
counts[j] += __popc(__ballot_sync(__activemask(), vote));
|
||||
#else
|
||||
counts[j] += __popc(__ballot(vote));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,7 +8,7 @@ namespace caffe2 {
|
||||
// Static definition of GPU warp size for unrolling and code generation
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ <= 620
|
||||
#if __CUDA_ARCH__ <= 700
|
||||
constexpr int kWarpSize = 32;
|
||||
#else
|
||||
#error Unknown __CUDA_ARCH__; please define parameters for compute capability
|
||||
|
@ -62,7 +62,12 @@ __device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunct
|
||||
template <typename T, bool KillWARDependency, class BinaryFunction>
|
||||
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
|
||||
// Within-warp, we use warp voting.
|
||||
#if CUDA_VERSION >= 9000
|
||||
T vote = __ballot_sync(__activemask(), in);
|
||||
#else
|
||||
T vote = __ballot(in);
|
||||
#endif
|
||||
|
||||
T index = __popc(getLaneMaskLe() & vote);
|
||||
T carry = __popc(vote);
|
||||
|
||||
|
@ -108,41 +108,67 @@ inline float cpu_half2float(float16 h) {
|
||||
}
|
||||
|
||||
}; // anonymous
|
||||
|
||||
#if __CUDACC__
|
||||
|
||||
#if CUDA_VERSION >= 9000
|
||||
inline float16 halfToFloat16(half x) {
|
||||
float16 r = *reinterpret_cast<float16*>(&x);
|
||||
return r;
|
||||
}
|
||||
|
||||
inline half float16ToHalf(const float16 x) {
|
||||
__half_raw hr;
|
||||
hr.x = x.x;
|
||||
half r(hr);
|
||||
return r;
|
||||
}
|
||||
|
||||
inline half floatToHalf(const float x) {
|
||||
float16 xh = cpu_float2half_rn(x);
|
||||
return float16ToHalf(xh);
|
||||
}
|
||||
|
||||
#else
|
||||
inline float16 halfToFloat16(__half x) {
|
||||
float16 r;
|
||||
r.x = x.x;
|
||||
return r;
|
||||
}
|
||||
|
||||
inline __half float16ToHalf(const float16 x) {
|
||||
__half r;
|
||||
r.x = x.x;
|
||||
return r;
|
||||
}
|
||||
|
||||
inline half floatToHalf(const float x) {
|
||||
float16 xh = cpu_float2half_rn(x);
|
||||
return float16ToHalf(xh);
|
||||
}
|
||||
#endif // CUDA_VERSION
|
||||
|
||||
#endif // __CUDACC__
|
||||
|
||||
// general version: defer to static_cast
|
||||
template <typename IN, typename OUT>
|
||||
CONVERSIONS_DECL OUT To(const IN in) {
|
||||
return static_cast<OUT>(in);
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__
|
||||
__device__ __inline__ __half inf_clip(__half h) {
|
||||
int isi = __hisinf(h);
|
||||
if (isi > 0) {
|
||||
// Exponent all ones except LSB (0x1e), mantissa is all ones (0x3ff)
|
||||
h.x = 0x7bffU;
|
||||
} else if (isi < 0) {
|
||||
// As above, negated
|
||||
h.x = 0x7bffU ^ 0x8000;
|
||||
}
|
||||
return h;
|
||||
}
|
||||
#endif
|
||||
|
||||
// explicit for fp16
|
||||
template <>
|
||||
CONVERSIONS_DECL float16 To(const float in) {
|
||||
#if __CUDA_ARCH__
|
||||
// hacky interface between C2 fp16 and CUDA
|
||||
float16 ret;
|
||||
#if 0
|
||||
// alternative truncation scheme
|
||||
__half r;
|
||||
r.x = __float2half_rn(in);
|
||||
ret.x = inf_clip(r).x;
|
||||
#if CUDA_VERSION >= 9000
|
||||
half rh = static_cast<half>(in);
|
||||
return halfToFloat16(rh);
|
||||
#else
|
||||
float16 ret;
|
||||
ret.x = __float2half(in).x;
|
||||
#endif
|
||||
return ret;
|
||||
#endif // CUDA_VERSION >= 9000
|
||||
#else
|
||||
return cpu_float2half_rn(in);
|
||||
#endif
|
||||
@ -151,7 +177,11 @@ CONVERSIONS_DECL float16 To(const float in) {
|
||||
template <>
|
||||
CONVERSIONS_DECL float To(const float16 in) {
|
||||
#if __CUDA_ARCH__
|
||||
#if CUDA_VERSION >= 9000
|
||||
__half_raw tmp;
|
||||
#else
|
||||
__half tmp;
|
||||
#endif
|
||||
tmp.x = in.x;
|
||||
return __half2float(tmp);
|
||||
#else
|
||||
|
@ -178,11 +178,10 @@ void Gemm<float16, CUDAContext>(
|
||||
N));
|
||||
|
||||
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
||||
// convert alpha, beta from caffe2::float16 -> __half
|
||||
__half alpha_fp16;
|
||||
alpha_fp16.x = convert::To<float, float16>(alpha).x;
|
||||
__half beta_fp16;
|
||||
beta_fp16.x = convert::To<float, float16>(beta).x;
|
||||
// convert alpha, beta from float -> __half
|
||||
auto alpha_fp16 = convert::floatToHalf(alpha);
|
||||
auto beta_fp16 = convert::floatToHalf(beta);
|
||||
|
||||
// call cublasHgemm
|
||||
CUBLAS_CHECK(cublasHgemm(
|
||||
context->cublas_handle(),
|
||||
@ -353,10 +352,8 @@ void Gemv<float16, CUDAContext>(
|
||||
CUDA_R_16F,
|
||||
LDC));
|
||||
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
||||
__half alpha_fp16;
|
||||
alpha_fp16.x = convert::To<float, float16>(alpha).x;
|
||||
__half beta_fp16;
|
||||
beta_fp16.x = convert::To<float, float16>(beta).x;
|
||||
auto alpha_fp16 = convert::floatToHalf(alpha);
|
||||
auto beta_fp16 = convert::floatToHalf(beta);
|
||||
|
||||
CUBLAS_CHECK(cublasHgemm(
|
||||
context->cublas_handle(),
|
||||
|
@ -3,7 +3,9 @@
|
||||
set(Caffe2_known_gpu_archs8 "20 21(20) 30 35 50 52 60 61") # for CUDA 8.x
|
||||
set(Caffe2_known_gpu_archs7 "20 21(20) 30 35 50 52") # for CUDA 7.x
|
||||
# This list will be used for CUDA_ARCH_NAME = All option
|
||||
set(Caffe2_known_gpu_archs ${Caffe2_known_gpu_archs8}) # latest supported CUDA version is 8.x
|
||||
set(Caffe2_known_gpu_archs "20 21(20) 30 35 50 52 60 61 70")
|
||||
set(Caffe2_known_gpu_archs7 "20 21(20) 30 35 50 52")
|
||||
set(Caffe2_known_gpu_archs8 "20 21(20) 30 35 50 52 60 61")
|
||||
|
||||
################################################################################################
|
||||
# A function for automatic detection of GPUs installed (if autodetection is enabled)
|
||||
@ -88,6 +90,8 @@ function(caffe2_select_nvcc_arch_flags out_variable)
|
||||
set(__cuda_arch_bin "50")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
|
||||
set(__cuda_arch_bin "60 61")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
|
||||
set(__cuda_arch_bin "70")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
|
||||
set(__cuda_arch_bin ${Caffe2_known_gpu_archs})
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
|
||||
@ -187,11 +191,12 @@ elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
|
||||
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
|
||||
set(Caffe2_known_gpu_archs ${Caffe2_known_gpu_archs8})
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
|
||||
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the
|
||||
# warning for now.
|
||||
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
|
||||
else() # CUDA 9.x or later version
|
||||
message(STATUS "The CUDA version is not offcially supported yet, cmake process will continue")
|
||||
endif()
|
||||
|
||||
include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
|
||||
|
2
third_party/cub
vendored
2
third_party/cub
vendored
Submodule third_party/cub updated: 89de7ab201...01347a797c
Reference in New Issue
Block a user