mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add vec_reduce_all specialization for std::plus on AArch64 (#152388)
AArch64 has an instruction for this. Differential Revision: [D73817183](https://our.internmc.facebook.com/intern/diff/D73817183/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152388 Approved by: https://github.com/Skylion007 ghstack dependencies: #152365, #152366
This commit is contained in:
committed by
PyTorch MergeBot
parent
b972435158
commit
ea17cd067d
@ -131,6 +131,15 @@ struct VecReduceAllSIMD<float, Op> {
|
||||
return v[0];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecReduceAllSIMD<float, std::plus<Vectorized<float>>> {
|
||||
static inline float apply(
|
||||
const std::plus<Vectorized<float>>& vec_fun,
|
||||
const Vectorized<float>& acc_vec) {
|
||||
return vaddvq_f32(acc_vec);
|
||||
}
|
||||
};
|
||||
#endif // defined(__aarch64__)
|
||||
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||
|
@ -74,13 +74,9 @@ float reduce(vec::VectorizedN<Half, kF16RegistersPerIteration>& x) {
|
||||
}
|
||||
});
|
||||
const auto [t0, t1] = vec::convert_half_float(x[0]);
|
||||
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
|
||||
return vaddvq_f32(t0 + t1);
|
||||
#else
|
||||
return vec::vec_reduce_all<float>(
|
||||
std::plus<vec::Vectorized<float>>(),
|
||||
t0 + t1);
|
||||
#endif
|
||||
}
|
||||
|
||||
float fp16_dot_with_fp16_arith(const Half* x, const Half* a, int len) {
|
||||
@ -130,13 +126,9 @@ static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n,
|
||||
#endif // !defined(__aarch64__) || defined( __ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
|
||||
|
||||
float reduce(vec::Vectorized<float> x) {
|
||||
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
|
||||
return vaddvq_f32(x);
|
||||
#else
|
||||
return vec::vec_reduce_all<float>(
|
||||
std::plus<vec::Vectorized<float>>(),
|
||||
x);
|
||||
#endif
|
||||
}
|
||||
|
||||
// The below reduce overload and fp16_dot_with_fp32_arith are adapted
|
||||
|
Reference in New Issue
Block a user