mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Move some of vec into headeronly in preparation for Half.h (#158976)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158976 Approved by: https://github.com/albanD, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
6de24135e5
commit
222fa451a2
@ -462,7 +462,7 @@ test_inductor_aoti() {
|
||||
# rebuild with the build cache with `BUILD_AOT_INDUCTOR_TEST` enabled
|
||||
/usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}"
|
||||
|
||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference -dist=loadfile
|
||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
|
||||
}
|
||||
|
||||
test_inductor_cpp_wrapper_shard() {
|
||||
|
@ -1,55 +1 @@
|
||||
#pragma once
|
||||
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
|
||||
/* GCC or clang-compatible compiler, targeting x86/x86-64 */
|
||||
#include <x86intrin.h>
|
||||
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
||||
/* Clang-compatible compiler, targeting arm neon */
|
||||
#include <arm_neon.h>
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
/* CLANG-compatible compiler, targeting ARM with SVE */
|
||||
#include <arm_sve.h>
|
||||
#endif
|
||||
#elif defined(_MSC_VER)
|
||||
/* Microsoft C/C++-compatible compiler */
|
||||
#include <intrin.h>
|
||||
#if _MSC_VER <= 1900
|
||||
#define _mm256_extract_epi64(X, Y) \
|
||||
(_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
|
||||
#define _mm256_extract_epi32(X, Y) \
|
||||
(_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
|
||||
#define _mm256_extract_epi16(X, Y) \
|
||||
(_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
|
||||
#define _mm256_extract_epi8(X, Y) \
|
||||
(_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
|
||||
#endif
|
||||
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
||||
/* GCC-compatible compiler, targeting ARM with NEON */
|
||||
#include <arm_neon.h>
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
/* GCC-compatible compiler, targeting ARM with SVE */
|
||||
#include <arm_sve.h>
|
||||
#endif
|
||||
#if defined(MISSING_ARM_VLD1)
|
||||
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
|
||||
#elif defined(MISSING_ARM_VST1)
|
||||
#include <ATen/cpu/vec/vec256/missing_vst1_neon.h>
|
||||
#endif
|
||||
#elif defined(__GNUC__) && defined(__IWMMXT__)
|
||||
/* GCC-compatible compiler, targeting ARM with WMMX */
|
||||
#include <mmintrin.h>
|
||||
#elif defined(__s390x__)
|
||||
// targets Z/architecture
|
||||
// we will include vecintrin later
|
||||
#elif (defined(__GNUC__) || defined(__xlC__)) && \
|
||||
(defined(__VEC__) || defined(__ALTIVEC__))
|
||||
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
|
||||
#include <altivec.h>
|
||||
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
|
||||
with the C++ types. => Can still use __bool/__vector */
|
||||
#undef bool
|
||||
#undef vector
|
||||
#undef pixel
|
||||
#elif defined(__GNUC__) && defined(__SPE__)
|
||||
/* GCC-compatible compiler, targeting PowerPC with SPE */
|
||||
#include <spe.h>
|
||||
#endif
|
||||
#include <torch/headeronly/cpu/vec/intrinsics.h>
|
||||
|
@ -1,396 +1 @@
|
||||
/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */
|
||||
|
||||
__extension__ extern __inline uint8x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u8_x2(const uint8_t* __a) {
|
||||
uint8x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int8x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s8_x2(const int8_t* __a) {
|
||||
int8x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u16_x2(const uint16_t* __a) {
|
||||
uint16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s16_x2(const int16_t* __a) {
|
||||
int16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint32x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u32_x2(const uint32_t* __a) {
|
||||
uint32x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int32x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s32_x2(const int32_t* __a) {
|
||||
int32x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u64_x2(const uint64_t* __a) {
|
||||
uint64x1x2_t ret;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s64_x2(const int64_t* __a) {
|
||||
int64x1x2_t ret;
|
||||
__builtin_aarch64_simd_oi __o;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_f16_x2(const float16_t* __a) {
|
||||
float16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float32x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_f32_x2(const float32_t* __a) {
|
||||
float32x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_f64_x2(const float64_t* __a) {
|
||||
float64x1x2_t ret;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly8x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_p8_x2(const poly8_t* __a) {
|
||||
poly8x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_p16_x2(const poly16_t* __a) {
|
||||
poly16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_p64_x2(const poly64_t* __a) {
|
||||
poly64x1x2_t ret;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint8x16x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u8_x2(const uint8_t* __a) {
|
||||
uint8x16x2_t ret;
|
||||
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int8x16x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s8_x2(const int8_t* __a) {
|
||||
int8x16x2_t ret;
|
||||
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u16_x2(const uint16_t* __a) {
|
||||
uint16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s16_x2(const int16_t* __a) {
|
||||
int16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint32x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u32_x2(const uint32_t* __a) {
|
||||
uint32x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int32x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s32_x2(const int32_t* __a) {
|
||||
int32x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u64_x2(const uint64_t* __a) {
|
||||
uint64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s64_x2(const int64_t* __a) {
|
||||
int64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_f16_x2(const float16_t* __a) {
|
||||
float16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float32x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_f32_x2(const float32_t* __a) {
|
||||
float32x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_f64_x2(const float64_t* __a) {
|
||||
float64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly8x16x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_p8_x2(const poly8_t* __a) {
|
||||
poly8x16x2_t ret;
|
||||
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_p16_x2(const poly16_t* __a) {
|
||||
poly16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_p64_x2(const poly64_t* __a) {
|
||||
poly64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
/* vst1x2 */
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s64_x2(int64_t* __a, int64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_f64_x2(float64_t* __a, float64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s8_x2(int8_t* __a, int8x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s16_x2(int16_t* __a, int16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s32_x2(int32_t* __a, int32x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_f16_x2(float16_t* __a, float16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_f32_x2(float32_t* __a, float32x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s8_x2(int8_t* __a, int8x16x2_t val) {
|
||||
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) {
|
||||
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s16_x2(int16_t* __a, int16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s32_x2(int32_t* __a, int32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s64_x2(int64_t* __a, int64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) {
|
||||
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f16_x2(float16_t* __a, float16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f64_x2(float64_t* __a, float64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
#include <torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h>
|
||||
|
@ -1,7 +1 @@
|
||||
/* Workaround for missing vst1q_f32_x2 in gcc-8. */
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
#include <torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h>
|
||||
|
@ -3,50 +3,12 @@
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <torch/headeronly/cpu/vec/vec_half.h>
|
||||
|
||||
namespace at::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
|
||||
!defined(__APPLE__)
|
||||
static inline uint16_t float2half_scalar(float val) {
|
||||
#if defined(CPU_CAPABILITY_AVX2)
|
||||
#if defined(_MSC_VER)
|
||||
__m256 v = _mm256_set1_ps(val);
|
||||
__m128i o =
|
||||
_mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
return static_cast<std::uint16_t>(_mm_cvtsi128_si32(o));
|
||||
#else
|
||||
return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
#endif
|
||||
#elif defined(CPU_CAPABILITY_AVX512)
|
||||
__m512 v = _mm512_set1_ps(val);
|
||||
__m256i o =
|
||||
_mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
return static_cast<std::uint16_t>(
|
||||
_mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline float half2float_scalar(uint16_t val) {
|
||||
#if defined(CPU_CAPABILITY_AVX2)
|
||||
#if defined(_MSC_VER)
|
||||
__m128i v = _mm_cvtsi32_si128(val);
|
||||
__m256 o = _mm256_cvtph_ps(v);
|
||||
return _mm256_cvtss_f32(o);
|
||||
#else
|
||||
return _cvtsh_ss(val);
|
||||
#endif
|
||||
#elif defined(CPU_CAPABILITY_AVX512)
|
||||
__m256i v =
|
||||
_mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||
__m512 o = _mm512_cvtph_ps(v);
|
||||
return _mm512_cvtss_f32(o);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Transpose a [2, 32] matrix to [32, 2]
|
||||
// Note: the output leading dimension should be 2,
|
||||
// that is, the output must be contiguous
|
||||
|
193
c10/util/Half.h
193
c10/util/Half.h
@ -13,6 +13,7 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
@ -162,198 +163,6 @@ inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
|
||||
~zero_mask);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format.
|
||||
*
|
||||
* @note The implementation relies on IEEE-like (no assumption about rounding
|
||||
* mode and no operations on denormals) floating-point operations and bitcasts
|
||||
* between integer and floating-point variables.
|
||||
*/
|
||||
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
|
||||
#ifdef C10_X86_F16
|
||||
return _cvtsh_ss(h);
|
||||
#else
|
||||
/*
|
||||
* Extend the half-precision floating-point number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+-----+------------+-------------------+
|
||||
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
|
||||
* +---+-----+------------+-------------------+
|
||||
* Bits 31 26-30 16-25 0-15
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
const uint32_t w = (uint32_t)h << 16;
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
/*
|
||||
* Extract mantissa and biased exponent of the input number into the high bits
|
||||
* of the 32-bit word:
|
||||
*
|
||||
* +-----+------------+---------------------+
|
||||
* |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
|
||||
* +-----+------------+---------------------+
|
||||
* Bits 27-31 17-26 0-16
|
||||
*/
|
||||
const uint32_t two_w = w + w;
|
||||
|
||||
/*
|
||||
* Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
|
||||
* mantissa and exponent of a single-precision floating-point number:
|
||||
*
|
||||
* S|Exponent | Mantissa
|
||||
* +-+---+-----+------------+----------------+
|
||||
* |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
|
||||
* +-+---+-----+------------+----------------+
|
||||
* Bits | 23-31 | 0-22
|
||||
*
|
||||
* Next, there are some adjustments to the exponent:
|
||||
* - The exponent needs to be corrected by the difference in exponent bias
|
||||
* between single-precision and half-precision formats (0x7F - 0xF = 0x70)
|
||||
* - Inf and NaN values in the inputs should become Inf and NaN values after
|
||||
* conversion to the single-precision number. Therefore, if the biased
|
||||
* exponent of the half-precision input was 0x1F (max possible value), the
|
||||
* biased exponent of the single-precision output must be 0xFF (max possible
|
||||
* value). We do this correction in two steps:
|
||||
* - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
|
||||
* below) rather than by 0x70 suggested by the difference in the exponent bias
|
||||
* (see above).
|
||||
* - Then we multiply the single-precision result of exponent adjustment by
|
||||
* 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
|
||||
* necessary exponent adjustment by 0x70 due to difference in exponent bias.
|
||||
* The floating-point multiplication hardware would ensure than Inf and
|
||||
* NaN would retain their value on at least partially IEEE754-compliant
|
||||
* implementations.
|
||||
*
|
||||
* Note that the above operations do not handle denormal inputs (where biased
|
||||
* exponent == 0). However, they also do not operate on denormal inputs, and
|
||||
* do not produce denormal results.
|
||||
*/
|
||||
constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
|
||||
// const float exp_scale = 0x1.0p-112f;
|
||||
constexpr uint32_t scale_bits = (uint32_t)15 << 23;
|
||||
float exp_scale_val = 0;
|
||||
#if defined(_MSC_VER) && defined(__clang__)
|
||||
__builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
|
||||
#else
|
||||
std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
|
||||
#endif
|
||||
|
||||
const float exp_scale = exp_scale_val;
|
||||
const float normalized_value =
|
||||
fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
|
||||
|
||||
/*
|
||||
* Convert denormalized half-precision inputs into single-precision results
|
||||
* (always normalized). Zero inputs are also handled here.
|
||||
*
|
||||
* In a denormalized number the biased exponent is zero, and mantissa has
|
||||
* on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
|
||||
*
|
||||
* zeros | mantissa
|
||||
* +---------------------------+------------+
|
||||
* |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
|
||||
* +---------------------------+------------+
|
||||
* Bits 10-31 0-9
|
||||
*
|
||||
* Now, remember that denormalized half-precision numbers are represented as:
|
||||
* FP16 = mantissa * 2**(-24).
|
||||
* The trick is to construct a normalized single-precision number with the
|
||||
* same mantissa and thehalf-precision input and with an exponent which would
|
||||
* scale the corresponding mantissa bits to 2**(-24). A normalized
|
||||
* single-precision floating-point number is represented as: FP32 = (1 +
|
||||
* mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
|
||||
* exponent is 126, a unit change in the mantissa of the input denormalized
|
||||
* half-precision number causes a change of the constructed single-precision
|
||||
* number by 2**(-24), i.e. the same amount.
|
||||
*
|
||||
* The last step is to adjust the bias of the constructed single-precision
|
||||
* number. When the input half-precision number is zero, the constructed
|
||||
* single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
|
||||
* 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
|
||||
* single-precision number to get the numerical equivalent of the input
|
||||
* half-precision number.
|
||||
*/
|
||||
constexpr uint32_t magic_mask = UINT32_C(126) << 23;
|
||||
constexpr float magic_bias = 0.5f;
|
||||
const float denormalized_value =
|
||||
fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
|
||||
|
||||
/*
|
||||
* - Choose either results of conversion of input as a normalized number, or
|
||||
* as a denormalized number, depending on the input exponent. The variable
|
||||
* two_w contains input exponent in bits 27-31, therefore if its smaller than
|
||||
* 2**27, the input is either a denormal number, or zero.
|
||||
* - Combine the result of conversion of exponent and mantissa with the sign
|
||||
* of the input number.
|
||||
*/
|
||||
constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
|
||||
const uint32_t result = sign |
|
||||
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
|
||||
: fp32_to_bits(normalized_value));
|
||||
return fp32_from_bits(result);
|
||||
#endif // C10_X86_F16
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 16-bit floating-point number in IEEE half-precision format, in bit
|
||||
* representation.
|
||||
*
|
||||
* @note The implementation relies on IEEE-like (no assumption about rounding
|
||||
* mode and no operations on denormals) floating-point operations and bitcasts
|
||||
* between integer and floating-point variables.
|
||||
*/
|
||||
inline uint16_t fp16_ieee_from_fp32_value(float f) {
|
||||
#ifdef C10_X86_F16
|
||||
return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
|
||||
#else
|
||||
// const float scale_to_inf = 0x1.0p+112f;
|
||||
// const float scale_to_zero = 0x1.0p-110f;
|
||||
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
|
||||
constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
|
||||
float scale_to_inf_val = 0, scale_to_zero_val = 0;
|
||||
std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
|
||||
std::memcpy(
|
||||
&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
|
||||
const float scale_to_inf = scale_to_inf_val;
|
||||
const float scale_to_zero = scale_to_zero_val;
|
||||
|
||||
#if defined(_MSC_VER) && _MSC_VER == 1916
|
||||
float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
|
||||
#else
|
||||
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
|
||||
#endif
|
||||
|
||||
const uint32_t w = fp32_to_bits(f);
|
||||
const uint32_t shl1_w = w + w;
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
|
||||
if (bias < UINT32_C(0x71000000)) {
|
||||
bias = UINT32_C(0x71000000);
|
||||
}
|
||||
|
||||
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
|
||||
const uint32_t bits = fp32_to_bits(base);
|
||||
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
|
||||
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
|
||||
const uint32_t nonsign = exp_bits + mantissa_bits;
|
||||
return static_cast<uint16_t>(
|
||||
(sign >> 16) |
|
||||
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
|
||||
#endif // C10_X86_F16
|
||||
}
|
||||
|
||||
#ifdef C10_X86_F16
|
||||
#undef C10_X86_F16
|
||||
#endif // C10_X86_F16
|
||||
|
@ -1,46 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#if __has_include(<bit>) && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)
|
||||
#include <bit>
|
||||
#define C10_HAVE_STD_BIT_CAST 1
|
||||
#else
|
||||
#define C10_HAVE_STD_BIT_CAST 0
|
||||
#endif // __has_include(<bit>) && (__cplusplus >= 202002L ||
|
||||
// (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L))
|
||||
|
||||
namespace c10 {
|
||||
|
||||
#if C10_HAVE_STD_BIT_CAST
|
||||
using std::bit_cast;
|
||||
#else
|
||||
// Implementations of std::bit_cast() from C++ 20.
|
||||
//
|
||||
// This is a less sketchy version of reinterpret_cast.
|
||||
//
|
||||
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
|
||||
// information as well as the source of our implementations.
|
||||
template <class To, class From>
|
||||
C10_HOST_DEVICE std::enable_if_t<
|
||||
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
|
||||
std::is_trivially_copyable_v<To>,
|
||||
To>
|
||||
// constexpr support needs compiler magic
|
||||
bit_cast(const From& src) noexcept {
|
||||
static_assert(
|
||||
std::is_trivially_constructible_v<To>,
|
||||
"This implementation additionally requires "
|
||||
"destination type to be trivially constructible");
|
||||
|
||||
To dst;
|
||||
std::memcpy(&dst, &src, sizeof(To));
|
||||
return dst;
|
||||
}
|
||||
#endif // C10_HAVE_STD_BIT_CAST
|
||||
#undef C10_HAVE_STD_BIT_CAST
|
||||
|
||||
} // namespace c10
|
||||
#include <torch/headeronly/util/bit_cast.h>
|
||||
|
@ -1,33 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace c10::detail {
|
||||
|
||||
C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_float(w);
|
||||
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
return __uint_as_float((unsigned int)w);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castu32_f32(w);
|
||||
#else
|
||||
return c10::bit_cast<float>(w);
|
||||
#endif
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_uint(f);
|
||||
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
return (uint32_t)__float_as_uint(f);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castf32_u32(f);
|
||||
#else
|
||||
return c10::bit_cast<uint32_t>(f);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace c10::detail
|
||||
#include <torch/headeronly/util/floating_point_utils.h>
|
||||
|
@ -9,6 +9,13 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
|
||||
)
|
||||
|
||||
# The below are tests that require CPU_CAPABILITY setup
|
||||
# You may think test_vec.cpp needs to be in there, but it does not.
|
||||
set(AOTI_ABI_CHECK_VEC_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
|
||||
)
|
||||
|
||||
add_executable(test_aoti_abi_check
|
||||
@ -23,6 +30,23 @@ target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
|
||||
foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES})
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
list(GET CPU_CAPABILITY_NAMES ${i} CPU_CAPABILITY)
|
||||
list(GET CPU_CAPABILITY_FLAGS ${i} FLAGS)
|
||||
separate_arguments(FLAGS UNIX_COMMAND "${FLAGS}")
|
||||
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
|
||||
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main)
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
# Define CPU_CAPABILITY and CPU_CAPABILITY_XXX macros for conditional compilation
|
||||
target_compile_definitions(${test_name}_${CPU_CAPABILITY} PRIVATE CPU_CAPABILITY=${CPU_CAPABILITY} CPU_CAPABILITY_${CPU_CAPABILITY})
|
||||
target_compile_options(${test_name}_${CPU_CAPABILITY} PRIVATE ${FLAGS})
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
if(INSTALL_TEST)
|
||||
install(TARGETS test_aoti_abi_check DESTINATION bin)
|
||||
# Install PDB files for MSVC builds
|
||||
|
22
test/cpp/aoti_abi_check/test_vec_half.cpp
Normal file
22
test/cpp/aoti_abi_check/test_vec_half.cpp
Normal file
@ -0,0 +1,22 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/headeronly/cpu/vec/vec_half.h>
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
|
||||
TEST(TestVecHalf, TestConversion) {
|
||||
float f32s[100];
|
||||
for (int i = 0; i < 100; i++) {
|
||||
f32s[i] = static_cast<float>(i + 0.3);
|
||||
}
|
||||
for (int i = 0; i < 100; i++) {
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
|
||||
!defined(__APPLE__)
|
||||
uint16_t u16 = torch::headeronly::vec::float2half_scalar(f32s[i]);
|
||||
float x = torch::headeronly::vec::half2float_scalar(u16);
|
||||
EXPECT_EQ(
|
||||
u16, torch::headeronly::detail::fp16_ieee_from_fp32_value(f32s[i]))
|
||||
<< "Test failed for float to uint16 " << f32s[i] << "\n";
|
||||
EXPECT_EQ(x, torch::headeronly::detail::fp16_ieee_to_fp32_value(u16))
|
||||
<< "Test failed for uint16 to float " << u16 << "\n";
|
||||
#endif
|
||||
}
|
||||
}
|
@ -6,7 +6,7 @@
|
||||
# c10/util/TypeCast.h
|
||||
convert
|
||||
|
||||
# c10/util/bit_cast.h
|
||||
# c10/util/bit_cast.h, torch/headeronly/util/bit_cast.h
|
||||
bit_cast
|
||||
|
||||
# c10/util/BFloat16-math.h, c10/util/BFloat16.h
|
||||
@ -27,6 +27,14 @@ Float8_e5m2fnuz
|
||||
# c10/util/Half.h
|
||||
Half
|
||||
|
||||
# torch/headeronly/util/Half.h
|
||||
fp16_ieee_from_fp32_value
|
||||
fp16_ieee_to_fp32_value
|
||||
|
||||
# torch/headeronly/util/floating_point_utils.h
|
||||
# fp32_from_bits called from fp16_ieee_to_fp32_value
|
||||
# fp32_to_bits called from fp16_ieee_from_fp32_value
|
||||
|
||||
# c10/util/complex.h
|
||||
complex
|
||||
|
||||
@ -48,6 +56,10 @@ maximum
|
||||
minimum
|
||||
size
|
||||
|
||||
# torch/headeronly/cpu/vec/vec_half.h
|
||||
float2half_scalar
|
||||
half2float_scalar
|
||||
|
||||
# torch/headeronly/macros/Export.h
|
||||
C10_API
|
||||
|
||||
|
@ -20,6 +20,7 @@ configure_file(
|
||||
|
||||
file(GLOB HEADERONLY_HEADERS
|
||||
*.h
|
||||
cpu/**/*.h
|
||||
macros/*.h
|
||||
util/*.h
|
||||
)
|
||||
|
55
torch/headeronly/cpu/vec/intrinsics.h
Normal file
55
torch/headeronly/cpu/vec/intrinsics.h
Normal file
@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
|
||||
/* GCC or clang-compatible compiler, targeting x86/x86-64 */
|
||||
#include <x86intrin.h>
|
||||
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
||||
/* Clang-compatible compiler, targeting arm neon */
|
||||
#include <arm_neon.h>
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
/* CLANG-compatible compiler, targeting ARM with SVE */
|
||||
#include <arm_sve.h>
|
||||
#endif
|
||||
#elif defined(_MSC_VER)
|
||||
/* Microsoft C/C++-compatible compiler */
|
||||
#include <intrin.h>
|
||||
#if _MSC_VER <= 1900
|
||||
#define _mm256_extract_epi64(X, Y) \
|
||||
(_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
|
||||
#define _mm256_extract_epi32(X, Y) \
|
||||
(_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
|
||||
#define _mm256_extract_epi16(X, Y) \
|
||||
(_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
|
||||
#define _mm256_extract_epi8(X, Y) \
|
||||
(_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
|
||||
#endif
|
||||
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
||||
/* GCC-compatible compiler, targeting ARM with NEON */
|
||||
#include <arm_neon.h>
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
/* GCC-compatible compiler, targeting ARM with SVE */
|
||||
#include <arm_sve.h>
|
||||
#endif
|
||||
#if defined(MISSING_ARM_VLD1)
|
||||
#include <torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h>
|
||||
#elif defined(MISSING_ARM_VST1)
|
||||
#include <torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h>
|
||||
#endif
|
||||
#elif defined(__GNUC__) && defined(__IWMMXT__)
|
||||
/* GCC-compatible compiler, targeting ARM with WMMX */
|
||||
#include <mmintrin.h>
|
||||
#elif defined(__s390x__)
|
||||
// targets Z/architecture
|
||||
// we will include vecintrin later
|
||||
#elif (defined(__GNUC__) || defined(__xlC__)) && \
|
||||
(defined(__VEC__) || defined(__ALTIVEC__))
|
||||
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
|
||||
#include <altivec.h>
|
||||
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
|
||||
with the C++ types. => Can still use __bool/__vector */
|
||||
#undef bool
|
||||
#undef vector
|
||||
#undef pixel
|
||||
#elif defined(__GNUC__) && defined(__SPE__)
|
||||
/* GCC-compatible compiler, targeting PowerPC with SPE */
|
||||
#include <spe.h>
|
||||
#endif
|
396
torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h
Normal file
396
torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h
Normal file
@ -0,0 +1,396 @@
|
||||
/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */
|
||||
|
||||
__extension__ extern __inline uint8x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u8_x2(const uint8_t* __a) {
|
||||
uint8x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int8x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s8_x2(const int8_t* __a) {
|
||||
int8x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u16_x2(const uint16_t* __a) {
|
||||
uint16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s16_x2(const int16_t* __a) {
|
||||
int16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint32x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u32_x2(const uint32_t* __a) {
|
||||
uint32x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int32x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s32_x2(const int32_t* __a) {
|
||||
int32x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u64_x2(const uint64_t* __a) {
|
||||
uint64x1x2_t ret;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s64_x2(const int64_t* __a) {
|
||||
int64x1x2_t ret;
|
||||
__builtin_aarch64_simd_oi __o;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_f16_x2(const float16_t* __a) {
|
||||
float16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float32x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_f32_x2(const float32_t* __a) {
|
||||
float32x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_f64_x2(const float64_t* __a) {
|
||||
float64x1x2_t ret;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly8x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_p8_x2(const poly8_t* __a) {
|
||||
poly8x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_p16_x2(const poly16_t* __a) {
|
||||
poly16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_p64_x2(const poly64_t* __a) {
|
||||
poly64x1x2_t ret;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint8x16x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u8_x2(const uint8_t* __a) {
|
||||
uint8x16x2_t ret;
|
||||
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int8x16x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s8_x2(const int8_t* __a) {
|
||||
int8x16x2_t ret;
|
||||
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u16_x2(const uint16_t* __a) {
|
||||
uint16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s16_x2(const int16_t* __a) {
|
||||
int16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint32x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u32_x2(const uint32_t* __a) {
|
||||
uint32x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int32x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s32_x2(const int32_t* __a) {
|
||||
int32x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u64_x2(const uint64_t* __a) {
|
||||
uint64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s64_x2(const int64_t* __a) {
|
||||
int64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_f16_x2(const float16_t* __a) {
|
||||
float16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float32x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_f32_x2(const float32_t* __a) {
|
||||
float32x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_f64_x2(const float64_t* __a) {
|
||||
float64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly8x16x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_p8_x2(const poly8_t* __a) {
|
||||
poly8x16x2_t ret;
|
||||
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_p16_x2(const poly16_t* __a) {
|
||||
poly16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_p64_x2(const poly64_t* __a) {
|
||||
poly64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
/* vst1x2 */
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s64_x2(int64_t* __a, int64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_f64_x2(float64_t* __a, float64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s8_x2(int8_t* __a, int8x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s16_x2(int16_t* __a, int16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s32_x2(int32_t* __a, int32x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_f16_x2(float16_t* __a, float16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_f32_x2(float32_t* __a, float32x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s8_x2(int8_t* __a, int8x16x2_t val) {
|
||||
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) {
|
||||
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s16_x2(int16_t* __a, int16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s32_x2(int32_t* __a, int32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s64_x2(int64_t* __a, int64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) {
|
||||
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f16_x2(float16_t* __a, float16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f64_x2(float64_t* __a, float64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
7
torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h
Normal file
7
torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h
Normal file
@ -0,0 +1,7 @@
|
||||
/* Workaround for missing vst1q_f32_x2 in gcc-8. */
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
58
torch/headeronly/cpu/vec/vec_half.h
Normal file
58
torch/headeronly/cpu/vec/vec_half.h
Normal file
@ -0,0 +1,58 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/cpu/vec/intrinsics.h>
|
||||
|
||||
namespace torch::headeronly::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
|
||||
!defined(__APPLE__)
|
||||
static inline uint16_t float2half_scalar(float val) {
|
||||
#if defined(CPU_CAPABILITY_AVX2)
|
||||
#if defined(_MSC_VER)
|
||||
__m256 v = _mm256_set1_ps(val);
|
||||
__m128i o =
|
||||
_mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
return static_cast<std::uint16_t>(_mm_cvtsi128_si32(o));
|
||||
#else
|
||||
return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
#endif
|
||||
#elif defined(CPU_CAPABILITY_AVX512)
|
||||
__m512 v = _mm512_set1_ps(val);
|
||||
__m256i o =
|
||||
_mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
return static_cast<std::uint16_t>(
|
||||
_mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline float half2float_scalar(uint16_t val) {
|
||||
#if defined(CPU_CAPABILITY_AVX2)
|
||||
#if defined(_MSC_VER)
|
||||
__m128i v = _mm_cvtsi32_si128(val);
|
||||
__m256 o = _mm256_cvtph_ps(v);
|
||||
return _mm256_cvtss_f32(o);
|
||||
#else
|
||||
return _cvtsh_ss(val);
|
||||
#endif
|
||||
#elif defined(CPU_CAPABILITY_AVX512)
|
||||
__m256i v =
|
||||
_mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||
__m512 o = _mm512_cvtph_ps(v);
|
||||
return _mm512_cvtss_f32(o);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace torch::headeronly::vec
|
||||
|
||||
namespace at::vec {
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
|
||||
!defined(__APPLE__)
|
||||
using torch::headeronly::vec::float2half_scalar;
|
||||
using torch::headeronly::vec::half2float_scalar;
|
||||
#endif
|
||||
} // namespace at::vec
|
@ -29,6 +29,7 @@ def define_torch_headeronly_ovrsource(name, is_mobile):
|
||||
public_include_directories = ["../.."],
|
||||
public_preprocessor_flags = pp_flags,
|
||||
public_raw_headers = native.glob([
|
||||
"cpu/**/*.h",
|
||||
"macros/*.h",
|
||||
"util/*.h",
|
||||
]),
|
||||
|
249
torch/headeronly/util/Half.h
Normal file
249
torch/headeronly/util/Half.h
Normal file
@ -0,0 +1,249 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/bit_cast.h>
|
||||
#include <torch/headeronly/util/floating_point_utils.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#include <cmath>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#if defined(CL_SYCL_LANGUAGE_VERSION)
|
||||
#include <CL/sycl.hpp> // for SYCL 1.2.1
|
||||
#elif defined(SYCL_LANGUAGE_VERSION)
|
||||
#include <sycl/sycl.hpp> // for SYCL 2020
|
||||
#endif
|
||||
|
||||
#if defined(__aarch64__) && !defined(__CUDACC__)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__) || defined(__clang__)
|
||||
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \
|
||||
defined(_M_IX86)
|
||||
#if defined(__F16C__) && \
|
||||
!(defined(__CUDA_ARCH__) || defined(__CUDACC__) || \
|
||||
defined(__HIP_DEVICE_COMPILE__))
|
||||
#define C10_X86_F16 1
|
||||
#include <immintrin.h> // import conversion ops from f16cintrin.h
|
||||
#endif // defined(__F16C__) && !(defined(__CUDA_ARCH__) || defined(__CUDACC__)
|
||||
// || defined(__HIP_DEVICE_COMPILE__))
|
||||
#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86
|
||||
#endif // __GNUC__ || __clang__
|
||||
|
||||
namespace torch::headeronly::detail {
|
||||
/*
|
||||
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format.
|
||||
*
|
||||
* @note The implementation relies on IEEE-like (no assumption about rounding
|
||||
* mode and no operations on denormals) floating-point operations and bitcasts
|
||||
* between integer and floating-point variables.
|
||||
*/
|
||||
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
|
||||
#ifdef C10_X86_F16
|
||||
return _cvtsh_ss(h);
|
||||
#else
|
||||
/*
|
||||
* Extend the half-precision floating-point number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+-----+------------+-------------------+
|
||||
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
|
||||
* +---+-----+------------+-------------------+
|
||||
* Bits 31 26-30 16-25 0-15
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
const uint32_t w = (uint32_t)h << 16;
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
/*
|
||||
* Extract mantissa and biased exponent of the input number into the high bits
|
||||
* of the 32-bit word:
|
||||
*
|
||||
* +-----+------------+---------------------+
|
||||
* |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
|
||||
* +-----+------------+---------------------+
|
||||
* Bits 27-31 17-26 0-16
|
||||
*/
|
||||
const uint32_t two_w = w + w;
|
||||
|
||||
/*
|
||||
* Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
|
||||
* mantissa and exponent of a single-precision floating-point number:
|
||||
*
|
||||
* S|Exponent | Mantissa
|
||||
* +-+---+-----+------------+----------------+
|
||||
* |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
|
||||
* +-+---+-----+------------+----------------+
|
||||
* Bits | 23-31 | 0-22
|
||||
*
|
||||
* Next, there are some adjustments to the exponent:
|
||||
* - The exponent needs to be corrected by the difference in exponent bias
|
||||
* between single-precision and half-precision formats (0x7F - 0xF = 0x70)
|
||||
* - Inf and NaN values in the inputs should become Inf and NaN values after
|
||||
* conversion to the single-precision number. Therefore, if the biased
|
||||
* exponent of the half-precision input was 0x1F (max possible value), the
|
||||
* biased exponent of the single-precision output must be 0xFF (max possible
|
||||
* value). We do this correction in two steps:
|
||||
* - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
|
||||
* below) rather than by 0x70 suggested by the difference in the exponent bias
|
||||
* (see above).
|
||||
* - Then we multiply the single-precision result of exponent adjustment by
|
||||
* 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
|
||||
* necessary exponent adjustment by 0x70 due to difference in exponent bias.
|
||||
* The floating-point multiplication hardware would ensure than Inf and
|
||||
* NaN would retain their value on at least partially IEEE754-compliant
|
||||
* implementations.
|
||||
*
|
||||
* Note that the above operations do not handle denormal inputs (where biased
|
||||
* exponent == 0). However, they also do not operate on denormal inputs, and
|
||||
* do not produce denormal results.
|
||||
*/
|
||||
constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
|
||||
// const float exp_scale = 0x1.0p-112f;
|
||||
constexpr uint32_t scale_bits = (uint32_t)15 << 23;
|
||||
float exp_scale_val = 0;
|
||||
#if defined(_MSC_VER) && defined(__clang__)
|
||||
__builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
|
||||
#else
|
||||
std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
|
||||
#endif
|
||||
|
||||
const float exp_scale = exp_scale_val;
|
||||
const float normalized_value =
|
||||
fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
|
||||
|
||||
/*
|
||||
* Convert denormalized half-precision inputs into single-precision results
|
||||
* (always normalized). Zero inputs are also handled here.
|
||||
*
|
||||
* In a denormalized number the biased exponent is zero, and mantissa has
|
||||
* on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
|
||||
*
|
||||
* zeros | mantissa
|
||||
* +---------------------------+------------+
|
||||
* |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
|
||||
* +---------------------------+------------+
|
||||
* Bits 10-31 0-9
|
||||
*
|
||||
* Now, remember that denormalized half-precision numbers are represented as:
|
||||
* FP16 = mantissa * 2**(-24).
|
||||
* The trick is to construct a normalized single-precision number with the
|
||||
* same mantissa and thehalf-precision input and with an exponent which would
|
||||
* scale the corresponding mantissa bits to 2**(-24). A normalized
|
||||
* single-precision floating-point number is represented as: FP32 = (1 +
|
||||
* mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
|
||||
* exponent is 126, a unit change in the mantissa of the input denormalized
|
||||
* half-precision number causes a change of the constructed single-precision
|
||||
* number by 2**(-24), i.e. the same amount.
|
||||
*
|
||||
* The last step is to adjust the bias of the constructed single-precision
|
||||
* number. When the input half-precision number is zero, the constructed
|
||||
* single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
|
||||
* 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
|
||||
* single-precision number to get the numerical equivalent of the input
|
||||
* half-precision number.
|
||||
*/
|
||||
constexpr uint32_t magic_mask = UINT32_C(126) << 23;
|
||||
constexpr float magic_bias = 0.5f;
|
||||
const float denormalized_value =
|
||||
fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
|
||||
|
||||
/*
|
||||
* - Choose either results of conversion of input as a normalized number, or
|
||||
* as a denormalized number, depending on the input exponent. The variable
|
||||
* two_w contains input exponent in bits 27-31, therefore if its smaller than
|
||||
* 2**27, the input is either a denormal number, or zero.
|
||||
* - Combine the result of conversion of exponent and mantissa with the sign
|
||||
* of the input number.
|
||||
*/
|
||||
constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
|
||||
const uint32_t result = sign |
|
||||
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
|
||||
: fp32_to_bits(normalized_value));
|
||||
return fp32_from_bits(result);
|
||||
#endif // C10_X86_F16
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 16-bit floating-point number in IEEE half-precision format, in bit
|
||||
* representation.
|
||||
*
|
||||
* @note The implementation relies on IEEE-like (no assumption about rounding
|
||||
* mode and no operations on denormals) floating-point operations and bitcasts
|
||||
* between integer and floating-point variables.
|
||||
*/
|
||||
inline uint16_t fp16_ieee_from_fp32_value(float f) {
|
||||
#ifdef C10_X86_F16
|
||||
return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
|
||||
#else
|
||||
// const float scale_to_inf = 0x1.0p+112f;
|
||||
// const float scale_to_zero = 0x1.0p-110f;
|
||||
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
|
||||
constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
|
||||
float scale_to_inf_val = 0, scale_to_zero_val = 0;
|
||||
std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
|
||||
std::memcpy(
|
||||
&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
|
||||
const float scale_to_inf = scale_to_inf_val;
|
||||
const float scale_to_zero = scale_to_zero_val;
|
||||
|
||||
#if defined(_MSC_VER) && _MSC_VER == 1916
|
||||
float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
|
||||
#else
|
||||
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
|
||||
#endif
|
||||
|
||||
const uint32_t w = fp32_to_bits(f);
|
||||
const uint32_t shl1_w = w + w;
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
|
||||
if (bias < UINT32_C(0x71000000)) {
|
||||
bias = UINT32_C(0x71000000);
|
||||
}
|
||||
|
||||
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
|
||||
const uint32_t bits = fp32_to_bits(base);
|
||||
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
|
||||
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
|
||||
const uint32_t nonsign = exp_bits + mantissa_bits;
|
||||
return static_cast<uint16_t>(
|
||||
(sign >> 16) |
|
||||
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
|
||||
#endif // C10_X86_F16
|
||||
}
|
||||
|
||||
} // namespace torch::headeronly::detail
|
||||
|
||||
namespace c10::detail {
|
||||
using torch::headeronly::detail::fp16_ieee_from_fp32_value;
|
||||
using torch::headeronly::detail::fp16_ieee_to_fp32_value;
|
||||
} // namespace c10::detail
|
50
torch/headeronly/util/bit_cast.h
Normal file
50
torch/headeronly/util/bit_cast.h
Normal file
@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
|
||||
#if __has_include(<bit>) && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)
|
||||
#include <bit>
|
||||
#define C10_HAVE_STD_BIT_CAST 1
|
||||
#else
|
||||
#define C10_HAVE_STD_BIT_CAST 0
|
||||
#endif // __has_include(<bit>) && (__cplusplus >= 202002L ||
|
||||
// (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L))
|
||||
|
||||
namespace torch::headeronly {
|
||||
|
||||
#if C10_HAVE_STD_BIT_CAST
|
||||
using std::bit_cast;
|
||||
#else
|
||||
// Implementations of std::bit_cast() from C++ 20.
|
||||
//
|
||||
// This is a less sketchy version of reinterpret_cast.
|
||||
//
|
||||
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
|
||||
// information as well as the source of our implementations.
|
||||
template <class To, class From>
|
||||
C10_HOST_DEVICE std::enable_if_t<
|
||||
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
|
||||
std::is_trivially_copyable_v<To>,
|
||||
To>
|
||||
// constexpr support needs compiler magic
|
||||
bit_cast(const From& src) noexcept {
|
||||
static_assert(
|
||||
std::is_trivially_constructible_v<To>,
|
||||
"This implementation additionally requires "
|
||||
"destination type to be trivially constructible");
|
||||
|
||||
To dst;
|
||||
std::memcpy(&dst, &src, sizeof(To));
|
||||
return dst;
|
||||
}
|
||||
#endif // C10_HAVE_STD_BIT_CAST
|
||||
#undef C10_HAVE_STD_BIT_CAST
|
||||
|
||||
} // namespace torch::headeronly
|
||||
|
||||
namespace c10 {
|
||||
using torch::headeronly::bit_cast;
|
||||
} // namespace c10
|
38
torch/headeronly/util/floating_point_utils.h
Normal file
38
torch/headeronly/util/floating_point_utils.h
Normal file
@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/bit_cast.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace torch::headeronly::detail {
|
||||
|
||||
C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_float(w);
|
||||
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
return __uint_as_float((unsigned int)w);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castu32_f32(w);
|
||||
#else
|
||||
return torch::headeronly::bit_cast<float>(w);
|
||||
#endif
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_uint(f);
|
||||
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
return (uint32_t)__float_as_uint(f);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castf32_u32(f);
|
||||
#else
|
||||
return torch::headeronly::bit_cast<uint32_t>(f);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace torch::headeronly::detail
|
||||
|
||||
namespace c10::detail {
|
||||
using torch::headeronly::detail::fp32_from_bits;
|
||||
using torch::headeronly::detail::fp32_to_bits;
|
||||
} // namespace c10::detail
|
Reference in New Issue
Block a user