mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port X86_F16 from executorch half to PyTorch half (#140720)
This was added in https://github.com/pytorch/executorch/pull/1789 . I'm working on sharing Half.h with ExecuTorch, and this is a missing feature. Differential Revision: [D65949409](https://our.internmc.facebook.com/intern/diff/D65949409/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140720 Approved by: https://github.com/malfet ghstack dependencies: #140564, #140565, #140566, #140567
This commit is contained in:
committed by
PyTorch MergeBot
parent
43de32d948
commit
17bb78a3d3
@ -51,6 +51,16 @@
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__) || defined(__clang__)
|
||||
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \
|
||||
defined(_M_IX86)
|
||||
#if defined(__F16C__)
|
||||
#define X86_F16 1
|
||||
#include <immintrin.h> // import conversion ops from f16cintrin.h
|
||||
#endif // defined(__F16C__)
|
||||
#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86
|
||||
#endif // __GNUC__ || __clang__
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
@ -161,6 +171,9 @@ inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
|
||||
* between integer and floating-point variables.
|
||||
*/
|
||||
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
|
||||
#ifdef X86_F16
|
||||
return _cvtsh_ss(h);
|
||||
#else
|
||||
/*
|
||||
* Extend the half-precision floating-point number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
@ -283,6 +296,7 @@ C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
|
||||
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
|
||||
: fp32_to_bits(normalized_value));
|
||||
return fp32_from_bits(result);
|
||||
#endif // X86_F16
|
||||
}
|
||||
|
||||
/*
|
||||
@ -295,6 +309,9 @@ C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
|
||||
* between integer and floating-point variables.
|
||||
*/
|
||||
inline uint16_t fp16_ieee_from_fp32_value(float f) {
|
||||
#ifdef X86_F16
|
||||
return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
|
||||
#else
|
||||
// const float scale_to_inf = 0x1.0p+112f;
|
||||
// const float scale_to_zero = 0x1.0p-110f;
|
||||
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
|
||||
@ -328,8 +345,13 @@ inline uint16_t fp16_ieee_from_fp32_value(float f) {
|
||||
return static_cast<uint16_t>(
|
||||
(sign >> 16) |
|
||||
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
|
||||
#endif // X86_F16
|
||||
}
|
||||
|
||||
#ifdef X86_F16
|
||||
#undef X86_F16
|
||||
#endif // X86_F16
|
||||
|
||||
#if defined(__aarch64__) && !defined(__CUDACC__)
|
||||
inline float16_t fp16_from_bits(uint16_t h) {
|
||||
return c10::bit_cast<float16_t>(h);
|
||||
|
Reference in New Issue
Block a user