mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement fast exp for AVX2 and AVX512 for the flash attention (#151441)
**Implement fexp for avx2 and avx512** Cristiano and all propose a clever exp using the IEEE representation with a fine control of the precision, especially useful for mix computation of the flash attention. - Implement Fast Exponential Computation on SIMD Architectures A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro Curioni - AVX2 and AVX512 float only, up to 20% faster for mix precision flash attention than the current implementation. - For the other types legacy implementation. **Precision** 1 ULP only valid in hybrid mode fp32 -> f16 due to the cast during the store operation in the flash attention: **Benchmark** Machine Xeon 6972P, results in TOPs, Python forward pass flash attention numhead 16, Head dimension 64 |Seq. L.| PT | fexp | |-------|------|------| | 512 | 0.8 | 1.3 | | 1024 | 1.7 | 1.7 | | 2048 | 6 | 6.1 | | 4096 | 16 | 16.8 | | 8192 | 30.6 | 32.3 | | 16384 | 40 | 40.8 | | 32768 | 44.9 | 51.4 | | 65536 | 45.8 | 54.4 | numhead 16, Head dimension 128 |Seq. L.| PT | fexp | |-------|------|------| | 512 | 2.5 | 4.1 | | 1024 | 3.3 | 4 | | 2048 | 11.4 | 10.5 | | 4096 | 27.4 | 28.4 | | 8192 | 44.4 | 46 | | 16384 | 64.2 | 68.1 | | 32768 | 77.8 | 83 | | 65536 | 82.1 | 88.1 | numhead 16, Head dimension 256 |Seq. L.| PT | fexp | |-------|------|------| | 512 | 1.7 | 3.4 | | 1024 | 4.2 | 6.5 | | 2048 | 14.6 | 16.1 | | 4096 | 30.1 | 31.1 | | 8192 | 60 | 62 | | 16384 | 83.3 | 87.3 | | 32768 | 98.7 | 106 | | 65536 | 102.2| 107.1| Pull Request resolved: https://github.com/pytorch/pytorch/pull/151441 Approved by: https://github.com/mingfeima
This commit is contained in:
committed by
PyTorch MergeBot
parent
9222552572
commit
b7860c7863
@ -163,6 +163,9 @@ class Vectorized<BFloat16> {
|
||||
Vectorized<BFloat16> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<BFloat16> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<BFloat16> fmod(const Vectorized<BFloat16>& q) const;
|
||||
Vectorized<BFloat16> hypot(const Vectorized<BFloat16>& b) const;
|
||||
Vectorized<BFloat16> i0() const;
|
||||
|
@ -249,6 +249,9 @@ class Vectorized<double> {
|
||||
Vectorized<double> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
|
||||
{ return Vectorized<double>(Sleef_fmoddx_sve(values, q)); },
|
||||
{
|
||||
|
@ -314,6 +314,9 @@ class Vectorized<float> {
|
||||
Vectorized<float> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
|
||||
{
|
||||
|
@ -308,6 +308,9 @@ class Vectorized<float> {
|
||||
Vectorized<float> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
|
||||
fmod,
|
||||
Sleef_fmodf4)
|
||||
|
@ -206,6 +206,10 @@ struct Vectorized16 {
|
||||
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
||||
&Vectorized<float>::exp_u20);
|
||||
}
|
||||
Derived fexp_u20() const {
|
||||
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
||||
&Vectorized<float>::exp_u20);
|
||||
}
|
||||
Derived fmod(const Derived& q) const {
|
||||
// This function is questionable with a conversion, so we use map2
|
||||
return map2(q, std::fmod);
|
||||
|
@ -488,6 +488,9 @@ class Vectorized16 {
|
||||
Vectorized<T> expm1() const {
|
||||
return map(Sleef_expm1f8_u10);
|
||||
}
|
||||
Vectorized<T> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<T> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
|
@ -198,6 +198,9 @@ class Vectorized<double> {
|
||||
Vectorized<double> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> fmod(const Vectorized<double>& q) const {
|
||||
return Vectorized<double>(Sleef_fmodd4(values, q));
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||
// See Note [Do not compile initializers with AVX]
|
||||
|
||||
@ -256,6 +255,63 @@ class Vectorized<float> {
|
||||
Vectorized<float> expm1() const {
|
||||
return Vectorized<float>(Sleef_expm1f8_u10(values));
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
const __m256 vec_c0 = _mm256_set1_ps(0.00010703434948458272f);
|
||||
const __m256 vec_c1 = _mm256_set1_ps(0.30354260500649682f);
|
||||
const __m256 vec_c2 = _mm256_set1_ps(-0.22433836478672356);
|
||||
const __m256 vec_c3 = _mm256_set1_ps(-0.079204240219773236);
|
||||
|
||||
const __m256 vec_exp_log2ef =
|
||||
_mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e)
|
||||
|
||||
const __m256 vec_a = _mm256_set1_ps(std::pow(2, 23) / std::log2(2));
|
||||
const __m256 vec_b = _mm256_set1_ps(std::pow(2, 23) * 127.f);
|
||||
|
||||
const __m256 vec_ln_flt_min =
|
||||
_mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50));
|
||||
const __m256 vec_ln_flt_max =
|
||||
_mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218));
|
||||
const __m256 vec_inf = _mm256_set1_ps(INFINITY);
|
||||
const __m256 zero = _mm256_setzero_ps();
|
||||
|
||||
// exp(x) = 2**(x * log2(e))
|
||||
// = 2**xi * 2**xf - TIPS we are using the EEEE floating point
|
||||
// representation with identification to the exponent and the
|
||||
// mentissa
|
||||
// 2**xf will be approximated to a polynomial of degree 3 computed with
|
||||
// Horner method
|
||||
// compute the min/max for the mask
|
||||
// Masks
|
||||
__m256 mask_too_small =
|
||||
_mm256_cmp_ps(values, vec_ln_flt_min, _CMP_LT_OS); // x < min
|
||||
__m256 mask_too_large =
|
||||
_mm256_cmp_ps(values, vec_ln_flt_max, _CMP_GT_OS); // x > max
|
||||
|
||||
// transformation with log2(e)
|
||||
auto vec_src = _mm256_mul_ps(values, vec_exp_log2ef);
|
||||
auto vec_fractional = _mm256_sub_ps(vec_src, _mm256_floor_ps(vec_src));
|
||||
|
||||
// compute polynomial using Horner Scheme
|
||||
auto vec_res = _mm256_fmadd_ps(vec_fractional, vec_c3, vec_c2);
|
||||
vec_res = _mm256_fmadd_ps(vec_fractional, vec_res, vec_c1);
|
||||
vec_res = _mm256_fmadd_ps(vec_fractional, vec_res, vec_c0);
|
||||
|
||||
vec_src = _mm256_sub_ps(vec_src, vec_res);
|
||||
// // the tips is here, headache in perspective
|
||||
auto tmp = _mm256_fmadd_ps(vec_a, vec_src, vec_b);
|
||||
// headache bis
|
||||
__m256i casted_integer = _mm256_cvttps_epi32(tmp);
|
||||
// bitwise to float for the final transformation
|
||||
auto result = _mm256_castsi256_ps(casted_integer);
|
||||
// boundary condition
|
||||
// Set to 0 where x < ln(FLT_MIN)
|
||||
result = _mm256_blendv_ps(result, zero, mask_too_small);
|
||||
// Set to +inf where x > ln(FLT_MAX)
|
||||
result = _mm256_blendv_ps(result, vec_inf, mask_too_large);
|
||||
// final interpretation to float
|
||||
return result;
|
||||
}
|
||||
|
||||
Vectorized<float> exp_u20() const {
|
||||
// A faster version of exp with ULP=20
|
||||
const __m256 vec_factorial_1 =
|
||||
|
@ -273,6 +273,9 @@ class Vectorized<double> {
|
||||
Vectorized<double> C10_ALWAYS_INLINE exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> C10_ALWAYS_INLINE fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
|
||||
Vectorized<double> lgamma() const __ubsan_ignore_undefined__ {
|
||||
return {Sleef_lgammad2_u10(_vec0), Sleef_lgammad2_u10(_vec1)};
|
||||
|
@ -352,6 +352,9 @@ class Vectorized<float> {
|
||||
Vectorized<float> C10_ALWAYS_INLINE exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<float> C10_ALWAYS_INLINE fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
|
||||
Vectorized<float> C10_ALWAYS_INLINE log() const {
|
||||
return {Sleef_logf4_u10(_vec0), Sleef_logf4_u10(_vec1)};
|
||||
|
@ -1023,6 +1023,9 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
|
||||
Vectorized<T> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<T> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
|
||||
Vectorized<T> log() const {
|
||||
return mapSleef(Sleef_logf4_u10, Sleef_logd2_u10);
|
||||
|
@ -535,6 +535,9 @@ class Vectorized16 {
|
||||
Vectorized<T> expm1() const {
|
||||
return map(Sleef_expm1f16_u10);
|
||||
}
|
||||
Vectorized<T> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<T> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
|
@ -221,6 +221,9 @@ class Vectorized<double> {
|
||||
Vectorized<double> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> fexp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<double> fmod(const Vectorized<double>& q) const {
|
||||
return Vectorized<double>(Sleef_fmodd8(values, q));
|
||||
}
|
||||
|
@ -310,6 +310,60 @@ class Vectorized<float> {
|
||||
Vectorized<float> expm1() const {
|
||||
return Vectorized<float>(Sleef_expm1f16_u10(values));
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
const __m512 vec_c0 = _mm512_set1_ps(0.00010703434948458272f);
|
||||
const __m512 vec_c1 = _mm512_set1_ps(0.30354260500649682f);
|
||||
const __m512 vec_c2 = _mm512_set1_ps(-0.22433836478672356);
|
||||
const __m512 vec_c3 = _mm512_set1_ps(-0.079204240219773236);
|
||||
|
||||
const __m512 vec_exp_log2ef =
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e)
|
||||
|
||||
const __m512 vec_a = _mm512_set1_ps(std::pow(2, 23) / std::log2(2));
|
||||
const __m512 vec_b = _mm512_set1_ps(std::pow(2, 23) * 127.f);
|
||||
|
||||
const __m512 vec_ln_flt_min =
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50));
|
||||
const __m512 vec_ln_flt_max =
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218));
|
||||
__m512i vec_infinity = _mm512_set1_epi32(0x7F800000);
|
||||
__m512i vec_zero = _mm512_setzero_epi32();
|
||||
|
||||
// Fast Exponential Computation on SIMD Architectures
|
||||
// A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro
|
||||
// Curioni exp(x) = 2**(x * log2(e))
|
||||
// = 2**xi * 2**xf - TIPS we are using the EEEE floating point
|
||||
// representation with identification to the exponent and the
|
||||
// mentissa
|
||||
// 2**xf will be approximated to a polynomial of degree 3 computed with
|
||||
// Horner method
|
||||
// mask for the boundary condition
|
||||
auto min_mask = _mm512_cmp_ps_mask(values, vec_ln_flt_min, _CMP_LT_OS);
|
||||
auto max_mask = _mm512_cmp_ps_mask(values, vec_ln_flt_max, _CMP_GT_OS);
|
||||
|
||||
// transformation with log2(e)
|
||||
auto vec_src = _mm512_mul_ps(values, vec_exp_log2ef);
|
||||
auto vec_fractional = _mm512_sub_ps(vec_src, _mm512_floor_ps(vec_src));
|
||||
|
||||
// compute polynomial using Horner Scheme, for superscalar processor
|
||||
auto vec_res = _mm512_fmadd_ps(vec_fractional, vec_c3, vec_c2);
|
||||
vec_res = _mm512_fmadd_ps(vec_fractional, vec_res, vec_c1);
|
||||
vec_res = _mm512_fmadd_ps(vec_fractional, vec_res, vec_c0);
|
||||
|
||||
vec_src = _mm512_sub_ps(vec_src, vec_res);
|
||||
// the tips is here, headache in perspective
|
||||
auto tmp = _mm512_fmadd_ps(vec_a, vec_src, vec_b);
|
||||
// headache bis - we loose precision with the cast but it "fits", but ok
|
||||
// after f32 -> f16 later
|
||||
__m512i casted_integer = _mm512_cvttps_epi32(tmp);
|
||||
// boundary condition, lower than the min -> 0
|
||||
casted_integer = _mm512_mask_mov_epi32(casted_integer, min_mask, vec_zero);
|
||||
// boundary condition, larger than the max -> +oo
|
||||
casted_integer =
|
||||
_mm512_mask_mov_epi32(casted_integer, max_mask, vec_infinity);
|
||||
// final interpretation to float
|
||||
return _mm512_castsi512_ps(casted_integer);
|
||||
}
|
||||
Vectorized<float> exp_u20() const {
|
||||
// A faster version of exp with ULP=20
|
||||
const __m512 vec_factorial_1 =
|
||||
|
@ -547,6 +547,9 @@ struct Vectorized {
|
||||
Vectorized<T> exp_u20() const {
|
||||
return map(std::exp);
|
||||
}
|
||||
Vectorized<T> fexp_u20() const {
|
||||
return map(std::exp);
|
||||
}
|
||||
Vectorized<T> frac() const {
|
||||
return *this - this->trunc();
|
||||
}
|
||||
|
@ -263,6 +263,7 @@ class VectorizedN {
|
||||
VECTORIZEDN_DEFINE_UNARY_OP(exp2)
|
||||
VECTORIZEDN_DEFINE_UNARY_OP(expm1)
|
||||
VECTORIZEDN_DEFINE_UNARY_OP(exp_u20)
|
||||
VECTORIZEDN_DEFINE_UNARY_OP(fexp_u20)
|
||||
VECTORIZEDN_DEFINE_UNARY_OP(frac)
|
||||
VECTORIZEDN_DEFINE_BINARY_OP(fmod)
|
||||
VECTORIZEDN_DEFINE_UNARY_OP(log)
|
||||
|
@ -96,7 +96,14 @@ inline void _exp_reduce_sum_fusion_kernel(
|
||||
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
|
||||
auto tmp0 = vec::Vectorized<T1>::loadu(a + i);
|
||||
auto tmp1 = tmp0 - vec_max;
|
||||
auto tmp2 = tmp1.exp_u20();
|
||||
Vectorized<T1> tmp2;
|
||||
if constexpr (std::is_same_v<T1, float> &&
|
||||
(std::is_same_v<T2, at::BFloat16> || std::is_same_v<T2, at::Half>))
|
||||
{
|
||||
tmp2 = tmp1.fexp_u20();
|
||||
} else {
|
||||
tmp2 = tmp1.exp_u20();
|
||||
}
|
||||
vec_tmp_sum += tmp2;
|
||||
_store(out + i, tmp2);
|
||||
}
|
||||
|
Reference in New Issue
Block a user