mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aarch64] Fix ATen-cpu aarch64 builds (#84294)
Summary: Fix ATen-cpu aarch64 builds and hook up cpukernel_neon Test Plan: CI Differential Revision: D39142670 Pull Request resolved: https://github.com/pytorch/pytorch/pull/84294 Approved by: https://github.com/ajtulloch
This commit is contained in:
committed by
PyTorch MergeBot
parent
5e5c610a58
commit
0fb1495512
@ -1,14 +1,20 @@
|
||||
#ifndef ATOMIC_ADD_FLOAT
|
||||
#define ATOMIC_ADD_FLOAT
|
||||
|
||||
#if (defined(__x86_64__) || defined(__i386__))
|
||||
#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__))
|
||||
#include <ATen/native/cpu/Intrinsics.h>
|
||||
#else
|
||||
#define _mm_pause()
|
||||
#endif
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#ifdef __aarch64__
|
||||
static __inline void _mm_pause() {
|
||||
__asm__ __volatile__("yield;" : : : "memory");
|
||||
}
|
||||
#else
|
||||
#define _mm_pause()
|
||||
#endif
|
||||
|
||||
static inline void cpu_atomic_add_float(float* dst, float fvalue)
|
||||
{
|
||||
typedef union {
|
||||
|
@ -3790,8 +3790,8 @@ void quantize_tensor_per_channel_impl<c10::quint8>(
|
||||
// channels_last contig.
|
||||
// If axis = 0 and channels_last contig, implementation for channels
|
||||
// first (NCHW) works.
|
||||
for (const auto b : c10::irange(batches)) {
|
||||
for (const auto e : c10::irange(elements_per_channel)) {
|
||||
for (const auto b C10_UNUSED : c10::irange(batches)) {
|
||||
for (const auto e C10_UNUSED : c10::irange(elements_per_channel)) {
|
||||
uint32_t c = 0;
|
||||
while (c + 8 < channels) {
|
||||
const int16x8_t vzero_point = vld1q_s16(&zero_points_int16t[c]);
|
||||
@ -3821,7 +3821,7 @@ void quantize_tensor_per_channel_impl<c10::quint8>(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (const auto b : c10::irange(batches)) {
|
||||
for (const auto b C10_UNUSED : c10::irange(batches)) {
|
||||
for (const auto c : c10::irange(channels)) {
|
||||
uint32_t e = 0;
|
||||
const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]);
|
||||
|
Reference in New Issue
Block a user