mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use fp16<->fp32 intrinsic (#17496)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17496 As title. Reviewed By: hyuen Differential Revision: D14222907 fbshipit-source-id: d5d6c032e725ca8b52aca2be7401ec3c59f6a242
This commit is contained in:
committed by
Facebook Github Bot
parent
f8778aef78
commit
1d522598fb
@ -104,9 +104,12 @@ void adagrad_fp16_update_prefetch__avx_f16c(
|
|||||||
|
|
||||||
for (; i < N; ++i) {
|
for (; i < N; ++i) {
|
||||||
float gi = g[i];
|
float gi = g[i];
|
||||||
float hi = h[i] + gi * gi;
|
float nhi =
|
||||||
nh[i] = hi;
|
_cvtsh_ss(reinterpret_cast<const unsigned short*>(h)[i]) + gi * gi;
|
||||||
nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
|
reinterpret_cast<unsigned short*>(nh)[i] = _cvtss_sh(nhi, 0);
|
||||||
|
float nwi = _cvtsh_ss(reinterpret_cast<const unsigned short*>(w)[i]) +
|
||||||
|
lr * gi / (std::sqrt(nhi) + epsilon);
|
||||||
|
reinterpret_cast<unsigned short*>(nw)[i] = _cvtss_sh(nwi, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
|
|
||||||
#if __APPLE_NEED_FIX || __CLANG_NEED_FIX
|
#if __APPLE_NEED_FIX || __CLANG_NEED_FIX
|
||||||
|
|
||||||
|
#include <c10/util/Half.h>
|
||||||
#include <emmintrin.h>
|
#include <emmintrin.h>
|
||||||
|
|
||||||
// This version of clang has a bug that _cvtsh_ss is not defined, see
|
// This version of clang has a bug that _cvtsh_ss is not defined, see
|
||||||
@ -25,6 +26,14 @@ _cvtsh_ss(unsigned short a)
|
|||||||
return r[0];
|
return r[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __inline unsigned short
|
||||||
|
__attribute__((__always_inline__, __nodebug__, __target__("f16c")))
|
||||||
|
_cvtss_sh(float a, int imm8) {
|
||||||
|
unsigned short ret;
|
||||||
|
*reinterpret_cast<at::Half*>(&ret) = a;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
#endif // __APPLE_NEED_FIX || __CLANG_NEED_FIX
|
#endif // __APPLE_NEED_FIX || __CLANG_NEED_FIX
|
||||||
|
|
||||||
#undef __APPLE_NEED_FIX
|
#undef __APPLE_NEED_FIX
|
||||||
@ -32,6 +41,7 @@ _cvtsh_ss(unsigned short a)
|
|||||||
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
|
|
||||||
|
#include <c10/util/Half.h>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
// It seems that microsoft msvc does not have a _cvtsh_ss implementation so
|
// It seems that microsoft msvc does not have a _cvtsh_ss implementation so
|
||||||
@ -54,4 +64,10 @@ static inline float _cvtsh_ss(unsigned short x) {
|
|||||||
return t1.floatval;
|
return t1.floatval;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline unsigned short _cvtss_sh(float x, int imm8) {
|
||||||
|
unsigned short ret;
|
||||||
|
*reinterpret_cast<at::Half*>(&ret) = x;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
#endif // _MSC_VER
|
#endif // _MSC_VER
|
||||||
|
Reference in New Issue
Block a user