Files
pytorch/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
PyTorch MergeBot dbb55b448b Revert "[7/N] Fix Wextra-semi warning (#140225)"
This reverts commit ffb979032dc149b4c895526fe5b92d713ed7b1e1.

Reverted https://github.com/pytorch/pytorch/pull/140225 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/140225#issuecomment-2469312229))
2024-11-12 00:02:06 +00:00

1671 lines
61 KiB
C++

#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX512)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512)
#ifndef SLEEF_CONST
#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER)
#define SLEEF_CONST const
#else
#define SLEEF_CONST
#endif
#define SLEEF_CONST_OLD SLEEF_CONST
#else
#define SLEEF_CONST_OLD
#endif
// bfloat16 conversion
static inline void cvtbf16_fp32(const __m256i& a, __m512& o) {
o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
}
static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
__m256i lo = _mm512_extracti32x8_epi32(a, 0);
__m256i hi = _mm512_extracti32x8_epi32(a, 1);
cvtbf16_fp32(lo, o1);
cvtbf16_fp32(hi, o2);
}
static inline __m256i cvtfp32_bf16(const __m512& src) {
__m512i value = _mm512_castps_si512(src);
__m512i nan = _mm512_set1_epi32(0xffff);
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
__m512i ones = _mm512_set1_epi32(0x1);
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
// uint32_t lsb = (input >> 16) & 1;
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
// uint32_t rounding_bias = 0x7fff + lsb;
t_value = _mm512_add_epi32(t_value, vec_bias);
// input += rounding_bias;
t_value = _mm512_add_epi32(t_value, value);
// input = input >> 16;
t_value = _mm512_srli_epi32(t_value, 16);
// Check NaN before converting back to bf16
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
return _mm512_cvtusepi32_epi16(t_value);
}
static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) {
__m512i lo = _mm512_castps_si512(a);
__m512i hi = _mm512_castps_si512(b);
__m512i nan = _mm512_set1_epi32(0xffff);
auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q);
__m512i ones = _mm512_set1_epi32(0x1);
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
// uint32_t lsb = (input >> 16) & 1;
auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones);
auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones);
// uint32_t rounding_bias = 0x7fff + lsb;
t_lo = _mm512_add_epi32(t_lo, vec_bias);
t_hi = _mm512_add_epi32(t_hi, vec_bias);
// input += rounding_bias;
t_lo = _mm512_add_epi32(t_lo, lo);
t_hi = _mm512_add_epi32(t_hi, hi);
// input = input >> 16;
t_lo = _mm512_srli_epi32(t_lo, 16);
t_hi = _mm512_srli_epi32(t_hi, 16);
// Check NaN before converting back to bf16
t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo);
t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi);
t_lo = _mm512_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4]
__m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
return _mm512_permutexvar_epi64(idx, t_lo);
}
static inline __m512i merge_compare_result(const __m512& a, const __m512& b) {
__m512i lo = _mm512_castps_si512(a);
__m512i hi = _mm512_castps_si512(b);
lo = _mm512_srli_epi32(lo, 16);
hi = _mm512_srli_epi32(hi, 16);
auto out = _mm512_packus_epi32(lo, hi);
__m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
return _mm512_permutexvar_epi64(idx, out);
}
// float16 conversion
static inline void cvtfp16_fp32(const __m256i& a, __m512& o) {
o = _mm512_cvtph_ps(a);
}
static inline void cvtfp16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
__m256i lo = _mm512_extracti32x8_epi32(a, 0);
__m256i hi = _mm512_extracti32x8_epi32(a, 1);
cvtfp16_fp32(lo, o1);
cvtfp16_fp32(hi, o2);
}
static inline __m256i cvtfp32_fp16(const __m512& src) {
return _mm512_cvtps_ph(
src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
static inline __m512i cvtfp32_fp16(const __m512& a, const __m512& b) {
__m256i lo = _mm512_cvtps_ph(
a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
__m256i hi = _mm512_cvtps_ph(
b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
__m512 t_lo = _mm512_castsi512_ps(_mm512_castsi256_si512(lo));
__m256 t_hi = _mm256_castsi256_ps(hi);
return _mm512_castps_si512(_mm512_insertf32x8(t_lo, t_hi, 1));
}
// dtype conversion between float16/bfloat16 and float32
template <typename T, typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline void cvt_to_fp32(const __m256i& a, __m512& o);
template <> inline void cvt_to_fp32<BFloat16>(const __m256i& a, __m512& o) {
cvtbf16_fp32(a, o);
}
template <> inline void cvt_to_fp32<Half>(const __m256i& a, __m512& o) {
cvtfp16_fp32(a, o);
}
template <typename T, typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2);
template <> inline void cvt_to_fp32<BFloat16>(const __m512i& a, __m512& o1, __m512& o2) {
cvtbf16_fp32(a, o1, o2);
}
template <> inline void cvt_to_fp32<Half>(const __m512i& a, __m512& o1, __m512& o2) {
cvtfp16_fp32(a, o1, o2);
}
template <typename T, bool is_compare_op = false,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline __m512i cvt_from_fp32(const __m512& a, const __m512& b);
template <> inline __m512i cvt_from_fp32<BFloat16, false>(const __m512& a, const __m512& b) {
return cvtfp32_bf16(a, b);
}
template <> inline __m512i cvt_from_fp32<BFloat16, true>(const __m512& a, const __m512& b) {
return merge_compare_result(a, b);
}
template <> inline __m512i cvt_from_fp32<Half, false>(const __m512& a, const __m512& b) {
return cvtfp32_fp16(a, b);
}
template <> inline __m512i cvt_from_fp32<Half, true>(const __m512& a, const __m512& b) {
return cvtfp32_fp16(a, b);
}
template <typename T>
class Vectorized16 {
static_assert(
is_reduced_floating_point_v<T>,
"Support only float16 and bfloat16.");
private:
__m512i values;
public:
using value_type = uint16_t;
using size_type = int;
static constexpr size_type size() {
return 32;
}
Vectorized16() {}
Vectorized16(__m512i v) : values(v) {}
Vectorized16(T val) {
value_type uw = val.x;
values = _mm512_set1_epi16(uw);
}
Vectorized16(T val1, T val2, T val3, T val4,
T val5, T val6, T val7, T val8,
T val9, T val10, T val11, T val12,
T val13, T val14, T val15, T val16,
T val17, T val18, T val19, T val20,
T val21, T val22, T val23, T val24,
T val25, T val26, T val27, T val28,
T val29, T val30, T val31, T val32) {
values = _mm512_set_epi16(
val32.x, val31.x, val30.x, val29.x, val28.x, val27.x, val26.x, val25.x,
val24.x, val23.x, val22.x, val21.x, val20.x, val19.x, val18.x, val17.x,
val16.x, val15.x, val14.x, val13.x, val12.x, val11.x, val10.x, val9.x,
val8.x, val7.x, val6.x, val5.x, val4.x, val3.x, val2.x, val1.x);
}
operator __m512i() const {
return values;
}
T& operator[](int idx) = delete;
const T& operator[](int idx) const = delete;
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0));
}
static Vectorized<T> loadu(const void* ptr, int16_t count = size()) {
if (count == size())
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
__mmask32 mask = (1ULL << count) - 1;
return _mm512_maskz_loadu_epi16(mask, ptr);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
} else if (count > 0) {
__mmask32 mask = (1ULL << count) - 1;
_mm512_mask_storeu_epi16(ptr, mask, values);
}
}
template <int64_t mask>
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
return _mm512_mask_blend_epi16(mask, a.values, b.values);
}
static Vectorized<T> blendv(const Vectorized<T>& a,
const Vectorized<T>& b, const Vectorized<T>& mask) {
auto all_ones = _mm512_set1_epi16(0xFFFF);
auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ);
return _mm512_mask_blend_epi16(mask_, a.values, b.values);
}
template<typename step_t>
static Vectorized<T> arange(T base = 0.f, step_t step = static_cast<step_t>(1)) {
return Vectorized<T>(
base, base + step, base + 2 * step, base + 3 * step,
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
}
static Vectorized<T> set(const Vectorized<T>& a,
const Vectorized<T>& b, int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
case 8:
return blend<255>(a, b);
case 9:
return blend<511>(a, b);
case 10:
return blend<1023>(a, b);
case 11:
return blend<2047>(a, b);
case 12:
return blend<4095>(a, b);
case 13:
return blend<8191>(a, b);
case 14:
return blend<16383>(a, b);
case 15:
return blend<32767>(a, b);
case 16:
return blend<65535>(a, b);
case 17:
return blend<131071>(a, b);
case 18:
return blend<262143>(a, b);
case 19:
return blend<524287>(a, b);
case 20:
return blend<1048575>(a, b);
case 21:
return blend<2097151>(a, b);
case 22:
return blend<4194303>(a, b);
case 23:
return blend<8388607>(a, b);
case 24:
return blend<16777215>(a, b);
case 25:
return blend<33554431>(a, b);
case 26:
return blend<67108863>(a, b);
case 27:
return blend<134217727>(a, b);
case 28:
return blend<268435455>(a, b);
case 29:
return blend<536870911>(a, b);
case 30:
return blend<1073741823>(a, b);
case 31:
return blend<2147483647>(a, b);
}
return b;
}
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wignored-qualifiers"
Vectorized<T> map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
const auto o1 = vop(lo);
const auto o2 = vop(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> isnan() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
__mmask16 lo_mask, hi_mask;
__m512 zero = _mm512_set1_ps(0.0);
__m512i zeroi = _mm512_castps_si512(zero);
lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q);
lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF));
hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q);
hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF));
return merge_compare_result(lo, hi);
}
#pragma clang diagnostic pop
Vectorized<T> abs() const {
return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values);
}
Vectorized<T> angle() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto angle_lambda = [](__m512 values) {
const auto zero_vec = _mm512_set1_ps(0.f);
const auto nan_vec = _mm512_set1_ps(NAN);
const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
const auto non_nan_mask_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
not_nan_mask, 0xFFFFFFFF);
const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(non_nan_mask_vec),
zero_vec, _CMP_EQ_OQ);
const auto pi = _mm512_set1_ps(c10::pi<float>);
const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
return angle;
};
auto o1 = angle_lambda(lo);
auto o2 = angle_lambda(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> real() const {
return *this;
}
Vectorized<T> imag() const {
return _mm512_set1_epi16(0);
}
Vectorized<T> conj() const {
return *this;
}
Vectorized<T> acos() const {
return map(Sleef_acosf16_u10);
}
Vectorized<T> acosh() const {
return map(Sleef_acoshf16_u10);
}
Vectorized<T> asin() const {
return map(Sleef_asinf16_u10);
}
Vectorized<T> atan() const {
return map(Sleef_atanf16_u10);
}
Vectorized<T> atanh() const {
return map(Sleef_atanhf16_u10);
}
Vectorized<T> atan2(const Vectorized<T> &b) const {
__m512 lo, hi;
__m512 b1, b2;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(b.values, b1, b2);
auto o1 = Sleef_atan2f16_u10(lo, b1);
auto o2 = Sleef_atan2f16_u10(hi, b2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> copysign(const Vectorized<T> &sign) const {
// copy sign bit (0x8000) from sign and remaining bits from values
__m512i mask_value = _mm512_set1_epi32(~0x80008000);
__m512i mask_signbit = _mm512_set1_epi32(0x80008000);
return Vectorized<T>(
_mm512_or_si512(
_mm512_and_si512(values, mask_value),
_mm512_and_si512(sign, mask_signbit)));
}
Vectorized<T> erf() const {
return map(Sleef_erff16_u10);
}
Vectorized<T> erfc() const {
return map(Sleef_erfcf16_u15);
}
Vectorized<T> erfinv() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (int64_t i = 0; i < size() / 2; i++) {
tmp1[i] = calc_erfinv(tmp1[i]);
tmp2[i] = calc_erfinv(tmp2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> exp() const {
return map(Sleef_expf16_u10);
}
Vectorized<T> exp2() const {
return map(Sleef_exp2f16_u10);
}
Vectorized<T> expm1() const {
return map(Sleef_expm1f16_u10);
}
Vectorized<T> exp_u20() const {
return exp();
}
Vectorized<T> fmod(const Vectorized<T> & q) const {
__m512 x_lo, x_hi;
cvt_to_fp32<T>(values, x_lo, x_hi);
__m512 q_lo, q_hi;
cvtbf16_fp32(q.values, q_lo, q_hi);
auto o1 = Sleef_fmodf16(x_lo, q_lo);
auto o2 = Sleef_fmodf16(x_hi, q_hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> hypot(const Vectorized<T> &b) const {
__m512 lo, hi;
__m512 b1, b2;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(b.values, b1, b2);
auto o1 = Sleef_hypotf16_u05(lo, b1);
auto o2 = Sleef_hypotf16_u05(hi, b2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> i0() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (int64_t i = 0; i < size() / 2; i++) {
tmp1[i] = calc_i0(tmp1[i]);
tmp2[i] = calc_i0(tmp2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> i0e() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
constexpr auto sz = size();
__at_align__ float tmp1[sz / 2], tmp2[sz / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (auto i = decltype(sz){0}; i < sz / 2; i++) {
tmp1[i] = calc_i0e(tmp1[i]);
tmp2[i] = calc_i0e(tmp2[i]);
}
const auto o1 = _mm512_loadu_ps(tmp1);
const auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> digamma() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
constexpr auto sz = size();
__at_align__ float tmp1[sz / 2], tmp2[sz / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (auto i = decltype(sz){0}; i < sz / 2; i++) {
tmp1[i] = calc_digamma(tmp1[i]);
tmp2[i] = calc_digamma(tmp2[i]);
}
const auto o1 = _mm512_loadu_ps(tmp1);
const auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> igamma(const Vectorized<T> &x) const {
__m512 lo, hi;
__m512 xlo, xhi;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(x.values, xlo, xhi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) {
tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> igammac(const Vectorized<T> &x) const {
__m512 lo, hi;
__m512 xlo, xhi;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(x.values, xlo, xhi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) {
tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> log() const {
return map(Sleef_logf16_u10);
}
Vectorized<T> log2() const {
return map(Sleef_log2f16_u10);
}
Vectorized<T> log10() const {
return map(Sleef_log10f16_u10);
}
Vectorized<T> log1p() const {
return map(Sleef_log1pf16_u10);
}
Vectorized<T> sin() const {
return map(Sleef_sinf16_u10);
}
Vectorized<T> sinh() const {
return map(Sleef_sinhf16_u10);
}
Vectorized<T> cos() const {
return map(Sleef_cosf16_u10);
}
Vectorized<T> cosh() const {
return map(Sleef_coshf16_u10);
}
Vectorized<T> ceil() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 = _mm512_ceil_ps(lo);
auto o2 = _mm512_ceil_ps(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> floor() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 = _mm512_floor_ps(lo);
auto o2 = _mm512_floor_ps(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> neg() const {
return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000));
}
Vectorized<T> round() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> tan() const {
return map(Sleef_tanf16_u10);
}
Vectorized<T> tanh() const {
return map(Sleef_tanhf16_u10);
}
Vectorized<T> trunc() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> lgamma() const {
return map(Sleef_lgammaf16_u10);
}
Vectorized<T> sqrt() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 = _mm512_sqrt_ps(lo);
auto o2 = _mm512_sqrt_ps(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> reciprocal() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto ones = _mm512_set1_ps(1);
auto o1 = _mm512_div_ps(ones, lo);
auto o2 = _mm512_div_ps(ones, hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> rsqrt() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto ones = _mm512_set1_ps(1);
auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo));
auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi));
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> pow(const Vectorized<T> &b) const {
__m512 lo, hi;
__m512 b1, b2;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(b.values, b1, b2);
auto o1 = Sleef_powf16_u10(lo, b1);
auto o2 = Sleef_powf16_u10(hi, b2);
return cvt_from_fp32<T>(o1, o2);
}
private:
template<typename Op>
Vectorized<T> inline binary_compare(const Vectorized<T>& b, Op op) const {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
cvt_to_fp32<T>(values, a_lo, a_hi);
cvt_to_fp32<T>(b.values, b_lo, b_hi);
auto o1 = op(a_lo, b_lo);
auto o2 = op(a_hi, b_hi);
return cvt_from_fp32<T, /*is_compare_op*/true>(o1, o2);
}
public:
Vectorized<T> inline operator>(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator<(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator>=(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator<=(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator==(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator!=(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
};
template<typename T, typename Op>
static inline Vectorized<T> binary_op_as_fp32(const Vectorized<T>& a, const Vectorized<T>& b, Op op) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
cvt_to_fp32<T>(__m512i(a), a_lo, a_hi);
cvt_to_fp32<T>(__m512i(b), b_lo, b_hi);
auto o1 = op(a_lo, b_lo);
auto o2 = op(a_hi, b_hi);
return cvt_from_fp32<T>(o1, o2);
}
template <>
class Vectorized<BFloat16>: public Vectorized16<BFloat16> {
public:
using Vectorized16::Vectorized16;
using value_type = BFloat16;
Vectorized<BFloat16> frac() const;
Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const;
};
Vectorized<BFloat16> inline operator+(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
}
Vectorized<BFloat16> inline operator-(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
}
Vectorized<BFloat16> inline operator*(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
}
Vectorized<BFloat16> inline operator/(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
}
Vectorized<BFloat16> inline operator&(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return _mm512_and_si512(a, b);
}
Vectorized<BFloat16> inline operator|(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return _mm512_or_si512(a, b);
}
Vectorized<BFloat16> inline operator^(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
return _mm512_xor_si512(a, b);
}
inline Vectorized<BFloat16> Vectorized<BFloat16>::eq(const Vectorized<BFloat16>& other) const {
return (*this == other) & Vectorized<BFloat16>(1.0f);
}
inline Vectorized<BFloat16> Vectorized<BFloat16>::ne(const Vectorized<BFloat16>& other) const {
return (*this != other) & Vectorized<BFloat16>(1.0f);
}
inline Vectorized<BFloat16> Vectorized<BFloat16>::gt(const Vectorized<BFloat16>& other) const {
return (*this > other) & Vectorized<BFloat16>(1.0f);
}
inline Vectorized<BFloat16> Vectorized<BFloat16>::ge(const Vectorized<BFloat16>& other) const {
return (*this >= other) & Vectorized<BFloat16>(1.0f);
}
inline Vectorized<BFloat16> Vectorized<BFloat16>::lt(const Vectorized<BFloat16>& other) const {
return (*this < other) & Vectorized<BFloat16>(1.0f);
}
inline Vectorized<BFloat16> Vectorized<BFloat16>::le(const Vectorized<BFloat16>& other) const {
return (*this <= other) & Vectorized<BFloat16>(1.0f);
}
// frac. Implement this here so we can use subtraction
inline Vectorized<BFloat16> Vectorized<BFloat16>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<BFloat16> inline maximum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
auto max_lo = _mm512_max_ps(a_lo, b_lo);
auto max_hi = _mm512_max_ps(a_hi, b_hi);
auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
// Exploit the fact that all-ones is a NaN.
auto o1 = _mm512_or_ps(max_lo, nan_lo);
auto o2 = _mm512_or_ps(max_hi, nan_hi);
return cvtfp32_bf16(o1, o2);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<BFloat16> inline minimum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
__m512i zero_vec = _mm512_set1_epi32(0);
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
auto min_lo = _mm512_min_ps(a_lo, b_lo);
auto min_hi = _mm512_min_ps(a_hi, b_hi);
auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
0xFFFFFFFF));
auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
0xFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
auto o1 = _mm512_or_ps(min_lo, nan_lo);
auto o2 = _mm512_or_ps(min_hi, nan_hi);
return cvtfp32_bf16(o1, o2);
}
template <>
Vectorized<BFloat16> inline clamp(const Vectorized<BFloat16>& a,
const Vectorized<BFloat16>& min, const Vectorized<BFloat16>& max) {
__m512 a_lo, a_hi;
__m512 min_lo, min_hi;
__m512 max_lo, max_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(min), min_lo, min_hi);
cvtbf16_fp32(__m512i(max), max_lo, max_hi);
auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
return cvtfp32_bf16(o1, o2);
}
template <>
Vectorized<BFloat16> inline clamp_max(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& max) {
__m512 a_lo, a_hi;
__m512 max_lo, max_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(max), max_lo, max_hi);
auto o1 = _mm512_min_ps(max_lo, a_lo);
auto o2 = _mm512_min_ps(max_hi, a_hi);
return cvtfp32_bf16(o1, o2);
}
template <>
Vectorized<BFloat16> inline clamp_min(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& min) {
__m512 a_lo, a_hi;
__m512 min_lo, min_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(min), min_lo, min_hi);
auto o1 = _mm512_max_ps(min_lo, a_lo);
auto o2 = _mm512_max_ps(min_hi, a_hi);
return cvtfp32_bf16(o1, o2);
}
template <>
inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<BFloat16>::size()); i += Vectorized<BFloat16>::size()) {
auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
_mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = src[i];
}
}
template <>
inline void convert(const float* src, BFloat16* dst, int64_t n) {
int64_t i;
for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) {
__m512 a = _mm512_loadu_ps(&src[i]);
__m512 b = _mm512_loadu_ps(&src[i + 16]);
__m512i bf = cvtfp32_bf16(a, b);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
}
for (; i < n; i++) {
dst[i] = c10::convert<BFloat16>(src[i]);
}
}
template <>
inline void convert(const double* src, BFloat16* dst, int64_t n) {
auto load_float = [](const double *src) -> __m512 {
// Load one float vector from an array of doubles
__m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src));
__m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8));
return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
};
int64_t i;
for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) {
__m512 a = load_float(&src[i]);
__m512 b = load_float(&src[i + 16]);
__m512i bf = cvtfp32_bf16(a, b);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
}
for (; i < n; i++) {
dst[i] = c10::convert<BFloat16>(src[i]);
}
}
template <>
Vectorized<BFloat16> inline fmadd(const Vectorized<BFloat16>& a,
const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& c) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
__m512 c_lo, c_hi;
cvtbf16_fp32(__m512i(a), a_lo, a_hi);
cvtbf16_fp32(__m512i(b), b_lo, b_hi);
cvtbf16_fp32(__m512i(c), c_lo, c_hi);
auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
return cvtfp32_bf16(o1, o2);
}
static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) {
__m512i r[8];
// a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15 e0e1 e2e3 e4e5 e6e7 e8e9 e10e11 e12e13 e14e15
// b0-b15 f0-f15
// c0-c15 g0-g15
// d0-d15 h0-h15
// i0-i15 m0-m15
// j0-j15 n0-n15
// k0-k15 o0-o15
// l0-l15 p0-p15
#ifndef __msvc_cl__
#pragma unroll(4)
#endif
for (int i = 0; i < 4; i++) {
r[i] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i]), t[i + 4], 0x01);
r[i + 4] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i + 8]), t[i + 12], 0x01);
}
// u0: a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 f8f9 e10e11 f10f11
// u1: a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 f4f5 e6e7 f6f7 e12e13 f12f13 e14e15 f14f15
// u2: c0c1 d0d1 c2c3 d2d3 c8c9 d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 g10g11 h10h11
// u3: c4c5 d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7 g12g13 h12h13 g14g15 h14h15
// i j m n
// k l o p
#ifndef __msvc_cl__
#pragma unroll(4)
#endif
for (int i = 0; i < 8; i += 2) {
u[i] = _mm512_unpacklo_epi32(r[i], r[i + 1]);
u[i + 1] = _mm512_unpackhi_epi32(r[i], r[i + 1]);
}
// r0: a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 g8g9 h8h9
// r1: a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 g2g3 h2h3 e10e11 f10f11 g10g11 h10h11
// r2: a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 c12c13 d12d13
// r3: a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15
// r4: i j k l m n o p
r[0] = _mm512_unpacklo_epi64(u[0], u[2]);
r[1] = _mm512_unpackhi_epi64(u[0], u[2]);
r[2] = _mm512_unpacklo_epi64(u[1], u[3]);
r[3] = _mm512_unpackhi_epi64(u[1], u[3]);
r[4] = _mm512_unpacklo_epi64(u[4], u[6]);
r[5] = _mm512_unpackhi_epi64(u[4], u[6]);
r[6] = _mm512_unpacklo_epi64(u[5], u[7]);
r[7] = _mm512_unpackhi_epi64(u[5], u[7]);
__m512i const1 = _mm512_set_epi32(
0x00370035,
0x00330031,
0x00270025,
0x00230021,
0x00170015,
0x00130011,
0x00070005,
0x00030001,
0x00360034,
0x00320030,
0x00260024,
0x00220020,
0x00160014,
0x00120010,
0x00060004,
0x00020000);
__m512i const2 = _mm512_set_epi32(
0x003f003d,
0x003b0039,
0x002f002d,
0x002b0029,
0x001f001d,
0x001b0019,
0x000f000d,
0x000b0009,
0x003e003c,
0x003a0038,
0x002e002c,
0x002a0028,
0x001e001c,
0x001a0018,
0x000e000c,
0x000a0008);
// merge values from two regs
// 0-- 1--
// 8-- 9--
// 2-- 3--
// 10-- 11--
// 4-- 5--
// 12-- 13--
// 6-- 7--
// 14-- 15--
#ifndef __msvc_cl__
#pragma unroll(4)
#endif
for (int i = 0; i < 4; i++) {
u[i] = _mm512_permutex2var_epi16(r[i], const1, r[i + 4]);
u[i + 4] = _mm512_permutex2var_epi16(r[i], const2, r[i + 4]);
}
}
// TODO(Leslie): Add the AVX2 Version of transpose_mxn for BFloat16 and Float16
// Code referred to FBGEMM:
// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607
template<>
inline void transpose_mxn<BFloat16, 16, 16>(
const BFloat16* src,
int64_t ld_src,
BFloat16* dst,
int64_t ld_dst) {
__m256i t[16];
// load from src to registers
// a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15
// b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15
// c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15
// d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15
// e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15
// f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15
// g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15
// h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15
// i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15
// j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15
// k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15
// l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15
// m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15
// n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15
// o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15
// p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15
#ifndef __msvc_cl__
#pragma unroll(16)
#endif
for (int i = 0; i < 16; i++) {
t[i] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i * ld_src));
}
__m512i u[8];
_transpose_mxn_half_16_16(t, u);
#ifndef __msvc_cl__
#pragma unroll(8)
#endif
for (int i = 0; i < 8; i++) {
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst),
_mm512_extracti32x8_epi32(u[i], 0x0));
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst),
_mm512_extracti32x8_epi32(u[i], 0x01));
}
}
// Code referred to FBGEMM:
// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607
template<>
inline void transpose_mxn<Half, 16, 16>(
const Half* src,
int64_t ld_src,
Half* dst,
int64_t ld_dst) {
__m256i t[16];
// load from src to registers
// Same matrix indices as above transpose_mxn<BFloat16, 16, 16>
#ifndef __msvc_cl__
#pragma unroll(16)
#endif
for (int i = 0; i < 16; i++) {
t[i] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i * ld_src));
}
__m512i u[8];
_transpose_mxn_half_16_16(t, u);
#ifndef __msvc_cl__
#pragma unroll(8)
#endif
for (int i = 0; i < 8; i++) {
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst),
_mm512_extracti32x8_epi32(u[i], 0x0));
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst),
_mm512_extracti32x8_epi32(u[i], 0x01));
}
}
static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) {
// t[0]: 0 32 1 33 2 34 3 35 8 40 9 41 10 42 11 43 16 ... 59
// t[1]: 4 36 5 37 6 38 7 39 12 44 13 45 14 46 15 47 20 ... 63
// t[2]: 64 96 65 97 66 98 67 99 72 104 73 105 74 106 75 ... 123
// t[3]: 68 100 69 101 70 102 71 103 76 108 77 109 78 110 79 111 84 ... 127
// t[4]: 128 160 129 161 130 162 131 163 136 168 137 169 138 170 139 171 144 ... 187
// t[5]: 132 164 133 165 134 166 135 167 140 172 141 173 142 174 143 175 148 ... 191
// t[6]: 192 224 193 225 194 226 195 227 200 232 201 233 202 234 203 235 208 ... 251
// t[7]: 196 228 197 229 198 230 199 231 204 236 205 237 206 238 207 239 212 ... 255
// t[8]: 256 288 257 289 258 290 259 291 264 296 265 297 266 298 267 299 272 ... 315
// t[9]: 260 292 261 293 262 294 263 295 268 300 269 301 270 302 271 303 276 ... 319
// t[10]: 320 352 321 353 322 354 323 355 328 360 329 361 330 362 331 363 336 ... 379
// t[11]: 324 356 325 357 326 358 327 359 332 364 333 365 334 366 335 367 340 ... 383
// t[12]: 384 416 385 417 386 418 387 419 392 424 393 425 394 426 395 427 400 ... 443
// t[13]: 388 420 389 421 390 422 391 423 396 428 397 429 398 430 399 431 404 ... 447
// t[14]: 448 480 449 481 450 482 451 483 456 488 457 489 458 490 459 491 464 ... 507
// t[15]: 452 484 453 485 454 486 455 487 460 492 461 493 462 494 463 495 468 ... 511
// t[16]: 512 544 513 545 514 546 515 547 520 552 521 553 522 554 523 555 528 ... 571
// ...
// t[31]: 964 996 965 997 966 998 967 999 972 1004 973 1005 974 1006 975 1007 980 ... 1023
#ifndef __msvc_cl__
#pragma unroll(16)
#endif
for (int i = 0; i < 16; ++i) {
d[i * 2] = _mm512_unpacklo_epi16(r[i * 2], r[i * 2 + 1]);
d[i * 2 + 1] = _mm512_unpackhi_epi16(r[i * 2], r[i * 2 + 1]);
}
// t[0]: 0 32 64 96 1 33 65 97 8 40 72 104 9 41 73 105 16 ... 121
// t[1]: 2 34 66 98 3 35 67 99 10 42 74 106 11 43 75 107 18 ... 123
// t[2]: 4 36 68 100 5 37 69 101 12 44 76 108 13 45 77 109 20 ... 125
// t[3]: 6 38 70 102 7 39 71 103 14 46 78 110 15 47 79 111 22 ... 127
// t[4]: 128 160 192 224 129 161 193 225 136 168 200 232 137 169 201 233 144 ... 249
// t[5]: 130 162 194 226 131 163 195 227 138 170 202 234 139 171 203 235 146 ... 251
// t[6]: 132 164 196 228 133 165 197 229 140 172 204 236 141 173 205 237 148 ... 253
// t[7]: 134 166 198 230 135 167 199 231 142 174 206 238 143 175 207 239 150 ... 255
// t[8]: 256 288 320 352 257 289 321 353 264 296 328 360 265 297 329 361 272 ... 377
// t[9]: 258 290 322 354 259 291 323 355 266 298 330 362 267 299 331 363 274 ... 379
// t[10]: 260 292 324 356 261 293 325 357 268 300 332 364 269 301 333 365 276 ... 381
// t[11]: 262 294 326 358 263 295 327 359 270 302 334 366 271 303 335 367 278 ... 383
// t[12]: 384 416 448 480 385 417 449 481 392 424 456 488 393 425 457 489 400 ... 505
// t[13]: 386 418 450 482 387 419 451 483 394 426 458 490 395 427 459 491 402 ... 507
// t[14]: 388 420 452 484 389 421 453 485 396 428 460 492 397 429 461 493 404 ... 509
// t[15]: 390 422 454 486 391 423 455 487 398 430 462 494 399 431 463 495 406 ... 511
// t[16]: 512 544 576 608 513 545 577 609 520 552 584 616 521 553 585 617 528 ... 633
// ...
// t[31]: 902 934 966 998 903 935 967 999 910 942 974 1006 911 943 975 1007 918 ... 1023
#ifndef __msvc_cl__
#pragma unroll(8)
#endif
for (int i = 0; i < 8; ++i) {
r[i * 4] = _mm512_unpacklo_epi32(d[i * 4], d[i * 4 + 2]);
r[i * 4 + 1] = _mm512_unpackhi_epi32(d[i * 4], d[i * 4 + 2]);
r[i * 4 + 2] = _mm512_unpacklo_epi32(d[i * 4 + 1], d[i * 4 + 3]);
r[i * 4 + 3] = _mm512_unpackhi_epi32(d[i * 4 + 1], d[i * 4 + 3]);
}
// t[0]: 0 32 64 96 128 160 192 224 8 40 72 104 136 168 200 232 16 ... 248
// t[1]: 1 33 65 97 129 161 193 225 9 41 73 105 137 169 201 233 17 ... 249
// t[2]: 2 34 66 98 130 162 194 226 10 42 74 106 138 170 202 234 18 ... 250
// t[3]: 3 35 67 99 131 163 195 227 11 43 75 107 139 171 203 235 19 ... 251
// t[4]: 4 36 68 100 132 164 196 228 12 44 76 108 140 172 204 236 20 ... 252
// t[5]: 5 37 69 101 133 165 197 229 13 45 77 109 141 173 205 237 21 ... 253
// t[6]: 6 38 70 102 134 166 198 230 14 46 78 110 142 174 206 238 22 ... 254
// t[7]: 7 39 71 103 135 167 199 231 15 47 79 111 143 175 207 239 23 ... 255
// t[8]: 256 288 320 352 384 416 448 480 264 296 328 360 392 424 456 488 272 ... 504
// t[9]: 257 289 321 353 385 417 449 481 265 297 329 361 393 425 457 489 273 ... 505
// t[10]: 258 290 322 354 386 418 450 482 266 298 330 362 394 426 458 490 274 ... 506
// t[11]: 259 291 323 355 387 419 451 483 267 299 331 363 395 427 459 491 275 ... 507
// t[12]: 260 292 324 356 388 420 452 484 268 300 332 364 396 428 460 492 276 ... 508
// t[13]: 261 293 325 357 389 421 453 485 269 301 333 365 397 429 461 493 277 ... 509
// t[14]: 262 294 326 358 390 422 454 486 270 302 334 366 398 430 462 494 278 ... 510
// t[15]: 263 295 327 359 391 423 455 487 271 303 335 367 399 431 463 495 279 ... 511
// t[16]: 512 544 576 608 640 672 704 736 520 552 584 616 648 680 712 744 528 ... 760
// ...
// t[31]: 775 807 839 871 903 935 967 999 783 815 847 879 911 943 975 1007 791 ... 1023
#ifndef __msvc_cl__
#pragma unroll(4)
#endif
for (int i = 0; i < 4; ++i) {
d[i * 8] = _mm512_unpacklo_epi64(r[i * 8], r[i * 8 + 4]);
d[i * 8 + 1] = _mm512_unpackhi_epi64(r[i * 8], r[i * 8 + 4]);
d[i * 8 + 2] = _mm512_unpacklo_epi64(r[i * 8 + 1], r[i * 8 + 5]);
d[i * 8 + 3] = _mm512_unpackhi_epi64(r[i * 8 + 1], r[i * 8 + 5]);
d[i * 8 + 4] = _mm512_unpacklo_epi64(r[i * 8 + 2], r[i * 8 + 6]);
d[i * 8 + 5] = _mm512_unpackhi_epi64(r[i * 8 + 2], r[i * 8 + 6]);
d[i * 8 + 6] = _mm512_unpacklo_epi64(r[i * 8 + 3], r[i * 8 + 7]);
d[i * 8 + 7] = _mm512_unpackhi_epi64(r[i * 8 + 3], r[i * 8 + 7]);
}
// t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 16 ... 496
// t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 17 ... 497
// t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 18 ... 498
// t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 19 ... 499
// t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 20 ... 500
// t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 21 ... 501
// t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 22 ... 502
// t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 487 23 ... 503
// t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 456 488 24 ... 504
// t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 25 ... 505
// t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 490 26 ... 506
// t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 395 427 459 491 27 ... 507
// t[12]: 12 44 76 108 140 172 204 236 268 300 332 364 396 428 460 492 28 ... 508
// t[13]: 13 45 77 109 141 173 205 237 269 301 333 365 397 429 461 493 29 ... 509
// t[14]: 14 46 78 110 142 174 206 238 270 302 334 366 398 430 462 494 30 ... 510
// t[15]: 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 31 ... 511
// t[16]: 512 544 576 608 640 672 704 736 768 800 832 864 896 928 960 992 528 ... 1008
// ...
// t[31]: 527 559 591 623 655 687 719 751 783 815 847 879 911 943 975 1007 543 ... 1023
__m512i const1 = _mm512_set_epi64(
0x000000000000000d,
0x000000000000000c,
0x0000000000000005,
0x0000000000000004,
0x0000000000000009,
0x0000000000000008,
0x0000000000000001,
0x0000000000000000);
__m512i const2 = _mm512_set_epi64(
0x000000000000000f,
0x000000000000000e,
0x0000000000000007,
0x0000000000000006,
0x000000000000000b,
0x000000000000000a,
0x0000000000000003,
0x0000000000000002);
#ifndef __msvc_cl__
#pragma unroll(8)
#endif
for (int i = 0; i < 8; ++i) {
r[i] = _mm512_permutex2var_epi64(d[i], /*idx*/const1, d[i + 8]);
r[i + 8] = _mm512_permutex2var_epi64(d[i], /*idx*/const2, d[i + 8]);
r[i + 16] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/const1, d[i + 24]);
r[i + 24] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/const2, d[i + 24]);
}
// t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 512 544 ... 992
// t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 513 545 ... 993
// t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 514 546 ... 994
// t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 515 547 ... 995
// t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 516 548 ... 996
// t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 517 549 ... 997
// t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 518 550 ... 998
// t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 487 519 551 ... 999
// t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 456 488 520 552 ... 1000
// t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 521 553 ... 1001
// t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 490 522 554 ... 1002
// t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 395 427 459 491 523 555 ... 1003
// t[12]: 12 44 76 108 140 172 204 236 268 300 332 364 396 428 460 492 524 556 ... 1004
// t[13]: 13 45 77 109 141 173 205 237 269 301 333 365 397 429 461 493 525 557 ... 1005
// t[14]: 14 46 78 110 142 174 206 238 270 302 334 366 398 430 462 494 526 558 ... 1006
// t[15]: 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 527 559 ... 1007
// t[16]: 16 48 80 112 144 176 208 240 272 304 336 368 400 432 464 496 528 560 ... 1008
// ...
// t[31]: 31 63 95 127 159 191 223 255 287 319 351 383 415 447 479 511 543 575 ... 1023
__m512i const3 = _mm512_set_epi64(
0x000000000000000b,
0x000000000000000a,
0x0000000000000009,
0x0000000000000008,
0x0000000000000003,
0x0000000000000002,
0x0000000000000001,
0x0000000000000000);
__m512i const4 = _mm512_set_epi64(
0x000000000000000f,
0x000000000000000e,
0x000000000000000d,
0x000000000000000c,
0x0000000000000007,
0x0000000000000006,
0x0000000000000005,
0x0000000000000004);
#ifndef __msvc_cl__
#pragma unroll(16)
#endif
for (int i = 0; i < 16; ++i) {
d[i] = _mm512_permutex2var_epi64(r[i], /*idx*/const3, r[i + 16]);
d[i + 16] = _mm512_permutex2var_epi64(r[i], /*idx*/const4, r[i + 16]);
}
}
// Code referred to FBGEMM:
// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6
template<>
inline void transpose_mxn<BFloat16>(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst, int M, int N) {
// load from src
TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn<BFloat16> expects M, N <= 32.");
__m512i r[32];
int i;
if (N == 32) {
for (i = 0; i < M; ++i) {
r[i] = _mm512_loadu_si512(&src[i * ld_src]);
}
} else {
__mmask32 src_mask = (1 << N) - 1;
for (i = 0; i < M; ++i) {
r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]);
}
}
for (; i < 32; ++i) {
r[i] = _mm512_setzero_si512();
}
__m512i d[32];
_transpose_mxn_half_32_32(r, d);
// store to dst
if (M == 32) {
for (i = 0; i < N; ++i) {
_mm512_storeu_si512(&dst[i * ld_dst], d[i]);
}
} else {
__mmask32 dst_mask = (1 << M) - 1;
for (i = 0; i < N; ++i) {
_mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]);
}
}
}
template <typename T, int M, int N,
typename std::enable_if_t<std::is_same_v<T, BFloat16> && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0>
inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst) {
transpose_mxn<BFloat16>(src, ld_src, dst, ld_dst, M, N);
}
template<>
inline void transpose_mxn<Half>(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst, int M, int N) {
TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn<Half> expects M, N <= 32.");
// load from src
__m512i r[32];
int i;
if (N == 32) {
for (i = 0; i < M; ++i) {
r[i] = _mm512_loadu_si512(&src[i * ld_src]);
}
} else {
__mmask32 src_mask = (1 << N) - 1;
for (i = 0; i < M; ++i) {
r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]);
}
}
for (; i < 32; ++i) {
r[i] = _mm512_setzero_si512();
}
__m512i d[32];
_transpose_mxn_half_32_32(r, d);
// store to dst
if (M == 32) {
for (i = 0; i < N; ++i) {
_mm512_storeu_si512(&dst[i * ld_dst], d[i]);
}
} else {
__mmask32 dst_mask = (1 << M) - 1;
for (i = 0; i < N; ++i) {
_mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]);
}
}
}
template <typename T, int M, int N,
typename std::enable_if_t<std::is_same_v<T, Half> && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0>
inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) {
transpose_mxn<Half>(src, ld_src, dst, ld_dst, M, N);
}
template <>
class Vectorized<Half>: public Vectorized16<Half> {
public:
using Vectorized16::Vectorized16;
using value_type = Half;
Vectorized<Half> frac() const;
Vectorized<Half> eq(const Vectorized<Half>& other) const;
Vectorized<Half> ne(const Vectorized<Half>& other) const;
Vectorized<Half> gt(const Vectorized<Half>& other) const;
Vectorized<Half> ge(const Vectorized<Half>& other) const;
Vectorized<Half> lt(const Vectorized<Half>& other) const;
Vectorized<Half> le(const Vectorized<Half>& other) const;
};
Vectorized<Half> inline operator+(const Vectorized<Half>& a, const Vectorized<Half>& b) {
return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
}
Vectorized<Half> inline operator-(const Vectorized<Half>& a, const Vectorized<Half>& b) {
return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
}
Vectorized<Half> inline operator*(const Vectorized<Half>& a, const Vectorized<Half>& b) {
return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
}
Vectorized<Half> inline operator/(const Vectorized<Half>& a, const Vectorized<Half>& b) {
return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
}
Vectorized<Half> inline operator&(const Vectorized<Half>& a, const Vectorized<Half>& b) {
return _mm512_and_si512(a, b);
}
Vectorized<Half> inline operator|(const Vectorized<Half>& a, const Vectorized<Half>& b) {
return _mm512_or_si512(a, b);
}
Vectorized<Half> inline operator^(const Vectorized<Half>& a, const Vectorized<Half>& b) {
return _mm512_xor_si512(a, b);
}
inline Vectorized<Half> Vectorized<Half>::eq(const Vectorized<Half>& other) const {
return (*this == other) & Vectorized<Half>(1.0f);
}
inline Vectorized<Half> Vectorized<Half>::ne(const Vectorized<Half>& other) const {
return (*this != other) & Vectorized<Half>(1.0f);
}
inline Vectorized<Half> Vectorized<Half>::gt(const Vectorized<Half>& other) const {
return (*this > other) & Vectorized<Half>(1.0f);
}
inline Vectorized<Half> Vectorized<Half>::ge(const Vectorized<Half>& other) const {
return (*this >= other) & Vectorized<Half>(1.0f);
}
inline Vectorized<Half> Vectorized<Half>::lt(const Vectorized<Half>& other) const {
return (*this < other) & Vectorized<Half>(1.0f);
}
inline Vectorized<Half> Vectorized<Half>::le(const Vectorized<Half>& other) const {
return (*this <= other) & Vectorized<Half>(1.0f);
}
// frac. Implement this here so we can use subtraction
inline Vectorized<Half> Vectorized<Half>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<Half> inline maximum(const Vectorized<Half>& a, const Vectorized<Half>& b) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
cvtfp16_fp32(__m512i(a), a_lo, a_hi);
cvtfp16_fp32(__m512i(b), b_lo, b_hi);
auto max_lo = _mm512_max_ps(a_lo, b_lo);
auto max_hi = _mm512_max_ps(a_hi, b_hi);
auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
// Exploit the fact that all-ones is a NaN.
auto o1 = _mm512_or_ps(max_lo, nan_lo);
auto o2 = _mm512_or_ps(max_hi, nan_hi);
return cvtfp32_fp16(o1, o2);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<Half> inline minimum(const Vectorized<Half>& a, const Vectorized<Half>& b) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
__m512i zero_vec = _mm512_set1_epi32(0);
cvtfp16_fp32(__m512i(a), a_lo, a_hi);
cvtfp16_fp32(__m512i(b), b_lo, b_hi);
auto min_lo = _mm512_min_ps(a_lo, b_lo);
auto min_hi = _mm512_min_ps(a_hi, b_hi);
auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
0xFFFFFFFF));
auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
0xFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
auto o1 = _mm512_or_ps(min_lo, nan_lo);
auto o2 = _mm512_or_ps(min_hi, nan_hi);
return cvtfp32_fp16(o1, o2);
}
template <>
Vectorized<Half> inline clamp(const Vectorized<Half>& a,
const Vectorized<Half>& min, const Vectorized<Half>& max) {
__m512 a_lo, a_hi;
__m512 min_lo, min_hi;
__m512 max_lo, max_hi;
cvtfp16_fp32(__m512i(a), a_lo, a_hi);
cvtfp16_fp32(__m512i(min), min_lo, min_hi);
cvtfp16_fp32(__m512i(max), max_lo, max_hi);
auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
return cvtfp32_fp16(o1, o2);
}
template <>
Vectorized<Half> inline clamp_max(const Vectorized<Half>& a, const Vectorized<Half>& max) {
__m512 a_lo, a_hi;
__m512 max_lo, max_hi;
cvtfp16_fp32(__m512i(a), a_lo, a_hi);
cvtfp16_fp32(__m512i(max), max_lo, max_hi);
auto o1 = _mm512_min_ps(max_lo, a_lo);
auto o2 = _mm512_min_ps(max_hi, a_hi);
return cvtfp32_fp16(o1, o2);
}
template <>
Vectorized<Half> inline clamp_min(const Vectorized<Half>& a, const Vectorized<Half>& min) {
__m512 a_lo, a_hi;
__m512 min_lo, min_hi;
cvtfp16_fp32(__m512i(a), a_lo, a_hi);
cvtfp16_fp32(__m512i(min), min_lo, min_hi);
auto o1 = _mm512_max_ps(min_lo, a_lo);
auto o2 = _mm512_max_ps(min_hi, a_hi);
return cvtfp32_fp16(o1, o2);
}
template <>
inline void convert(const Half* src, Half* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<Half>::size()); i += Vectorized<Half>::size()) {
auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
_mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = src[i];
}
}
template <>
inline void convert(const float* src, Half* dst, int64_t n) {
int64_t i;
for (i = 0; i + Vectorized<Half>::size() <= n; i += Vectorized<Half>::size()) {
__m512 a = _mm512_loadu_ps(&src[i]);
__m512 b = _mm512_loadu_ps(&src[i + 16]);
__m512i bf = cvtfp32_fp16(a, b);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
}
for (; i < n; i++) {
dst[i] = c10::convert<Half>(src[i]);
}
}
template <>
inline void convert(const double* src, Half* dst, int64_t n) {
auto load_float = [](const double *src) -> __m512 {
// Load one float vector from an array of doubles
__m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src));
__m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8));
return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
};
int64_t i;
for (i = 0; i + Vectorized<Half>::size() <= n; i += Vectorized<Half>::size()) {
__m512 a = load_float(&src[i]);
__m512 b = load_float(&src[i + 16]);
__m512i bf = cvtfp32_fp16(a, b);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
}
for (; i < n; i++) {
dst[i] = c10::convert<Half>(src[i]);
}
}
template <>
Vectorized<Half> inline fmadd(const Vectorized<Half>& a,
const Vectorized<Half>& b, const Vectorized<Half>& c) {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
__m512 c_lo, c_hi;
cvtfp16_fp32(__m512i(a), a_lo, a_hi);
cvtfp16_fp32(__m512i(b), b_lo, b_hi);
cvtfp16_fp32(__m512i(c), c_lo, c_hi);
auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
return cvtfp32_fp16(o1, o2);
}
#define CONVERT_VECTORIZED_INIT(type, name) \
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
__m512 o1, o2; \
cvt_to_fp32<type>(__m512i(a), o1, o2); \
return std::make_tuple(o1, o2); \
} \
\
inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
return cvt_from_fp32<type>(__m512(a), __m512(b)); \
}
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_VECTORIZED_INIT(Half, half);
#else //defined(CPU_CAPABILITY_AVX512)
#define CONVERT_NON_VECTORIZED_INIT(type, name) \
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
constexpr int64_t K = Vectorized<type>::size(); \
__at_align__ float arr[K]; \
__at_align__ type arr2[K]; \
a.store(arr2); \
for (const auto k : c10::irange(K)) { \
arr[k] = c10::convert<float>(arr2[k]); \
} \
return std::make_tuple( \
Vectorized<float>::loadu(arr), \
Vectorized<float>::loadu(arr + Vectorized<float>::size())); \
} \
\
inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
constexpr int64_t K = Vectorized<type>::size(); \
__at_align__ float arr[K]; \
__at_align__ type arr2[K]; \
a.store(arr); \
b.store(arr + Vectorized<float>::size()); \
for (const auto k : c10::irange(K)) { \
arr2[k] = c10::convert<type>(arr[k]); \
} \
return Vectorized<type>::loadu(arr2); \
}
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_NON_VECTORIZED_INIT(Half, half);
#endif // defined(CPU_CAPABILITY_AVX512)
#if defined(CPU_CAPABILITY_AVX512)
#define LOAD_FP32_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
auto values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data)); \
__m512 out_values; \
cvt_to_fp32<type>(values, out_values); \
out = out_values; \
} \
\
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vectorized<float>& out2) { \
auto vec = Vectorized<type>::loadu(data); \
__m512 out1_values, out2_values; \
cvt_to_fp32<type>(vec, out1_values, out2_values); \
out1 = out1_values; \
out2 = out2_values; \
}
LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16)
LOAD_FP32_VECTORIZED_INIT(Half, fp16)
#else // defined(CPU_CAPABILITY_AVX512)
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
__at_align__ float values[Vectorized<float>::size()]; \
for (const auto k : c10::irange(Vectorized<float>::size())) { \
values[k] = data[k]; \
} \
out = Vectorized<float>::loadu(values); \
} \
\
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vectorized<float>& out2) { \
load_fp32_from_##name(data, out1); \
data += Vectorized<float>::size(); \
load_fp32_from_##name(data, out2); \
}
LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16);
LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16);
#endif
}}}