Fix compile error for nv_bloat162 (#7248)

some systems seem not to have the __nv_bfloat162 definition so a
placeholder was introduced. newer CUDA libs have that definition, which
breaks the compile process. this patch adds the official cuda_bf16.h
guard while keeping the old code and a safety assert in case the
definition should change in the future. see #7190 for reference

---------

Signed-off-by: LosCrossos <165311345+loscrossos@users.noreply.github.com>
Signed-off-by: LosCrossos <165311345+mytait@users.noreply.github.com>
Co-authored-by: LosCrossos <165311345+mytait@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
LosCrossos
2025-04-27 07:16:34 +02:00
committed by GitHub
parent fff77bd293
commit ee492c30a7
2 changed files with 12 additions and 0 deletions

View File

@ -13,8 +13,14 @@ namespace cg = cooperative_groups;
// only used to avoid compilation error due to lack of definition. // only used to avoid compilation error due to lack of definition.
#ifndef BF16_AVAILABLE #ifndef BF16_AVAILABLE
#if defined(__CUDA_BF16_H__)
static_assert(sizeof(__nv_bfloat162) == sizeof(__half2),
"CUDA's __nv_bfloat162 doesn't match __half2 size");
#else
// Fallback to simple typedef only if CUDA doesn't provide it
using __nv_bfloat162 = __half2; using __nv_bfloat162 = __half2;
#endif #endif
#endif
inline __device__ float gelu(const float x) inline __device__ float gelu(const float x)
{ {

View File

@ -12,8 +12,14 @@ namespace cg = cooperative_groups;
// only used to avoid compilation error due to lack of definition. // only used to avoid compilation error due to lack of definition.
#ifndef BF16_AVAILABLE #ifndef BF16_AVAILABLE
#if defined(__CUDA_BF16_H__)
static_assert(sizeof(__nv_bfloat162) == sizeof(__half2),
"CUDA's __nv_bfloat162 doesn't match __half2 size");
#else
// Fallback to simple typedef only if CUDA doesn't provide it
using __nv_bfloat162 = __half2; using __nv_bfloat162 = __half2;
#endif #endif
#endif
// Bias add // Bias add