mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
efficient zero_mask implementation for vec128_*_neon (#155766)
Differential Revision: [D76481039](https://our.internmc.facebook.com/intern/diff/D76481039/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155766 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
b359571c60
commit
ad86c05b78
@ -202,18 +202,14 @@ class Vectorized<float> {
|
||||
store(tmp);
|
||||
return tmp[idx];
|
||||
}
|
||||
// For boolean version where we want to if any 1/all zero
|
||||
// etc. can be done faster in a different way.
|
||||
int zero_mask() const {
|
||||
__at_align__ float tmp[size()];
|
||||
store(tmp);
|
||||
int mask = 0;
|
||||
for (int i = 0; i < size(); ++i) {
|
||||
if (tmp[i] == 0.f) {
|
||||
mask |= (1 << i);
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
uint32x4_t is_zero_vec = vceqzq_f32(values);
|
||||
const int32x4_t shift = vcombine_s32(
|
||||
vcreate_s32(0x0 | (int64_t(0x1) << 32)),
|
||||
vcreate_s32(0x2 | (int64_t(0x3) << 32)));
|
||||
uint32x4_t bits_vec =
|
||||
vshlq_u32(vandq_u32(is_zero_vec, vdupq_n_u32(1)), shift);
|
||||
return vaddvq_u32(bits_vec);
|
||||
}
|
||||
Vectorized<float> isnan() const {
|
||||
return vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, values)));
|
||||
|
@ -220,8 +220,32 @@ class Vectorized<c10::Half> : public Vectorized16<
|
||||
std::memcpy(ptr, tmp_values, count * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
// For boolean version where we want to if any 1/all zero
|
||||
// etc. can be done faster in a different way.
|
||||
int zero_mask() const {
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
uint16x8_t is_zero_vec = vceqzq_f16(values);
|
||||
const int16x8_t shift = vcombine_s16(
|
||||
vcreate_s16(
|
||||
0x0 | (int64_t(0x1) << 16) | (int64_t(0x2) << 32) |
|
||||
(int64_t(0x3) << 48)),
|
||||
vcreate_s16(
|
||||
0x4 | (int64_t(0x5) << 16) | (int64_t(0x6) << 32) |
|
||||
(int64_t(0x7) << 48)));
|
||||
uint16x8_t bits_vec =
|
||||
vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
|
||||
return vaddvq_u16(bits_vec);
|
||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
// use known working implmentation.
|
||||
__at_align__ value_type tmp[size()];
|
||||
store(tmp);
|
||||
int mask = 0;
|
||||
for (int i = 0; i < size(); ++i) {
|
||||
if (tmp[i] == 0) {
|
||||
mask |= (1 << i);
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
}
|
||||
Vectorized<c10::Half> isnan() const {
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values)));
|
||||
|
Reference in New Issue
Block a user