Fix compilation on aarch64 with gcc (#124511)

Which is more stringent than clang when equivalently sized NEON registers are cast to each other. In particular, at one point `uint16x4_t` were cast to `int16x4_t`, which gcc does not allow. Added `vreinterpret_s16_u16` (which is a no-op) to solve this and tested in https://godbolt.org/z/sYb4ThM6M

Test plan: Build aarch64 wheels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124511
Approved by: https://github.com/mikekgfb
This commit is contained in:
Nikita Shulga
2024-04-19 19:53:19 +00:00
committed by PyTorch MergeBot
parent 179108f14d
commit e6a788ac26

View File

@ -354,8 +354,8 @@ inline void tinygemm_kernel(
int BLOCK_K) {
int16_t shift_vals[4] = {0, -4, -8, -12};
int16x4_t shifts = vld1_s16(shift_vals);
int16x4_t mask = vdup_n_s16(0x0F);
int16x4_t offs = vdup_n_s16(8);
uint16x4_t mask = vdup_n_u16(0x0F);
for (const auto m : c10::irange(BLOCK_M)) {
for (int n = 0; n < BLOCK_N; n+= 16) {
float32x4_t c_val[4];
@ -375,7 +375,8 @@ inline void tinygemm_kernel(
}
c10::ForcedUnroll<4>{}([&](auto i) {
uint16_t b_pack = reinterpret_cast<const uint16_t*>(B + k * ldb + n / 2)[i];
int16x4_t b_ints = vsub_s16(vand_u16(vshl_u16(vdup_n_u16(b_pack), shifts), mask), offs);
uint16x4_t b_masked = vand_u16(vshl_u16(vdup_n_u16(b_pack), shifts), mask);
int16x4_t b_ints = vsub_s16(vreinterpret_s16_u16(b_masked), offs);
float32x4_t b_vals = vcvtq_f32_s32(vmovl_s16(b_ints));
b_vals = vaddq_f32(zeros[i], vmulq_f32(scales[i], b_vals));
c_val[i] = vfmaq_f32(c_val[i], b_vals, a_val);