Add vec_reduce_all specialization for std::plus on AArch64 (#152388)

AArch64 has an instruction for this.

Differential Revision: [D73817183](https://our.internmc.facebook.com/intern/diff/D73817183/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152388
Approved by: https://github.com/Skylion007
ghstack dependencies: #152365, #152366
This commit is contained in:
Scott Wolchok
2025-05-13 16:04:13 -07:00
committed by PyTorch MergeBot
parent b972435158
commit ea17cd067d
2 changed files with 9 additions and 8 deletions

View File

@ -131,6 +131,15 @@ struct VecReduceAllSIMD<float, Op> {
return v[0];
}
};
template <>
struct VecReduceAllSIMD<float, std::plus<Vectorized<float>>> {
static inline float apply(
const std::plus<Vectorized<float>>& vec_fun,
const Vectorized<float>& acc_vec) {
return vaddvq_f32(acc_vec);
}
};
#endif // defined(__aarch64__)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \

View File

@ -74,13 +74,9 @@ float reduce(vec::VectorizedN<Half, kF16RegistersPerIteration>& x) {
}
});
const auto [t0, t1] = vec::convert_half_float(x[0]);
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
return vaddvq_f32(t0 + t1);
#else
return vec::vec_reduce_all<float>(
std::plus<vec::Vectorized<float>>(),
t0 + t1);
#endif
}
float fp16_dot_with_fp16_arith(const Half* x, const Half* a, int len) {
@ -130,13 +126,9 @@ static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n,
#endif // !defined(__aarch64__) || defined( __ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
float reduce(vec::Vectorized<float> x) {
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
return vaddvq_f32(x);
#else
return vec::vec_reduce_all<float>(
std::plus<vec::Vectorized<float>>(),
x);
#endif
}
// The below reduce overload and fp16_dot_with_fp32_arith are adapted