Extending the Pytorch vec backend for SVE (ARM) (#119571)

**Motivation:**
In Pytorch, Aten vectorization supports multiple platforms, including x86 and Arm, as well as multiple data types. It provides a generic implementation of Vector (Vec) type that allows the programmer to write code packing various primitives (such as floats) within 256bit & 512bits registers. It can be extended to support other ISAs easily by adding more VecISA sub-classes.

**Reference Link:** https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cpu/vec

**This PR:**

* Our goal with this contribution is to add support for SVE backend for Vec in the Aten vectorization for CPU backend which can be benefitted by any ARM architecture supported CPU's that supports SVE.

* More about SVE ISA for ARM: [https://developer.arm.com/Architectures/Scalable Vector Extensions](https://developer.arm.com/Architectures/Scalable%20Vector%20Extensions)

* We are using the ARM C Language Extensions for SVE (https://developer.arm.com/documentation/102699/0100/Optimizing-with-intrinsics ) to accelerate performance for various operators in the SVE backend for Vec.

* Currently we are adding support only for SVE ISA with the vector length of 256 bits (SVE 256). In future, we plan to extend this SVE support for other vector lengths as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119571
Approved by: https://github.com/malfet, https://github.com/snadampal

Co-authored-by: Divya Kotadiya <divya.kotadiya@fujitsu.com>
This commit is contained in:
maajidkhann
2024-09-18 18:59:10 +00:00
committed by PyTorch MergeBot
parent bad69044d8
commit 5a6ddbcc3b
29 changed files with 2554 additions and 9 deletions

View File

@ -54,7 +54,7 @@ if(NOT BUILD_LITE_INTERPRETER)
endif()
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp")
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h")
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp")

View File

@ -105,6 +105,11 @@ std::string get_cpu_capability() {
return "DEFAULT";
case native::CPUCapability::ZVECTOR:
return "Z VECTOR";
#elif defined(HAVE_SVE_CPU_DEFINITION)
case native::CPUCapability::DEFAULT:
return "DEFAULT";
case native::CPUCapability::SVE256:
return "SVE256";
#else
case native::CPUCapability::DEFAULT:
return "NO AVX";

View File

@ -78,7 +78,7 @@ struct VecReduceAllSIMD<float, Op> {
#endif // defined(CPU_CAPABILITY_AVX512)
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {

View File

@ -5,6 +5,10 @@
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* Clang-compatible compiler, targeting arm neon */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* CLANG-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#elif defined(_MSC_VER)
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
@ -17,6 +21,10 @@
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* GCC-compatible compiler, targeting ARM with NEON */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* GCC-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#if defined (MISSING_ARM_VLD1)
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
#elif defined (MISSING_ARM_VST1)

View File

@ -0,0 +1,63 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_SVE)
// Define the data type of VLS(vector-length specific).
typedef svbool_t vls_pred_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svint8_t vls_int8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svint16_t vls_int16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svint32_t vls_int32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svint64_t vls_int64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svuint8_t vls_uint8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
#define ptrue svptrue_b8()
#define ZERO_S8 svdup_n_s8(0)
#define ZERO_S16 svdup_n_s16(0)
#define ZERO_S32 svdup_n_s32(0)
#define ZERO_S64 svdup_n_s64(0)
#define ZERO_U8 svdup_n_u8(0)
#define ZERO_U16 svdup_n_u16(0)
#define ZERO_U32 svdup_n_u32(0)
#define ZERO_U64 svdup_n_u64(0)
#define ZERO_F16 svdup_n_f16(0.f)
#define ZERO_F32 svdup_n_f32(0.f)
#define ZERO_F64 svdup_n_f64(0.0)
#define ONE_S8 svdup_n_s8(1)
#define ONE_S16 svdup_n_s16(1)
#define ONE_S32 svdup_n_s32(1)
#define ONE_S64 svdup_n_s64(1)
#define ONE_U8 svdup_n_u8(1)
#define ONE_U16 svdup_n_u16(1)
#define ONE_U32 svdup_n_u32(1)
#define ONE_U64 svdup_n_u64(1)
#define ONE_F16 svdup_n_f16(1.f)
#define ONE_F32 svdup_n_f32(1.f)
#define ONE_F64 svdup_n_f64(1.0)
#define ALL_S8_TRUE_MASK svdup_n_s8(0xff)
#define ALL_S8_FALSE_MASK svdup_n_s8(0x0)
#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff)
#define ALL_S16_FALSE_MASK svdup_n_s16(0x0)
#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff)
#define ALL_S32_FALSE_MASK svdup_n_s32(0x0)
#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff)
#define ALL_S64_FALSE_MASK svdup_n_s64(0x0)
#define ALL_U8_TRUE_MASK svdup_n_u8(0x01)
#define ALL_U8_FALSE_MASK svdup_n_u8(0x00)
#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK)
#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK)
#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK)
#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK)
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK)
#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK)
#endif // defined(CPU_CAPABILITY_SVE)

View File

@ -0,0 +1,176 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with SVE]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/sve/sve_helper.h>
#if defined(CPU_CAPABILITY_SVE)
#include <ATen/cpu/vec/sve/vec_float.h>
#include <ATen/cpu/vec/sve/vec_double.h>
#include <ATen/cpu/vec/sve/vec_int.h>
#include <ATen/cpu/vec/sve/vec_qint.h>
#endif
namespace at {
namespace vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
return svreinterpret_f32_f64(src);
}
template<>
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
return svreinterpret_f64_f32(src);
}
#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \
template<> \
inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \
return svreinterpret_s##int_bit##_f##float_bit(src); \
} \
template<> \
inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \
return svreinterpret_f##float_bit##_s##int_bit(src); \
}
DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64)
DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64)
DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64)
DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32)
DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32)
DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex_) {
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3);
return svld1_gather_s64index_f64(ptrue, base_addr, vindex);
}
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex_) {
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2);
return svld1_gather_s32index_f32(ptrue, base_addr, vindex);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
const Vectorized<int64_t>& vindex_, const Vectorized<double>& mask_) {
svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_),
ALL_S64_TRUE_MASK);
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3);
return svsel_f64(mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src);
}
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline mask_gather(const Vectorized<float>& src, const float* base_addr,
const Vectorized<int32_t>& vindex_, const Vectorized<float>& mask_) {
svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_),
ALL_S32_TRUE_MASK);
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2);
return svsel_f32(mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Only works for inputs in the range: [-2^51, 2^51]
// From: https://stackoverflow.com/a/41148578
template<>
Vectorized<int64_t>
inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000));
return svsub_s64_x(ptrue,
svreinterpret_s64_f64(x),
svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000)));
}
template<>
Vectorized<int32_t>
inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
return svcvt_s32_f32_x(ptrue, src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, a1, a3, a3}
// b = {b0, b1, b2, b3}
// group cols crossing lanes:
// return {a0, b0, a1, b1}
// {a2, b2, a3, b3}
return std::make_pair(Vectorized<double>(svzip1_f64(a, b)),
Vectorized<double>(svzip2_f64(a, b)));
}
template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, a1, a2, a3, a4, a5, a6, a7}
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
// group cols crossing lanes:
// return {a0, b0, a1, b1, a2, b2, a3, b3}
// {a4, b4, a5, b5, a6, b6, a7, b7}
return std::make_pair(Vectorized<float>(svzip1_f32(a, b)),
Vectorized<float>(svzip2_f32(a, b)));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, b0, a1, b1}
// b = {a2, b2, a3, b3}
// swap lanes:
// return {a0, a1, a2, a3}
// {b0, b1, b2, b3}
return std::make_pair(Vectorized<double>(svuzp1_f64(a, b)),
Vectorized<double>(svuzp2_f64(a, b)));
}
template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
// swap lanes:
// return {a0, a1, a2, a3, a4, a5, a6, a7}
// {b0, b1, b2, b3, b4, b5, b6, b7}
return std::make_pair(Vectorized<float>(svuzp1_f32(a, b)),
Vectorized<float>(svuzp2_f32(a, b)));
}
#endif // defined(CPU_CAPABILITY_SVE)
}}}

View File

@ -0,0 +1,505 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/sve/sve_helper.h>
#include <cmath>
#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
#include <sleef.h>
#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
#else
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
#endif
namespace at {
namespace vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
template <> class Vectorized<double> {
private:
vls_float64_t values;
public:
using value_type = double;
using size_type = int;
static constexpr size_type size() {
return VECTOR_WIDTH / sizeof(double);
}
Vectorized() {}
Vectorized(svfloat64_t v) : values(v) {}
Vectorized(double val) {
values = svdup_n_f64(val);
}
template<typename... Args,
typename = std::enable_if_t<(sizeof...(Args) == size())>>
Vectorized(Args... vals) {
__at_align__ double buffer[size()] = { vals... };
values = svld1_f64(ptrue, buffer);
}
operator svfloat64_t() const {
return values;
}
static Vectorized<double> blendv(const Vectorized<double>& a, const Vectorized<double>& b,
const Vectorized<double>& mask_) {
svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_),
ALL_S64_TRUE_MASK);
return svsel_f64(mask, b, a);
}
template<typename step_t>
static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) {
__at_align__ double buffer[size()];
for (int64_t i = 0; i < size(); i++) {
buffer[i] = base + i * step;
}
return svld1_f64(ptrue, buffer);
}
static Vectorized<double> set(const Vectorized<double>& a, const Vectorized<double>& b,
int64_t count = size()) {
if (count == 0) {
return a;
} else if (count < size()) {
return svsel_f64(svwhilelt_b64(0ull, count), b, a);
}
return b;
}
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return svld1_f64(ptrue, reinterpret_cast<const double*>(ptr));
svbool_t pg = svwhilelt_b64(0ull, count);
return svld1_f64(pg, reinterpret_cast<const double*>(ptr));
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
svst1_f64(ptrue, reinterpret_cast<double*>(ptr), values);
} else {
svbool_t pg = svwhilelt_b64(0ull, count);
svst1_f64(pg, reinterpret_cast<double*>(ptr), values);
}
}
const double& operator[](int idx) const = delete;
double& operator[](int idx) = delete;
int64_t zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
int64_t mask = 0;
__at_align__ int64_t mask_array[size()];
svbool_t svbool_mask = svcmpeq_f64(ptrue, values, ZERO_F64);
svst1_s64(ptrue, mask_array, svsel_s64(svbool_mask,
ALL_S64_TRUE_MASK,
ALL_S64_FALSE_MASK));
for (int64_t i = 0; i < size(); ++i) {
if (mask_array[i]) mask |= (1ull << i);
}
return mask;
}
Vectorized<double> isnan() const {
// NaN check
svbool_t mask = svcmpuo_f64(ptrue, values, ZERO_F64);
return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK);
}
bool has_inf_nan() const {
return svptest_any(ptrue, svcmpuo_f64(ptrue, svsub_f64_x(ptrue, values, values), ZERO_F64));
}
Vectorized<double> map(double (*f)(double)) const {
__at_align__ double tmp[size()];
store(tmp);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<double> abs() const {
return svabs_f64_x(ptrue, values);
}
Vectorized<double> angle() const {
const auto nan_vec = svdup_n_f64(NAN);
const auto nan_mask = svcmpuo_f64(ptrue, values, ZERO_F64);
const auto pi = svdup_n_f64(c10::pi<double>);
const auto neg_mask = svcmplt_f64(ptrue, values, ZERO_F64);
auto angle = svsel_f64(neg_mask, pi, ZERO_F64);
angle = svsel_f64(nan_mask, nan_vec, angle);
return angle;
}
Vectorized<double> real() const {
return *this;
}
Vectorized<double> imag() const {
return Vectorized<double>(0.0);
}
Vectorized<double> conj() const {
return *this;
}
Vectorized<double> acos() const {
return USE_SLEEF(Vectorized<double>(Sleef_acosdx_u10sve(values)),map(std::acos));
}
Vectorized<double> acosh() const {
return USE_SLEEF( Vectorized<double>(Sleef_acoshdx_u10sve(values)),map(std::acosh));
}
Vectorized<double> asin() const {
return USE_SLEEF(Vectorized<double>(Sleef_asindx_u10sve(values)),map(std::asin));
}
Vectorized<double> atan() const {
return USE_SLEEF(Vectorized<double>(Sleef_atandx_u10sve(values)),map(std::atan));
}
Vectorized<double> atanh() const {
return USE_SLEEF(Vectorized<double>(Sleef_atanhdx_u10sve(values)),map(std::atanh));
}
Vectorized<double> atan2(const Vectorized<double> &b) const {
USE_SLEEF({return Vectorized<double>(Sleef_atan2dx_u10sve(values, b));},
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
}
return loadu(tmp);
}
)
}
Vectorized<double> copysign(const Vectorized<double> &sign) const {
USE_SLEEF( {return Vectorized<double>(Sleef_copysigndx_sve(values, sign));},
{
__at_align__ double tmp[size()];
__at_align__ double tmp_sign[size()];
store(tmp);
sign.store(tmp_sign);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
}
return loadu(tmp);
}
)
}
Vectorized<double> erf() const {
return USE_SLEEF(Vectorized<double>(Sleef_erfdx_u10sve(values)),map(std::erf));
}
Vectorized<double> erfc() const {
return USE_SLEEF(Vectorized<double>(Sleef_erfcdx_u15sve(values)),map(std::erfc));
}
Vectorized<double> erfinv() const {
return map(calc_erfinv);
}
Vectorized<double> exp() const {
return USE_SLEEF(Vectorized<double>(Sleef_expdx_u10sve(values)),map(std::exp));
}
Vectorized<double> exp2() const {
return USE_SLEEF(Vectorized<double>(Sleef_exp2dx_u10sve(values)),map(std::exp2));
}
Vectorized<double> expm1() const {
return USE_SLEEF(Vectorized<double>(Sleef_expm1dx_u10sve(values)),map(std::expm1));
}
Vectorized<double> exp_u20() const {
return exp();
}
Vectorized<double> fmod(const Vectorized<double>& q) const {
USE_SLEEF({return Vectorized<double>(Sleef_fmoddx_sve(values, q));},
{
__at_align__ double tmp[size()];
__at_align__ double tmp_q[size()];
store(tmp);
q.store(tmp_q);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
}
return loadu(tmp);
}
)
}
Vectorized<double> hypot(const Vectorized<double> &b) const {
USE_SLEEF({return Vectorized<double>(Sleef_hypotdx_u05sve(values, b));},
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})
}
Vectorized<double> i0() const {
return map(calc_i0);
}
Vectorized<double> i0e() const {
return map(calc_i0e);
}
Vectorized<double> digamma() const {
return map(calc_digamma);
}
Vectorized<double> igamma(const Vectorized<double> &x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> igammac(const Vectorized<double> &x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> nextafter(const Vectorized<double> &b) const {
USE_SLEEF(
{
return Vectorized<double>(Sleef_nextafterfx_sve(values, b));
},
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
}
return loadu(tmp);
}
)
}
Vectorized<double> log() const {
return USE_SLEEF(Vectorized<double>(Sleef_logdx_u10sve(values)),map(std::log));
}
Vectorized<double> log2() const {
return USE_SLEEF(Vectorized<double>(Sleef_log2dx_u10sve(values)),map(std::log2));
}
Vectorized<double> log10() const {
return USE_SLEEF(Vectorized<double>(Sleef_log10dx_u10sve(values)),map(std::log10));
}
Vectorized<double> log1p() const {
return USE_SLEEF(Vectorized<double>(Sleef_log1pdx_u10sve(values)),map(std::log1p));
}
Vectorized<double> frac() const;
Vectorized<double> sin() const {
return USE_SLEEF( Vectorized<double>(Sleef_sindx_u10sve(values)),map(std::sin));
}
Vectorized<double> sinh() const {
return USE_SLEEF(Vectorized<double>(Sleef_sinhdx_u10sve(values)),map(std::sinh));
}
Vectorized<double> cos() const {
return USE_SLEEF(Vectorized<double>(Sleef_cosdx_u10sve(values)),map(std::cos));
}
Vectorized<double> cosh() const {
return USE_SLEEF( Vectorized<double>(Sleef_coshdx_u10sve(values)),map(std::cosh));
}
Vectorized<double> ceil() const {
return svrintp_f64_x(ptrue, values);
}
Vectorized<double> floor() const {
return svrintm_f64_x(ptrue, values);
}
Vectorized<double> neg() const {
return svneg_f64_x(ptrue, values);
}
Vectorized<double> round() const {
return svrinti_f64_x(ptrue, values);
}
Vectorized<double> tan() const {
return USE_SLEEF( Vectorized<double>(Sleef_tandx_u10sve(values)),map(std::tan));
}
Vectorized<double> tanh() const {
return USE_SLEEF( Vectorized<double>(Sleef_tanhdx_u10sve(values)),map(std::tanh));
}
Vectorized<double> trunc() const {
return svrintz_f64_x(ptrue, values);
}
Vectorized<double> lgamma() const {
return USE_SLEEF( Vectorized<double>(Sleef_lgammadx_u10sve(values)),map(std::lgamma));
}
Vectorized<double> sqrt() const {
return svsqrt_f64_x(ptrue, values);
}
Vectorized<double> reciprocal() const {
return svdivr_f64_x(ptrue, values, ONE_F64);
}
Vectorized<double> rsqrt() const {
return svdivr_f64_x(ptrue, svsqrt_f64_x(ptrue, values), ONE_F64);
}
Vectorized<double> pow(const Vectorized<double> &b) const {
USE_SLEEF( {return Vectorized<double>(Sleef_powdx_u10sve(values, b));},
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::pow(tmp[i], tmp_b[i]);
}
return loadu(tmp);
}
)
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<double> operator==(const Vectorized<double>& other) const {
svbool_t mask = svcmpeq_f64(ptrue, values, other);
return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK);
}
Vectorized<double> operator!=(const Vectorized<double>& other) const {
svbool_t mask = svcmpne_f64(ptrue, values, other);
return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK);
}
Vectorized<double> operator<(const Vectorized<double>& other) const {
svbool_t mask = svcmplt_f64(ptrue, values, other);
return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK);
}
Vectorized<double> operator<=(const Vectorized<double>& other) const {
svbool_t mask = svcmple_f64(ptrue, values, other);
return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK);
}
Vectorized<double> operator>(const Vectorized<double>& other) const {
svbool_t mask = svcmpgt_f64(ptrue, values, other);
return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK);
}
Vectorized<double> operator>=(const Vectorized<double>& other) const {
svbool_t mask = svcmpge_f64(ptrue, values, other);
return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK);
}
Vectorized<double> eq(const Vectorized<double>& other) const;
Vectorized<double> ne(const Vectorized<double>& other) const;
Vectorized<double> gt(const Vectorized<double>& other) const;
Vectorized<double> ge(const Vectorized<double>& other) const;
Vectorized<double> lt(const Vectorized<double>& other) const;
Vectorized<double> le(const Vectorized<double>& other) const;
};
template <>
Vectorized<double> inline operator+(const Vectorized<double>& a, const Vectorized<double>& b) {
return svadd_f64_x(ptrue, a, b);
}
template <>
Vectorized<double> inline operator-(const Vectorized<double>& a, const Vectorized<double>& b) {
return svsub_f64_x(ptrue, a, b);
}
template <>
Vectorized<double> inline operator*(const Vectorized<double>& a, const Vectorized<double>& b) {
return svmul_f64_x(ptrue, a, b);
}
template <>
Vectorized<double> inline operator/(const Vectorized<double>& a, const Vectorized<double>& b) {
return svdiv_f64_x(ptrue, a, b);
}
// frac. Implement this here so we can use subtraction
Vectorized<double> inline Vectorized<double>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline maximum(const Vectorized<double>& a, const Vectorized<double>& b) {
return svmax_f64_x(ptrue, a, b);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline minimum(const Vectorized<double>& a, const Vectorized<double>& b) {
return svmin_f64_x(ptrue, a, b);
}
template <>
Vectorized<double> inline clamp(const Vectorized<double>& a, const Vectorized<double>& min, const Vectorized<double>& max) {
return svmin_f64_x(ptrue, max, svmax_f64_x(ptrue, min, a));
}
template <>
Vectorized<double> inline clamp_max(const Vectorized<double>& a, const Vectorized<double>& max) {
return svmin_f64_x(ptrue, max, a);
}
template <>
Vectorized<double> inline clamp_min(const Vectorized<double>& a, const Vectorized<double>& min) {
return svmax_f64_x(ptrue, min, a);
}
template <>
Vectorized<double> inline operator&(const Vectorized<double>& a, const Vectorized<double>& b) {
return svreinterpret_f64_s64(svand_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b)));
}
template <>
Vectorized<double> inline operator|(const Vectorized<double>& a, const Vectorized<double>& b) {
return svreinterpret_f64_s64(svorr_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b)));
}
template <>
Vectorized<double> inline operator^(const Vectorized<double>& a, const Vectorized<double>& b) {
return svreinterpret_f64_s64(sveor_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b)));
}
Vectorized<double> inline Vectorized<double>::eq(const Vectorized<double>& other) const {
return (*this == other) & Vectorized<double>(1.0);
}
Vectorized<double> inline Vectorized<double>::ne(const Vectorized<double>& other) const {
return (*this != other) & Vectorized<double>(1.0);
}
Vectorized<double> inline Vectorized<double>::gt(const Vectorized<double>& other) const {
return (*this > other) & Vectorized<double>(1.0);
}
Vectorized<double> inline Vectorized<double>::ge(const Vectorized<double>& other) const {
return (*this >= other) & Vectorized<double>(1.0);
}
Vectorized<double> inline Vectorized<double>::lt(const Vectorized<double>& other) const {
return (*this < other) & Vectorized<double>(1.0);
}
Vectorized<double> inline Vectorized<double>::le(const Vectorized<double>& other) const {
return (*this <= other) & Vectorized<double>(1.0);
}
template <>
inline void convert(const double* src, double* dst, int64_t n) {
const int64_t fraction = n % Vectorized<double>::size();
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<double>::size()) {
svst1_f64(ptrue, dst + i, svldnt1_f64(ptrue, src + i));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<double>::size()) {
svbool_t pg = svwhilelt_b64(i, n);
svst1_f64(pg, dst + i, svldnt1_f64(pg, src + i));
}
}
template <>
Vectorized<double> inline fmadd(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
return svmad_f64_x(ptrue, a, b, c);
}
#endif // defined(CPU_CAPABILITY_SVE)
}}}

View File

@ -0,0 +1,570 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/sve/sve_helper.h>
#include <cmath>
#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
#include <sleef.h>
#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
#else
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
#endif
namespace at {
namespace vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
template <> class Vectorized<float> {
private:
vls_float32_t values;
public:
using value_type = float;
using size_type = int;
static constexpr size_type size() {
return VECTOR_WIDTH / sizeof(float);
}
Vectorized() {}
Vectorized(svfloat32_t v) : values(v) {}
Vectorized(float val) {
values = svdup_n_f32(val);
}
template<typename... Args,
typename = std::enable_if_t<(sizeof...(Args) == size())>>
Vectorized(Args... vals) {
__at_align__ float buffer[size()] = { vals... };
values = svld1_f32(ptrue, buffer);
}
operator svfloat32_t() const {
return values;
}
static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b,
const Vectorized<float>& mask_) {
svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_),
ALL_S32_TRUE_MASK);
return svsel_f32(mask, b, a);
}
template<typename step_t>
static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
__at_align__ float buffer[size()];
for (int64_t i = 0; i < size(); i++) {
buffer[i] = base + i * step;
}
return svld1_f32(ptrue, buffer);
}
static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b,
int64_t count = size()) {
if (count == 0) {
return a;
} else if (count < size()) {
return svsel_f32(svwhilelt_b32(0ull, count), b, a);
}
return b;
}
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
svbool_t pg = svwhilelt_b32(0ull, count);
return svld1_f32(pg, reinterpret_cast<const float*>(ptr));
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
svst1_f32(ptrue, reinterpret_cast<float*>(ptr), values);
} else {
svbool_t pg = svwhilelt_b32(0ull, count);
svst1_f32(pg, reinterpret_cast<float*>(ptr), values);
}
}
const float& operator[](int idx) const = delete;
float& operator[](int idx) = delete;
int64_t zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
int64_t mask = 0;
__at_align__ int32_t mask_array[size()];
svbool_t svbool_mask = svcmpeq_f32(ptrue, values, ZERO_F32);
svst1_s32(ptrue, mask_array, svsel_s32(svbool_mask,
ALL_S32_TRUE_MASK,
ALL_S32_FALSE_MASK));
for (int64_t i = 0; i < size(); ++i) {
if (mask_array[i]) mask |= (1ull << i);
}
return mask;
}
Vectorized<float> isnan() const {
// NaN check
svbool_t mask = svcmpuo_f32(ptrue, values, ZERO_F32);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
bool has_inf_nan() const {
return svptest_any(ptrue, svcmpuo_f32(ptrue, svsub_f32_x(ptrue, values, values), ZERO_F32));
}
Vectorized<float> map(float (*f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<float> abs() const {
return svabs_f32_x(ptrue, values);
}
Vectorized<float> angle() const {
const auto nan_vec = svdup_n_f32(NAN);
const auto nan_mask = svcmpuo_f32(ptrue, values, ZERO_F32);
const auto pi = svdup_n_f32(c10::pi<float>);
const auto neg_mask = svcmplt_f32(ptrue, values, ZERO_F32);
auto angle = svsel_f32(neg_mask, pi, ZERO_F32);
angle = svsel_f32(nan_mask, nan_vec, angle);
return angle;
}
Vectorized<float> real() const {
return values;
}
Vectorized<float> imag() const {
return Vectorized<float>(0.f);
}
Vectorized<float> conj() const {
return values;
}
Vectorized<float> acos() const {
return USE_SLEEF(Vectorized<float>(Sleef_acosfx_u10sve(values)),map(std::acos));
}
Vectorized<float> acosh() const {
return USE_SLEEF(Vectorized<float>(Sleef_acoshfx_u10sve(values)),map(std::acosh));
}
Vectorized<float> asin() const {
return USE_SLEEF(Vectorized<float>(Sleef_asinfx_u10sve(values)),map(std::asin));
}
Vectorized<float> atan() const {
return USE_SLEEF(Vectorized<float>(Sleef_atanfx_u10sve(values)),map(std::atan));
}
Vectorized<float> atanh() const {
return USE_SLEEF(Vectorized<float>(Sleef_atanhfx_u10sve(values)),map(std::atanh));
}
Vectorized<float> atan2(const Vectorized<float> &b) const {
USE_SLEEF({return Vectorized<float>(Sleef_atan2fx_u10sve(values, b));},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++){
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
}
return loadu(tmp);
}
)
}
Vectorized<float> copysign(const Vectorized<float> &sign) const {
USE_SLEEF({return Vectorized<float>(Sleef_copysignfx_sve(values, sign));},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_sign[size()];
store(tmp);
sign.store(tmp_sign);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
}
return loadu(tmp);
})
}
Vectorized<float> erf() const {
return USE_SLEEF(Vectorized<float>(Sleef_erffx_u10sve(values)),map(std::erf));
}
Vectorized<float> erfc() const {
return USE_SLEEF(Vectorized<float>(Sleef_erfcfx_u15sve(values)),map(std::erfc));
}
Vectorized<float> erfinv() const {
return map(calc_erfinv);
}
Vectorized<float> exp() const {
return USE_SLEEF(Vectorized<float>(Sleef_expfx_u10sve(values)),map(std::exp));
}
Vectorized<float> exp2() const {
return USE_SLEEF(Vectorized<float>(Sleef_exp2fx_u10sve(values)),map(std::exp2));
}
Vectorized<float> expm1() const {
return USE_SLEEF(Vectorized<float>(Sleef_expm1fx_u10sve(values)),map(std::expm1));
}
Vectorized<float> exp_u20() const {
return exp();
}
Vectorized<float> fmod(const Vectorized<float>& q) const {
USE_SLEEF({return Vectorized<float>(Sleef_fmodfx_sve(values, q));},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_q[size()];
store(tmp);
q.store(tmp_q);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
}
return loadu(tmp);
})
}
Vectorized<float> hypot(const Vectorized<float> &b) const {
USE_SLEEF( {return Vectorized<float>(Sleef_hypotfx_u05sve(values, b));},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
}
return loadu(tmp);
}
)
}
Vectorized<float> i0() const {
return map(calc_i0);
}
Vectorized<float> i0e() const {
return map(calc_i0e);
}
Vectorized<float> digamma() const {
return map(calc_digamma);
}
Vectorized<float> igamma(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> igammac(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> nextafter(const Vectorized<float> &b) const {
USE_SLEEF(
{
return Vectorized<float>(Sleef_nextafterfx_sve(values, b));
},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
}
return loadu(tmp);
}
)
}
Vectorized<float> log() const {
return USE_SLEEF(Vectorized<float>(Sleef_logfx_u10sve(values)),map(std::log));
}
Vectorized<float> log2() const {
return USE_SLEEF(Vectorized<float>(Sleef_log2fx_u10sve(values)),map(std::log2));
}
Vectorized<float> log10() const {
return USE_SLEEF(Vectorized<float>(Sleef_log10fx_u10sve(values)),map(std::log10));
}
Vectorized<float> log1p() const {
return USE_SLEEF(Vectorized<float>(Sleef_log1pfx_u10sve(values)),map(std::log1p));
}
Vectorized<float> frac() const;
Vectorized<float> sin() const {
return USE_SLEEF(Vectorized<float>(Sleef_sinfx_u10sve(values)),map(std::sin));
}
Vectorized<float> sinh() const {
return USE_SLEEF(Vectorized<float>(Sleef_sinhfx_u10sve(values)),map(std::sinh));
}
Vectorized<float> cos() const {
return USE_SLEEF(Vectorized<float>(Sleef_cosfx_u10sve(values)),map(std::cos));
}
Vectorized<float> cosh() const {
return USE_SLEEF(Vectorized<float>(Sleef_coshfx_u10sve(values)),map(std::cosh));
}
Vectorized<float> ceil() const {
return svrintp_f32_x(ptrue, values);
}
Vectorized<float> floor() const {
return svrintm_f32_x(ptrue, values);
}
Vectorized<float> neg() const {
return svneg_f32_x(ptrue, values);
}
Vectorized<float> round() const {
return svrinti_f32_x(ptrue, values);
}
Vectorized<float> tan() const {
return USE_SLEEF(Vectorized<float>(Sleef_tanfx_u10sve(values)),map(std::tan));
}
Vectorized<float> tanh() const {
return USE_SLEEF(Vectorized<float>(Sleef_tanhfx_u10sve(values)),map(std::tanh));
}
Vectorized<float> trunc() const {
return svrintz_f32_x(ptrue, values);
}
Vectorized<float> lgamma() const {
return USE_SLEEF(Vectorized<float>(Sleef_lgammafx_u10sve(values)),map(std::lgamma));
}
Vectorized<float> sqrt() const {
return svsqrt_f32_x(ptrue, values);
}
Vectorized<float> reciprocal() const {
return svdivr_f32_x(ptrue, values, ONE_F32);
}
Vectorized<float> rsqrt() const {
return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, values), ONE_F32);
}
Vectorized<float> pow(const Vectorized<float> &b) const {
USE_SLEEF( {return Vectorized<float>(Sleef_powfx_u10sve(values, b));},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::pow(tmp[i], tmp_b[i]);
}
return loadu(tmp);
}
)
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<float> operator==(const Vectorized<float>& other) const {
svbool_t mask = svcmpeq_f32(ptrue, values, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator!=(const Vectorized<float>& other) const {
svbool_t mask = svcmpne_f32(ptrue, values, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator<(const Vectorized<float>& other) const {
svbool_t mask = svcmplt_f32(ptrue, values, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator<=(const Vectorized<float>& other) const {
svbool_t mask = svcmple_f32(ptrue, values, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator>(const Vectorized<float>& other) const {
svbool_t mask = svcmpgt_f32(ptrue, values, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator>=(const Vectorized<float>& other) const {
svbool_t mask = svcmpge_f32(ptrue, values, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> eq(const Vectorized<float>& other) const;
Vectorized<float> ne(const Vectorized<float>& other) const;
Vectorized<float> gt(const Vectorized<float>& other) const;
Vectorized<float> ge(const Vectorized<float>& other) const;
Vectorized<float> lt(const Vectorized<float>& other) const;
Vectorized<float> le(const Vectorized<float>& other) const;
};
template <>
Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
return svadd_f32_x(ptrue, a, b);
}
template <>
Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
return svsub_f32_x(ptrue, a, b);
}
template <>
Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
return svmul_f32_x(ptrue, a, b);
}
template <>
Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
return svdiv_f32_x(ptrue, a, b);
}
// frac. Implement this here so we can use subtraction
Vectorized<float> inline Vectorized<float>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
return svmax_f32_x(ptrue, a, b);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
return svmin_f32_x(ptrue, a, b);
}
template <>
Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
return svmin_f32_x(ptrue, max, svmax_f32_x(ptrue, min, a));
}
template <>
Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
return svmin_f32_x(ptrue, max, a);
}
template <>
Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
return svmax_f32_x(ptrue, min, a);
}
template <>
Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
return svreinterpret_f32_s32(svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
}
template <>
Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
return svreinterpret_f32_s32(svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
}
template <>
Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
return svreinterpret_f32_s32(sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
}
Vectorized<float> inline Vectorized<float>::eq(const Vectorized<float>& other) const {
return (*this == other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::ne(const Vectorized<float>& other) const {
return (*this != other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::gt(const Vectorized<float>& other) const {
return (*this > other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::ge(const Vectorized<float>& other) const {
return (*this >= other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::lt(const Vectorized<float>& other) const {
return (*this < other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::le(const Vectorized<float>& other) const {
return (*this <= other) & Vectorized<float>(1.0f);
}
template <>
inline void convert(const float* src, float* dst, int64_t n) {
const int64_t fraction = n % Vectorized<float>::size();
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
svst1_f32(ptrue, dst + i, svldnt1_f32(ptrue, src + i));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
svbool_t pg = svwhilelt_b32(i, n);
svst1_f32(pg, dst + i, svldnt1_f32(pg, src + i));
}
}
template <>
inline void convert(const float *src, at::Half *dst, int64_t n) {
const int64_t fraction = n % Vectorized<float>::size();
svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized<float>::size());
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<float>::size());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
svfloat16_t src_vec = svuzp1_f16(svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)),
ZERO_F16);
svst1_f16(pg_16, reinterpret_cast<float16_t*>(dst) + i, src_vec);
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
pg_16 = svwhilelt_b16(i, n);
pg_32 = svwhilelt_b32(i, n);
svfloat16_t src_vec = svuzp1_f16(svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)),
ZERO_F16);
svst1_f16(pg_16, reinterpret_cast<float16_t*>(dst) + i, src_vec);
}
}
template <>
inline void convert(const at::Half *src, float *dst, int64_t n) {
const int64_t fraction = n % Vectorized<float>::size();
svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized<float>::size());
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<float>::size());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
svfloat16_t src_vec = svzip1_f16(svldnt1_f16(pg_16, reinterpret_cast<const float16_t*>(src) + i),
ZERO_F16);
svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
pg_16 = svwhilelt_b16(i, n);
pg_32 = svwhilelt_b32(i, n);
svfloat16_t src_vec = svzip1_f16(svldnt1_f16(pg_16, reinterpret_cast<const float16_t*>(src) + i),
ZERO_F16);
svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec));
}
}
template <>
inline void convert(const bool *src, float *dst, int64_t n) {
const int64_t fraction = n % Vectorized<float>::size();
svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized<float>::size());
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<float>::size());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8));
svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32);
svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
pg_8 = svwhilelt_b8(i, n);
pg_32 = svwhilelt_b32(i, n);
svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8));
svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32);
svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32));
}
}
template <>
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
return svmad_f32_x(ptrue, a, b, c);
}
#endif // defined(CPU_CAPABILITY_SVE)
}}}

View File

@ -0,0 +1,410 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/sve/sve_helper.h>
namespace at {
namespace vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
#define VEC_INT_SVE_TEMPLATE(vl, bit) \
template <> class Vectorized<int##bit##_t> { \
private: \
vls_int##bit##_t values; \
public: \
using value_type = int##bit##_t; \
using size_type = int; \
static constexpr size_type size() { \
return vl; \
} \
Vectorized() {} \
Vectorized(svint##bit##_t v) : values(v) {} \
Vectorized(int##bit##_t val) { \
values = svdup_n_s##bit(val); \
} \
template<typename... Args, \
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
Vectorized(Args... vals) { \
__at_align__ int##bit##_t buffer[size()] = { vals... }; \
values = svld1_s##bit(ptrue, buffer); \
} \
operator svint##bit##_t() const { \
return values; \
} \
static Vectorized<int##bit##_t> blendv(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b, \
const Vectorized<int##bit##_t>& mask_) { \
svbool_t mask = svcmpeq_s##bit(ptrue, mask_, ALL_S##bit##_TRUE_MASK); \
return svsel_s##bit(mask, b, a); \
} \
/* step sometimes requires a higher precision type (e.g., T=int, step_t=double) */ \
template <typename step_t> \
static Vectorized<int##bit##_t> arange(int##bit##_t base = 0, step_t step = static_cast<step_t>(1)) { \
__at_align__ int##bit##_t buffer[size()]; \
for (int64_t i = 0; i < size(); i++) { \
buffer[i] = base + i * step; \
} \
return svld1_s##bit(ptrue, buffer); \
} \
static Vectorized<int##bit##_t> set(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b, \
int##bit##_t count = size()) { \
if (count == 0) { \
return a; \
} else if (count < size()) { \
return svsel_s##bit(svwhilelt_b##bit(0ull, count), b, a); \
} \
return b; \
} \
static Vectorized<int##bit##_t> loadu(const void* ptr, int64_t count = size()) { \
if (count == size()) \
return svld1_s##bit(ptrue, reinterpret_cast<const int##bit##_t*>(ptr)); \
svbool_t pg = svwhilelt_b##bit(0ull, count); \
return svld1_s##bit(pg, reinterpret_cast<const int##bit##_t*>(ptr)); \
} \
void store(void* ptr, int64_t count = size()) const { \
if (count == size()) { \
svst1_s##bit(ptrue, reinterpret_cast<int##bit##_t*>(ptr), values); \
} else { \
svbool_t pg = svwhilelt_b##bit(0ull, count); \
svst1_s##bit(pg, reinterpret_cast<int##bit##_t*>(ptr), values); \
} \
} \
const int##bit##_t& operator[](int idx) const = delete; \
int##bit##_t& operator[](int idx) = delete; \
Vectorized<int##bit##_t> abs() const { \
return svabs_s##bit##_x(ptrue, values); \
} \
Vectorized<int##bit##_t> real() const { \
return values; \
} \
Vectorized<int##bit##_t> imag() const { \
return svdup_n_s##bit(0); \
} \
Vectorized<int##bit##_t> conj() const { \
return values; \
} \
Vectorized<int##bit##_t> frac() const; \
Vectorized<int##bit##_t> neg() const { \
return svneg_s##bit##_x(ptrue, values); \
} \
Vectorized<int##bit##_t> operator==(const Vectorized<int##bit##_t>& other) const { \
svbool_t mask = svcmpeq_s##bit(ptrue, values, other); \
return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \
} \
Vectorized<int##bit##_t> operator!=(const Vectorized<int##bit##_t>& other) const { \
svbool_t mask = svcmpne_s##bit(ptrue, values, other); \
return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \
} \
Vectorized<int##bit##_t> operator<(const Vectorized<int##bit##_t>& other) const { \
svbool_t mask = svcmplt_s##bit(ptrue, values, other); \
return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \
} \
Vectorized<int##bit##_t> operator<=(const Vectorized<int##bit##_t>& other) const { \
svbool_t mask = svcmple_s##bit(ptrue, values, other); \
return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \
} \
Vectorized<int##bit##_t> operator>(const Vectorized<int##bit##_t>& other) const { \
svbool_t mask = svcmpgt_s##bit(ptrue, values, other); \
return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \
} \
Vectorized<int##bit##_t> operator>=(const Vectorized<int##bit##_t>& other) const { \
svbool_t mask = svcmpge_s##bit(ptrue, values, other); \
return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \
} \
Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \
}; \
template <> \
Vectorized<int##bit##_t> inline operator+(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b) { \
return svadd_s##bit##_x(ptrue, a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator-(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b) { \
return svsub_s##bit##_x(ptrue, a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator*(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b) { \
return svmul_s##bit##_x(ptrue, a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline maximum(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b) { \
return svmax_s##bit##_x(ptrue, a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline minimum(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b) { \
return svmin_s##bit##_x(ptrue, a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline clamp(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& min, \
const Vectorized<int##bit##_t>& max) { \
return svmin_s##bit##_x(ptrue, max, svmax_s##bit##_x(ptrue, min, a)); \
} \
template <> \
Vectorized<int##bit##_t> inline clamp_max(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& max) { \
return svmin_s##bit##_x(ptrue, max, a); \
} \
template <> \
Vectorized<int##bit##_t> inline clamp_min(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& min) { \
return svmax_s##bit##_x(ptrue, min, a); \
} \
template <> \
Vectorized<int##bit##_t> inline operator&(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b) { \
return svand_s##bit##_x(ptrue, a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator|(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b) { \
return svorr_s##bit##_x(ptrue, a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator^(const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b) { \
return sveor_s##bit##_x(ptrue, a, b); \
} \
template <> \
inline Vectorized<int##bit##_t> operator~(const Vectorized<int##bit##_t>& a) { \
return sveor_s##bit##_x(ptrue, a, svdup_n_s##bit(-1)); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq(const Vectorized<int##bit##_t>& other) const { \
return (*this == other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne(const Vectorized<int##bit##_t>& other) const { \
return (*this != other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt(const Vectorized<int##bit##_t>& other) const { \
return (*this > other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge(const Vectorized<int##bit##_t>& other) const { \
return (*this >= other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt(const Vectorized<int##bit##_t>& other) const { \
return (*this < other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le(const Vectorized<int##bit##_t>& other) const { \
return (*this <= other) & Vectorized<int##bit##_t>(1); \
}
VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int64_t), 64)
VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int32_t), 32)
VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int16_t), 16)
VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int8_t), 8)
template <typename T>
Vectorized<T> inline intdiv_nosve(const Vectorized<T>& a, const Vectorized<T>& b) {
T values_a[Vectorized<T>::size()];
T values_b[Vectorized<T>::size()];
a.store(values_a);
b.store(values_b);
for (int i = 0; i != Vectorized<T>::size(); i++) {
values_a[i] /= values_b[i];
}
return Vectorized<T>::loadu(values_a);
}
template <>
Vectorized<int64_t> inline operator/(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return svdiv_s64_x(ptrue, a, b);
}
template <>
Vectorized<int32_t> inline operator/(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return svdiv_s32_x(ptrue, a, b);
}
template <>
Vectorized<int16_t> inline operator/(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return intdiv_nosve(a, b);
}
template <>
Vectorized<int8_t> inline operator/(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
return intdiv_nosve(a, b);
}
template <>
inline void convert(const int32_t *src, int64_t *dst, int64_t n) {
const int64_t fraction = n % Vectorized<int64_t>::size();
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<int64_t>::size());
svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized<int64_t>::size());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<int64_t>::size())
svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i)));
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<int64_t>::size()) {
pg_32 = svwhilelt_b32(i, n);
pg_64 = svwhilelt_b64(i, n);
svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i)));
}
}
template <>
inline void convert(const int64_t *src, float *dst, int64_t n) {
const int64_t fraction = n % Vectorized<int64_t>::size();
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<int64_t>::size());
svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized<int64_t>::size());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<int64_t>::size()) {
svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i);
svfloat32_t src_vec_f32 = svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32);
svst1_f32(pg_32, dst + i, src_vec_f32);
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<int64_t>::size()) {
pg_32 = svwhilelt_b32(i, n);
pg_64 = svwhilelt_b64(i, n);
svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i);
svfloat32_t src_vec_f32 = svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32);
svst1_f32(pg_32, dst + i, src_vec_f32);
}
}
template <>
inline void convert(const int32_t *src, float *dst, int64_t n) {
const int64_t fraction = n % Vectorized<int32_t>::size();
svbool_t pg = svwhilelt_b32(0ull, Vectorized<int32_t>::size());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<int32_t>::size()) {
svint32_t src_vec = svldnt1_s32(pg, src + i);
svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<int32_t>::size()) {
pg = svwhilelt_b32(i, n);
svint32_t src_vec = svldnt1_s32(pg, src + i);
svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec));
}
}
template <>
inline void convert(const bool *src, int64_t *dst, int64_t n) {
const int64_t fraction = n % Vectorized<int64_t>::size();
svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized<int64_t>::size());
svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized<int64_t>::size());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<int64_t>::size()) {
svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
svuint64_t src_vec_u64 = svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8)));
svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64);
svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<int64_t>::size()) {
pg_8 = svwhilelt_b8(i, n);
pg_64 = svwhilelt_b64(i, n);
svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
svuint64_t src_vec_u64 = svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8)));
svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64);
svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64));
}
}
template <>
inline void convert(const bool *src, int32_t *dst, int64_t n) {
const int64_t fraction = n % Vectorized<int32_t>::size();
svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized<int32_t>::size());
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<int32_t>::size());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<int32_t>::size()) {
svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8));
svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32);
svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<int32_t>::size()) {
pg_8 = svwhilelt_b8(i, n);
pg_32 = svwhilelt_b32(i, n);
svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8));
svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32);
svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32));
}
}
template <>
inline void convert(const uint8_t *src, bool *dst, int64_t n) {
const int64_t fraction = n % Vectorized<uint8_t>::size();
svbool_t pg = svwhilelt_b8(0ull, Vectorized<uint8_t>::size());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<uint8_t>::size()) {
svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8);
svst1_u8(pg, reinterpret_cast<uint8_t*>(dst) + i,
svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<uint8_t>::size()) {
pg = svwhilelt_b8(i, n);
svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8);
svst1_u8(pg, reinterpret_cast<uint8_t*>(dst) + i,
svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK));
}
}
template <>
Vectorized<int64_t> inline operator<<(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return svlsl_s64_x(ptrue, a, svreinterpret_u64_s64(b));
}
template <>
Vectorized<int32_t> inline operator<<(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return svlsl_s32_x(ptrue, a, svreinterpret_u32_s32(b));
}
template <>
Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return svlsl_s16_x(ptrue, a, svreinterpret_u16_s16(b));
}
template <>
Vectorized<int8_t> inline operator<<(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
return svlsl_s8_x(ptrue, a, svreinterpret_u8_s8(b));
}
template <>
Vectorized<int64_t> inline operator>>(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return svasr_s64_x(ptrue, a, svreinterpret_u64_s64(b));
}
template <>
Vectorized<int32_t> inline operator>>(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return svasr_s32_x(ptrue, a, svreinterpret_u32_s32(b));
}
template <>
Vectorized<int16_t> inline operator>>(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return svasr_s16_x(ptrue, a, svreinterpret_u16_s16(b));
}
template <>
Vectorized<int8_t> inline operator>>(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
return svasr_s8_x(ptrue, a, svreinterpret_u8_s8(b));
}
#endif // defined(CPU_CAPABILITY_SVE)
}}}

View File

@ -0,0 +1,567 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with SVE]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/native/quantized/AffineQuantizerBase.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
#include <c10/util/quint8.h>
#include <array>
// This file defines Vectorized<> for the quantized types.
//
//
// Currently, we simply use these classes as efficient converters between
// the quantized types and Vectorized<float>, usually in bandwidth-bound cases
// where doing the arithmetic in full-precision is acceptable (e.g.
// elementwise operators).
//
//
// Conversions are as follows:
// Vectorized<qint8> -> 4x Vectorized<float>
// Vectorized<quint8> -> 4x Vectorized<float>
// Vectorized<qint32> -> 1x Vectorized<float>
//
// The size of the returned float vector is specified by the special
// constexpr function float_num_vecs. The type of the value returned
// from dequantize (and expected as an argument to quantize) is
// specified by float_vec_return_type.
//
// When writing kernels with these vectors, it is expected that floating-
// point operations will be carried out in a loop over Vectorized<T>::float_num_vecs
// iterations.
namespace at {
namespace vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
// NOTE: These are low-performance implementations that we fall back on
// if we are not building with SVE. This may not be an issue, because
// currently for quantization we assume the user has at least SVE
// installed, so these can simply act as a reference implementation.
//
// If in the future we relax this requirement (SVE+), we should probably
// revisit these implementations
template <
typename T,
typename float_vec_return_type_,
typename int_vec_return_type_,
int size_>
struct VectorizedQuantizedConverter {
using size_type = int;
static constexpr size_type size() {
return size_;
}
static constexpr int float_num_vecs() {
return size() / Vectorized<float>::size();
}
static constexpr int int_num_vecs() {
return size() / Vectorized<int32_t>::size();
}
using float_vec_return_type = float_vec_return_type_;
using int_vec_return_type = int_vec_return_type_;
using value_type = typename T::underlying;
std::array<value_type, size_> vals;
VectorizedQuantizedConverter(T val) {
for (size_t i = 0; i < size(); ++i) {
vals[i] = val.val_;
}
}
VectorizedQuantizedConverter(const void* ptr) {
memcpy(vals.data(), ptr, sizeof(value_type) * size());
}
void store(void* ptr, int count = size()) const {
memcpy(ptr, vals.data(), count * sizeof(value_type));
}
float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point,
Vectorized<float> scale_zp_premul) const {
float_vec_return_type rv;
float tmp_scale[Vectorized<float>::size()];
float tmp_zero_point[Vectorized<float>::size()];
scale.store(tmp_scale);
zero_point.store(tmp_zero_point);
for (int i = 0; i < float_num_vecs(); ++i) {
float tmp_vals[Vectorized<float>::size()];
for (int j = 0; j < Vectorized<float>::size(); ++j) {
tmp_vals[j] =
at::native::dequantize_val<T>(tmp_scale[j], tmp_zero_point[j], T(vals[Vectorized<float>::size() * i + j]));
}
rv[i] = Vectorized<float>::loadu(tmp_vals);
}
return rv;
}
float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
float_vec_return_type rv;
float tmp_scale[Vectorized<float>::size()];
float tmp_zero_point[Vectorized<float>::size()];
scale.store(tmp_scale);
zero_point.store(tmp_zero_point);
for (int i = 0; i < float_num_vecs(); ++i) {
float tmp_vals[Vectorized<float>::size()];
for (int j = 0; j < Vectorized<float>::size(); ++j) {
tmp_vals[j] =
at::native::dequantize_val<T>(tmp_scale[j], tmp_zero_point[j], T(vals[Vectorized<float>::size() * i + j]));
}
rv[i] = Vectorized<float>::loadu(tmp_vals);
}
return rv;
}
protected:
VectorizedQuantizedConverter() {}
};
template <>
struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
c10::qint32,
std::array<Vectorized<float>, 1>,
std::array<Vectorized<c10::qint32>, 1>,
VECTOR_WIDTH / 4> {
Vectorized()
: VectorizedQuantizedConverter<
c10::qint32,
std::array<Vectorized<float>, 1>,
std::array<Vectorized<c10::qint32>, 1>,
VECTOR_WIDTH / 4>() {}
Vectorized(c10::qint32 val)
: VectorizedQuantizedConverter<
c10::qint32,
std::array<Vectorized<float>, 1>,
std::array<Vectorized<c10::qint32>, 1>,
VECTOR_WIDTH / 4>(val) {}
Vectorized(const void* ptr)
: VectorizedQuantizedConverter<
c10::qint32,
std::array<Vectorized<float>, 1>,
std::array<Vectorized<c10::qint32>, 1>,
VECTOR_WIDTH / 4>(ptr) {}
#if 1
static Vectorized<c10::qint32> loadu(const void* ptr) {
return Vectorized<c10::qint32>(ptr);
}
static Vectorized<c10::qint32> loadu(const void* ptr, int64_t count) {
__at_align__ value_type tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (const auto i : c10::irange(size())) {
tmp_values[i] = 0;
}
std::memcpy(tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
return loadu(tmp_values);
}
#else
static Vectorized<c10::qint32> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return svld1_s32(ptrue, reinterpret_cast<const int32_t*>(ptr));
svbool_t pg = svwhilelt_b32(0ull, count);
return svld1_s32(pg, reinterpret_cast<const int32_t*>(ptr));
}
#endif
static Vectorized<c10::qint32> quantize(
const float_vec_return_type& rhs,
float scale,
int32_t zero_point,
float inverse_scale) {
std::array<value_type, size()> qvals;
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
for (int i = 0; i < float_num_vecs(); ++i) {
rhs[i].store(&float_vals[i * Vectorized<float>::size()], Vectorized<float>::size());
}
at::native::quantize_vec<c10::qint32, /*precision=*/32>(
scale,
zero_point,
float_vals.data(),
(c10::qint32*)qvals.data(),
Vectorized<float>::size() * float_num_vecs());
return Vectorized<c10::qint32>::loadu(qvals.data());
}
Vectorized<c10::qint32> maximum(Vectorized<c10::qint32> b) const {
Vectorized<c10::qint32> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::max<value_type>(vals[i], b.vals[i]);
}
return retval;
}
Vectorized<c10::qint32> minimum(Vectorized<c10::qint32> b) const {
Vectorized<c10::qint32> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
}
return retval;
}
Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const {
return maximum(zero_point);
}
Vectorized<c10::qint32> relu6(
Vectorized<c10::qint32> zero_point,
Vectorized<c10::qint32> q_six) {
Vectorized<c10::qint32> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::min<value_type>(
std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
}
return retval;
}
int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
int_vec_return_type retval;
for (size_t i = 0; i < size(); ++i) {
retval[0].vals[i] = vals[i] - b.vals[i];
}
return retval;
}
static Vectorized<c10::qint32> requantize_from_int(
const int_vec_return_type& inp,
float multiplier,
int32_t zero_point) {
Vectorized<c10::qint32> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] =
nearbyint(static_cast<float>(inp[0].vals[i]) * multiplier) +
zero_point;
}
return retval;
}
};
template <>
Vectorized<c10::qint32> inline maximum(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
return a.maximum(b);
}
template <>
Vectorized<c10::qint32> inline operator*(
const Vectorized<c10::qint32>& a,
const Vectorized<c10::qint32>& b) {
Vectorized<c10::qint32> retval;
for (size_t i = 0; i < std::decay_t<decltype(a)>::size(); ++i) {
retval.vals[i] = a.vals[i] * b.vals[i];
}
return retval;
}
template <>
Vectorized<c10::qint32> inline operator+(
const Vectorized<c10::qint32>& a,
const Vectorized<c10::qint32>& b) {
Vectorized<c10::qint32> retval;
for (size_t i = 0; i < std::decay_t<decltype(a)>::size(); ++i) {
retval.vals[i] = a.vals[i] + b.vals[i];
}
return retval;
}
template <>
struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
c10::qint8,
std::array<Vectorized<float>, 4>,
std::array<Vectorized<c10::qint32>, 4>,
VECTOR_WIDTH> {
Vectorized()
: VectorizedQuantizedConverter<
c10::qint8,
std::array<Vectorized<float>, 4>,
std::array<Vectorized<c10::qint32>, 4>,
VECTOR_WIDTH>() {}
Vectorized(c10::qint8 val)
: VectorizedQuantizedConverter<
c10::qint8,
std::array<Vectorized<float>, 4>,
std::array<Vectorized<c10::qint32>, 4>,
VECTOR_WIDTH>(val) {}
Vectorized(const void* ptr)
: VectorizedQuantizedConverter<
c10::qint8,
std::array<Vectorized<float>, 4>,
std::array<Vectorized<c10::qint32>, 4>,
VECTOR_WIDTH>(ptr) {}
static Vectorized<c10::qint8> loadu(const void* ptr) {
return Vectorized<c10::qint8>(ptr);
}
static Vectorized<c10::qint8> loadu(const void* ptr, int64_t count) {
__at_align__ value_type tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (const auto i : c10::irange(size())) {
tmp_values[i] = 0;
}
std::memcpy(tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
return loadu(tmp_values);
}
static Vectorized<c10::qint8> quantize(
const float_vec_return_type& rhs,
float scale,
int32_t zero_point,
float inverse_scale) {
std::array<value_type, size()> qvals;
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
for (int i = 0; i < float_num_vecs(); ++i) {
rhs[i].store(&float_vals[i * Vectorized<float>::size()], Vectorized<float>::size());
}
at::native::quantize_vec<c10::qint8>(
scale,
zero_point,
float_vals.data(),
(c10::qint8*)qvals.data(),
Vectorized<float>::size() * float_num_vecs());
return Vectorized<c10::qint8>::loadu(qvals.data());
}
Vectorized<c10::qint8> maximum(Vectorized<c10::qint8> b) const {
Vectorized<c10::qint8> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::max<value_type>(vals[i], b.vals[i]);
}
return retval;
}
Vectorized<c10::qint8> minimum(Vectorized<c10::qint8> b) const {
Vectorized<c10::qint8> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
}
return retval;
}
Vectorized<c10::qint8> relu(Vectorized<c10::qint8> zero_point) const {
return maximum(zero_point);
}
Vectorized<c10::qint8> relu6(
Vectorized<c10::qint8> zero_point,
Vectorized<c10::qint8> q_six) {
Vectorized<c10::qint8> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::min<value_type>(
std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
}
return retval;
}
int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
int_vec_return_type retval;
constexpr int elem_per_int_vec = size() / int_num_vecs();
for (size_t i = 0; i < int_num_vecs(); ++i) {
for (size_t j = 0; j < elem_per_int_vec; ++j) {
retval[i].vals[j] =
static_cast<int32_t>(vals[i * elem_per_int_vec + j]) -
static_cast<int32_t>(b.vals[i * elem_per_int_vec + j]);
}
}
return retval;
}
static Vectorized<c10::qint8> requantize_from_int(
const int_vec_return_type& inp,
float multiplier,
int32_t zero_point) {
constexpr int elem_per_int_vec = size() / int_num_vecs();
constexpr auto min_val = std::numeric_limits<value_type>::min();
constexpr auto max_val = std::numeric_limits<value_type>::max();
Vectorized<c10::qint8> retval;
for (size_t i = 0; i < int_num_vecs(); ++i) {
for (size_t j = 0; j < elem_per_int_vec; ++j) {
int32_t rounded =
nearbyint(static_cast<float>(inp[i].vals[j]) * multiplier) +
zero_point;
retval.vals[i * elem_per_int_vec + j] =
std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
}
}
return retval;
}
};
template <>
Vectorized<c10::qint8> inline maximum(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
return a.maximum(b);
}
template <>
struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
c10::quint8,
std::array<Vectorized<float>, 4>,
std::array<Vectorized<c10::qint32>, 4>,
VECTOR_WIDTH> {
Vectorized()
: VectorizedQuantizedConverter<
c10::quint8,
std::array<Vectorized<float>, 4>,
std::array<Vectorized<c10::qint32>, 4>,
VECTOR_WIDTH>() {}
Vectorized(c10::quint8 val)
: VectorizedQuantizedConverter<
c10::quint8,
std::array<Vectorized<float>, 4>,
std::array<Vectorized<c10::qint32>, 4>,
VECTOR_WIDTH>(val) {}
Vectorized(const void* ptr)
: VectorizedQuantizedConverter<
c10::quint8,
std::array<Vectorized<float>, 4>,
std::array<Vectorized<c10::qint32>, 4>,
VECTOR_WIDTH>(ptr) {}
#if 1
static Vectorized<c10::quint8> loadu(const void* ptr) {
return Vectorized<c10::quint8>(ptr);
}
static Vectorized<c10::quint8> loadu(const void* ptr, int64_t count) {
__at_align__ value_type tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (const auto i : c10::irange(size())) {
tmp_values[i] = 0;
}
std::memcpy(tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
return loadu(tmp_values);
}
#else
static Vectorized<c10::quint8> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return svld1_u8(ptrue, reinterpret_cast<const uint8_t*>(ptr));
svbool_t pg = svwhilelt_b8(0ull, count);
return svld1_u8(pg, reinterpret_cast<const uint8_t*>(ptr));
}
#endif
static Vectorized<c10::quint8> quantize(
const float_vec_return_type& rhs,
float scale,
int32_t zero_point,
float inverse_scale) {
std::array<value_type, size()> qvals;
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
for (int i = 0; i < float_num_vecs(); ++i) {
rhs[i].store(&float_vals[i * Vectorized<float>::size()], Vectorized<float>::size());
}
at::native::quantize_vec<c10::quint8>(
scale,
zero_point,
float_vals.data(),
(c10::quint8*)qvals.data(),
Vectorized<float>::size() * float_num_vecs());
return Vectorized<c10::quint8>::loadu(qvals.data());
}
Vectorized<c10::quint8> maximum(Vectorized<c10::quint8> b) const {
Vectorized<c10::quint8> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::max<value_type>(vals[i], b.vals[i]);
}
return retval;
}
Vectorized<c10::quint8> minimum(Vectorized<c10::quint8> b) const {
Vectorized<c10::quint8> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
}
return retval;
}
Vectorized<c10::quint8> relu(Vectorized<c10::quint8> zero_point) const {
return maximum(zero_point);
}
Vectorized<c10::quint8> relu6(
Vectorized<c10::quint8> zero_point,
Vectorized<c10::quint8> q_six) {
Vectorized<c10::quint8> retval;
for (size_t i = 0; i < size(); ++i) {
retval.vals[i] = std::min<value_type>(
std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
}
return retval;
}
int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
int_vec_return_type retval;
constexpr int elem_per_int_vec = size() / int_num_vecs();
for (size_t i = 0; i < int_num_vecs(); ++i) {
for (size_t j = 0; j < elem_per_int_vec; ++j) {
retval[i].vals[j] =
static_cast<int32_t>(vals[i * elem_per_int_vec + j]) -
static_cast<int32_t>(b.vals[i * elem_per_int_vec + j]);
}
}
return retval;
}
static Vectorized<c10::quint8> requantize_from_int(
const int_vec_return_type& inp,
float multiplier,
int32_t zero_point) {
constexpr int elem_per_int_vec = size() / int_num_vecs();
constexpr auto min_val = std::numeric_limits<value_type>::min();
constexpr auto max_val = std::numeric_limits<value_type>::max();
Vectorized<c10::quint8> retval;
for (size_t i = 0; i < int_num_vecs(); ++i) {
for (size_t j = 0; j < elem_per_int_vec; ++j) {
int32_t rounded =
nearbyint(static_cast<float>(inp[i].vals[j]) * multiplier) +
zero_point;
retval.vals[i * elem_per_int_vec + j] =
std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
}
}
return retval;
}
};
template <>
Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
return a.maximum(b);
}
#endif // defined(CPU_CAPABILITY_SVE)
}}}

View File

@ -7,9 +7,13 @@
#include <ATen/cpu/vec/vec_base.h>
#if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR))
#include <ATen/cpu/vec/vec256/vec256_float.h>
#if defined(CPU_CAPABILITY_SVE256)
#include <ATen/cpu/vec/sve/vec_common_sve.h>
#else
#include <ATen/cpu/vec/vec256/vec256_float_neon.h>
#include <ATen/cpu/vec/vec256/vec256_half_neon.h>
#endif
#include <ATen/cpu/vec/vec256/vec256_float.h>
#include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
#include <ATen/cpu/vec/vec256/vec256_double.h>
#include <ATen/cpu/vec/vec256/vec256_int.h>

View File

@ -1097,7 +1097,7 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
return Vectorized<type>::loadu(arr2); \
}
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE256)
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_half_float(const Vectorized<Half>& a) {
static_assert(Vectorized<Half>::size() == 2 * Vectorized<float>::size());
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)

View File

@ -843,7 +843,7 @@ Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const V
return a.maximum(b);
}
#else
#elif !defined(CPU_CAPABILITY_SVE256)
// NOTE: These are low-performance implementations that we fall back on
// if we are not building with AVX2. This may not be an issue, because

View File

@ -990,7 +990,7 @@ inline mask_gather(const Vectorized<T>& src, T const* base_addr,
buffer[i] = src_arr[i];
}
}
mask = Vectorized<T>(); // "zero out" mask
mask = Vectorized<T>(static_cast<T>(0)); // "zero out" mask
return Vectorized<T>::loadu(static_cast<void*>(buffer));
}

View File

@ -1140,87 +1140,103 @@ REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel);
REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel);
REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel);
REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel);
REGISTER_SVE256_DISPATCH(cholesky_stub, &cholesky_kernel);
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
REGISTER_SVE256_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel);
REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
REGISTER_SVE256_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel);
REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
REGISTER_SVE256_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel);
REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel);
REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel);
REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel);
REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel);
REGISTER_SVE256_DISPATCH(geqrf_stub, &geqrf_kernel);
REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl);
REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
REGISTER_SVE256_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel);
REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel);
REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel);
REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel);
REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel);
REGISTER_SVE256_DISPATCH(ormqr_stub, &ormqr_kernel);
REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel);
REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel);
REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel);
REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel);
REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel);
REGISTER_SVE256_DISPATCH(lstsq_stub, &lstsq_kernel);
REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel);
REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_SVE256_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel);
REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_SVE256_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel);
REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel);
REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel);
REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel);
REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel);
REGISTER_SVE256_DISPATCH(ldl_factor_stub, &ldl_factor_kernel);
REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel);
REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
REGISTER_SVE256_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel);
REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel);
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel);
REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel);
REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel);
REGISTER_SVE256_DISPATCH(lu_solve_stub, &lu_solve_kernel);
REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel);
REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel);
REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel);
REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel);
REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel);
REGISTER_SVE256_DISPATCH(svd_stub, &svd_kernel);
REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel);
REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
REGISTER_SVE256_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
} // namespace at::native

View File

@ -34,6 +34,17 @@ static CPUCapability compute_cpu_capability() {
if (strcmp(envar, "zvector") == 0) {
return CPUCapability::ZVECTOR;
}
#elif defined(HAVE_SVE_CPU_DEFINITION)
int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW.
#ifdef HAVE_SVE256_CPU_DEFINITION
if (strcmp(envar, "sve256") == 0) {
if (sve_vl == 256) {
return CPUCapability::SVE256;
}
TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT");
return CPUCapability::DEFAULT;
}
#endif
#else
#ifdef HAVE_AVX512_CPU_DEFINITION
if (strcmp(envar, "avx512") == 0) {
@ -52,7 +63,7 @@ static CPUCapability compute_cpu_capability() {
TORCH_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar);
}
#if !defined(__powerpc__) && !defined(__s390x__)
#if !defined(__powerpc__) && !defined(__s390x__) && !defined(HAVE_SVE_CPU_DEFINITION)
if (cpuinfo_initialize()) {
#if defined(HAVE_AVX512_CPU_DEFINITION)
// GCC supports some AVX512 intrinsics such as _mm512_set_epi16 only in
@ -79,6 +90,23 @@ static CPUCapability compute_cpu_capability() {
}
#endif
#if defined(__linux__) && defined(HAVE_SVE_CPU_DEFINITION)
if (cpuinfo_initialize() && cpuinfo_has_arm_sve()) {
int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW.
if (sve_vl <= 0) {
// SVE is not supported on this system.
// Return the default CPU capability.
return CPUCapability::DEFAULT;
}
#ifdef HAVE_SVE256_CPU_DEFINITION
if (sve_vl == 256) { // Check for SVE256
return CPUCapability::SVE256;
}
#endif
// Return the default CPU capability.
return CPUCapability::DEFAULT;
}
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
return CPUCapability::VSX;
#else
@ -106,6 +134,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
) {
constexpr auto supported_devices = c10::array_of<c10::DeviceType>(
c10::DeviceType::CPU,
@ -139,6 +170,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, SVE256
#endif
);
if (!std::holds_alternative<ErrorType>(result)) {
@ -191,6 +225,9 @@ void* DispatchStubImpl::get_call_ptr(
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
) {
auto result = try_get_call_ptr(
@ -211,6 +248,10 @@ void* DispatchStubImpl::get_call_ptr(
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
,
ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
,
SVE256
#endif
);
if (std::holds_alternative<ErrorType>(result)) {
@ -242,6 +283,9 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
){
@ -274,6 +318,16 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
if (capability >= static_cast<int>(CPUCapability::ZVECTOR)) {
return ZVECTOR != nullptr ? DispatchResult(ZVECTOR) : ErrorType::MissingDeviceKernel;
}
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::SVE256)) {
if (C10_UNLIKELY(!SVE256)) {
// dispatch to DEFAULT, since the SVE kernel is missing
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
} else {
return DispatchResult(SVE256);
}
}
#endif
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
}
@ -292,6 +346,9 @@ void* DispatchStubImpl::choose_cpu_impl(
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
) {
auto capability = static_cast<int>(get_cpu_capability());
(void)capability;
@ -326,6 +383,17 @@ void* DispatchStubImpl::choose_cpu_impl(
TORCH_INTERNAL_ASSERT(ZVECTOR, "DispatchStub: missing ZVECTOR kernel");
return ZVECTOR;
}
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::SVE256)) {
if (C10_UNLIKELY(!SVE256)) {
// dispatch to DEFAULT, since the SVE kernel is missing
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
return DEFAULT;
} else {
return SVE256;
}
}
#endif
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
return DEFAULT;

View File

@ -64,6 +64,8 @@ enum class CPUCapability {
VSX = 1,
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
ZVECTOR = 1,
#elif defined(HAVE_SVE_CPU_DEFINITION)
SVE256 = 1,
#else
AVX2 = 1,
AVX512 = 2,
@ -112,6 +114,9 @@ struct TORCH_API DispatchStubImpl {
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
);
@ -130,6 +135,9 @@ struct TORCH_API DispatchStubImpl {
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
);
@ -148,6 +156,9 @@ struct TORCH_API DispatchStubImpl {
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
);
@ -169,6 +180,9 @@ struct TORCH_API DispatchStubImpl {
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, void *SVE256
#endif
);
@ -221,6 +235,9 @@ private:
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, reinterpret_cast<void*>(SVE256)
#endif
)
);
@ -275,6 +292,9 @@ public:
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, reinterpret_cast<void*>(SVE256)
#endif
);
if (std::holds_alternative<ErrorType>(result)){
@ -296,6 +316,9 @@ public:
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
static TORCH_API FnPtr ZVECTOR;
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
static TORCH_API FnPtr SVE256;
#endif
private:
DispatchStubImpl impl;
};
@ -387,6 +410,12 @@ struct RegisterPRIVATEUSE1Dispatch {
#define REGISTER_ZVECTOR_DISPATCH(name, fn)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn)
#else
#define REGISTER_SVE256_DISPATCH(name, fn)
#endif
// Macro to register the same kernel for all CPU arch types. This is useful
// if a kernel does not benefit from being recompiled across different arch types.
#define REGISTER_ALL_CPU_DISPATCH(name, fn) \
@ -394,7 +423,8 @@ struct RegisterPRIVATEUSE1Dispatch {
REGISTER_AVX512_DISPATCH(name, fn) \
REGISTER_AVX2_DISPATCH(name, fn) \
REGISTER_VSX_DISPATCH(name, fn) \
REGISTER_ZVECTOR_DISPATCH(name, fn)
REGISTER_ZVECTOR_DISPATCH(name, fn) \
REGISTER_SVE256_DISPATCH(name, fn)
#define REGISTER_NO_CPU_DISPATCH(name) \
REGISTER_ALL_CPU_DISPATCH(name, nullptr)
@ -432,12 +462,14 @@ struct RegisterPRIVATEUSE1Dispatch {
#elif defined(CPU_CAPABILITY)
// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
// ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others.
#ifdef CPU_CAPABILITY_AVX512
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
#else
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
} // namespace at::native

View File

@ -466,6 +466,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cp
REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
REGISTER_SVE256_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
// offsets dispatches
REGISTER_ARCH_DISPATCH(
@ -476,6 +477,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cp
REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
// Currently some computation is being duplicated across forward and backward.
// TODO: Cache indices in forward pass to re-use in backward
@ -546,6 +548,9 @@ REGISTER_VSX_DISPATCH(
REGISTER_ZVECTOR_DISPATCH(
_segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_SVE256_DISPATCH(
_segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_ARCH_DISPATCH(
_segment_reduce_offsets_backward_stub,
@ -563,5 +568,8 @@ REGISTER_VSX_DISPATCH(
REGISTER_ZVECTOR_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel);
REGISTER_SVE256_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel);
} // namespace at::native

View File

@ -19,7 +19,7 @@ Vectorized<scalar_t> is_lerp_weight_small(Vectorized<scalar_t> weight) {
// is_lerp_weight_small doesn't work for complex because z.abs() returns a
// complex vector which can't be compared. Either implement it with z.abs_2_(),
// or fallback to the scalar function.
#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER))
#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) || defined(CPU_CAPABILITY_SVE))
template <typename value_t>
Vectorized<c10::complex<value_t>> is_lerp_weight_small(Vectorized<c10::complex<value_t>> weight) {
using vec_reg_t = decltype(weight.abs_2_());

View File

@ -165,6 +165,7 @@ REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_co
REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_SVE256_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
// _out variants can be shared between PocketFFT and MKL
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,

View File

@ -27,5 +27,6 @@ REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel);
} // namespace at::native

View File

@ -161,16 +161,19 @@ REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel);
REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
}

View File

@ -449,6 +449,7 @@ REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp);
int64_t _fused_sdp_choice_meta(
const Tensor& query_,

View File

@ -992,6 +992,9 @@ namespace {
blend_init(a, b);
test_blendv<vec, VT, 0, vec::size()>(expected_val, a, b, mask);
}
// NOTE: In this test, blend<mask> is not required to implement SVE Vectorized::set.
// so, this test is disabled for SVE.
#if !defined(CPU_CAPABILITY_SVE)
TYPED_TEST(BitwiseFloatsAdditional2, Blend) {
using vec = TypeParam;
using VT = ValueType<TypeParam>;
@ -1005,6 +1008,7 @@ namespace {
constexpr int64_t power_sets = 1LL << (vec::size());
test_blend<vec, VT, power_sets - 1>(expected_val, a, b);
}
#endif
template<typename vec, typename VT>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
void test_set(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], int64_t count){
@ -1606,6 +1610,7 @@ namespace {
ASSERT_TRUE(vec_pinf.has_inf_nan()) << "Test failed for positive Infinity\n";
ASSERT_TRUE(vec_ninf.has_inf_nan()) << "Test failed for negative Infinity\n";
}
#if !defined(CPU_CAPABILITY_SVE)
TYPED_TEST(VecConvertTests, Convert) {
using vec = TypeParam;
using src_t = ValueType<TypeParam>;
@ -1658,6 +1663,7 @@ namespace {
TEST_CONVERT_TO(double);
#undef TEST_CONVERT_TO
}
#endif
TYPED_TEST(VecMaskTests, MaskedLoad) {
using vec = TypeParam;
using src_t = ValueType<TypeParam>;
@ -1716,6 +1722,7 @@ namespace {
#undef TEST_MASK_LOAD
#undef TEST_MASK_LOAD_N
}
#if !defined(CPU_CAPABILITY_SVE)
TYPED_TEST(VecMaskTests, MaskedCheck) {
using VT = ValueType<TypeParam>;
using vec = TypeParam;
@ -1739,6 +1746,8 @@ namespace {
#undef TEST_MASK_CHECK_N
}
#endif
#if !defined(CPU_CAPABILITY_SVE)
TYPED_TEST(VecMaskTests, ToFrom) {
using vec = TypeParam;
using VT = ValueType<TypeParam>;
@ -1764,6 +1773,8 @@ namespace {
<< "Failure Details:\nTest Seed to reproduce: " << seed;
}
}
#endif
#if !defined(CPU_CAPABILITY_SVE)
TYPED_TEST(VecMaskTests, Cast) {
using vec = TypeParam;
using src_t = ValueType<TypeParam>;
@ -1808,6 +1819,7 @@ namespace {
#undef TEST_MASK_CAST
#undef TEST_MASK_CAST_N
}
#endif
#else
#error GTEST does not have TYPED_TEST
#endif

View File

@ -53,6 +53,9 @@ CACHE_ALIGN #define
defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__))
#undef CHECK_DEQUANT_WITH_LOW_PRECISION
#define CHECK_WITH_FMA 1
#elif defined(CPU_CAPABILITY_SVE)
#define CHECK_DEQUANT_WITH_LOW_PRECISION 1
#define CHECK_WITH_FMA 1
#elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2)
#undef CHECK_DEQUANT_WITH_LOW_PRECISION
#undef CHECK_WITH_FMA

View File

@ -322,6 +322,18 @@ if(INTERN_BUILD_ATEN_OPS)
LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}")
endif(CXX_ZVECTOR_FOUND)
if(CXX_SVE_FOUND)
if(CXX_SVE256_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION")
list(APPEND CPU_CAPABILITY_NAMES "SVE256")
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8.2-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
else()
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8.2-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
endif()
endif(CXX_SVE256_FOUND)
endif(CXX_SVE_FOUND)
list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES)
math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1")

View File

@ -22,6 +22,15 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(ASIMD_FOUND false CACHE BOOL "ASIMD/NEON available on host")
ENDIF (ASIMD_TRUE)
#sve instruction can be found on the majority part of modern ARM processor
STRING(REGEX REPLACE "^.*(sve).*$" "\\1" SVE_THERE ${CPUINFO})
STRING(COMPARE EQUAL "sve" "${SVE_THERE}" SVE_TRUE)
IF (SVE_TRUE)
set(SVE_FOUND true CACHE BOOL "SVE available on host")
ELSE (SVE_TRUE)
set(SVE_FOUND false CACHE BOOL "SVE available on host")
ENDIF (SVE_TRUE)
#Find the processor type (for now OMAP3 or OMAP4)
STRING(REGEX REPLACE "^.*(OMAP3).*$" "\\1" OMAP3_THERE "${CPUINFO}")
STRING(COMPARE EQUAL "OMAP3" "${OMAP3_THERE}" OMAP3_TRUE)
@ -79,3 +88,72 @@ if(NOT CORTEXA9_FOUND)
MESSAGE(STATUS "No OMAP4 processor on this machine.")
endif(NOT CORTEXA9_FOUND)
mark_as_advanced(NEON_FOUND)
#SVE support is availale is only for Linux OS.
IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
# Include necessary modules for checking C and C++ source compilations
INCLUDE(CheckCSourceCompiles)
INCLUDE(CheckCXXSourceCompiles)
# Test code for SVE support
SET(SVE_CODE "
#include <arm_sve.h>
int main()
{
svfloat64_t a;
a = svdup_n_f64(0);
return 0;
}
")
# Macro to check for SVE instruction support
MACRO(CHECK_SVE lang type flags)
# Save the current state of required flags
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
# Set the flags necessary for compiling the test code with SVE support
SET(CMAKE_REQUIRED_FLAGS "${CMAKE_${lang}_FLAGS_INIT} ${flags}")
# Check if the source code compiles with the given flags for the specified language (C or C++)
IF(lang STREQUAL "CXX")
CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type})
ELSE()
CHECK_C_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type})
ENDIF()
# If the compilation test is successful, set appropriate variables indicating support
IF(${lang}_HAS_${type})
set(${lang}_SVE_FOUND TRUE CACHE BOOL "SVE available on host")
SET(${lang}_${type}_FOUND TRUE CACHE BOOL "${lang} ${type} support")
SET(${lang}_${type}_FLAGS "${flags}" CACHE STRING "${lang} ${type} flags")
ENDIF()
# Restore the original state of required flags
SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
# If the compilation test fails, indicate that the support is not found
IF(NOT ${lang}_${type}_FOUND)
SET(${lang}_${type}_FOUND FALSE CACHE BOOL "${lang} ${type} support")
SET(${lang}_${type}_FLAGS "" CACHE STRING "${lang} ${type} flags")
ENDIF()
# Mark the variables as advanced to hide them in the default CMake GUI
MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS)
ENDMACRO()
# Check for SVE256 vector length
CHECK_SVE(CXX "SVE256" "-march=armv8-a+sve -msve-vector-bits=256")
# If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user
if(NOT CXX_SVE256_FOUND)
set(CXX_SVE_FOUND FALSE CACHE BOOL "SVE not available on host")
message(STATUS "No SVE processor on this machine.")
else()
# If SVE256 support is found, set CXX_SVE_FOUND to TRUE and notify the user
set(CXX_SVE_FOUND TRUE CACHE BOOL "SVE available on host")
message(STATUS "SVE support detected.")
endif()
# Mark the SVE support variable as advanced
mark_as_advanced(CXX_SVE_FOUND)
ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux")

View File

@ -1240,6 +1240,7 @@ def main():
"include/ATen/cpu/vec/vec256/zarch/*.h",
"include/ATen/cpu/vec/vec512/*.h",
"include/ATen/cpu/vec/*.h",
"include/ATen/cpu/vec/sve/*.h",
"include/ATen/core/*.h",
"include/ATen/cuda/*.cuh",
"include/ATen/cuda/*.h",

View File

@ -16,5 +16,6 @@ def get_cpu_capability() -> str:
- "NO AVX"
- "AVX2"
- "AVX512"
- "SVE256"
"""
return torch._C._get_cpu_capability()