mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[PyTorch] Pull ARM's box-cox (#164152)
Summary: ARM has provided with an SVE128 box-cox implementation. It uses the same underlying algorithm as the previous version, but it has better log and exp implementations. These supplied mathematical functions have switches to adjust the precision/speed trade-off. We've noted a slight precision improvement, while also about a 5% peroformance increase Before: ZeroLambda1 61.66ns 16.22M NonZeroLambda1 125.73ns 7.95M NonZeroLambdaManyColumns 1.84ms 542.11 NonZeroLambdaEigenColumnar 262.31us 3.81K NonZeroLambdaEigenRowMajor 275.17us 3.63K NonZeroLambdaWithPyTorchColumnar 97.43us 10.26K NonZeroLambdaWithPyTorchRowMajor 90.82us 11.01K NonZeroLambdaWithPyTorchRowMajorFullBatch 96.96us 10.31K NonZeroLambdaBatch 151.84us 6.59K After: ZeroLambda1 57.85ns 17.29M NonZeroLambda1 118.85ns 8.41M NonZeroLambdaManyColumns 1.82ms 548.16 NonZeroLambdaEigenColumnar 261.67us 3.82K NonZeroLambdaEigenRowMajor 274.53us 3.64K NonZeroLambdaWithPyTorchColumnar 89.12us 11.22K NonZeroLambdaWithPyTorchRowMajor 83.49us 11.98K NonZeroLambdaWithPyTorchRowMajorFullBatch 88.79us 11.26K NonZeroLambdaBatch 144.74us 6.91K Test Plan: Correctness: buck2 test @//mode/opt //koski/functions_contrib/df4ai/tests:batch_box_cox_test Performance: buck2 run @//mode/opt //koski/functions_contrib/df4ai/benchmark:boxcox_benchmark Differential Revision: D83485704 Privacy Context Container: L1196524 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164152 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
e901866dd7
commit
31681bcacc
@ -2,175 +2,126 @@
|
||||
#include <arm_neon.h>
|
||||
#include <arm_neon_sve_bridge.h>
|
||||
#include <arm_sve.h>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
|
||||
#include "c10/macros/Macros.h"
|
||||
|
||||
// Log and exp approximations inspired from ACL implementation
|
||||
/// Select `svlog` accuracy:
|
||||
/// - 0: original.
|
||||
/// - 1: more accurate, similar performance.
|
||||
/// - 2: very high accuracy, a bit lower speed.
|
||||
#define SVLOG_ACCURACY 2
|
||||
|
||||
inline float32x4_t vtaylor_polyq_for_log_f32(float32x4_t x) {
|
||||
const float32x4_t log_tab_1 = vdupq_n_f32(-2.29561495781f);
|
||||
const float32x4_t log_tab_2 = vdupq_n_f32(-2.47071170807f);
|
||||
const float32x4_t log_tab_3 = vdupq_n_f32(-5.68692588806f);
|
||||
const float32x4_t log_tab_4 = vdupq_n_f32(-0.165253549814f);
|
||||
const float32x4_t log_tab_5 = vdupq_n_f32(5.17591238022f);
|
||||
const float32x4_t log_tab_6 = vdupq_n_f32(0.844007015228f);
|
||||
const float32x4_t log_tab_7 = vdupq_n_f32(4.58445882797f);
|
||||
const float32x4_t log_tab_8 = vdupq_n_f32(0.0141278216615f);
|
||||
/// Handle special cases in `svexp`:
|
||||
/// - 0: original.
|
||||
/// - 1: use clamp, better performance.
|
||||
/// - 2: no special case handling.
|
||||
#define SVEXP_SPECIAL_CLAMP 1
|
||||
|
||||
float32x4_t A = vmlaq_f32(log_tab_1, log_tab_5, x);
|
||||
float32x4_t B = vmlaq_f32(log_tab_3, log_tab_7, x);
|
||||
float32x4_t C = vmlaq_f32(log_tab_2, log_tab_6, x);
|
||||
float32x4_t x2 = vmulq_f32(x, x);
|
||||
float32x4_t D = svget_neonq(svmad_f32_x(
|
||||
svptrue_b8(),
|
||||
svset_neonq(svundef_f32(), x),
|
||||
svset_neonq(svundef_f32(), log_tab_8),
|
||||
svset_neonq(svundef_f32(), log_tab_4)));
|
||||
float32x4_t x4 = vmulq_f32(x2, x2);
|
||||
float32x4_t res = vmlaq_f32(vmlaq_f32(A, B, x2), vmlaq_f32(C, D, x2), x4);
|
||||
return res;
|
||||
#if SVLOG_ACCURACY == 2
|
||||
static inline svfloat32_t svlog(svfloat32_t x) {
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
svint32_t u = svreinterpret_s32(x) - 0x3F2AAAAB;
|
||||
|
||||
svfloat32_t r = svreinterpret_f32((u & 0x007FFFFF) + 0x3F2AAAAB) - 1.0f;
|
||||
svfloat32_t n = svcvt_f32_x(ptrue, u >> 23);
|
||||
asm("" : "+w"(r)); // NOTE: can improve instruction scheduling.
|
||||
|
||||
svfloat32_t r2 = r * r;
|
||||
svfloat32_t p = -0x1.4F9934p-3f + r * 0x1.5A9AA2p-3f;
|
||||
svfloat32_t q = -0x1.00187Cp-2f + r * 0x1.961348p-3f;
|
||||
svfloat32_t y = -0x1.FFFFC8p-2f + r * 0x1.555D7Cp-2f;
|
||||
return (r + n * 0x1.62E43p-1f) +
|
||||
(y + (q + (p + -0x1.3E737Cp-3f * r2) * r2) * r2) * r2;
|
||||
}
|
||||
#elif SVLOG_ACCURACY == 1
|
||||
static inline svfloat32_t svlog(svfloat32_t x) {
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
inline float32x4_t vlogq_f32(float32x4_t x) {
|
||||
const float32x4_t CONST_LN2 = vdupq_n_f32(0.6931471805f); // ln(2)
|
||||
svint32_t u = svreinterpret_s32(x) - 0x3F2AAAAB;
|
||||
|
||||
// Extract exponent
|
||||
int32x4_t m = svget_neonq(svsub_n_s32_x(
|
||||
svptrue_b8(),
|
||||
svset_neonq(
|
||||
svundef_s32(),
|
||||
vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_f32(x), 23))),
|
||||
127));
|
||||
float32x4_t val = vreinterpretq_f32_s32(
|
||||
vsubq_s32(vreinterpretq_s32_f32(x), vshlq_n_s32(m, 23)));
|
||||
svfloat32_t r = svreinterpret_f32((u & 0x007FFFFF) + 0x3F2AAAAB) - 1.0f;
|
||||
svfloat32_t n = svcvt_f32_x(ptrue, u >> 23);
|
||||
asm("" : "+w"(r)); // NOTE: can improve instruction scheduling.
|
||||
|
||||
// Polynomial Approximation
|
||||
float32x4_t poly = vtaylor_polyq_for_log_f32(val);
|
||||
svfloat32_t r2 = r * r;
|
||||
svfloat32_t A = -0x1.923814p-3f + r * 0x1.689E5Ep-3f;
|
||||
svfloat32_t B = -0x1.FC0968p-3f + r * 0x1.93BF0Cp-3f;
|
||||
svfloat32_t C = -0x1.000478p-1f + r * 0x1.556906p-2f;
|
||||
|
||||
// Reconstruct
|
||||
poly = vmlaq_f32(poly, vcvtq_f32_s32(m), CONST_LN2);
|
||||
return (r + n * 0x1.62E43p-1f) + (C + (B + A * r2) * r2) * r2;
|
||||
}
|
||||
#elif SVLOG_ACCURACY == 0
|
||||
static inline svfloat32_t svlog(svfloat32_t x) {
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
svint32_t u = svsra_n_s32(svdup_n_s32(-127), svreinterpret_s32(x), 23);
|
||||
|
||||
svfloat32_t n = svcvt_f32_x(ptrue, u);
|
||||
svfloat32_t r = svreinterpret_f32(svreinterpret_s32(x) - (u << 23));
|
||||
|
||||
svfloat32_t D = -0.165253549814f + r * 0.0141278216615f;
|
||||
svfloat32_t C = -2.47071170807f + r * 0.844007015228f;
|
||||
svfloat32_t B = -5.68692588806f + r * 4.58445882797f;
|
||||
svfloat32_t A = -2.29561495781f + r * 5.17591238022f;
|
||||
|
||||
svfloat32_t r2 = r * r;
|
||||
return (A + n * 0.6931471805f) + (B + (C + D * r2) * r2) * r2;
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline svfloat32_t svexp(svfloat32_t x) {
|
||||
// Clamp interval set to prevent denormals!
|
||||
const svfloat32_t max_input = svdup_n_f32(88.722839f);
|
||||
const svfloat32_t min_input = svdup_n_f32(-87.33654f);
|
||||
const svfloat32_t shift = svdup_n_f32(0x1.0000FEp+23f);
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
#if SVEXP_SPECIAL_CLAMP == 1
|
||||
x = svmax_x(ptrue, svmin_x(ptrue, x, max_input), min_input);
|
||||
#endif
|
||||
|
||||
svfloat32_t z = svmla_n_f32_x(ptrue, shift, x, 0x1.715476p+0f);
|
||||
svfloat32_t n = z - shift;
|
||||
svfloat32_t scale = svreinterpret_f32(svreinterpret_u32(z) << 23);
|
||||
|
||||
svfloat32_t r_hi = x - n * 0x1.62E400p-1f;
|
||||
svfloat32_t r = r_hi - n * 0x1.7F7D1Cp-20f;
|
||||
svfloat32_t r2 = r * r;
|
||||
|
||||
svfloat32_t C = 0x1.573E2Ep-5f + r * 0x1.0E4020p-7f;
|
||||
svfloat32_t B = 0x1.FFFDB6p-2f + r * 0x1.555E66p-3f;
|
||||
svfloat32_t A = r * 0x1.FFFFECp-1f;
|
||||
|
||||
svfloat32_t poly = scale + (A + (B + C * r2) * r2) * scale;
|
||||
|
||||
#if SVEXP_SPECIAL_CLAMP == 0
|
||||
const svfloat32_t inf = svdup_n_f32(std::numeric_limits<float>::infinity());
|
||||
poly = svsel_f32(svcmplt_f32(ptrue, x, min_input), svdup_n_f32(0.0f), poly);
|
||||
poly = svsel_f32(svcmpgt_f32(ptrue, x, max_input), inf, poly);
|
||||
#endif
|
||||
|
||||
return poly;
|
||||
}
|
||||
|
||||
inline float32x4_t vexpq_f32(float32x4_t x) {
|
||||
const auto c1 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3f7ffff6)));
|
||||
const auto c2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3efffedb)));
|
||||
const auto c3 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3e2aaf33)));
|
||||
const auto c4 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3d2b9f17)));
|
||||
const auto c5 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3c072010)));
|
||||
|
||||
const auto shift = vreinterpretq_f32_u32(
|
||||
svget_neonq(svdup_n_u32(0x4b00007f))); // 2^23 + 127 = 0x1.0000fep23f
|
||||
const auto inv_ln2 = vreinterpretq_f32_u32(
|
||||
svget_neonq(svdup_n_u32(0x3fb8aa3b))); // 1 / ln(2) = 0x1.715476p+0f
|
||||
const auto neg_ln2_hi = vreinterpretq_f32_u32(svget_neonq(
|
||||
svdup_n_u32(0xbf317200))); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
|
||||
const auto neg_ln2_lo = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(
|
||||
0xb5bfbe8e))); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
|
||||
|
||||
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
|
||||
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
|
||||
const auto zero = svdup_n_f32(0.f);
|
||||
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
|
||||
|
||||
// Range reduction:
|
||||
// e^x = 2^n * e^r
|
||||
// where:
|
||||
// n = floor(x / ln(2))
|
||||
// r = x - n * ln(2)
|
||||
//
|
||||
// By adding x / ln(2) with 2^23 + 127 (shift):
|
||||
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
|
||||
// forces decimal part
|
||||
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n)
|
||||
// + 127 will occupy the whole fraction part of z in FP32 format.
|
||||
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
|
||||
// of x / ln(2) (i.e. n) because the decimal part has been pushed out and
|
||||
// lost.
|
||||
// * The addition of 127 makes the FP32 fraction part of z ready to be used
|
||||
// as the exponent
|
||||
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
|
||||
const auto z = vfmaq_f32(shift, x, inv_ln2);
|
||||
const auto n = z - shift;
|
||||
const auto scale =
|
||||
vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n
|
||||
|
||||
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
|
||||
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in term
|
||||
// of accuracy and performance.
|
||||
const auto r_hi = vfmaq_f32(x, n, neg_ln2_hi);
|
||||
const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo);
|
||||
|
||||
// Compute the truncated Taylor series of e^r.
|
||||
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
|
||||
const auto r2 = r * r;
|
||||
|
||||
const auto p1 = c1 * r;
|
||||
const auto p23 = vfmaq_f32(c2, c3, r);
|
||||
const auto p45 = vfmaq_f32(c4, c5, r);
|
||||
const auto p2345 = vfmaq_f32(p23, p45, r2);
|
||||
const auto p12345 = vfmaq_f32(p1, p2345, r2);
|
||||
|
||||
auto poly = svset_neonq(svundef_f32(), vfmaq_f32(scale, p12345, scale));
|
||||
|
||||
auto pHigh = svcmpgt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), max_input);
|
||||
auto pLow = svcmplt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), min_input);
|
||||
|
||||
auto bound = svsel_f32(
|
||||
pHigh,
|
||||
inf,
|
||||
zero);
|
||||
|
||||
auto pCombined = svorr_b_z(svptrue_b8(), pLow, pHigh);
|
||||
|
||||
// Handle underflow and overflow.
|
||||
poly = svsel_f32(
|
||||
pCombined,
|
||||
bound,
|
||||
poly);
|
||||
|
||||
return svget_neonq(poly);
|
||||
}
|
||||
|
||||
// ln(x) = log2(x) * ln(2)
|
||||
// pow(x, n) = exp(n * ln(x))
|
||||
inline float32x4_t compute_batch_box_cox_vec_sve128_float(
|
||||
static inline svfloat32_t compute_batch_box_cox_vec_sve128_float(
|
||||
svfloat32_t lambda1_v,
|
||||
svfloat32_t lambda2_v,
|
||||
svfloat32_t data_v,
|
||||
svfloat32_t k_eps) {
|
||||
// sum_v = lambda2_v + data_v
|
||||
float32x4_t sum_v = vaddq_f32(svget_neonq(data_v), svget_neonq(lambda2_v));
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
// test lambda1_v: predNZ == 1 iff lambda1_v != 0
|
||||
svbool_t predNZ = svcmpne_n_f32(svptrue_b8(), lambda1_v, 0.0f);
|
||||
|
||||
// clamp sum_v: sum_v = max(sum_v, k_eps)
|
||||
sum_v = vmaxq_f32(sum_v, svget_neonq(k_eps));
|
||||
|
||||
// lnData = log(sum_v)
|
||||
svfloat32_t lnData = svset_neonq(svundef_f32(), vlogq_f32(sum_v));
|
||||
|
||||
// if any lambda1 != 0, compute pow(sum_v, lambda1) using lnData
|
||||
// pow(sum_v, lambda1) == exp(lambda1 * ln(sum_v))
|
||||
svfloat32_t lnData = svlog(svmax_x(ptrue, data_v + lambda2_v, k_eps));
|
||||
svbool_t predNZ = svcmpne_n_f32(ptrue, lambda1_v, 0.0f);
|
||||
if (C10_LIKELY(svptest_any(predNZ, predNZ))) {
|
||||
// mult = lambda1 * ln(sum_v)
|
||||
float32x4_t mult = vmulq_f32(svget_neonq(lnData), svget_neonq(lambda1_v));
|
||||
|
||||
// lambda1_r = 1 / lambda1
|
||||
svfloat32_t lambda1_r = svdivr_f32_m(predNZ, lambda1_v, svdup_n_f32(1.0f));
|
||||
|
||||
// pow = exp(mult)
|
||||
float32x4_t pow = vexpq_f32(mult);
|
||||
|
||||
// merge results
|
||||
// lnData if lambda1 == 0, (lambda1_r * pow - lambda1_r) if lambda1 != 0
|
||||
svfloat32_t pow = svexp(lnData * lambda1_v);
|
||||
lnData = svsel_f32(predNZ, lambda1_r, lnData);
|
||||
lnData =
|
||||
svnmsb_f32_m(predNZ, lnData, svset_neonq(svundef_f32(), pow), lnData);
|
||||
lnData = svnmsb_f32_m(predNZ, lnData, pow, lnData);
|
||||
}
|
||||
return svget_neonq(lnData);
|
||||
return lnData;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -186,11 +137,11 @@ template <>
|
||||
void compute_batch_box_cox_vec_sve128(
|
||||
std::size_t N,
|
||||
std::size_t D,
|
||||
const float* data_ptr,
|
||||
const float* __restrict lambda1_ptr,
|
||||
const float* __restrict lambda2_ptr,
|
||||
float* output_ptr) {
|
||||
svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6));
|
||||
const float *data_ptr,
|
||||
const float *__restrict lambda1_ptr,
|
||||
const float *__restrict lambda2_ptr,
|
||||
float *output_ptr) {
|
||||
const svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6));
|
||||
|
||||
std::size_t remainder = D % 4;
|
||||
std::size_t loopBound = D - remainder;
|
||||
@ -204,17 +155,17 @@ void compute_batch_box_cox_vec_sve128(
|
||||
svfloat32_t lambda2_v =
|
||||
svset_neonq(svundef_f32(), vld1q_f32(lambda2_ptr + j));
|
||||
svfloat32_t data_v = svset_neonq(svundef_f32(), vld1q_f32(data_ptr));
|
||||
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
svfloat32_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
lambda1_v, lambda2_v, data_v, k_eps);
|
||||
vst1q_f32(output_ptr, result);
|
||||
vst1q_f32(output_ptr, svget_neonq(result));
|
||||
}
|
||||
if (C10_LIKELY(remainder > 0)) {
|
||||
svfloat32_t lambda1_v = svld1_f32(remainderPred, lambda1_ptr + loopBound);
|
||||
svfloat32_t lambda2_v = svld1_f32(remainderPred, lambda2_ptr + loopBound);
|
||||
svfloat32_t data_v = svld1_f32(remainderPred, data_ptr);
|
||||
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
svfloat32_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
lambda1_v, lambda2_v, data_v, k_eps);
|
||||
svst1_f32(remainderPred, output_ptr, svset_neonq(svundef_f32(), result));
|
||||
svst1_f32(remainderPred, output_ptr, result);
|
||||
data_ptr += remainder;
|
||||
output_ptr += remainder;
|
||||
}
|
||||
|
Reference in New Issue
Block a user