Move some of vec into headeronly in preparation for Half.h (#158976)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158976
Approved by: https://github.com/albanD, https://github.com/desertfire
This commit is contained in:
Jane Xu
2025-07-28 11:43:16 -07:00
committed by PyTorch MergeBot
parent 6de24135e5
commit 222fa451a2
20 changed files with 923 additions and 771 deletions

View File

@ -462,7 +462,7 @@ test_inductor_aoti() {
# rebuild with the build cache with `BUILD_AOT_INDUCTOR_TEST` enabled
/usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}"
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference -dist=loadfile
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
}
test_inductor_cpp_wrapper_shard() {

View File

@ -1,55 +1 @@
#pragma once
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* GCC or clang-compatible compiler, targeting x86/x86-64 */
#include <x86intrin.h>
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* Clang-compatible compiler, targeting arm neon */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* CLANG-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#elif defined(_MSC_VER)
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
#if _MSC_VER <= 1900
#define _mm256_extract_epi64(X, Y) \
(_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
#define _mm256_extract_epi32(X, Y) \
(_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
#define _mm256_extract_epi16(X, Y) \
(_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
#define _mm256_extract_epi8(X, Y) \
(_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
#endif
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* GCC-compatible compiler, targeting ARM with NEON */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* GCC-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#if defined(MISSING_ARM_VLD1)
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
#elif defined(MISSING_ARM_VST1)
#include <ATen/cpu/vec/vec256/missing_vst1_neon.h>
#endif
#elif defined(__GNUC__) && defined(__IWMMXT__)
/* GCC-compatible compiler, targeting ARM with WMMX */
#include <mmintrin.h>
#elif defined(__s390x__)
// targets Z/architecture
// we will include vecintrin later
#elif (defined(__GNUC__) || defined(__xlC__)) && \
(defined(__VEC__) || defined(__ALTIVEC__))
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
#include <altivec.h>
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
with the C++ types. => Can still use __bool/__vector */
#undef bool
#undef vector
#undef pixel
#elif defined(__GNUC__) && defined(__SPE__)
/* GCC-compatible compiler, targeting PowerPC with SPE */
#include <spe.h>
#endif
#include <torch/headeronly/cpu/vec/intrinsics.h>

View File

@ -1,396 +1 @@
/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */
__extension__ extern __inline uint8x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u8_x2(const uint8_t* __a) {
uint8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int8x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s8_x2(const int8_t* __a) {
int8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u16_x2(const uint16_t* __a) {
uint16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s16_x2(const int16_t* __a) {
int16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint32x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u32_x2(const uint32_t* __a) {
uint32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int32x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s32_x2(const int32_t* __a) {
int32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u64_x2(const uint64_t* __a) {
uint64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s64_x2(const int64_t* __a) {
int64x1x2_t ret;
__builtin_aarch64_simd_oi __o;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_f16_x2(const float16_t* __a) {
float16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float32x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_f32_x2(const float32_t* __a) {
float32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_f64_x2(const float64_t* __a) {
float64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly8x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_p8_x2(const poly8_t* __a) {
poly8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_p16_x2(const poly16_t* __a) {
poly16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_p64_x2(const poly64_t* __a) {
poly64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint8x16x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u8_x2(const uint8_t* __a) {
uint8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int8x16x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s8_x2(const int8_t* __a) {
int8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u16_x2(const uint16_t* __a) {
uint16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s16_x2(const int16_t* __a) {
int16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint32x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u32_x2(const uint32_t* __a) {
uint32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int32x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s32_x2(const int32_t* __a) {
int32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u64_x2(const uint64_t* __a) {
uint64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s64_x2(const int64_t* __a) {
int64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f16_x2(const float16_t* __a) {
float16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float32x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f32_x2(const float32_t* __a) {
float32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f64_x2(const float64_t* __a) {
float64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly8x16x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p8_x2(const poly8_t* __a) {
poly8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p16_x2(const poly16_t* __a) {
poly16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p64_x2(const poly64_t* __a) {
poly64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
/* vst1x2 */
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s64_x2(int64_t* __a, int64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_f64_x2(float64_t* __a, float64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s8_x2(int8_t* __a, int8x8x2_t val) {
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) {
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s16_x2(int16_t* __a, int16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s32_x2(int32_t* __a, int32x2x2_t val) {
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) {
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) {
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_f16_x2(float16_t* __a, float16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_f32_x2(float32_t* __a, float32x2x2_t val) {
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s8_x2(int8_t* __a, int8x16x2_t val) {
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) {
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s16_x2(int16_t* __a, int16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s32_x2(int32_t* __a, int32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s64_x2(int64_t* __a, int64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) {
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f16_x2(float16_t* __a, float16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f64_x2(float64_t* __a, float64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
#include <torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h>

View File

@ -1,7 +1 @@
/* Workaround for missing vst1q_f32_x2 in gcc-8. */
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
#include <torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h>

View File

@ -3,50 +3,12 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <c10/util/Exception.h>
#include <torch/headeronly/cpu/vec/vec_half.h>
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
static inline uint16_t float2half_scalar(float val) {
#if defined(CPU_CAPABILITY_AVX2)
#if defined(_MSC_VER)
__m256 v = _mm256_set1_ps(val);
__m128i o =
_mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return static_cast<std::uint16_t>(_mm_cvtsi128_si32(o));
#else
return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
#endif
#elif defined(CPU_CAPABILITY_AVX512)
__m512 v = _mm512_set1_ps(val);
__m256i o =
_mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return static_cast<std::uint16_t>(
_mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
#endif
}
static inline float half2float_scalar(uint16_t val) {
#if defined(CPU_CAPABILITY_AVX2)
#if defined(_MSC_VER)
__m128i v = _mm_cvtsi32_si128(val);
__m256 o = _mm256_cvtph_ps(v);
return _mm256_cvtss_f32(o);
#else
return _cvtsh_ss(val);
#endif
#elif defined(CPU_CAPABILITY_AVX512)
__m256i v =
_mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
__m512 o = _mm512_cvtph_ps(v);
return _mm512_cvtss_f32(o);
#endif
}
#endif
// Transpose a [2, 32] matrix to [32, 2]
// Note: the output leading dimension should be 2,
// that is, the output must be contiguous

View File

@ -13,6 +13,7 @@
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <c10/util/floating_point_utils.h>
#include <torch/headeronly/util/Half.h>
#include <type_traits>
#if defined(__cplusplus)
@ -162,198 +163,6 @@ inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
~zero_mask);
}
/*
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format.
*
* @note The implementation relies on IEEE-like (no assumption about rounding
* mode and no operations on denormals) floating-point operations and bitcasts
* between integer and floating-point variables.
*/
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
#ifdef C10_X86_F16
return _cvtsh_ss(h);
#else
/*
* Extend the half-precision floating-point number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+-----+------------+-------------------+
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 31 26-30 16-25 0-15
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)h << 16;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the high bits
* of the 32-bit word:
*
* +-----+------------+---------------------+
* |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
* +-----+------------+---------------------+
* Bits 27-31 17-26 0-16
*/
const uint32_t two_w = w + w;
/*
* Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
* mantissa and exponent of a single-precision floating-point number:
*
* S|Exponent | Mantissa
* +-+---+-----+------------+----------------+
* |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
* +-+---+-----+------------+----------------+
* Bits | 23-31 | 0-22
*
* Next, there are some adjustments to the exponent:
* - The exponent needs to be corrected by the difference in exponent bias
* between single-precision and half-precision formats (0x7F - 0xF = 0x70)
* - Inf and NaN values in the inputs should become Inf and NaN values after
* conversion to the single-precision number. Therefore, if the biased
* exponent of the half-precision input was 0x1F (max possible value), the
* biased exponent of the single-precision output must be 0xFF (max possible
* value). We do this correction in two steps:
* - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
* below) rather than by 0x70 suggested by the difference in the exponent bias
* (see above).
* - Then we multiply the single-precision result of exponent adjustment by
* 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
* necessary exponent adjustment by 0x70 due to difference in exponent bias.
* The floating-point multiplication hardware would ensure than Inf and
* NaN would retain their value on at least partially IEEE754-compliant
* implementations.
*
* Note that the above operations do not handle denormal inputs (where biased
* exponent == 0). However, they also do not operate on denormal inputs, and
* do not produce denormal results.
*/
constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
// const float exp_scale = 0x1.0p-112f;
constexpr uint32_t scale_bits = (uint32_t)15 << 23;
float exp_scale_val = 0;
#if defined(_MSC_VER) && defined(__clang__)
__builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
#else
std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
#endif
const float exp_scale = exp_scale_val;
const float normalized_value =
fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
/*
* Convert denormalized half-precision inputs into single-precision results
* (always normalized). Zero inputs are also handled here.
*
* In a denormalized number the biased exponent is zero, and mantissa has
* on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
*
* zeros | mantissa
* +---------------------------+------------+
* |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
* +---------------------------+------------+
* Bits 10-31 0-9
*
* Now, remember that denormalized half-precision numbers are represented as:
* FP16 = mantissa * 2**(-24).
* The trick is to construct a normalized single-precision number with the
* same mantissa and thehalf-precision input and with an exponent which would
* scale the corresponding mantissa bits to 2**(-24). A normalized
* single-precision floating-point number is represented as: FP32 = (1 +
* mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
* exponent is 126, a unit change in the mantissa of the input denormalized
* half-precision number causes a change of the constructed single-precision
* number by 2**(-24), i.e. the same amount.
*
* The last step is to adjust the bias of the constructed single-precision
* number. When the input half-precision number is zero, the constructed
* single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
* 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
* single-precision number to get the numerical equivalent of the input
* half-precision number.
*/
constexpr uint32_t magic_mask = UINT32_C(126) << 23;
constexpr float magic_bias = 0.5f;
const float denormalized_value =
fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
/*
* - Choose either results of conversion of input as a normalized number, or
* as a denormalized number, depending on the input exponent. The variable
* two_w contains input exponent in bits 27-31, therefore if its smaller than
* 2**27, the input is either a denormal number, or zero.
* - Combine the result of conversion of exponent and mantissa with the sign
* of the input number.
*/
constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
const uint32_t result = sign |
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
: fp32_to_bits(normalized_value));
return fp32_from_bits(result);
#endif // C10_X86_F16
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 16-bit floating-point number in IEEE half-precision format, in bit
* representation.
*
* @note The implementation relies on IEEE-like (no assumption about rounding
* mode and no operations on denormals) floating-point operations and bitcasts
* between integer and floating-point variables.
*/
inline uint16_t fp16_ieee_from_fp32_value(float f) {
#ifdef C10_X86_F16
return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
#else
// const float scale_to_inf = 0x1.0p+112f;
// const float scale_to_zero = 0x1.0p-110f;
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
float scale_to_inf_val = 0, scale_to_zero_val = 0;
std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
std::memcpy(
&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
const float scale_to_inf = scale_to_inf_val;
const float scale_to_zero = scale_to_zero_val;
#if defined(_MSC_VER) && _MSC_VER == 1916
float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
#else
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
#endif
const uint32_t w = fp32_to_bits(f);
const uint32_t shl1_w = w + w;
const uint32_t sign = w & UINT32_C(0x80000000);
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
if (bias < UINT32_C(0x71000000)) {
bias = UINT32_C(0x71000000);
}
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
const uint32_t bits = fp32_to_bits(base);
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
const uint32_t nonsign = exp_bits + mantissa_bits;
return static_cast<uint16_t>(
(sign >> 16) |
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
#endif // C10_X86_F16
}
#ifdef C10_X86_F16
#undef C10_X86_F16
#endif // C10_X86_F16

View File

@ -1,46 +1 @@
#pragma once
#include <cstring>
#include <type_traits>
#include <c10/macros/Macros.h>
#if __has_include(<bit>) && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)
#include <bit>
#define C10_HAVE_STD_BIT_CAST 1
#else
#define C10_HAVE_STD_BIT_CAST 0
#endif // __has_include(<bit>) && (__cplusplus >= 202002L ||
// (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L))
namespace c10 {
#if C10_HAVE_STD_BIT_CAST
using std::bit_cast;
#else
// Implementations of std::bit_cast() from C++ 20.
//
// This is a less sketchy version of reinterpret_cast.
//
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
// information as well as the source of our implementations.
template <class To, class From>
C10_HOST_DEVICE std::enable_if_t<
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
std::is_trivially_copyable_v<To>,
To>
// constexpr support needs compiler magic
bit_cast(const From& src) noexcept {
static_assert(
std::is_trivially_constructible_v<To>,
"This implementation additionally requires "
"destination type to be trivially constructible");
To dst;
std::memcpy(&dst, &src, sizeof(To));
return dst;
}
#endif // C10_HAVE_STD_BIT_CAST
#undef C10_HAVE_STD_BIT_CAST
} // namespace c10
#include <torch/headeronly/util/bit_cast.h>

View File

@ -1,33 +1 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <cstdint>
namespace c10::detail {
C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
#if defined(__OPENCL_VERSION__)
return as_float(w);
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return __uint_as_float((unsigned int)w);
#elif defined(__INTEL_COMPILER)
return _castu32_f32(w);
#else
return c10::bit_cast<float>(w);
#endif
}
C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
#if defined(__OPENCL_VERSION__)
return as_uint(f);
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return (uint32_t)__float_as_uint(f);
#elif defined(__INTEL_COMPILER)
return _castf32_u32(f);
#else
return c10::bit_cast<uint32_t>(f);
#endif
}
} // namespace c10::detail
#include <torch/headeronly/util/floating_point_utils.h>

View File

@ -9,6 +9,13 @@ set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
)
# The below are tests that require CPU_CAPABILITY setup
# You may think test_vec.cpp needs to be in there, but it does not.
set(AOTI_ABI_CHECK_VEC_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
)
add_executable(test_aoti_abi_check
@ -23,6 +30,23 @@ target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES})
get_filename_component(test_name ${test_src} NAME_WE)
list(GET CPU_CAPABILITY_NAMES ${i} CPU_CAPABILITY)
list(GET CPU_CAPABILITY_FLAGS ${i} FLAGS)
separate_arguments(FLAGS UNIX_COMMAND "${FLAGS}")
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
target_link_libraries(${test_name}_${CPU_CAPABILITY} PRIVATE gtest_main)
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${ATen_CPU_INCLUDE})
# Define CPU_CAPABILITY and CPU_CAPABILITY_XXX macros for conditional compilation
target_compile_definitions(${test_name}_${CPU_CAPABILITY} PRIVATE CPU_CAPABILITY=${CPU_CAPABILITY} CPU_CAPABILITY_${CPU_CAPABILITY})
target_compile_options(${test_name}_${CPU_CAPABILITY} PRIVATE ${FLAGS})
endforeach()
endforeach()
if(INSTALL_TEST)
install(TARGETS test_aoti_abi_check DESTINATION bin)
# Install PDB files for MSVC builds

View File

@ -0,0 +1,22 @@
#include <gtest/gtest.h>
#include <torch/headeronly/cpu/vec/vec_half.h>
#include <torch/headeronly/util/Half.h>
TEST(TestVecHalf, TestConversion) {
float f32s[100];
for (int i = 0; i < 100; i++) {
f32s[i] = static_cast<float>(i + 0.3);
}
for (int i = 0; i < 100; i++) {
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
uint16_t u16 = torch::headeronly::vec::float2half_scalar(f32s[i]);
float x = torch::headeronly::vec::half2float_scalar(u16);
EXPECT_EQ(
u16, torch::headeronly::detail::fp16_ieee_from_fp32_value(f32s[i]))
<< "Test failed for float to uint16 " << f32s[i] << "\n";
EXPECT_EQ(x, torch::headeronly::detail::fp16_ieee_to_fp32_value(u16))
<< "Test failed for uint16 to float " << u16 << "\n";
#endif
}
}

View File

@ -6,7 +6,7 @@
# c10/util/TypeCast.h
convert
# c10/util/bit_cast.h
# c10/util/bit_cast.h, torch/headeronly/util/bit_cast.h
bit_cast
# c10/util/BFloat16-math.h, c10/util/BFloat16.h
@ -27,6 +27,14 @@ Float8_e5m2fnuz
# c10/util/Half.h
Half
# torch/headeronly/util/Half.h
fp16_ieee_from_fp32_value
fp16_ieee_to_fp32_value
# torch/headeronly/util/floating_point_utils.h
# fp32_from_bits called from fp16_ieee_to_fp32_value
# fp32_to_bits called from fp16_ieee_from_fp32_value
# c10/util/complex.h
complex
@ -48,6 +56,10 @@ maximum
minimum
size
# torch/headeronly/cpu/vec/vec_half.h
float2half_scalar
half2float_scalar
# torch/headeronly/macros/Export.h
C10_API

View File

@ -20,6 +20,7 @@ configure_file(
file(GLOB HEADERONLY_HEADERS
*.h
cpu/**/*.h
macros/*.h
util/*.h
)

View File

@ -0,0 +1,55 @@
#pragma once
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* GCC or clang-compatible compiler, targeting x86/x86-64 */
#include <x86intrin.h>
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* Clang-compatible compiler, targeting arm neon */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* CLANG-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#elif defined(_MSC_VER)
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
#if _MSC_VER <= 1900
#define _mm256_extract_epi64(X, Y) \
(_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
#define _mm256_extract_epi32(X, Y) \
(_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
#define _mm256_extract_epi16(X, Y) \
(_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
#define _mm256_extract_epi8(X, Y) \
(_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
#endif
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* GCC-compatible compiler, targeting ARM with NEON */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* GCC-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#if defined(MISSING_ARM_VLD1)
#include <torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h>
#elif defined(MISSING_ARM_VST1)
#include <torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h>
#endif
#elif defined(__GNUC__) && defined(__IWMMXT__)
/* GCC-compatible compiler, targeting ARM with WMMX */
#include <mmintrin.h>
#elif defined(__s390x__)
// targets Z/architecture
// we will include vecintrin later
#elif (defined(__GNUC__) || defined(__xlC__)) && \
(defined(__VEC__) || defined(__ALTIVEC__))
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
#include <altivec.h>
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
with the C++ types. => Can still use __bool/__vector */
#undef bool
#undef vector
#undef pixel
#elif defined(__GNUC__) && defined(__SPE__)
/* GCC-compatible compiler, targeting PowerPC with SPE */
#include <spe.h>
#endif

View File

@ -0,0 +1,396 @@
/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */
__extension__ extern __inline uint8x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u8_x2(const uint8_t* __a) {
uint8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int8x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s8_x2(const int8_t* __a) {
int8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u16_x2(const uint16_t* __a) {
uint16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s16_x2(const int16_t* __a) {
int16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint32x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u32_x2(const uint32_t* __a) {
uint32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int32x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s32_x2(const int32_t* __a) {
int32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_u64_x2(const uint64_t* __a) {
uint64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_s64_x2(const int64_t* __a) {
int64x1x2_t ret;
__builtin_aarch64_simd_oi __o;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_f16_x2(const float16_t* __a) {
float16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float32x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_f32_x2(const float32_t* __a) {
float32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_f64_x2(const float64_t* __a) {
float64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly8x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_p8_x2(const poly8_t* __a) {
poly8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly16x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_p16_x2(const poly16_t* __a) {
poly16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly64x1x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1_p64_x2(const poly64_t* __a) {
poly64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint8x16x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u8_x2(const uint8_t* __a) {
uint8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int8x16x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s8_x2(const int8_t* __a) {
int8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u16_x2(const uint16_t* __a) {
uint16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s16_x2(const int16_t* __a) {
int16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint32x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u32_x2(const uint32_t* __a) {
uint32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int32x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s32_x2(const int32_t* __a) {
int32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u64_x2(const uint64_t* __a) {
uint64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s64_x2(const int64_t* __a) {
int64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f16_x2(const float16_t* __a) {
float16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float32x4x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f32_x2(const float32_t* __a) {
float32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f64_x2(const float64_t* __a) {
float64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly8x16x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p8_x2(const poly8_t* __a) {
poly8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly16x8x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p16_x2(const poly16_t* __a) {
poly16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly64x2x2_t
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p64_x2(const poly64_t* __a) {
poly64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
return ret;
}
/* vst1x2 */
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s64_x2(int64_t* __a, int64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_f64_x2(float64_t* __a, float64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s8_x2(int8_t* __a, int8x8x2_t val) {
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) {
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s16_x2(int16_t* __a, int16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_s32_x2(int32_t* __a, int32x2x2_t val) {
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) {
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) {
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_f16_x2(float16_t* __a, float16x4x2_t val) {
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_f32_x2(float32_t* __a, float32x2x2_t val) {
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) {
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s8_x2(int8_t* __a, int8x16x2_t val) {
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) {
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s16_x2(int16_t* __a, int16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s32_x2(int32_t* __a, int32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s64_x2(int64_t* __a, int64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) {
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f16_x2(float16_t* __a, float16x8x2_t val) {
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f64_x2(float64_t* __a, float64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) {
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
}

View File

@ -0,0 +1,7 @@
/* Workaround for missing vst1q_f32_x2 in gcc-8. */
__extension__ extern __inline void
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
}

View File

@ -0,0 +1,58 @@
#pragma once
#include <torch/headeronly/cpu/vec/intrinsics.h>
namespace torch::headeronly::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
static inline uint16_t float2half_scalar(float val) {
#if defined(CPU_CAPABILITY_AVX2)
#if defined(_MSC_VER)
__m256 v = _mm256_set1_ps(val);
__m128i o =
_mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return static_cast<std::uint16_t>(_mm_cvtsi128_si32(o));
#else
return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
#endif
#elif defined(CPU_CAPABILITY_AVX512)
__m512 v = _mm512_set1_ps(val);
__m256i o =
_mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return static_cast<std::uint16_t>(
_mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
#endif
}
static inline float half2float_scalar(uint16_t val) {
#if defined(CPU_CAPABILITY_AVX2)
#if defined(_MSC_VER)
__m128i v = _mm_cvtsi32_si128(val);
__m256 o = _mm256_cvtph_ps(v);
return _mm256_cvtss_f32(o);
#else
return _cvtsh_ss(val);
#endif
#elif defined(CPU_CAPABILITY_AVX512)
__m256i v =
_mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
__m512 o = _mm512_cvtph_ps(v);
return _mm512_cvtss_f32(o);
#endif
}
#endif
} // namespace CPU_CAPABILITY
} // namespace torch::headeronly::vec
namespace at::vec {
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
using torch::headeronly::vec::float2half_scalar;
using torch::headeronly::vec::half2float_scalar;
#endif
} // namespace at::vec

View File

@ -29,6 +29,7 @@ def define_torch_headeronly_ovrsource(name, is_mobile):
public_include_directories = ["../.."],
public_preprocessor_flags = pp_flags,
public_raw_headers = native.glob([
"cpu/**/*.h",
"macros/*.h",
"util/*.h",
]),

View File

@ -0,0 +1,249 @@
#pragma once
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/bit_cast.h>
#include <torch/headeronly/util/floating_point_utils.h>
#if defined(__cplusplus)
#include <cmath>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#endif
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <cstdint>
#include <cstring>
#ifdef __CUDACC__
#include <cuda_fp16.h>
#endif
#ifdef __HIPCC__
#include <hip/hip_fp16.h>
#endif
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#elif defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#if defined(__aarch64__) && !defined(__CUDACC__)
#include <arm_neon.h>
#endif
#if defined(__GNUC__) || defined(__clang__)
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \
defined(_M_IX86)
#if defined(__F16C__) && \
!(defined(__CUDA_ARCH__) || defined(__CUDACC__) || \
defined(__HIP_DEVICE_COMPILE__))
#define C10_X86_F16 1
#include <immintrin.h> // import conversion ops from f16cintrin.h
#endif // defined(__F16C__) && !(defined(__CUDA_ARCH__) || defined(__CUDACC__)
// || defined(__HIP_DEVICE_COMPILE__))
#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86
#endif // __GNUC__ || __clang__
namespace torch::headeronly::detail {
/*
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format.
*
* @note The implementation relies on IEEE-like (no assumption about rounding
* mode and no operations on denormals) floating-point operations and bitcasts
* between integer and floating-point variables.
*/
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
#ifdef C10_X86_F16
return _cvtsh_ss(h);
#else
/*
* Extend the half-precision floating-point number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+-----+------------+-------------------+
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 31 26-30 16-25 0-15
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)h << 16;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the high bits
* of the 32-bit word:
*
* +-----+------------+---------------------+
* |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
* +-----+------------+---------------------+
* Bits 27-31 17-26 0-16
*/
const uint32_t two_w = w + w;
/*
* Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
* mantissa and exponent of a single-precision floating-point number:
*
* S|Exponent | Mantissa
* +-+---+-----+------------+----------------+
* |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
* +-+---+-----+------------+----------------+
* Bits | 23-31 | 0-22
*
* Next, there are some adjustments to the exponent:
* - The exponent needs to be corrected by the difference in exponent bias
* between single-precision and half-precision formats (0x7F - 0xF = 0x70)
* - Inf and NaN values in the inputs should become Inf and NaN values after
* conversion to the single-precision number. Therefore, if the biased
* exponent of the half-precision input was 0x1F (max possible value), the
* biased exponent of the single-precision output must be 0xFF (max possible
* value). We do this correction in two steps:
* - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
* below) rather than by 0x70 suggested by the difference in the exponent bias
* (see above).
* - Then we multiply the single-precision result of exponent adjustment by
* 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
* necessary exponent adjustment by 0x70 due to difference in exponent bias.
* The floating-point multiplication hardware would ensure than Inf and
* NaN would retain their value on at least partially IEEE754-compliant
* implementations.
*
* Note that the above operations do not handle denormal inputs (where biased
* exponent == 0). However, they also do not operate on denormal inputs, and
* do not produce denormal results.
*/
constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
// const float exp_scale = 0x1.0p-112f;
constexpr uint32_t scale_bits = (uint32_t)15 << 23;
float exp_scale_val = 0;
#if defined(_MSC_VER) && defined(__clang__)
__builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
#else
std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
#endif
const float exp_scale = exp_scale_val;
const float normalized_value =
fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
/*
* Convert denormalized half-precision inputs into single-precision results
* (always normalized). Zero inputs are also handled here.
*
* In a denormalized number the biased exponent is zero, and mantissa has
* on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
*
* zeros | mantissa
* +---------------------------+------------+
* |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
* +---------------------------+------------+
* Bits 10-31 0-9
*
* Now, remember that denormalized half-precision numbers are represented as:
* FP16 = mantissa * 2**(-24).
* The trick is to construct a normalized single-precision number with the
* same mantissa and thehalf-precision input and with an exponent which would
* scale the corresponding mantissa bits to 2**(-24). A normalized
* single-precision floating-point number is represented as: FP32 = (1 +
* mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
* exponent is 126, a unit change in the mantissa of the input denormalized
* half-precision number causes a change of the constructed single-precision
* number by 2**(-24), i.e. the same amount.
*
* The last step is to adjust the bias of the constructed single-precision
* number. When the input half-precision number is zero, the constructed
* single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
* 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
* single-precision number to get the numerical equivalent of the input
* half-precision number.
*/
constexpr uint32_t magic_mask = UINT32_C(126) << 23;
constexpr float magic_bias = 0.5f;
const float denormalized_value =
fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
/*
* - Choose either results of conversion of input as a normalized number, or
* as a denormalized number, depending on the input exponent. The variable
* two_w contains input exponent in bits 27-31, therefore if its smaller than
* 2**27, the input is either a denormal number, or zero.
* - Combine the result of conversion of exponent and mantissa with the sign
* of the input number.
*/
constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
const uint32_t result = sign |
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
: fp32_to_bits(normalized_value));
return fp32_from_bits(result);
#endif // C10_X86_F16
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 16-bit floating-point number in IEEE half-precision format, in bit
* representation.
*
* @note The implementation relies on IEEE-like (no assumption about rounding
* mode and no operations on denormals) floating-point operations and bitcasts
* between integer and floating-point variables.
*/
inline uint16_t fp16_ieee_from_fp32_value(float f) {
#ifdef C10_X86_F16
return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
#else
// const float scale_to_inf = 0x1.0p+112f;
// const float scale_to_zero = 0x1.0p-110f;
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
float scale_to_inf_val = 0, scale_to_zero_val = 0;
std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
std::memcpy(
&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
const float scale_to_inf = scale_to_inf_val;
const float scale_to_zero = scale_to_zero_val;
#if defined(_MSC_VER) && _MSC_VER == 1916
float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
#else
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
#endif
const uint32_t w = fp32_to_bits(f);
const uint32_t shl1_w = w + w;
const uint32_t sign = w & UINT32_C(0x80000000);
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
if (bias < UINT32_C(0x71000000)) {
bias = UINT32_C(0x71000000);
}
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
const uint32_t bits = fp32_to_bits(base);
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
const uint32_t nonsign = exp_bits + mantissa_bits;
return static_cast<uint16_t>(
(sign >> 16) |
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
#endif // C10_X86_F16
}
} // namespace torch::headeronly::detail
namespace c10::detail {
using torch::headeronly::detail::fp16_ieee_from_fp32_value;
using torch::headeronly::detail::fp16_ieee_to_fp32_value;
} // namespace c10::detail

View File

@ -0,0 +1,50 @@
#pragma once
#include <cstring>
#include <type_traits>
#include <torch/headeronly/macros/Macros.h>
#if __has_include(<bit>) && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)
#include <bit>
#define C10_HAVE_STD_BIT_CAST 1
#else
#define C10_HAVE_STD_BIT_CAST 0
#endif // __has_include(<bit>) && (__cplusplus >= 202002L ||
// (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L))
namespace torch::headeronly {
#if C10_HAVE_STD_BIT_CAST
using std::bit_cast;
#else
// Implementations of std::bit_cast() from C++ 20.
//
// This is a less sketchy version of reinterpret_cast.
//
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
// information as well as the source of our implementations.
template <class To, class From>
C10_HOST_DEVICE std::enable_if_t<
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
std::is_trivially_copyable_v<To>,
To>
// constexpr support needs compiler magic
bit_cast(const From& src) noexcept {
static_assert(
std::is_trivially_constructible_v<To>,
"This implementation additionally requires "
"destination type to be trivially constructible");
To dst;
std::memcpy(&dst, &src, sizeof(To));
return dst;
}
#endif // C10_HAVE_STD_BIT_CAST
#undef C10_HAVE_STD_BIT_CAST
} // namespace torch::headeronly
namespace c10 {
using torch::headeronly::bit_cast;
} // namespace c10

View File

@ -0,0 +1,38 @@
#pragma once
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/bit_cast.h>
#include <cstdint>
namespace torch::headeronly::detail {
C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
#if defined(__OPENCL_VERSION__)
return as_float(w);
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return __uint_as_float((unsigned int)w);
#elif defined(__INTEL_COMPILER)
return _castu32_f32(w);
#else
return torch::headeronly::bit_cast<float>(w);
#endif
}
C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
#if defined(__OPENCL_VERSION__)
return as_uint(f);
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return (uint32_t)__float_as_uint(f);
#elif defined(__INTEL_COMPILER)
return _castf32_u32(f);
#else
return torch::headeronly::bit_cast<uint32_t>(f);
#endif
}
} // namespace torch::headeronly::detail
namespace c10::detail {
using torch::headeronly::detail::fp32_from_bits;
using torch::headeronly::detail::fp32_to_bits;
} // namespace c10::detail