diff --git a/c10/util/Half.h b/c10/util/Half.h index c7c17485ba8d..2bfbc168c66f 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -51,6 +51,16 @@ #include #endif +#if defined(__GNUC__) || defined(__clang__) +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \ + defined(_M_IX86) +#if defined(__F16C__) +#define X86_F16 1 +#include // import conversion ops from f16cintrin.h +#endif // defined(__F16C__) +#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86 +#endif // __GNUC__ || __clang__ + namespace c10 { namespace detail { @@ -161,6 +171,9 @@ inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { * between integer and floating-point variables. */ C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) { +#ifdef X86_F16 + return _cvtsh_ss(h); +#else /* * Extend the half-precision floating-point number to 32 bits and shift to the * upper part of the 32-bit word: @@ -283,6 +296,7 @@ C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) { (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); return fp32_from_bits(result); +#endif // X86_F16 } /* @@ -295,6 +309,9 @@ C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) { * between integer and floating-point variables. */ inline uint16_t fp16_ieee_from_fp32_value(float f) { +#ifdef X86_F16 + return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT); +#else // const float scale_to_inf = 0x1.0p+112f; // const float scale_to_zero = 0x1.0p-110f; constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; @@ -328,8 +345,13 @@ inline uint16_t fp16_ieee_from_fp32_value(float f) { return static_cast( (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); +#endif // X86_F16 } +#ifdef X86_F16 +#undef X86_F16 +#endif // X86_F16 + #if defined(__aarch64__) && !defined(__CUDACC__) inline float16_t fp16_from_bits(uint16_t h) { return c10::bit_cast(h);