Use atomicAdd from cuda_fp16 header when building with CUDA 10 (#12108)

Summary:
An efficient atomicAdd for halfs has been added in `cuda_fp16.h` in CUDA 10:
```__CUDA_FP16_DECL__ __half atomicAdd(__half *address, __half val);```

Through this change, PyTorch will be able to utilize efficient atomicAdd when building with CUDA 10.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12108

Differential Revision: D10053385

Pulled By: soumith

fbshipit-source-id: 946c90691a8f6bdcf6d6e367a507ac3c9970b750
This commit is contained in:
Syed Tousif Ahmed
2018-09-26 15:19:10 -07:00
committed by Facebook Github Bot
parent 6ff568df4d
commit 1b45f68397

View File

@ -96,19 +96,24 @@ static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
}
static inline __device__ void atomicAdd(at::Half *address, at::Half val) {
unsigned int * address_as_ui =
(unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
#if ((CUDA_VERSION < 10000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
unsigned int * address_as_ui =
(unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
at::Half hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
hsum = THCNumerics<at::Half>::add(hsum, val);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
#else
atomicAdd(reinterpret_cast<__half*>(address), val);
#endif
do {
assumed = old;
at::Half hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
hsum = THCNumerics<at::Half>::add(hsum, val);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)