efficient zero_mask implementation for vec128_*_neon (#155766)

Differential Revision: [D76481039](https://our.internmc.facebook.com/intern/diff/D76481039/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155766
Approved by: https://github.com/malfet
This commit is contained in:
Scott Wolchok
2025-07-03 09:44:05 -07:00
committed by PyTorch MergeBot
parent b359571c60
commit ad86c05b78
2 changed files with 33 additions and 13 deletions

View File

@ -202,18 +202,14 @@ class Vectorized<float> {
store(tmp);
return tmp[idx];
}
// For boolean version where we want to if any 1/all zero
// etc. can be done faster in a different way.
int zero_mask() const {
__at_align__ float tmp[size()];
store(tmp);
int mask = 0;
for (int i = 0; i < size(); ++i) {
if (tmp[i] == 0.f) {
mask |= (1 << i);
}
}
return mask;
uint32x4_t is_zero_vec = vceqzq_f32(values);
const int32x4_t shift = vcombine_s32(
vcreate_s32(0x0 | (int64_t(0x1) << 32)),
vcreate_s32(0x2 | (int64_t(0x3) << 32)));
uint32x4_t bits_vec =
vshlq_u32(vandq_u32(is_zero_vec, vdupq_n_u32(1)), shift);
return vaddvq_u32(bits_vec);
}
Vectorized<float> isnan() const {
return vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, values)));

View File

@ -220,8 +220,32 @@ class Vectorized<c10::Half> : public Vectorized16<
std::memcpy(ptr, tmp_values, count * sizeof(float16_t));
}
}
// For boolean version where we want to if any 1/all zero
// etc. can be done faster in a different way.
int zero_mask() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
uint16x8_t is_zero_vec = vceqzq_f16(values);
const int16x8_t shift = vcombine_s16(
vcreate_s16(
0x0 | (int64_t(0x1) << 16) | (int64_t(0x2) << 32) |
(int64_t(0x3) << 48)),
vcreate_s16(
0x4 | (int64_t(0x5) << 16) | (int64_t(0x6) << 32) |
(int64_t(0x7) << 48)));
uint16x8_t bits_vec =
vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
return vaddvq_u16(bits_vec);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
// use known working implmentation.
__at_align__ value_type tmp[size()];
store(tmp);
int mask = 0;
for (int i = 0; i < size(); ++i) {
if (tmp[i] == 0) {
mask |= (1 << i);
}
}
return mask;
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
Vectorized<c10::Half> isnan() const {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values)));