Port X86_F16 from executorch half to PyTorch half (#140720)

This was added in https://github.com/pytorch/executorch/pull/1789 . I'm working on sharing Half.h with ExecuTorch, and this is a missing feature.

Differential Revision: [D65949409](https://our.internmc.facebook.com/intern/diff/D65949409/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140720
Approved by: https://github.com/malfet
ghstack dependencies: #140564, #140565, #140566, #140567
This commit is contained in:
Scott Wolchok
2024-11-15 11:13:49 -08:00
committed by PyTorch MergeBot
parent 43de32d948
commit 17bb78a3d3

View File

@ -51,6 +51,16 @@
#include <arm_neon.h>
#endif
#if defined(__GNUC__) || defined(__clang__)
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \
defined(_M_IX86)
#if defined(__F16C__)
#define X86_F16 1
#include <immintrin.h> // import conversion ops from f16cintrin.h
#endif // defined(__F16C__)
#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86
#endif // __GNUC__ || __clang__
namespace c10 {
namespace detail {
@ -161,6 +171,9 @@ inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
* between integer and floating-point variables.
*/
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
#ifdef X86_F16
return _cvtsh_ss(h);
#else
/*
* Extend the half-precision floating-point number to 32 bits and shift to the
* upper part of the 32-bit word:
@ -283,6 +296,7 @@ C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
: fp32_to_bits(normalized_value));
return fp32_from_bits(result);
#endif // X86_F16
}
/*
@ -295,6 +309,9 @@ C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
* between integer and floating-point variables.
*/
inline uint16_t fp16_ieee_from_fp32_value(float f) {
#ifdef X86_F16
return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
#else
// const float scale_to_inf = 0x1.0p+112f;
// const float scale_to_zero = 0x1.0p-110f;
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
@ -328,8 +345,13 @@ inline uint16_t fp16_ieee_from_fp32_value(float f) {
return static_cast<uint16_t>(
(sign >> 16) |
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
#endif // X86_F16
}
#ifdef X86_F16
#undef X86_F16
#endif // X86_F16
#if defined(__aarch64__) && !defined(__CUDACC__)
inline float16_t fp16_from_bits(uint16_t h) {
return c10::bit_cast<float16_t>(h);