mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix compile error for nv_bloat162 (#7248)
some systems seem not to have the __nv_bfloat162 definition so a placeholder was introduced. newer CUDA libs have that definition, which breaks the compile process. this patch adds the official cuda_bf16.h guard while keeping the old code and a safety assert in case the definition should change in the future. see #7190 for reference --------- Signed-off-by: LosCrossos <165311345+loscrossos@users.noreply.github.com> Signed-off-by: LosCrossos <165311345+mytait@users.noreply.github.com> Co-authored-by: LosCrossos <165311345+mytait@users.noreply.github.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
@ -13,8 +13,14 @@ namespace cg = cooperative_groups;
|
||||
|
||||
// only used to avoid compilation error due to lack of definition.
|
||||
#ifndef BF16_AVAILABLE
|
||||
#if defined(__CUDA_BF16_H__)
|
||||
static_assert(sizeof(__nv_bfloat162) == sizeof(__half2),
|
||||
"CUDA's __nv_bfloat162 doesn't match __half2 size");
|
||||
#else
|
||||
// Fallback to simple typedef only if CUDA doesn't provide it
|
||||
using __nv_bfloat162 = __half2;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
inline __device__ float gelu(const float x)
|
||||
{
|
||||
|
@ -12,8 +12,14 @@ namespace cg = cooperative_groups;
|
||||
|
||||
// only used to avoid compilation error due to lack of definition.
|
||||
#ifndef BF16_AVAILABLE
|
||||
#if defined(__CUDA_BF16_H__)
|
||||
static_assert(sizeof(__nv_bfloat162) == sizeof(__half2),
|
||||
"CUDA's __nv_bfloat162 doesn't match __half2 size");
|
||||
#else
|
||||
// Fallback to simple typedef only if CUDA doesn't provide it
|
||||
using __nv_bfloat162 = __half2;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Bias add
|
||||
|
||||
|
Reference in New Issue
Block a user