From 1b45f6839789f3849fbde01c4fd88e5034172c44 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Wed, 26 Sep 2018 15:19:10 -0700 Subject: [PATCH] Use atomicAdd from cuda_fp16 header when building with CUDA 10 (#12108) Summary: An efficient atomicAdd for halfs has been added in `cuda_fp16.h` in CUDA 10: ```__CUDA_FP16_DECL__ __half atomicAdd(__half *address, __half val);``` Through this change, PyTorch will be able to utilize efficient atomicAdd when building with CUDA 10. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12108 Differential Revision: D10053385 Pulled By: soumith fbshipit-source-id: 946c90691a8f6bdcf6d6e367a507ac3c9970b750 --- aten/src/THC/THCAtomics.cuh | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/aten/src/THC/THCAtomics.cuh b/aten/src/THC/THCAtomics.cuh index 756fa0f905ac..8752b0df458c 100644 --- a/aten/src/THC/THCAtomics.cuh +++ b/aten/src/THC/THCAtomics.cuh @@ -96,19 +96,24 @@ static inline __device__ void atomicAdd(int64_t *address, int64_t val) { } static inline __device__ void atomicAdd(at::Half *address, at::Half val) { - unsigned int * address_as_ui = - (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; + #if ((CUDA_VERSION < 10000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + unsigned int * address_as_ui = + (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + at::Half hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + hsum = THCNumerics::add(hsum, val); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); + #else + atomicAdd(reinterpret_cast<__half*>(address), val); + #endif - do { - assumed = old; - at::Half hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - hsum = THCNumerics::add(hsum, val); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } while (assumed != old); } #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)