Compare commits

...

9 Commits

Author SHA1 Message Date
8ac81dba21 add vec128 vecreduce 2025-09-23 17:16:40 +00:00
dae9a71d99 fix ci issues 2025-09-12 09:32:19 +00:00
bf4b0e8c41 fix double free 2025-09-12 09:32:19 +00:00
0384f48daa Fix compile 2025-09-12 09:32:18 +00:00
3b92a1adfe Fix tests 2025-09-12 09:30:57 +00:00
6ca9dc026d add SVE dispatch 2025-09-12 09:30:57 +00:00
a499828924 Make size non-constexpr 2025-09-12 09:30:57 +00:00
e84eabd4f9 Vec length agnostic SVE Vectorized class POC 2025-09-12 09:30:57 +00:00
5e53e458b9 [feat]: add optimized exp_u20 implementation from Arm Optimized Routines (AOR)
Co-authored-by: Fadi Arafeh <Fadi.Arafeh@arm.com>
Signed-off-by: Analle Abuammar <analle.abuammar@arm.com>
2025-09-12 09:30:57 +00:00
64 changed files with 953 additions and 767 deletions

View File

@ -10,6 +10,10 @@
#include <ideep.hpp>
#endif
#if !defined(__s390x__) && !defined(__powerpc__)
#include <cpuinfo.h>
#endif
#include <caffe2/core/common.h>
#include <ATen/native/DispatchStub.h>
@ -103,7 +107,9 @@ std::string get_cpu_capability() {
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
case native::CPUCapability::ZVECTOR:
return "Z VECTOR";
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
#elif defined(HAVE_SVE_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
case native::CPUCapability::SVE:
return "SVE";
case native::CPUCapability::SVE256:
return "SVE256";
#else
@ -118,6 +124,12 @@ std::string get_cpu_capability() {
return "";
}
int get_sve_len() {
// It is possible that we override the cpu_capability with
// environment variable
return cpuinfo_get_max_arm_sve_length();
}
static std::string used_cpu_capability() {
// It is possible that we override the cpu_capability with
// environment variable

View File

@ -15,4 +15,6 @@ TORCH_API std::string get_cxx_flags();
TORCH_API std::string get_cpu_capability();
TORCH_API int get_sve_len();
} // namespace at

View File

@ -34,9 +34,9 @@ inline scalar_t vec_reduce_all(
scalar_t acc_arr[Vec::size()];
acc_vec.store(acc_arr);
for (const auto i : c10::irange(1, size)) {
std::array<scalar_t, Vec::size()> acc_arr_next = {0};
scalar_t acc_arr_next[Vec::size()] = {0};
acc_arr_next[0] = acc_arr[i];
Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
Vec acc_vec_next = Vec::loadu(acc_arr_next);
acc_vec = vec_fun(acc_vec, acc_vec_next);
}
acc_vec.store(acc_arr);
@ -102,8 +102,7 @@ struct VecReduceAllSIMD<float, Op> {
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
// !defined(C10_MOBILE)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
!defined(CPU_CAPABILITY_SVE)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(
@ -143,8 +142,7 @@ struct VecReduceAllSIMD<float, std::plus<Vectorized<float>>> {
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
// && !defined(CPU_CAPABILITY_SVE)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
defined(CPU_CAPABILITY_SVE256)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && (defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE))
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(
@ -152,18 +150,28 @@ struct VecReduceAllSIMD<float, Op> {
const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 128-bit shuffle
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
Vec v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 64-bit shuffle
ind = svdupq_n_u32(2, 3, 0, 1);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 32-bit shuffle
ind = svdupq_n_u32(1, 0, 2, 3);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
if (Vec::size() == 8) {
// 128-bit shuffle
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
Vec v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 64-bit shuffle
ind = svdupq_n_u32(2, 3, 0, 1);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 32-bit shuffle
ind = svdupq_n_u32(1, 0, 2, 3);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
} else {
svuint32_t ind = svdupq_n_u32(2, 3, 0, 1); // 64-bit stride-2
Vec v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
ind = svdupq_n_u32(1, 0, 2, 3); // 32-bit stride-1
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
}
return svlasta(svpfalse(), v);
}
};

View File

@ -4,7 +4,7 @@
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_SVE)
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
// Define the data type of VLS(vector-length specific).
typedef svbool_t vls_pred_t
@ -77,4 +77,4 @@ typedef svfloat64_t vls_float64_t
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK)
#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK)
#endif // defined(CPU_CAPABILITY_SVE)
#endif // defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)

View File

@ -19,7 +19,7 @@ namespace vec {
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
#if (defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)) && defined(__ARM_FEATURE_BF16)
template <>
struct is_vec_specialized_for<BFloat16> : std::bool_constant<true> {};
@ -230,8 +230,6 @@ __attribute__((optimize("no-tree-vectorize")))
#endif
inline std::tuple<Vectorized<float>, Vectorized<float>>
convert_bfloat16_float(const Vectorized<c10::BFloat16>& a) {
static_assert(
Vectorized<c10::BFloat16>::size() == 2 * Vectorized<float>::size());
auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f));
auto bf16_vec1 = svzip1_bf16(zero, a);
auto bf16_vec2 = svzip2_bf16(zero, a);
@ -243,19 +241,18 @@ convert_bfloat16_float(const Vectorized<c10::BFloat16>& a) {
inline Vectorized<c10::BFloat16> convert_float_bfloat16(
const Vectorized<float>& a,
const Vectorized<float>& b) {
static_assert(
Vectorized<c10::BFloat16>::size() == 2 * Vectorized<float>::size());
svbfloat16_t x1 = svcvt_bf16_f32_z(ptrue, a);
svbfloat16_t x2 = svcvt_bf16_f32_z(ptrue, b);
return Vectorized<c10::BFloat16>(svuzp1_bf16(x1, x2));
}
inline void load_fp32_from_bf16(const BFloat16* data, Vectorized<float>& out) {
__at_align__ float values[Vectorized<float>::size()];
__at_align__ float * values = new float[Vectorized<float>::size()];
for (const auto k : c10::irange(Vectorized<float>::size())) {
values[k] = data[k];
}
out = Vectorized<float>::loadu(values);
delete[] values;
}
inline void load_fp32_from_bf16(
@ -308,8 +305,8 @@ Vectorized<c10::BFloat16> inline operator/(
}
inline Vectorized<BFloat16>::Vectorized() {
const short zero = 0;
values = svdup_n_bf16(c10::bit_cast<bfloat16_t>(zero));
auto vals_f = svdup_n_f32(0);
values = convert_float_bfloat16(vals_f, vals_f);
}
inline Vectorized<BFloat16>::Vectorized(int val) {

View File

@ -8,7 +8,7 @@
#include <ATen/cpu/vec/sve/sve_helper.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_SVE)
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE256)
#include <ATen/cpu/vec/sve/vec_bfloat16.h>
#include <ATen/cpu/vec/sve/vec_double.h>
#include <ATen/cpu/vec/sve/vec_float.h>
@ -27,7 +27,7 @@ namespace at::vec {
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#define DEFINE_SVE_CAST(t1_t, t1_prefix, t2_t, t2_prefix) \
@ -231,6 +231,5 @@ std::pair<
#endif // __ARM_FEATURE_BF16
#endif // defined(CPU_CAPABILITY_SVE)
} // namespace CPU_CAPABILITY
} // namespace at::vec
}

View File

@ -22,7 +22,7 @@ namespace at::vec {
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
template <>
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
@ -55,10 +55,11 @@ class Vectorized<double> {
operator svfloat64_t() const {
return values;
}
template <uint64_t mask>
static Vectorized<double> blend(
const Vectorized<double>& a,
const Vectorized<double>& b) {
const Vectorized<double>& b,
int64_t mask
) {
// Build an array of flags: each element is 1 if the corresponding bit in
// 'mask' is set, 0 otherwise.
__at_align__ int64_t flag_arr[size()];

View File

@ -2,8 +2,10 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/sve/sve_helper.h>
#include <ATen/cpu/vec/vec_base.h>
#include <algorithm>
#include <cmath>
#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
#include <sleef.h>
#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
@ -22,7 +24,7 @@ namespace at::vec {
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE256)
template <>
struct is_vec_specialized_for<float> : std::bool_constant<true> {};
@ -30,52 +32,77 @@ struct is_vec_specialized_for<float> : std::bool_constant<true> {};
template <>
class Vectorized<float> {
private:
vls_float32_t values;
__at_align__ float values[2048 / sizeof(float)];
public:
using value_type = float;
using size_type = int;
static constexpr size_type size() {
return VECTOR_WIDTH / sizeof(float);
static inline size_type size() {
return svcntw();
}
Vectorized() {
values = svdup_n_f32(0);
inline Vectorized() {svst1_f32(ptrue, values, svdup_n_f32(0));}
inline Vectorized(const float val) {
svst1_f32(ptrue, values, svdup_n_f32(val));
}
Vectorized(svfloat32_t v) : values(v) {}
Vectorized(float val) {
values = svdup_n_f32(val);
inline Vectorized(const svfloat32_t val) {
svst1_f32(ptrue, values, val);
}
template <
typename... Args,
typename = std::enable_if_t<(sizeof...(Args) == size())>>
Vectorized(Args... vals) {
__at_align__ float buffer[size()] = {vals...};
values = svld1_f32(ptrue, buffer);
template<typename T,
typename = std::enable_if_t<std::is_pointer_v<T>>>
inline Vectorized(float * val) {
svst1_f32(ptrue, values, svld1_f32(ptrue, val));
}
operator svfloat32_t() const {
return values;
template<typename... Args,
typename = std::enable_if_t<(sizeof...(Args) == size())>>
inline Vectorized(Args... vals) {
values = { vals... };
}
template <uint64_t mask>
static Vectorized<float> blend(
const Vectorized<float>& a,
const Vectorized<float>& b) {
// Build an array of flags: each element is 1 if the corresponding bit in
// 'mask' is set, 0 otherwise.
__at_align__ int32_t flag_arr[size()];
inline operator svfloat32_t() const {
return svld1_f32(ptrue, values);
}
static inline Vectorized<float> from_ptr(const float * vs) {
Vectorized<float> v;
svst1_f32(ptrue, v.values, svld1_f32(ptrue, static_cast<const float *>(vs)));
return v;
}
static inline Vectorized<float> from_ptr(const float * vs, int count) {
Vectorized<float> v;
svst1_f32(ptrue, v.values, svld1_f32(svwhilelt_b32_s32(0, count), static_cast<const float *>(vs)));
return v;
}
inline void set_lane(int i, float value) {
values[i] = value;
}
inline Vectorized<float> map(float (*fn)(float)) const {
Vectorized<float> result;
for (int64_t i = 0; i < size(); ++i) {
result.set_lane(i, fn(values[i]));
}
return result;
}
inline Vectorized<float> map2(float (*fn)(float, float), const Vectorized<float> &b) const {
Vectorized<float> result;
for (int64_t i = 0; i < size(); ++i) {
result.set_lane(i, fn(values[i], b.values[i]));
}
return result;
}
static inline Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b, const uint64_t mask) {
// Build an array of flags: each element is 1 if the corresponding bit in 'mask' is set, 0 otherwise.
__at_align__ int32_t * flag_arr = new int32_t[size()];
for (int i = 0; i < size(); i++) {
flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0;
}
// Load the flag array into an SVE int32 vector.
svint32_t int_mask = svld1_s32(svptrue_b32(), flag_arr);
// Compare each lane of int_mask to 0; returns an svbool_t predicate where
// true indicates a nonzero flag.
svbool_t blend_mask = svcmpne_n_s32(svptrue_b32(), int_mask, 0);
// Use svsel to select elements from b where the predicate is true, else
// from a.
svfloat32_t result = svsel_f32(blend_mask, b.values, a.values);
return Vectorized<float>(result);
svint32_t int_mask = svld1_s32(ptrue, flag_arr);
delete[] flag_arr;
// Compare each lane of int_mask to 0; returns an svbool_t predicate where true indicates a nonzero flag.
svbool_t blend_mask = svcmpne_n_s32(ptrue, int_mask, 0);
// Use svsel to select elements from b where the predicate is true, else from a.
return svsel_f32(blend_mask, b, a);
}
static Vectorized<float> blendv(
static inline Vectorized<float> blendv(
const Vectorized<float>& a,
const Vectorized<float>& b,
const Vectorized<float>& mask_) {
@ -84,16 +111,18 @@ class Vectorized<float> {
return svsel_f32(mask, b, a);
}
template <typename step_t>
static Vectorized<float> arange(
static inline Vectorized<float> arange(
float base = 0.f,
step_t step = static_cast<step_t>(1)) {
__at_align__ float buffer[size()];
__at_align__ float * buffer = new float[size()];
for (int64_t i = 0; i < size(); i++) {
buffer[i] = base + i * step;
}
return svld1_f32(ptrue, buffer);
auto tmp = Vectorized<float>::from_ptr(buffer);
delete[] buffer;
return tmp;
}
static Vectorized<float> set(
static inline Vectorized<float> set(
const Vectorized<float>& a,
const Vectorized<float>& b,
int64_t count = size()) {
@ -169,271 +198,213 @@ class Vectorized<float> {
poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
return poly;
}
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
svbool_t pg = svwhilelt_b32(0ull, count);
return svld1_f32(pg, reinterpret_cast<const float*>(ptr));
static inline Vectorized<float> loadu(const void* ptr) {
return Vectorized<float>::from_ptr(reinterpret_cast<const float *>(ptr));
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
svst1_f32(ptrue, reinterpret_cast<float*>(ptr), values);
} else {
svbool_t pg = svwhilelt_b32(0ull, count);
svst1_f32(pg, reinterpret_cast<float*>(ptr), values);
}
static inline Vectorized<float> loadu(const void* ptr, int64_t count) {
return Vectorized<float>::from_ptr(reinterpret_cast<const float *>(ptr), count);
}
const float& operator[](int idx) const = delete;
float& operator[](int idx) = delete;
int64_t zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit
// and others are translated to 0-bit
inline void store(void* ptr) const {
svst1_f32(ptrue, static_cast<float *>(ptr), svld1_f32(ptrue, values));
}
inline void store(void* ptr, int count) const {
svst1_f32(svwhilelt_b32_s32(0, count), static_cast<float *>(ptr), svld1_f32(ptrue, values));
}
inline const float& operator[](int idx) const {
return values[idx];
};
inline float& operator[](int idx) {
return values[idx];
};
inline int64_t zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
int64_t mask = 0;
__at_align__ int32_t mask_array[size()];
__at_align__ int32_t * mask_array = new int32_t[size()];
svbool_t svbool_mask = svcmpeq_f32(ptrue, values, ZERO_F32);
svst1_s32(
ptrue,
mask_array,
svsel_s32(svbool_mask, ALL_S32_TRUE_MASK, ALL_S32_FALSE_MASK));
for (int64_t i = 0; i < size(); ++i) {
if (mask_array[i])
mask |= (1ull << i);
svbool_t svbool_mask = svcmpeq_f32(ptrue, *this, ZERO_F32);
svst1_s32(ptrue, mask_array, svsel_s32(svbool_mask,
ALL_S32_TRUE_MASK,
ALL_S32_FALSE_MASK));
for (int64_t j = 0; j < size(); ++j) {
if (mask_array[j]) mask |= (1ull << j);
}
delete[] mask_array;
return mask;
}
Vectorized<float> isnan() const {
inline Vectorized<float> isnan() const {
// NaN check
svbool_t mask = svcmpuo_f32(ptrue, values, ZERO_F32);
auto mask = svcmpuo_f32(ptrue, *this, ZERO_F32);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
bool has_inf_nan() const {
return svptest_any(
ptrue,
svcmpuo_f32(ptrue, svsub_f32_x(ptrue, values, values), ZERO_F32));
inline bool has_inf_nan() const {
return svptest_any(ptrue, svcmpuo_f32(ptrue, svsub_f32_x(ptrue, *this, *this), ZERO_F32));
}
Vectorized<float> map(float (*f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
inline Vectorized<float> abs() const {
return svabs_f32_x(ptrue, *this);
}
Vectorized<float> abs() const {
return svabs_f32_x(ptrue, values);
}
Vectorized<float> angle() const {
inline Vectorized<float> angle() const {
const auto nan_vec = svdup_n_f32(NAN);
const auto nan_mask = svcmpuo_f32(ptrue, values, ZERO_F32);
const auto nan_mask = svcmpuo_f32(ptrue, *this, ZERO_F32);
const auto pi = svdup_n_f32(c10::pi<float>);
const auto neg_mask = svcmplt_f32(ptrue, values, ZERO_F32);
const auto neg_mask = svcmplt_f32(ptrue, *this, ZERO_F32);
auto angle = svsel_f32(neg_mask, pi, ZERO_F32);
angle = svsel_f32(nan_mask, nan_vec, angle);
return angle;
return svsel_f32(nan_mask, nan_vec, angle);
}
Vectorized<float> real() const {
return values;
inline Vectorized<float> real() const {
return *this;
}
Vectorized<float> imag() const {
inline Vectorized<float> imag() const {
return Vectorized<float>(0.f);
}
Vectorized<float> conj() const {
return values;
inline Vectorized<float> conj() const {
return *this;
}
Vectorized<float> acos() const {
return USE_SLEEF(
Vectorized<float>(Sleef_acosfx_u10sve(values)), map(std::acos));
inline Vectorized<float> acos() const {
return USE_SLEEF(Sleef_acosfx_u10sve(*this), map(std::acos));
}
Vectorized<float> acosh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_acoshfx_u10sve(values)), map(std::acosh));
inline Vectorized<float> acosh() const {
return USE_SLEEF(Sleef_acoshfx_u10sve(*this), map(std::acosh));
}
Vectorized<float> asin() const {
return USE_SLEEF(
Vectorized<float>(Sleef_asinfx_u10sve(values)), map(std::asin));
inline Vectorized<float> asin() const {
return USE_SLEEF(Sleef_asinfx_u10sve(*this), map(std::asin));
}
Vectorized<float> asinh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_asinhfx_u10sve(values)), map(std::asinh));
inline Vectorized<float> asinh() const {
return USE_SLEEF(Sleef_asinhfx_u10sve(*this), map(std::asinh));
}
Vectorized<float> atan() const {
return USE_SLEEF(
Vectorized<float>(Sleef_atanfx_u10sve(values)), map(std::atan));
inline Vectorized<float> atan() const {
return USE_SLEEF(Sleef_atanfx_u10sve(*this), map(std::atan));
}
Vectorized<float> atanh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_atanhfx_u10sve(values)), map(std::atanh));
inline Vectorized<float> atanh() const {
return USE_SLEEF(Sleef_atanhfx_u10sve(*this), map(std::atanh));
}
Vectorized<float> atan2(const Vectorized<float>& b) const {USE_SLEEF(
{ return Vectorized<float>(Sleef_atan2fx_u10sve(values, b)); },
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} Vectorized<float> copysign(const Vectorized<float>& sign) const {
USE_SLEEF(
{ return Vectorized<float>(Sleef_copysignfx_sve(values, sign)); },
{
__at_align__ float tmp[size()];
__at_align__ float tmp_sign[size()];
store(tmp);
sign.store(tmp_sign);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
}
return loadu(tmp);
})} Vectorized<float> erf() const {
return USE_SLEEF(
Vectorized<float>(Sleef_erffx_u10sve(values)), map(std::erf));
inline Vectorized<float> atan2(const Vectorized<float> &b) const {
return USE_SLEEF(Sleef_atan2fx_u10sve(*this, b), map2(std::atan2, b));
}
Vectorized<float> erfc() const {
return USE_SLEEF(
Vectorized<float>(Sleef_erfcfx_u15sve(values)), map(std::erfc));
inline Vectorized<float> copysign(const Vectorized<float> &sign) const {
return USE_SLEEF(Sleef_copysignfx_sve(*this, sign), map2(std::copysign, sign));
}
Vectorized<float> erfinv() const {
inline Vectorized<float> erf() const {
return USE_SLEEF(Sleef_erffx_u10sve(*this), map(std::erf));
}
inline Vectorized<float> erfc() const {
return USE_SLEEF(Sleef_erfcfx_u15sve(*this), map(std::erfc));
}
inline Vectorized<float> erfinv() const {
return map(calc_erfinv);
}
Vectorized<float> exp() const {
return USE_SLEEF(
Vectorized<float>(Sleef_expfx_u10sve(values)), map(std::exp));
inline Vectorized<float> exp() const {
return USE_SLEEF(Sleef_expfx_u10sve(*this), map(std::exp));
}
Vectorized<float> exp2() const {
return USE_SLEEF(
Vectorized<float>(Sleef_exp2fx_u10sve(values)), map(std::exp2));
inline Vectorized<float> exp2() const {
return USE_SLEEF(Sleef_exp2fx_u10sve(*this), map(std::exp2));
}
Vectorized<float> expm1() const {
return USE_SLEEF(
Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
inline Vectorized<float> expm1() const {
return USE_SLEEF(Sleef_expm1fx_u10sve(*this), map(std::expm1));
}
// Implementation copied from Arm Optimized Routines:
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
Vectorized<float> exp_u20() const {
return exp();
// special case to handle special inputs that are too large or too small
// i.e. where there's at least one element x, s.t. |x| >= 87.3...
svbool_t is_special_case = svacgt (svptrue_b32(), *this, 0x1.5d5e2ap+6f);
if (svptest_any (svptrue_b32(), is_special_case)) {
return exp();
}
const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
const svfloat32_t c1 = svdup_n_f32(0.5f);
const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
const float shift = 0x1.803f8p17f;
/* n = round(x/(ln2/N)). */
svfloat32_t z = svmad_x (svptrue_b32(), inv_ln2, *this, shift);
svfloat32_t n = svsub_x (svptrue_b32(), z, shift);
/* r = x - n*ln2/N. */
svfloat32_t r = *this;
r = svmls_x(svptrue_b32(), r, n, ln2_hi);
r = svmls_x(svptrue_b32(), r, n, ln2_lo);
/* scale = 2^(n/N). */
svfloat32_t scale = svexpa (svreinterpret_u32 (z));
/* poly(r) = exp(r) - 1 ~= r + 0.5 r^2. */
svfloat32_t r2 = svmul_x (svptrue_b32 (), r, r);
svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
return svmla_x (svptrue_b32(), scale, scale, poly);
}
Vectorized<float> fexp_u20() const {
return exp();
return exp_u20();
}
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
{
__at_align__ float tmp[size()];
__at_align__ float tmp_q[size()];
store(tmp);
q.store(tmp_q);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
}
return loadu(tmp);
})} Vectorized<float> hypot(const Vectorized<float>& b) const {
USE_SLEEF(
{ return Vectorized<float>(Sleef_hypotfx_u05sve(values, b)); },
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} Vectorized<float> i0() const {
inline Vectorized<float> fmod(const Vectorized<float>& q) const {
return USE_SLEEF(Sleef_fmodfx_sve(*this, q), return map2(std::fmod, q));
}
inline Vectorized<float> hypot(const Vectorized<float> &b) const {
return USE_SLEEF(Sleef_hypotfx_u05sve(*this, b), map2(std::hypot, b));
}
inline Vectorized<float> i0() const {
return map(calc_i0);
}
Vectorized<float> i0e() const {
return map(calc_i0e);
inline Vectorized<float> i0e() const {
return map(calc_i0e<float>);
}
Vectorized<float> digamma() const {
inline Vectorized<float> digamma() const {
return map(calc_digamma);
}
Vectorized<float> igamma(const Vectorized<float>& x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
inline Vectorized<float> igamma(const Vectorized<float> &x) const {
return map2(calc_igamma<float>, x);
}
Vectorized<float> igammac(const Vectorized<float>& x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
inline Vectorized<float> igammac(const Vectorized<float> &x) const {
return map2(calc_igammac<float>, x);
}
Vectorized<float> nextafter(const Vectorized<float>& b) const {USE_SLEEF(
{ return Vectorized<float>(Sleef_nextafterfx_sve(values, b)); },
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} Vectorized<float> log() const {
return USE_SLEEF(
Vectorized<float>(Sleef_logfx_u10sve(values)), map(std::log));
inline Vectorized<float> nextafter(const Vectorized<float> &b) const {
return USE_SLEEF(Sleef_nextafterfx_sve(*this, b), map2(std::nextafter, b));
}
Vectorized<float> log2() const {
return USE_SLEEF(
Vectorized<float>(Sleef_log2fx_u10sve(values)), map(std::log2));
inline Vectorized<float> log() const {
return USE_SLEEF(Sleef_logfx_u10sve(*this), map(std::log));
}
Vectorized<float> log10() const {
return USE_SLEEF(
Vectorized<float>(Sleef_log10fx_u10sve(values)), map(std::log10));
inline Vectorized<float> log2() const {
return USE_SLEEF(Sleef_log2fx_u10sve(*this), map(std::log2));
}
Vectorized<float> log1p() const {
return USE_SLEEF(
Vectorized<float>(Sleef_log1pfx_u10sve(values)), map(std::log1p));
inline Vectorized<float> log10() const {
return USE_SLEEF(Sleef_log10fx_u10sve(*this), map(std::log10));
}
Vectorized<float> frac() const;
Vectorized<float> sin() const {
return USE_SLEEF(
Vectorized<float>(Sleef_sinfx_u10sve(values)), map(std::sin));
inline Vectorized<float> log1p() const {
return USE_SLEEF(Sleef_log1pfx_u10sve(*this), map(std::log1p));
}
Vectorized<float> sinh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_sinhfx_u10sve(values)), map(std::sinh));
inline Vectorized<float> frac() const;
inline Vectorized<float> sin() const {
return USE_SLEEF(Sleef_sinfx_u10sve(*this), map(std::sin));
}
Vectorized<float> cos() const {
return USE_SLEEF(
Vectorized<float>(Sleef_cosfx_u10sve(values)), map(std::cos));
inline Vectorized<float> sinh() const {
return USE_SLEEF(Sleef_sinhfx_u10sve(*this), map(std::sinh));
}
Vectorized<float> cosh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_coshfx_u10sve(values)), map(std::cosh));
inline Vectorized<float> cos() const {
return USE_SLEEF(Sleef_cosfx_u10sve(*this), map(std::cos));
}
Vectorized<float> ceil() const {
return svrintp_f32_x(ptrue, values);
inline Vectorized<float> cosh() const {
return USE_SLEEF(Sleef_coshfx_u10sve(*this), map(std::cosh));
}
Vectorized<float> floor() const {
return svrintm_f32_x(ptrue, values);
inline Vectorized<float> ceil() const {
return svrintp_f32_x(ptrue, *this);
}
Vectorized<float> neg() const {
return svneg_f32_x(ptrue, values);
inline Vectorized<float> floor() const {
return svrintm_f32_x(ptrue, *this);
}
Vectorized<float> round() const {
return svrinti_f32_x(ptrue, values);
inline Vectorized<float> neg() const {
return svneg_f32_x(ptrue, *this);
}
Vectorized<float> tan() const {
return USE_SLEEF(
Vectorized<float>(Sleef_tanfx_u10sve(values)), map(std::tan));
inline Vectorized<float> round() const {
return svrinti_f32_x(ptrue, *this);
}
inline Vectorized<float> tan() const {
return USE_SLEEF(Sleef_tanfx_u10sve(*this), map(std::tan));
}
// Implementation is picked from
// https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L179
Vectorized<float> tanh() const {
inline Vectorized<float> tanh() const {
// Constants used for the tanh calculation.
const svfloat32_t CONST_1 =
svdup_n_f32(1.f); // Constant 1.0f for the tanh formula.
@ -450,7 +421,7 @@ class Vectorized<float> {
// instability. svmax_f32_z ensures values are greater than -10, and
// svmin_f32_z ensures they are less than 10.
svfloat32_t x = svmin_f32_z(
ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
ptrue, svmax_f32_z(ptrue, *this, CONST_MIN_TANH), CONST_MAX_TANH);
// Step 2: Calculate exp(2 * x), where x is the clamped value.
// svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
@ -472,104 +443,85 @@ class Vectorized<float> {
// Return the calculated tanh values.
return tanh;
}
Vectorized<float> trunc() const {
return svrintz_f32_x(ptrue, values);
inline Vectorized<float> trunc() const {
return svrintz_f32_x(ptrue, *this);
}
Vectorized<float> lgamma() const {
return USE_SLEEF(
Vectorized<float>(Sleef_lgammafx_u10sve(values)), map(std::lgamma));
inline Vectorized<float> lgamma() const {
return USE_SLEEF(Sleef_lgammafx_u10sve(*this), map(std::lgamma));
}
Vectorized<float> sqrt() const {
return svsqrt_f32_x(ptrue, values);
inline Vectorized<float> sqrt() const {
return svsqrt_f32_x(ptrue, *this);
}
Vectorized<float> reciprocal() const {
return svdivr_f32_x(ptrue, values, ONE_F32);
inline Vectorized<float> reciprocal() const {
return svdivr_f32_x(ptrue, *this, svdup_n_f32(1.f));
}
Vectorized<float> rsqrt() const {
return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, values), ONE_F32);
inline Vectorized<float> rsqrt() const {
return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, *this), ONE_F32);
}
Vectorized<float> pow(const Vectorized<float>& b) const {USE_SLEEF(
{ return Vectorized<float>(Sleef_powfx_u10sve(values, b)); },
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::pow(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} // Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<float> operator==(const Vectorized<float>& other) const {
svbool_t mask = svcmpeq_f32(ptrue, values, other);
inline Vectorized<float> pow(const Vectorized<float> &b) const {
return USE_SLEEF(Sleef_powfx_u10sve(*this, b), map(std::pow, b));
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
inline Vectorized<float> operator==(const Vectorized<float>& other) const {
svbool_t mask = svcmpeq_f32(ptrue, *this, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
inline Vectorized<float> operator!=(const Vectorized<float>& other) const {
svbool_t mask = svcmpne_f32(ptrue, *this, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
inline Vectorized<float> operator<(const Vectorized<float>& other) const {
svbool_t mask = svcmplt_f32(ptrue, *this, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator!=(const Vectorized<float>& other) const {
svbool_t mask = svcmpne_f32(ptrue, values, other);
inline Vectorized<float> operator<=(const Vectorized<float>& other) const {
svbool_t mask = svcmple_f32(ptrue, *this, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator<(const Vectorized<float>& other) const {
svbool_t mask = svcmplt_f32(ptrue, values, other);
inline Vectorized<float> operator>(const Vectorized<float>& other) const {
svbool_t mask = svcmpgt_f32(ptrue, *this, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator<=(const Vectorized<float>& other) const {
svbool_t mask = svcmple_f32(ptrue, values, other);
inline Vectorized<float> operator>=(const Vectorized<float>& other) const {
svbool_t mask = svcmpge_f32(ptrue, *this, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator>(const Vectorized<float>& other) const {
svbool_t mask = svcmpgt_f32(ptrue, values, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> operator>=(const Vectorized<float>& other) const {
svbool_t mask = svcmpge_f32(ptrue, values, other);
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
}
Vectorized<float> eq(const Vectorized<float>& other) const;
Vectorized<float> ne(const Vectorized<float>& other) const;
Vectorized<float> gt(const Vectorized<float>& other) const;
Vectorized<float> ge(const Vectorized<float>& other) const;
Vectorized<float> lt(const Vectorized<float>& other) const;
Vectorized<float> le(const Vectorized<float>& other) const;
inline Vectorized<float> eq(const Vectorized<float>& other) const;
inline Vectorized<float> ne(const Vectorized<float>& other) const;
inline Vectorized<float> gt(const Vectorized<float>& other) const;
inline Vectorized<float> ge(const Vectorized<float>& other) const;
inline Vectorized<float> lt(const Vectorized<float>& other) const;
inline Vectorized<float> le(const Vectorized<float>& other) const;
};
template <>
Vectorized<float> inline operator+(
const Vectorized<float>& a,
const Vectorized<float>& b) {
inline Vectorized<float> operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
return svadd_f32_x(ptrue, a, b);
}
template <>
Vectorized<float> inline operator-(
const Vectorized<float>& a,
const Vectorized<float>& b) {
inline Vectorized<float> operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
return svsub_f32_x(ptrue, a, b);
}
template <>
Vectorized<float> inline operator*(
const Vectorized<float>& a,
const Vectorized<float>& b) {
inline Vectorized<float> operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
return svmul_f32_x(ptrue, a, b);
}
template <>
Vectorized<float> inline operator/(
const Vectorized<float>& a,
const Vectorized<float>& b) {
inline Vectorized<float> operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
return svdiv_f32_x(ptrue, a, b);
}
// frac. Implement this here so we can use subtraction
Vectorized<float> inline Vectorized<float>::frac() const {
inline Vectorized<float> Vectorized<float>::frac() const {
return *this - this->trunc();
}
@ -585,115 +537,91 @@ Vectorized<float> inline maximum(
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline minimum(
const Vectorized<float>& a,
const Vectorized<float>& b) {
inline Vectorized<float> minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
return svmin_f32_x(ptrue, a, b);
}
template <>
Vectorized<float> inline clamp(
const Vectorized<float>& a,
const Vectorized<float>& min,
const Vectorized<float>& max) {
inline Vectorized<float> clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
return svmin_f32_x(ptrue, max, svmax_f32_x(ptrue, min, a));
}
template <>
Vectorized<float> inline clamp_max(
const Vectorized<float>& a,
const Vectorized<float>& max) {
inline Vectorized<float> clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
return svmin_f32_x(ptrue, max, a);
}
template <>
Vectorized<float> inline clamp_min(
const Vectorized<float>& a,
const Vectorized<float>& min) {
inline Vectorized<float> clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
return svmax_f32_x(ptrue, min, a);
}
template <>
Vectorized<float> inline operator&(
const Vectorized<float>& a,
const Vectorized<float>& b) {
return svreinterpret_f32_s32(
svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
inline Vectorized<float> operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
return svreinterpret_f32_s32(svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
}
template <>
Vectorized<float> inline operator|(
const Vectorized<float>& a,
const Vectorized<float>& b) {
return svreinterpret_f32_s32(
svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
inline Vectorized<float> operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
return svreinterpret_f32_s32(svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
}
template <>
Vectorized<float> inline operator^(
const Vectorized<float>& a,
const Vectorized<float>& b) {
return svreinterpret_f32_s32(
sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
inline Vectorized<float> operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
return svreinterpret_f32_s32(sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
}
Vectorized<float> inline Vectorized<float>::eq(
const Vectorized<float>& other) const {
inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const {
return (*this == other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::ne(
const Vectorized<float>& other) const {
inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const {
return (*this != other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::gt(
const Vectorized<float>& other) const {
inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const {
return (*this > other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::ge(
const Vectorized<float>& other) const {
inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const {
return (*this >= other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::lt(
const Vectorized<float>& other) const {
inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const {
return (*this < other) & Vectorized<float>(1.0f);
}
Vectorized<float> inline Vectorized<float>::le(
const Vectorized<float>& other) const {
inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const {
return (*this <= other) & Vectorized<float>(1.0f);
}
template <>
inline void convert(const float* src, float* dst, int64_t n) {
const int64_t fraction = n % Vectorized<float>::size();
const int64_t fraction = n % svcntw();
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
for (int64_t i = 0; i < n - fraction; i += svcntw()) {
svst1_f32(ptrue, dst + i, svldnt1_f32(ptrue, src + i));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
for (int64_t i = n - fraction; i < n; i += svcntw()) {
svbool_t pg = svwhilelt_b32(i, n);
svst1_f32(pg, dst + i, svldnt1_f32(pg, src + i));
}
}
template <>
inline void convert(const float* src, at::Half* dst, int64_t n) {
const int64_t fraction = n % Vectorized<float>::size();
svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized<float>::size());
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<float>::size());
inline void convert(const float *src, at::Half *dst, int64_t n) {
const int64_t fraction = n % svcntw();
svbool_t pg_16 = svwhilelt_b16(0ull, svcntw());
svbool_t pg_32 = svwhilelt_b32(0ull, svcntw());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
svfloat16_t src_vec = svuzp1_f16(
svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16);
for (int64_t i = 0; i < n - fraction; i += svcntw()) {
svfloat16_t src_vec = svuzp1_f16(svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)),
ZERO_F16);
svst1_f16(pg_16, reinterpret_cast<float16_t*>(dst) + i, src_vec);
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
for (int64_t i = n - fraction; i < n; i += svcntw()) {
pg_16 = svwhilelt_b16(i, n);
pg_32 = svwhilelt_b32(i, n);
svfloat16_t src_vec = svuzp1_f16(
@ -703,19 +631,18 @@ inline void convert(const float* src, at::Half* dst, int64_t n) {
}
template <>
inline void convert(const at::Half* src, float* dst, int64_t n) {
const int64_t fraction = n % Vectorized<float>::size();
svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized<float>::size());
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<float>::size());
inline void convert(const at::Half *src, float *dst, int64_t n) {
const int64_t fraction = n % svcntw();
svbool_t pg_16 = svwhilelt_b16(0ull, svcntw());
svbool_t pg_32 = svwhilelt_b32(0ull, svcntw());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
svfloat16_t src_vec = svzip1_f16(
svldnt1_f16(pg_16, reinterpret_cast<const float16_t*>(src) + i),
ZERO_F16);
for (int64_t i = 0; i < n - fraction; i += svcntw()) {
svfloat16_t src_vec = svzip1_f16(svldnt1_f16(pg_16, reinterpret_cast<const float16_t*>(src) + i),
ZERO_F16);
svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
for (int64_t i = n - fraction; i < n; i += svcntw()) {
pg_16 = svwhilelt_b16(i, n);
pg_32 = svwhilelt_b32(i, n);
svfloat16_t src_vec = svzip1_f16(
@ -726,20 +653,19 @@ inline void convert(const at::Half* src, float* dst, int64_t n) {
}
template <>
inline void convert(const bool* src, float* dst, int64_t n) {
const int64_t fraction = n % Vectorized<float>::size();
svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized<float>::size());
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<float>::size());
inline void convert(const bool *src, float *dst, int64_t n) {
const int64_t fraction = n % svcntw();
svbool_t pg_8 = svwhilelt_b8(0ull, svcntw());
svbool_t pg_32 = svwhilelt_b32(0ull, svcntw());
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
svuint8_t src_vec_u8 =
svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
for (int64_t i = 0; i < n - fraction; i += svcntw()) {
svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8));
svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32);
svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
for (int64_t i = n - fraction; i < n; i += svcntw()) {
pg_8 = svwhilelt_b8(i, n);
pg_32 = svwhilelt_b32(i, n);
svuint8_t src_vec_u8 =
@ -751,10 +677,7 @@ inline void convert(const bool* src, float* dst, int64_t n) {
}
template <>
Vectorized<float> inline fmadd(
const Vectorized<float>& a,
const Vectorized<float>& b,
const Vectorized<float>& c) {
inline Vectorized<float> fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
return svmad_f32_x(ptrue, a, b, c);
}

View File

@ -15,7 +15,7 @@ namespace at::vec {
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
#define VEC_INT_SVE_TEMPLATE(vl, bit) \
template <> \
@ -49,10 +49,11 @@ inline namespace CPU_CAPABILITY {
operator svint##bit##_t() const { \
return values; \
} \
template <uint64_t mask> \
static Vectorized<int##bit##_t> blend( \
const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b) { \
const Vectorized<int##bit##_t>& b, \
uint64_t mask \
) { \
__at_align__ int##bit##_t flag_arr[size()]; \
for (int i = 0; i < size(); ++i) { \
flag_arr[i] = (i < 64 && (mask & (1ULL << i))) ? 1 : 0; \
@ -493,7 +494,7 @@ Vectorized<int8_t> inline operator>>(
return svasr_s8_x(ptrue, a, svreinterpret_u8_s8(b));
}
#endif // defined(CPU_CAPABILITY_SVE)
#endif // defined(CPU_CAPABILITY_SVE256)
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -46,7 +46,7 @@ namespace at::vec {
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
// NOTE: These are low-performance implementations that we fall back on
// if we are not building with SVE. This may not be an issue, because
@ -100,12 +100,12 @@ struct VectorizedQuantizedConverter {
Vectorized<float> zero_point,
Vectorized<float> scale_zp_premul) const {
float_vec_return_type rv;
float tmp_scale[Vectorized<float>::size()];
float tmp_zero_point[Vectorized<float>::size()];
float * tmp_scale = new float[Vectorized<float>::size()];
float * tmp_zero_point = new float[Vectorized<float>::size()];
scale.store(tmp_scale);
zero_point.store(tmp_zero_point);
for (int i = 0; i < float_num_vecs(); ++i) {
float tmp_vals[Vectorized<float>::size()];
float * tmp_vals = new float[Vectorized<float>::size()];
for (int j = 0; j < Vectorized<float>::size(); ++j) {
tmp_vals[j] = at::native::dequantize_val<T>(
tmp_scale[j],
@ -113,7 +113,11 @@ struct VectorizedQuantizedConverter {
T(vals[Vectorized<float>::size() * i + j]));
}
rv[i] = Vectorized<float>::loadu(tmp_vals);
delete[] tmp_vals;
}
delete[] tmp_scale;
delete[] tmp_zero_point;
return rv;
}
@ -121,12 +125,12 @@ struct VectorizedQuantizedConverter {
Vectorized<float> scale,
Vectorized<float> zero_point) const {
float_vec_return_type rv;
float tmp_scale[Vectorized<float>::size()];
float tmp_zero_point[Vectorized<float>::size()];
float * tmp_scale = new float[Vectorized<float>::size()];
float * tmp_zero_point = new float[Vectorized<float>::size()];
scale.store(tmp_scale);
zero_point.store(tmp_zero_point);
for (int i = 0; i < float_num_vecs(); ++i) {
float tmp_vals[Vectorized<float>::size()];
float * tmp_vals = new float[Vectorized<float>::size()];
for (int j = 0; j < Vectorized<float>::size(); ++j) {
tmp_vals[j] = at::native::dequantize_val<T>(
tmp_scale[j],
@ -134,7 +138,10 @@ struct VectorizedQuantizedConverter {
T(vals[Vectorized<float>::size() * i + j]));
}
rv[i] = Vectorized<float>::loadu(tmp_vals);
delete[] tmp_vals;
}
delete[] tmp_scale;
delete[] tmp_zero_point;
return rv;
}
@ -205,7 +212,7 @@ struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
int32_t zero_point,
float inverse_scale) {
std::array<value_type, size()> qvals;
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
float * float_vals = new float[float_num_vecs() * Vectorized<float>::size()];
for (int i = 0; i < float_num_vecs(); ++i) {
rhs[i].store(
@ -216,10 +223,11 @@ struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
at::native::quantize_vec<c10::qint32, /*precision=*/32>(
scale,
zero_point,
float_vals.data(),
float_vals,
(c10::qint32*)qvals.data(),
Vectorized<float>::size() * float_num_vecs());
delete[] float_vals;
return Vectorized<c10::qint32>::loadu(qvals.data());
}
@ -359,7 +367,7 @@ struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
int32_t zero_point,
float inverse_scale) {
std::array<value_type, size()> qvals;
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
float * float_vals = new float[float_num_vecs() * Vectorized<float>::size()];
for (int i = 0; i < float_num_vecs(); ++i) {
rhs[i].store(
@ -370,10 +378,11 @@ struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
at::native::quantize_vec<c10::qint8>(
scale,
zero_point,
float_vals.data(),
float_vals,
(c10::qint8*)qvals.data(),
Vectorized<float>::size() * float_num_vecs());
delete[] float_vals;
return Vectorized<c10::qint8>::loadu(qvals.data());
}
@ -511,7 +520,7 @@ struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
int32_t zero_point,
float inverse_scale) {
std::array<value_type, size()> qvals;
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
float * float_vals = new float[float_num_vecs() * Vectorized<float>::size()];
for (int i = 0; i < float_num_vecs(); ++i) {
rhs[i].store(
@ -522,10 +531,11 @@ struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
at::native::quantize_vec<c10::quint8>(
scale,
zero_point,
float_vals.data(),
float_vals,
(c10::quint8*)qvals.data(),
Vectorized<float>::size() * float_num_vecs());
delete[] float_vals;
return Vectorized<c10::quint8>::loadu(qvals.data());
}
@ -600,7 +610,7 @@ Vectorized<c10::quint8> inline maximum(
return a.maximum(b);
}
#endif // defined(CPU_CAPABILITY_SVE)
#endif // defined(CPU_CAPABILITY_SVE256)
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -4,7 +4,9 @@
#include <ATen/cpu/vec/intrinsics.h>
#ifdef __aarch64__
#if !defined(CPU_CAPABILITY_SVE)
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE256)
#include <ATen/cpu/vec/sve/vec_common_sve.h>
#else
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>

View File

@ -241,7 +241,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
Vectorized() = default;
Vectorized(c10::BFloat16 val)
: Vectorized16(at_vdupq_n_bf16(c10::bit_cast<at_bfloat16_t>(val.x))) {}
: Vectorized16(at_vdupq_n_bf16(val.x)) {}
Vectorized(float val) : Vectorized(c10::BFloat16(val)) {}
Vectorized(
value_type val0,
@ -253,14 +253,14 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
value_type val6,
value_type val7)
: Vectorized16(at_bfloat16x8_t{
c10::bit_cast<at_bfloat16_t>(val0.x),
c10::bit_cast<at_bfloat16_t>(val1.x),
c10::bit_cast<at_bfloat16_t>(val2.x),
c10::bit_cast<at_bfloat16_t>(val3.x),
c10::bit_cast<at_bfloat16_t>(val4.x),
c10::bit_cast<at_bfloat16_t>(val5.x),
c10::bit_cast<at_bfloat16_t>(val6.x),
c10::bit_cast<at_bfloat16_t>(val7.x)}) {}
val0.x,
val1.x,
val2.x,
val3.x,
val4.x,
val5.x,
val6.x,
val7.x}) {}
static Vectorized<c10::BFloat16> blendv(
const Vectorized<c10::BFloat16>& a,

View File

@ -4,7 +4,7 @@
namespace at::vec {
inline namespace CPU_CAPABILITY {
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE))
template <typename src_t>
struct VecConvert<
float,

View File

@ -41,32 +41,16 @@ inline namespace CPU_CAPABILITY {
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
#endif
template <int index, bool mask_val>
template <int index>
struct BlendRegs {
static float32x4_t impl(
const float32x4_t& a,
const float32x4_t& b,
float32x4_t& res);
};
template <int index>
struct BlendRegs<index, true> {
static float32x4_t impl(
const float32x4_t& a,
const float32x4_t& b,
float32x4_t& res) {
return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index);
}
};
template <int index>
struct BlendRegs<index, false> {
static float32x4_t impl(
const float32x4_t& a,
const float32x4_t& b,
float32x4_t& res) {
return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index);
}
float32x4_t& res,
bool mask_val
) {
return vsetq_lane_f32(vgetq_lane_f32(mask_val ? b : a, index), res, index);
}
};
template <>
@ -94,19 +78,15 @@ class Vectorized<float> {
operator float32x4_t() const {
return values;
}
template <int64_t mask>
static Vectorized<float> blend(
const Vectorized<float>& a,
const Vectorized<float>& b) {
const Vectorized<float>& b,
int64_t mask) {
Vectorized<float> vec;
vec.values = BlendRegs < 0,
(mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values);
vec.values = BlendRegs < 1,
(mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values);
vec.values = BlendRegs < 2,
(mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values);
vec.values = BlendRegs < 3,
(mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values);
vec.values = BlendRegs <0>::impl(a.values, b.values, vec.values, (mask & 0x01) != 0);
vec.values = BlendRegs <1> ::impl(a.values, b.values, vec.values, (mask & 0x02) != 0);
vec.values = BlendRegs <2> ::impl(a.values, b.values, vec.values, (mask & 0x04) != 0);
vec.values = BlendRegs <3> ::impl(a.values, b.values, vec.values, (mask & 0x08) != 0);
return vec;
}
static Vectorized<float> blendv(
@ -307,11 +287,48 @@ class Vectorized<float> {
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
// Implementation copied from Arm Optimized Routine https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
Vectorized<float> exp_u20() const {
return exp();
// bail out to sleef if it's a special case:
// i.e. there's an input s.t. |input| > 87.3....
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
uint32x4_t cmp = vcagtq_f32 (values, special_bound);
if (vpaddd_u64 (vreinterpretq_u64_u32 (cmp)) != 0) {
return exp();
}
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
const float ln2_hi = 0x1.62e4p-1f;
const float ln2_lo = 0x1.7f7d1cp-20f;
const float c0 = 0x1.0e4020p-7f;
const float c2 = 0x1.555e66p-3f;
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
/* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
x = ln2*n + r, with r in [-ln2/2, ln2/2]. */
float32x4_t n = vrndaq_f32 (vmulq_f32 (values, inv_ln2));
float32x4_t r = vfmsq_laneq_f32 (values, n, ln2_c02, 0);
r = vfmsq_laneq_f32 (r, n, ln2_c02, 1);
uint32x4_t e = vshlq_n_u32 (vreinterpretq_u32_s32 (vcvtq_s32_f32 (n)), 23);
float32x4_t scale = vreinterpretq_f32_u32 (vaddq_u32 (e, exponent_bias));
float32x4_t r2 = vmulq_f32 (r, r);
float32x4_t p = vfmaq_laneq_f32 (c1, r, ln2_c02, 2);
float32x4_t q = vfmaq_laneq_f32 (c3, r, ln2_c02, 3);
q = vfmaq_f32 (q, p, r2);
p = vmulq_f32 (c4, r);
float32x4_t poly = vfmaq_f32 (p, q, r2);
return vfmaq_f32 (scale, poly, scale);
}
Vectorized<float> fexp_u20() const {
return exp();
return exp_u20();
}
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
fmod,

View File

@ -813,11 +813,12 @@ static inline Vectorized<T> binary_op_as_fp32(
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name( \
const type* data, Vectorized<float>& out) { \
__at_align__ float values[Vectorized<float>::size()]; \
__at_align__ float * values = new float[Vectorized<float>::size()]; \
for (const auto k : c10::irange(Vectorized<float>::size())) { \
values[k] = data[k]; \
} \
out = Vectorized<float>::loadu(values); \
delete[] values; \
} \
\
inline void load_fp32_from_##name( \

View File

@ -269,12 +269,13 @@ LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16)
#else // defined(CPU_CAPABILITY_AVX2)
#if !( \
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
!defined(CPU_CAPABILITY_SVE256))
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__))
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
#endif
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16)
#endif
#endif // defined(CPU_CAPABILITY_AVX2)
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -294,7 +294,7 @@ struct VecConvert<
};
#endif
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
#if (defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)) && defined(__ARM_FEATURE_BF16)
template <>
struct VecConvert<float, 1, BFloat16, 1> {

View File

@ -270,7 +270,7 @@ LOAD_FP32_VECTORIZED_INIT(Half, fp16)
#if !( \
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
!defined(CPU_CAPABILITY_SVE256))
!defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE))
CONVERT_NON_VECTORIZED_INIT(Half, half)
#endif

View File

@ -915,7 +915,7 @@ Vectorized<c10::quint8> inline maximum(
return a.maximum(b);
}
#elif !defined(CPU_CAPABILITY_SVE256)
#elif !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
// NOTE: These are low-performance implementations that we fall back on
// if we are not building with AVX2. This may not be an issue, because
@ -1374,11 +1374,11 @@ Vectorized<c10::quint8> inline maximum(
#endif // if defined(CPU_CAPABILITY_AVX2)
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
at::vec::Vectorized<int8_t> src) {
auto s8x8 = vld1_s8(src.operator const int8_t*());
auto s16x8 = vmovl_s8(s8x8);
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
std::pair<Vectorized<float>, Vectorized<float>>
inline convert_int8_to_float(at::vec::Vectorized<int8_t> src) {
auto s8x8 = vld1_s8(src.operator const int8_t*());
auto s16x8 = vmovl_s8(s8x8);
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));

View File

@ -292,8 +292,7 @@ class Vectorized16 {
_mm512_mask_storeu_epi16(ptr, mask, values);
}
}
template <int64_t mask>
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b, int64_t mask) {
return _mm512_mask_blend_epi16(mask, a.values, b.values);
}
static Vectorized<T> blendv(

View File

@ -69,10 +69,10 @@ class Vectorized<c10::complex<double>> {
operator __m512d() const {
return values;
}
template <int64_t mask>
static Vectorized<c10::complex<double>> blend(
const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
const Vectorized<c10::complex<double>>& b,
int64_t mask) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
// NOLINTNEXTLINE(clang-diagnostic-warning)
switch (mask) {

View File

@ -89,10 +89,10 @@ class Vectorized<c10::complex<float>> {
operator __m512() const {
return values;
}
template <int64_t mask>
static Vectorized<c10::complex<float>> blend(
const Vectorized<c10::complex<float>>& a,
const Vectorized<c10::complex<float>>& b) {
const Vectorized<c10::complex<float>>& b,
int64_t mask) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
static_assert(mask > -1 && mask < 256, "Unexpected mask value");
// The compiler would hopefully convert this switch condition

View File

@ -55,10 +55,10 @@ class Vectorized<double> {
operator __m512d() const {
return values;
}
template <int64_t mask>
static Vectorized<double> blend(
const Vectorized<double>& a,
const Vectorized<double>& b) {
const Vectorized<double>& b,
int64_t mask) {
return _mm512_mask_blend_pd(mask, a.values, b.values);
}
static Vectorized<double> blendv(

View File

@ -95,10 +95,10 @@ class Vectorized<float> {
operator __m512() const {
return values;
}
template <int64_t mask>
static Vectorized<float> blend(
const Vectorized<float>& a,
const Vectorized<float>& b) {
const Vectorized<float>& b,
int64_t mask) {
return _mm512_mask_blend_ps(mask, a.values, b.values);
}
static Vectorized<float> blendv(

View File

@ -528,10 +528,10 @@ class Vectorized<int16_t> : public Vectorizedi {
val2,
val1);
}
template <int64_t mask>
static Vectorized<int16_t> blend(
Vectorized<int16_t> a,
Vectorized<int16_t> b) {
Vectorized<int16_t> b,
int64_t mask) {
return _mm512_mask_blend_epi16(mask, a.values, b.values);
}
static Vectorized<int16_t> blendv(

View File

@ -68,7 +68,7 @@ Windows llvm will not have this definition.
#define VECTOR_WIDTH 64
#define int_vector __m512i
#elif defined(__aarch64__) && \
!defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512
!defined(CPU_CAPABILITY_SVE) && !defined(CPU_CAPABILITY_SVE256) // CPU_CAPABILITY_AVX512
// SVE code expects 256-vectors; leave that set for SVE?
#if defined(__GNUC__)
#define __at_align__ __attribute__((aligned(16)))
@ -79,6 +79,18 @@ Windows llvm will not have this definition.
#endif
#define VECTOR_WIDTH 16
#else // CPU_CAPABILITY_AVX512
#if defined(CPU_CAPABILITY_SVE)
#if defined(__GNUC__)
#define __at_align__ __attribute__((aligned(16)))
#elif defined(_WIN32)
#define __at_align__ __declspec(align(16))
#else
#define __at_align__
#endif
#define VECTOR_WIDTH 16
#define int_vector __m256i
#else // CPU_CAPABILITY_SVE256 || CPU_CAPABILITY_SVE
#if defined(CPU_CAPABILITY_SVE256)
#if defined(__GNUC__)
#define __at_align__ __attribute__((aligned(32)))
#elif defined(_WIN32)
@ -88,6 +100,18 @@ Windows llvm will not have this definition.
#endif
#define VECTOR_WIDTH 32
#define int_vector __m256i
#else // CPU_CAPABILITY_SVE
#if defined(__GNUC__)
#define __at_align__ __attribute__((aligned(16)))
#elif defined(_WIN32)
#define __at_align__ __declspec(align(16))
#else
#define __at_align__
#endif
#define VECTOR_WIDTH 16
#define int_vector __m256i
#endif // CPU_CAPABILITY_SVE256
#endif // CPU_CAPABILITY_SVE256 || CPU_CAPABILITY_SVE
#endif // CPU_CAPABILITY_AVX512
namespace at::vec {
@ -210,8 +234,7 @@ struct Vectorized {
auto as_bytes() const -> const char* {
return reinterpret_cast<const char*>(values);
}
template <int64_t mask_>
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b, const int64_t mask_) {
int64_t mask = mask_;
Vectorized vector;
for (const auto i : c10::irange(size())) {
@ -1312,7 +1335,7 @@ std::
T const* base_addr,
const Vectorized<int_same_size_t<T>>& vindex,
Vectorized<T>& mask) {
static constexpr int size = Vectorized<T>::size();
static const int size = Vectorized<T>::size();
T src_arr[size];
int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
int_same_size_t<T> index_arr[size];
@ -1405,7 +1428,7 @@ inline Vectorized<T> convert_to_fp_of_same_size(
// clang-format on
template <typename T>
inline std::enable_if_t<
Vectorized<T>::size() % 2 == 0,
true,
std::pair<Vectorized<T>, Vectorized<T>>>
deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
static constexpr int size = Vectorized<T>::size();
@ -1444,7 +1467,7 @@ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2)
// clang-format on
template <typename T>
inline std::enable_if_t<
Vectorized<T>::size() % 2 == 0,
true,
std::pair<Vectorized<T>, Vectorized<T>>>
interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
static constexpr int size = Vectorized<T>::size();
@ -1486,7 +1509,7 @@ inline void convert(const src_T* src, dst_T* dst, int64_t n) {
template <typename T>
inline Vectorized<T> flip(const Vectorized<T>& data) {
static constexpr int size = Vectorized<T>::size();
static const int size = Vectorized<T>::size();
T output[size];
T buffer[size];
data.store(static_cast<void*>(buffer));

View File

@ -15,7 +15,7 @@ template <
struct VecConvert {
static inline VectorizedN<dst_t, dst_n> apply(
const VectorizedN<src_t, src_n>& src) {
constexpr int count = std::min(
const int count = std::min(
VectorizedN<src_t, src_n>::size(), VectorizedN<dst_t, dst_n>::size());
__at_align__ src_t src_buf[VectorizedN<src_t, src_n>::size()];
src.store(src_buf);

View File

@ -2,6 +2,8 @@
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec_n.h>
#include <cassert>
namespace at::vec {
inline namespace CPU_CAPABILITY {
@ -38,9 +40,9 @@ struct VecMaskLoad {
static inline VectorizedN<data_t, data_n> apply(
const data_t* ptr,
const VecMask<mask_t, mask_n>& vec_mask) {
constexpr typename VecMask<mask_t, mask_n>::size_type size =
const typename VecMask<mask_t, mask_n>::size_type size =
VecMask<mask_t, mask_n>::size();
static_assert(VectorizedN<data_t, data_n>::size() >= size);
assert((VectorizedN<data_t, data_n>::size() >= size));
__at_align__ data_t data[size];
__at_align__ mask_t mask[size];
auto mask_ = VectorizedN<mask_t, mask_n>(vec_mask);
@ -134,7 +136,7 @@ class VecMask {
template <typename U, int L>
static VecMask<T, N> from(const VectorizedN<U, L>& b_vec) {
__at_align__ U b_buf[size()];
if constexpr (size() >= VectorizedN<U, L>::size()) {
if (size() >= VectorizedN<U, L>::size()) {
b_vec.store(b_buf);
for (int i = VectorizedN<U, L>::size(); i < size(); i++) {
b_buf[i] = static_cast<U>(0);
@ -235,16 +237,18 @@ class VecMask {
template <
typename U,
int L,
std::enable_if_t<L >= 2 && VectorizedN<U, L>::size() >= size(), int> = 0>
std::enable_if_t<L >= 2, int> = 0>
VectorizedN<U, L> loadu(const U* ptr) const {
assert((VectorizedN<U, L>::size() >= size()));
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
}
template <
typename U,
int L,
std::enable_if_t<L == 1 && Vectorized<U>::size() >= size(), int> = 0>
std::enable_if_t<L == 1, int> = 0>
Vectorized<U> loadu(const U* ptr) const {
assert((Vectorized<U>::size() >= size()));
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
}
};

View File

@ -28,7 +28,7 @@ class VectorizedN {
using size_type = int;
static constexpr size_type size_T = sizeof(T);
static constexpr size_type size() {
static size_type size() {
return Vectorized<T>::size() * N;
}

View File

@ -1157,6 +1157,7 @@ REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_SVE_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_SVE256_DISPATCH(cholesky_stub, &cholesky_kernel)
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl)
@ -1164,6 +1165,7 @@ REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_SVE_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_SVE256_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel)
@ -1171,6 +1173,7 @@ REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_SVE_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_SVE256_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel)
@ -1178,6 +1181,7 @@ REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_SVE_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_SVE256_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel)
@ -1185,6 +1189,7 @@ REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_SVE_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_SVE256_DISPATCH(geqrf_stub, &geqrf_kernel)
REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl)
@ -1192,6 +1197,7 @@ REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_SVE_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_SVE256_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel)
@ -1199,6 +1205,7 @@ REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_SVE_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_SVE256_DISPATCH(ormqr_stub, &ormqr_kernel)
REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel)
@ -1206,6 +1213,7 @@ REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_SVE_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_SVE256_DISPATCH(lstsq_stub, &lstsq_kernel)
REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel)
@ -1213,6 +1221,7 @@ REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_SVE_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_SVE256_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel)
@ -1220,6 +1229,7 @@ REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_SVE_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_SVE256_DISPATCH(lu_factor_stub, &lu_factor_kernel)
REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel)
@ -1227,6 +1237,7 @@ REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_SVE_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_SVE256_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel)
@ -1234,6 +1245,7 @@ REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_SVE_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_SVE256_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel)
@ -1241,6 +1253,7 @@ REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_SVE_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_SVE256_DISPATCH(lu_solve_stub, &lu_solve_kernel)
REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel)
@ -1248,6 +1261,7 @@ REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel)
REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel)
REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel)
REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel)
REGISTER_SVE_DISPATCH(svd_stub, &svd_kernel)
REGISTER_SVE256_DISPATCH(svd_stub, &svd_kernel)
REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel)
@ -1255,5 +1269,6 @@ REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_SVE_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
REGISTER_SVE256_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
} // namespace at::native

View File

@ -38,17 +38,27 @@ static CPUCapability compute_cpu_capability() {
return CPUCapability::ZVECTOR;
}
#elif defined(HAVE_SVE_CPU_DEFINITION)
int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW.
#ifdef HAVE_SVE256_CPU_DEFINITION
int sve_vl = cpuinfo_get_max_arm_sve_length(); // Returns maximum SVE VL supported by your HW.
#ifdef HAVE_SVE_CPU_DEFINITION
if (envar == "sve256") {
if (sve_vl == 256) {
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
if (cpuinfo_has_arm_bf16()) {
if (cpuinfo_has_arm_bf16()) {
if (sve_vl == 256) {
return CPUCapability::SVE256;
} else if (sve_vl > 0) {
return CPUCapability::SVE;
}
#endif
}
TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT");
#endif
TORCH_WARN("SVE capability not available on hardware. Falling back to DEFAULT");
return CPUCapability::DEFAULT;
} else if (envar == "sve") {
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
if (cpuinfo_has_arm_bf16() && sve_vl > 0) {
return CPUCapability::SVE;
}
#endif
TORCH_WARN("SVE capability not available on hardware. Falling back to DEFAULT");
return CPUCapability::DEFAULT;
}
#endif
@ -100,19 +110,15 @@ static CPUCapability compute_cpu_capability() {
#if defined(__linux__) && defined(HAVE_SVE_CPU_DEFINITION)
if (cpuinfo_initialize() && cpuinfo_has_arm_sve()) {
int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW.
if (sve_vl <= 0) {
// SVE is not supported on this system.
// Return the default CPU capability.
return CPUCapability::DEFAULT;
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
if (cpuinfo_has_arm_bf16()) {
if (sve_vl == 256) { // Check for SVE256
return CPUCapability::SVE256;
} else if (sve_vl > 0) {
return CPUCapability::SVE;
}
}
#ifdef HAVE_SVE256_CPU_DEFINITION
if (sve_vl == 256) { // Check for SVE256
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
if (cpuinfo_has_arm_bf16())
return CPUCapability::SVE256;
#endif
}
#endif
#endif
// Return the default CPU capability.
return CPUCapability::DEFAULT;
}
@ -144,7 +150,8 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, void *SVE
, void *SVE256
#endif
) {
@ -182,7 +189,8 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, SVE
, SVE256
#endif
);
@ -239,7 +247,8 @@ void* DispatchStubImpl::get_call_ptr(
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, void *SVE
, void *SVE256
#endif
) {
@ -263,7 +272,9 @@ void* DispatchStubImpl::get_call_ptr(
,
ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
,
SVE
,
SVE256
#endif
@ -298,7 +309,8 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, void *SVE
, void *SVE256
#endif
){
@ -333,7 +345,7 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
return ZVECTOR != nullptr ? DispatchResult(ZVECTOR) : ErrorType::MissingDeviceKernel;
}
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::SVE256)) {
if (C10_UNLIKELY(!SVE256)) {
// dispatch to DEFAULT, since the SVE kernel is missing
@ -342,6 +354,14 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
return DispatchResult(SVE256);
}
}
if (capability >= static_cast<int>(CPUCapability::SVE)) {
if (C10_UNLIKELY(!SVE)) {
// dispatch to DEFAULT, since the SVE kernel is missing
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
} else {
return DispatchResult(SVE);
}
}
#endif
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
}
@ -360,7 +380,8 @@ void* DispatchStubImpl::choose_cpu_impl(
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, void *SVE
, void *SVE256
#endif
) {
@ -398,7 +419,7 @@ void* DispatchStubImpl::choose_cpu_impl(
return ZVECTOR;
}
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::SVE256)) {
if (C10_UNLIKELY(!SVE256)) {
// dispatch to DEFAULT, since the SVE kernel is missing
@ -408,6 +429,15 @@ void* DispatchStubImpl::choose_cpu_impl(
return SVE256;
}
}
if (capability >= static_cast<int>(CPUCapability::SVE)) {
if (C10_UNLIKELY(!SVE)) {
// dispatch to DEFAULT, since the SVE kernel is missing
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
return DEFAULT;
} else {
return SVE;
}
}
#endif
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
return DEFAULT;

View File

@ -64,8 +64,9 @@ enum class CPUCapability {
VSX = 1,
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
ZVECTOR = 1,
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
SVE256 = 1,
#elif defined(HAVE_SVE_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
SVE=1,
SVE256 = 2,
#else
AVX2 = 1,
AVX512 = 2,
@ -115,7 +116,8 @@ struct TORCH_API DispatchStubImpl {
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, void *SVE
, void *SVE256
#endif
);
@ -136,7 +138,8 @@ struct TORCH_API DispatchStubImpl {
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, void *SVE
, void *SVE256
#endif
);
@ -157,7 +160,8 @@ struct TORCH_API DispatchStubImpl {
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, void *SVE
, void *SVE256
#endif
);
@ -181,7 +185,8 @@ struct TORCH_API DispatchStubImpl {
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, void *SVE
, void *SVE256
#endif
);
@ -238,7 +243,8 @@ private:
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, reinterpret_cast<void*>(SVE)
, reinterpret_cast<void*>(SVE256)
#endif
)
@ -299,7 +305,8 @@ public:
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
, reinterpret_cast<void*>(SVE)
, reinterpret_cast<void*>(SVE256)
#endif
);
@ -322,7 +329,8 @@ public:
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
static TORCH_API FnPtr ZVECTOR;
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
static TORCH_API FnPtr SVE;
static TORCH_API FnPtr SVE256;
#endif
private:
@ -426,9 +434,11 @@ struct RegisterPRIVATEUSE1Dispatch {
#define REGISTER_ZVECTOR_DISPATCH(name, fn)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
#ifdef HAVE_SVE_CPU_DEFINITION
#define REGISTER_SVE_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE, fn)
#define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn)
#else
#define REGISTER_SVE_DISPATCH(name, fn)
#define REGISTER_SVE256_DISPATCH(name, fn)
#endif
@ -440,6 +450,7 @@ struct RegisterPRIVATEUSE1Dispatch {
REGISTER_AVX2_DISPATCH(name, fn) \
REGISTER_VSX_DISPATCH(name, fn) \
REGISTER_ZVECTOR_DISPATCH(name, fn) \
REGISTER_SVE_DISPATCH(name, fn) \
REGISTER_SVE256_DISPATCH(name, fn)
#define REGISTER_NO_CPU_DISPATCH(name) \
@ -488,6 +499,7 @@ struct RegisterPRIVATEUSE1Dispatch {
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#define ALSO_REGISTER_SVE_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
} // namespace at::native

View File

@ -466,6 +466,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cp
REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
REGISTER_SVE_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
REGISTER_SVE256_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
// offsets dispatches
@ -477,6 +478,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cp
REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
REGISTER_SVE_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
// Currently some computation is being duplicated across forward and backward.
@ -548,6 +550,9 @@ REGISTER_VSX_DISPATCH(
REGISTER_ZVECTOR_DISPATCH(
_segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_lengths_backward_kernel)
REGISTER_SVE_DISPATCH(
_segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_lengths_backward_kernel)
REGISTER_SVE256_DISPATCH(
_segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_lengths_backward_kernel)
@ -568,6 +573,9 @@ REGISTER_VSX_DISPATCH(
REGISTER_ZVECTOR_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel)
REGISTER_SVE_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel)
REGISTER_SVE256_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel)

View File

@ -274,7 +274,7 @@ inline Vectorized<scalar_t> div_floor_floating_vec(
return floordiv;
}
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
#if (defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)) && defined(__ARM_FEATURE_BF16)
// Since sve lacks sufficient bf16 intrinsics, do the calculations in f32 to
// avoid rounding errors. This should not cause performance issues as

View File

@ -11,6 +11,7 @@
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <c10/util/irange.h>
#include <variant>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -44,13 +45,23 @@ inline void _scale_attn_mask_fusion_kernel(
#endif
const auto vec_size1 = at::vec::Vectorized<T1>::size();
const auto vec_size2 = at::vec::Vectorized<T2>::size();
constexpr int64_t T1_n =
const int64_t T1_n =
(vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1;
constexpr int64_t T2_n = 1;
auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
std::variant<at::vec::VectorizedN<T1, 2>, at::vec::VectorizedN<T1, 1>> vec_scale;
if (T1_n == 2)
vec_scale = at::vec::VectorizedN<T1, 2>(val);
else if (T1_n == 1)
vec_scale = at::vec::VectorizedN<T1, 1>(val);
int64_t i = 0;
for (; i < size - (size % vec_size2); i += vec_size2) {
auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
std::variant<at::vec::VectorizedN<T1, 2>, at::vec::VectorizedN<T1, 1>> a_n;
if (T1_n == 2)
a_n = at::vec::VectorizedN<T1, 2>::loadu(a + i);
else if (T1_n == 1)
a_n = at::vec::VectorizedN<T1, 1>::loadu(a + i);
at::vec::VectorizedN<T2, T2_n> b_n;
#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE)
if (is_b_stride_zero) {
@ -61,9 +72,16 @@ inline void _scale_attn_mask_fusion_kernel(
} else {
b_n = at::vec::VectorizedN<T2, T2_n>::loadu(b + i);
}
auto b_n_convert = at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
auto res = a_n * vec_scale + b_n_convert;
res.store(out + i);
std::variant<at::vec::VectorizedN<T1, 2>, at::vec::VectorizedN<T1, 1>> b_n_convert;
if (T1_n == 2) {
auto b_n_convert = at::vec::convert<T1, 2, T2, T2_n, true>(b_n);
auto res = std::get<at::vec::VectorizedN<T1, 2>>(a_n) * std::get<at::vec::VectorizedN<T1, 2>>(vec_scale) + b_n_convert;
res.store(out + i);
} else if(T1_n == 1) {
auto b_n_convert = at::vec::convert<T1, 1, T2, T2_n, true>(b_n);
auto res = std::get<at::vec::VectorizedN<T1, 1>>(a_n) * std::get<at::vec::VectorizedN<T1, 1>>(vec_scale) + b_n_convert;
res.store(out + i);
}
}
for (; i < size; i++) {
auto tmp0 = a[i];

View File

@ -694,7 +694,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bilinear,
gx = gx * gx_mult;
gy = gy * gy_mult;
constexpr int64_t step = Vec::size();
const int64_t step = Vec::size();
auto interleaved_gGrid = interleave2(gx, gy);
auto gGrid_ptr = gGrid_slice.data() + offset * 2;
std::get<0>(interleaved_gGrid).store(gGrid_ptr,
@ -1010,7 +1010,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bicubic,
gx = gx * gx_mult;
gy = gy * gy_mult;
constexpr int64_t step = Vec::size();
const int64_t step = Vec::size();
auto interleaved_gGrid = interleave2(gx, gy);
auto gGrid_ptr = gGrid_slice.data() + offset * 2;
std::get<0>(interleaved_gGrid).store(gGrid_ptr,
@ -1041,7 +1041,7 @@ static inline void grid_sample_2d_grid_slice_iterator(
using Vec = Vectorized<scalar_t>;
using iVec = Vectorized<int_same_size_t<scalar_t>>;
constexpr int64_t step = Vec::size();
const int64_t step = Vec::size();
// Loop over each output pixel in grid.
// We consider the following three cases (after slicing out the batch

View File

@ -19,7 +19,7 @@ Vectorized<scalar_t> is_lerp_weight_small(Vectorized<scalar_t> weight) {
// is_lerp_weight_small doesn't work for complex because z.abs() returns a
// complex vector which can't be compared. Either implement it with z.abs_2_(),
// or fallback to the scalar function.
#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) || defined(CPU_CAPABILITY_SVE))
#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) || defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE))
template <typename value_t>
Vectorized<c10::complex<value_t>> is_lerp_weight_small(Vectorized<c10::complex<value_t>> weight) {
using vec_reg_t = decltype(weight.abs_2_());

View File

@ -210,13 +210,22 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve
Vec opt_scalar = Vec(S > 0 ? c10::load((scalar_t*)data[S]) : scalar_t(0));
int64_t i = 0;
for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
int size = Vec::size();
#if !defined(CPU_CAPABILITY_SVE) && !defined(CPU_CAPABILITY_SVE256)
// Loop unrolling prevents compiler from optimizing the SVE classes
for (; i <= n - 2 * size; i += 2 * size) {
auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + size);
auto out1 = std::apply(vop, std::move(args1));
auto out2 = std::apply(vop, std::move(args2));
out1.store(data[0] + i * sizeof(scalar_t));
out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
out2.store(data[0] + (i + size) * sizeof(scalar_t));
}
#endif
for (; i <= n - size; i += size) {
auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
auto out1 = c10::guts::apply(vop, std::move(args1));
out1.store(data[0] + i * sizeof(scalar_t));
}
if (i < n) {
int64_t strides[ntensors];

View File

@ -80,7 +80,7 @@ inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n,
template <typename func_t, typename vec_func_t>
inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
VEC_LOOP_HEADER(func_t, data)
constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
const int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
int64_t count = n / (4 * Vec::size());
if (count > 0) {
vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true);
@ -96,7 +96,7 @@ inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_
VEC_LOOP_HEADER(func_t, data)
// reduce down each column of 4 * Vec::size() elements.
constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
const int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
int64_t outer_stride[2] = { vector_stride, vector_stride };
UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false);

View File

@ -154,8 +154,8 @@ inline void map_acc(
using Vec = vec::Vectorized<scalar_t>;
using aVec = vec::Vectorized<accumut>;
int64_t d = 0;
constexpr int64_t kVecSize = Vec::size();
constexpr int64_t kaVecSize = aVec::size();
const int64_t kVecSize = Vec::size();
const int64_t kaVecSize = aVec::size();
for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
Vec data2_vec = Vec::loadu(input_data2 + d);
auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);

View File

@ -22,8 +22,8 @@ inline namespace CPU_CAPABILITY {
constexpr auto kF32RegisterPairsPerIteration = 4;
constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
constexpr auto kF32ElementsPerRegister = vec::Vectorized<float>::size();
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;
const auto kF32ElementsPerRegister = vec::Vectorized<float>::size();
const auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;
namespace {
template <typename T>
@ -150,16 +150,16 @@ float reduce(vec::VectorizedN<float, kF32RegistersPerIteration>& x) {
// BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path
// working.
#if __ARM_FEATURE_BF16_VECTOR_ARITHMETIC
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(CPU_CAPABILITY_SVE256) && defined(__clang__) && __clang_major__ > 15
// https://godbolt.org/z/z8P4Yncra
#define COMPILER_SUPPORTS_BF16_TARGET 1
#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
// https://gcc.gnu.org/gcc-10/changes.html
// https://godbolt.org/z/cdGG7vn8o
#define COMPILER_SUPPORTS_BF16_TARGET 1
#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
#define COMPILER_SUPPORTS_BF16_TARGET 0
#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
#else // __ARM_FEATURE_BF16_VECTOR_ARITHMETIC
#define COMPILER_SUPPORTS_BF16_TARGET 0
#endif // __ARM_FEATURE_BF16_VECTOR_ARITHMETIC
@ -212,7 +212,7 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
const vec::Vectorized<c10::Half>& b,
const vec::Vectorized<float>& acc_low,
const vec::Vectorized<float>& acc_high) {
#if defined(__ARM_FEATURE_FP16_FML) && !defined(CPU_CAPABILITY_SVE)
#if defined(__ARM_FEATURE_FP16_FML) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
return std::make_pair(vfmlalq_low_f16(acc_low, a, b), vfmlalq_high_f16(acc_high, a, b));
#else
const auto [a_float_low, a_float_high] = convert_half_float(a);

View File

@ -28,8 +28,8 @@ inline void _update(at::opmath_type<scalar_t>* out_ptr, int64_t e, int64_t c, co
using opmath_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<scalar_t>;
using aVec = VecType<scalar_t>;
constexpr int64_t kVecSize = Vec::size();
constexpr int64_t kVLEN = kVecSize * 4;
const int64_t kVecSize = Vec::size();
const int64_t kVLEN = kVecSize * 4;
int64_t k = 0;
aVec val_vec = aVec((opmath_t)val);

View File

@ -21,11 +21,11 @@ Vectorized<acc_t> load_reduce_vec(const scalar_t* data, F reduce, acc_t ident) {
using vacc_t = Vectorized<acc_t>;
static_assert(vacc_t::size() <= vec_t::size());
const auto val = vec_t::loadu(data);
alignas(64) std::array<scalar_t, vec_t::size()> values;
val.store(values.data());
alignas(64) scalar_t values[vec_t::size()];
val.store(values);
constexpr int vstride = vec_t::size() / vacc_t::size();
alignas(64) std::array<acc_t, vacc_t::size()> acc;
alignas(64) acc_t acc[vacc_t::size()];
acc.fill(ident);
for (const auto k : c10::irange(vstride)) {
for (const auto i : c10::irange(vacc_t::size())) {
@ -33,7 +33,7 @@ Vectorized<acc_t> load_reduce_vec(const scalar_t* data, F reduce, acc_t ident) {
}
}
return vacc_t::loadu(acc.data());
return vacc_t::loadu(acc);
}
template <typename scalar_t>
@ -138,7 +138,7 @@ struct OuterSumCastLoadPolicy <vec_t, vacc_t,
using scalar_t = vechold_type<vec_t>;
using acc_t = vechold_type<vacc_t>;
static constexpr int64_t memsize() {
static int64_t memsize() {
return sizeof(scalar_t) * vacc_t::size();
}
@ -161,7 +161,7 @@ template <typename vec_t, typename vacc_t>
struct OuterSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_floating_point_v<vechold_type<vec_t>>>> {
using scalar_t = vechold_type<vec_t>;
static constexpr int64_t memsize() {
static int64_t memsize() {
return sizeof(scalar_t) * vacc_t::size();
}
@ -198,7 +198,7 @@ template <typename scalar_t>
struct NanSumLoadPolicy<Vectorized<scalar_t>> {
using vec_t = Vectorized<scalar_t>;
static constexpr int64_t memsize() {
static int64_t memsize() {
return LoadPolicy<vec_t>::memsize();
}
@ -267,7 +267,7 @@ struct InnerNanSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_flo
template <typename vec_t, typename vacc_t>
struct OuterNanSumCastLoadPolicy {
static constexpr int64_t memsize() {
static int64_t memsize() {
return OuterSumCastLoadPolicy<vec_t, vacc_t>::memsize();
}
@ -300,13 +300,23 @@ static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
}
}
template <typename StorePolicy, typename scalar_t>
static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
const scalar_t *values, size_t numel) {
auto *base_ptr = data + stride * index;
for (const auto k : c10::irange(numel)) {
auto val = values[k];
StorePolicy::store(base_ptr, stride, k, val);
}
}
template <typename StorePolicy, typename scalar_t>
static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
const Vectorized<scalar_t> &values) {
using vec_t = Vectorized<scalar_t>;
alignas(64) std::array<scalar_t, vec_t::size()> array_values{};
values.store(array_values.data());
store<StorePolicy>(data, stride, index, array_values);
alignas(64) scalar_t array_values[vec_t::size()] = {};
values.store(array_values);
store<StorePolicy, scalar_t>(data, stride, index, array_values, vec_t::size());
}
/** Simultaneously sum over n rows at once
@ -436,9 +446,9 @@ void vectorized_inner_sum(
char * C10_RESTRICT data[2], int64_t outer_stride, int64_t out_stride,
int64_t size0, int64_t size1) {
using vacc_t = Vectorized<acc_t>;
constexpr int64_t vec_stride = VecLoadPolicy::memsize();
constexpr int64_t scalar_stride = ScalarLoadPolicy::memsize();
constexpr int64_t vec_numel = vec_stride / scalar_stride;
const int64_t vec_stride = VecLoadPolicy::memsize();
const int64_t scalar_stride = ScalarLoadPolicy::memsize();
const int64_t vec_numel = vec_stride / scalar_stride;
const int64_t vec_size = size0 / vec_numel;
// Input is contiguous over the first (reduced) dimension
@ -451,9 +461,9 @@ void vectorized_inner_sum(
final_acc += ScalarLoadPolicy::load(row_in, scalar_stride, k);
}
alignas(64) std::array<acc_t, vacc_t::size()> partials{};
vec_acc.store(partials.data());
for (const auto k : c10::irange(partials.size())) {
alignas(64) acc_t partials[vacc_t::size()] = {};
vec_acc.store(partials);
for (const auto k : c10::irange(vacc_t::size())) {
final_acc += partials[k];
}
store<StorePolicy>(data[0], out_stride, j, final_acc);
@ -479,7 +489,7 @@ void vectorized_outer_sum(
int64_t size0, int64_t size1) {
using vacc_t = Vectorized<acc_t>;
constexpr int64_t scalar_stride = ScalarLoadPolicy::memsize();
constexpr int64_t vec_stride = VecLoadPolicy::memsize();
const int64_t vec_stride = VecLoadPolicy::memsize();
constexpr int64_t nrows = 4;
// Input is contiguous over the second (non-reduced) dimension

View File

@ -93,7 +93,7 @@ ColumnwiseMoments(
int64_t C,
int64_t D) {
using Vec = vec::Vectorized<T>;
constexpr int64_t K = Vec::size();
const int64_t K = Vec::size();
const int64_t inner_size = D / K * K;
Vec acc0_vec{0}, acc1_vec{0};
for (const auto m : c10::irange(HxW)) {
@ -668,20 +668,20 @@ void GroupNormInputBackward(
const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
const bool gamma_null = (gamma == nullptr);
at::parallel_for(0, N * G, 1, [=](int64_t start, int64_t end) {
constexpr int64_t K = vec::Vectorized<PT>::size();
const int64_t K = vec::Vectorized<PT>::size();
const int64_t d = D / K * K;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> ds_arr;
opmath_t ds_arr[at::vec::Vectorized<opmath_t>::size()];
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> db_arr;
opmath_t db_arr[at::vec::Vectorized<opmath_t>::size()];
for (const auto i : c10::irange(start, end)) {
const int64_t g = i % G;
const opmath_t* ds_ptr = ds + i * D;
const opmath_t* db_ptr = db + i * D;
const PT* gamma_ptr = gamma_null ? nullptr : (gamma + g * D);
CalcDsDb(ds_ptr, db_ptr, gamma_ptr, d, K, ds_arr.data(), db_arr.data());
opmath_t ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), opmath_t(0));
opmath_t db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), opmath_t(0));
CalcDsDb(ds_ptr, db_ptr, gamma_ptr, d, K, ds_arr, db_arr);
opmath_t ds_val = std::accumulate(&ds_arr[0], &ds_arr[at::vec::Vectorized<opmath_t>::size()], opmath_t(0));
opmath_t db_val = std::accumulate(&db_arr[0], &db_arr[at::vec::Vectorized<opmath_t>::size()], opmath_t(0));
for (const auto j : c10::irange(d, D)) {
const opmath_t gamma_v = gamma_null ? opmath_t(1) : opmath_t(gamma[g * D + j]);
ds_val += ds_ptr[j] * gamma_v;
@ -718,7 +718,7 @@ GammaBackward(
PT* dgamma) {
const int64_t G = group;
const int64_t D = C / G;
constexpr int64_t K = at::vec::Vectorized<PT>::size();
const int64_t K = at::vec::Vectorized<PT>::size();
using Vec = at::vec::Vectorized<PT>;
const int64_t inner_size = D / K * K;
for (const auto g : c10::irange(G)) {
@ -818,7 +818,7 @@ template <typename PT, typename opmath_t>
std::enable_if_t<std::is_same_v<PT, opmath_t>, void>
BetaBackward(int64_t N, int64_t C, const opmath_t* db, PT* dbeta) {
using Vec = at::vec::Vectorized<PT>;
constexpr int64_t K = Vec::size();
const int64_t K = Vec::size();
Vec acc_vec{0}, zero{0};
const int64_t inner_size = C / K * K;
int64_t i = 0;
@ -943,7 +943,7 @@ DsDbRowwiseMomentsChannelsLast(
opmath_t* db_ptr,
int64_t C) {
using Vec = vec::Vectorized<T>;
constexpr int64_t K = vec::Vectorized<T>::size();
const int64_t K = vec::Vectorized<T>::size();
const int64_t inner_size = C / K * K;
int64_t d = 0;
for (; d < inner_size; d += K) {
@ -1247,7 +1247,7 @@ inline typename std::
int64_t D) {
using Vec = vec::Vectorized<T>;
const bool gamma_null = (gamma_ptr == nullptr);
constexpr int64_t K = Vec::size();
const int64_t K = Vec::size();
const int64_t inner_size = D / K * K;
int64_t d = 0;
opmath_t ds_gamma{0}, db_gamma{0};

View File

@ -625,7 +625,7 @@ void weight_to_int4pack_kernel(
int K = weight.size(1);
// 64 for avx512 and 32 for avx2/non-vectorized
constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
const int BLOCK_N = vec::Vectorized<float>::size() * 4;
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
// parallel on NB blocks
@ -713,7 +713,7 @@ void int4pack_mm_kernel_(
constexpr int BLOCK_M = 4;
// 64 for avx512 and 32 for avx2/non-vectorized
constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
const int BLOCK_N = vec::Vectorized<float>::size() * 4;
// 32, 64, 128, 256
const int BLOCK_K = qGroupSize;

View File

@ -109,8 +109,8 @@ template <typename T, int64_t kMaxDepth>
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
using math_t = opmath_t<T>;
constexpr int64_t kVecSize = vec::Vectorized<T>::size();
constexpr int64_t kAccVecSize = vec::Vectorized<math_t>::size();
const int64_t kVecSize = vec::Vectorized<T>::size();
const int64_t kAccVecSize = vec::Vectorized<math_t>::size();
const int64_t n = N / kVecSize;
const int64_t m = divup(n, kChunkSize);
const int64_t depth = utils::CeilLog2(m);
@ -155,10 +155,10 @@ std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, in
m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
}
std::array<math_t, kAccVecSize> m1_arr{};
std::array<math_t, kAccVecSize> m2_arr{};
m1_stk[0].store(m1_arr.data());
m2_stk[0].store(m2_arr.data());
math_t m1_arr[kAccVecSize] = {};
math_t m2_arr[kAccVecSize] = {};
m1_stk[0].store(m1_arr);
m2_stk[0].store(m2_arr);
int64_t m0 = 0;
math_t m1 = 0;
@ -182,7 +182,7 @@ std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, in
template <typename T>
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
using Vec = vec::Vectorized<T>;
constexpr int64_t kVecSize = Vec::size();
const int64_t kVecSize = Vec::size();
const int64_t n = N / kVecSize;
const int64_t m = divup(n, kChunkSize);
const int64_t depth = utils::CeilLog2(m);

View File

@ -165,6 +165,7 @@ REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_co
REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_SVE_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_SVE256_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
// _out variants can be shared between PocketFFT and MKL

View File

@ -142,7 +142,7 @@ Tensor qcat_nhwc_kernel(
continue;
}
constexpr auto VLEN = Vec::size();
const auto VLEN = Vec::size();
int64_t c = 0;
// Vectorized loop
@ -170,16 +170,16 @@ Tensor qcat_nhwc_kernel(
}
// Vectorized loop for channel between 8 and 32 (avx2)
constexpr auto kVLEN = Vectorized<float>::size();
const auto kVLEN = Vectorized<float>::size();
int64_t elem_size = curr_C - c;
if ((VLEN == 4 * kVLEN) && elem_size >= kVLEN) {
auto curr_scale_vec = Vectorized<float>(curr_scale);
auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
int64_t vec_num = elem_size / kVLEN;
std::array<typename scalar_t::underlying, VLEN> buf_in{};
memcpy(buf_in.data(), iptr + c, vec_num * kVLEN);
auto inp_vec = Vec::loadu(buf_in.data());
typename scalar_t::underlying buf_in[VLEN] = {};
memcpy(buf_in, iptr + c, vec_num * kVLEN);
auto inp_vec = Vec::loadu(buf_in);
auto float_values = inp_vec.dequantize(
curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
Vec::float_vec_return_type retvals;
@ -1487,7 +1487,7 @@ void _qmaxpool_2d_nhwc_kernel(
int64_t c = 0;
// Interleaved vector loop 4x
constexpr auto vec_width = Vectorized<scalar_t>::size();
const auto vec_width = Vectorized<scalar_t>::size();
for (; c + 4 * vec_width <= iC; c += 4 * vec_width) {
Vectorized<scalar_t> acc{
scalar_t(std::numeric_limits<scalar_t_underlying>::lowest())};
@ -1623,7 +1623,7 @@ void qmaxpool_3d_nthwc_kernel(
w_start += dW;
int64_t c = 0;
constexpr auto vec_width = Vectorized<scalar_t>::size();
const auto vec_width = Vectorized<scalar_t>::size();
// Vector loop
for (; c + vec_width <= iC; c += vec_width) {
Vectorized<scalar_t> acc{
@ -2449,7 +2449,7 @@ void q_batch_norm_kernel(
reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());
constexpr int kVLen = Vectorized<float>::size();
const int kVLen = Vectorized<float>::size();
const int64_t outer_size = N * HxW;
using Vec = Vectorized<scalar_t>;
// Hoisted variables
@ -2975,7 +2975,7 @@ void quantized_normalize_kernel(
float y_scale = Y->q_scale();
float y_inv_scale = 1.0f / y_scale;
constexpr int kFloatVLen = fVec::size();
const int kFloatVLen = fVec::size();
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
int64_t kNumIntVecInLayer = N / kIntVLen;
int64_t kNonVecRemInLayer = N % kIntVLen;
@ -3263,7 +3263,7 @@ void quantized_groupnorm_nhwc_kernel(
float y_scale = Y->q_scale();
float y_inv_scale = 1.0f / y_scale;
constexpr int kFloatVLen = fVec::size();
const int kFloatVLen = fVec::size();
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
int64_t channels_per_group = C / G;
int64_t HxW = N / channels_per_group;

View File

@ -27,6 +27,7 @@ REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_SVE_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
} // namespace at::native

View File

@ -161,6 +161,7 @@ REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_SVE_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel)
@ -168,6 +169,7 @@ REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_interse
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_SVE_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel)
@ -175,5 +177,6 @@ REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projectio
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_SVE_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
}

View File

@ -448,6 +448,7 @@ REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_SVE_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta)

View File

@ -134,7 +134,7 @@ namespace {
TYPED_TEST(Memory, UnAlignedLoadStore) {
using vec = TypeParam;
using VT = ValueType<TypeParam>;
constexpr size_t b_size = vec::size() * sizeof(VT);
const size_t b_size = vec::size() * sizeof(VT);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
CACHE_ALIGN unsigned char ref_storage[128 * b_size];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
@ -164,7 +164,7 @@ namespace {
for (size_t offset = 0; offset < b_size; offset += 1) {
unsigned char* p1 = ref_storage + offset;
unsigned char* p2 = storage + offset;
for (; p1 + b_size <= std::end(ref_storage); p1 += b_size, p2 += b_size) {
for (; p1 + b_size <= &ref_storage[128 * b_size]; p1 += b_size, p2 += b_size) {
vec v = vec::loadu(p1);
v.store(p2);
}
@ -381,7 +381,7 @@ namespace {
TYPED_TEST(Hyperbolic, Tanh) {
using vec = TypeParam;
// NOTE: Because SVE uses ACL logic, the precision changes, hence the adjusted tolerance.
#if defined(CPU_CAPABILITY_SVE)
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE256)
using UVT = UvalueType<vec>;
UVT tolerance = getDefaultTolerance<UVT>();
test_unary<vec>(
@ -586,7 +586,7 @@ namespace {
}
}
}
#if defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16)
#if (defined(CPU_CAPABILITY_SVE256)) && defined(__ARM_FEATURE_BF16)
TEST(NanBfloat16, IsNan) {
for (unsigned int ii = 0; ii < 0xFFFF; ++ii) {
c10::BFloat16 val(ii, c10::BFloat16::from_bits());
@ -598,6 +598,19 @@ namespace {
}
}
}
#endif
#if (defined(CPU_CAPABILITY_SVE)) && defined(__ARM_FEATURE_BF16)
TEST(NanBfloat16, IsNan) {
for (unsigned int ii = 0; ii < 0xFFFF; ++ii) {
c10::BFloat16 val(ii, c10::BFloat16::from_bits());
bool expected = std::isnan(val);
CACHE_ALIGN c10::BFloat16 actual_vals[at::vec::SVE::Vectorized<c10::BFloat16>::size()];
at::vec::SVE::Vectorized<c10::BFloat16>(val).isnan().store(actual_vals);
for (int jj = 0; jj < at::vec::SVE::Vectorized<c10::BFloat16>::size(); ++jj) {
EXPECT_EQ(expected, c10::bit_cast<uint16_t>(actual_vals[jj]) != 0) << "bf16 isnan failure for bit pattern " << std::hex << ii << std::dec;
}
}
}
#endif
TYPED_TEST(LGamma, LGamma) {
using vec = TypeParam;
@ -653,7 +666,7 @@ namespace {
TYPED_TEST(Interleave, Interleave) {
using vec = TypeParam;
using VT = ValueType<TypeParam>;
constexpr auto N = vec::size() * 2LL;
const auto N = vec::size() * 2LL;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
CACHE_ALIGN VT vals[N];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
@ -663,7 +676,7 @@ namespace {
for (VT& v : vals) {
v = generator.get();
}
copy_interleave(vals, interleaved);
copy_interleave<VT>(vals, interleaved, N);
auto a = vec::loadu(vals);
auto b = vec::loadu(vals + vec::size());
auto cc = interleave2(a, b);
@ -673,7 +686,7 @@ namespace {
TYPED_TEST(Interleave, DeInterleave) {
using vec = TypeParam;
using VT = ValueType<TypeParam>;
constexpr auto N = vec::size() * 2LL;
const auto N = vec::size() * 2LL;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
CACHE_ALIGN VT vals[N];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
@ -683,7 +696,7 @@ namespace {
for (VT& v : vals) {
v = generator.get();
}
copy_interleave(vals, interleaved);
copy_interleave<VT>(vals, interleaved, N);
// test interleaved with vals this time
auto a = vec::loadu(interleaved);
auto b = vec::loadu(interleaved + vec::size());
@ -1017,78 +1030,70 @@ namespace {
RESOLVE_OVERLOAD(filter_fmadd));
}
#endif
template<typename vec, typename VT, int64_t mask>
typename std::enable_if_t<(mask < 0 || mask> 255), void>
template<typename vec, typename VT>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()])
{
void test_blend(VT * expected_val, VT * a, VT * b, int64_t mask) {
if (mask >= 0 && mask <= 255) {
// generate expected_val
int64_t m = mask;
for (int64_t i = 0; i < vec::size(); i++) {
expected_val[i] = (m & 0x01) ? b[i] : a[i];
m = m >> 1;
}
// test with blend
auto vec_a = vec::loadu(a);
auto vec_b = vec::loadu(b);
auto expected = vec::loadu(expected_val);
auto actual = vec::blend(vec_a, vec_b, mask);
auto mask_str = std::string("\nblend mask: ") + std::to_string(mask);
if (AssertVectorized<vec>(std::string(NAME_INFO(test_blend)) + mask_str, expected, actual).check()) return;
test_blend<vec, VT>(expected_val, a, b, mask - 1);
}
}
template<typename vec, typename VT, int64_t mask>
typename std::enable_if_t<(mask >= 0 && mask <= 255), void>
template<typename vec, typename VT>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()]) {
// generate expected_val
int64_t m = mask;
for (int64_t i = 0; i < vec::size(); i++) {
expected_val[i] = (m & 0x01) ? b[i] : a[i];
m = m >> 1;
}
// test with blend
auto vec_a = vec::loadu(a);
auto vec_b = vec::loadu(b);
auto expected = vec::loadu(expected_val);
auto actual = vec::template blend<mask>(vec_a, vec_b);
auto mask_str = std::string("\nblend mask: ") + std::to_string(mask);
if (AssertVectorized<vec>(std::string(NAME_INFO(test_blend)) + mask_str, expected, actual).check()) return;
test_blend<vec, VT, mask - 1>(expected_val, a, b);
bool test_blendv(VT * expected_val, VT * a, VT * b, VT * mask, int64_t idx, size_t N) {
if ((size_t) idx == N) {
using bit_rep = BitType<VT>;
// generate expected_val
for (int64_t i = 0; i < vec::size(); i++) {
bit_rep hex_mask = 0;
hex_mask=c10::bit_cast<bit_rep>(mask[i]);
expected_val[i] = (hex_mask & 0x01) ? b[i] : a[i];
}
// test with blendv
auto vec_a = vec::loadu(a);
auto vec_b = vec::loadu(b);
auto vec_m = vec::loadu(mask);
auto expected = vec::loadu(expected_val);
auto actual = vec::blendv(vec_a, vec_b, vec_m);
auto mask_str = std::string("\nblendv mask: ");
for (int64_t i = 0; i < vec::size(); i++) {
mask_str += std::to_string(mask[i]) + " ";
}
if (AssertVectorized<vec>(std::string(NAME_INFO(test_blendv)) + mask_str, expected, actual).check()) {
return false;
}
return true;
} else {
// shuffle mask and do blendv test
VT m = mask[idx];
if (!test_blendv<vec, VT>(expected_val, a, b, mask, idx+1, N)) return false;
if (m != (VT)0) {
mask[idx] = (VT)0;
}
else {
uint64_t hex_mask = 0xFFFFFFFFFFFFFFFF;
std::memcpy(&mask[idx], &hex_mask, sizeof(VT));
}
if (!test_blendv<vec, VT>(expected_val, a, b, mask, idx+1, N)) return false;
mask[idx] = m;
return true;
}
}
template<typename vec, typename VT, int64_t idx, int64_t N>
std::enable_if_t<(!is_complex<VT>::value && idx == N), bool>
template<typename T>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
test_blendv(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], VT mask[vec::size()]) {
using bit_rep = BitType<VT>;
// generate expected_val
for (int64_t i = 0; i < vec::size(); i++) {
bit_rep hex_mask = 0;
hex_mask=c10::bit_cast<bit_rep>(mask[i]);
expected_val[i] = (hex_mask & 0x01) ? b[i] : a[i];
}
// test with blendv
auto vec_a = vec::loadu(a);
auto vec_b = vec::loadu(b);
auto vec_m = vec::loadu(mask);
auto expected = vec::loadu(expected_val);
auto actual = vec::blendv(vec_a, vec_b, vec_m);
auto mask_str = std::string("\nblendv mask: ");
for (int64_t i = 0; i < vec::size(); i++) {
mask_str += std::to_string(mask[i]) + " ";
}
if (AssertVectorized<vec>(std::string(NAME_INFO(test_blendv)) + mask_str, expected, actual).check()) {
return false;
}
return true;
}
template<typename vec, typename VT, int64_t idx, int64_t N>
std::enable_if_t<(!is_complex<VT>::value && idx != N), bool>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
test_blendv(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], VT mask[vec::size()]) {
// shuffle mask and do blendv test
VT m = mask[idx];
if (!test_blendv<vec, VT, idx+1, N>(expected_val, a, b, mask)) return false;
if (m != (VT)0) {
mask[idx] = (VT)0;
}
else {
uint64_t hex_mask = 0xFFFFFFFFFFFFFFFF;
std::memcpy(&mask[idx], &hex_mask, sizeof(VT));
}
if (!test_blendv<vec, VT, idx+1, N>(expected_val, a, b, mask)) return false;
mask[idx] = m;
return true;
}
template<typename T, int N>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
void blend_init(T(&a)[N], T(&b)[N]) {
void blend_init(T * a, T * b, int N) {
a[0] = (T)1.0;
b[0] = a[0] + (T)N;
for (const auto i : c10::irange(1, N)) {
@ -1107,8 +1112,8 @@ namespace {
CACHE_ALIGN VT mask[vec::size()] = {0};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
CACHE_ALIGN VT expected_val[vec::size()];
blend_init(a, b);
test_blendv<vec, VT, 0, vec::size()>(expected_val, a, b, mask);
blend_init(a, b, vec::size());
test_blendv<vec, VT>(expected_val, a, b, mask, 0, vec::size());
}
TYPED_TEST(BitwiseFloatsAdditional2, Blend) {
using vec = TypeParam;
@ -1119,9 +1124,9 @@ namespace {
CACHE_ALIGN VT b[vec::size()];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
CACHE_ALIGN VT expected_val[vec::size()];
blend_init(a, b);
constexpr int64_t power_sets = 1LL << (vec::size());
test_blend<vec, VT, power_sets - 1>(expected_val, a, b);
blend_init(a, b, vec::size());
const int64_t power_sets = 1LL << (vec::size());
test_blend<vec, VT>(expected_val, a, b, power_sets - 1);
}
template<typename vec, typename VT>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
@ -1152,7 +1157,7 @@ namespace {
CACHE_ALIGN VT b[vec::size()];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
CACHE_ALIGN VT expected_val[vec::size()];
blend_init(a, b);
blend_init(a, b, vec::size());
test_set<vec, VT>(expected_val, a, b, vec::size());
}
template<typename T>
@ -1218,7 +1223,7 @@ namespace {
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
constexpr int min_val = std::numeric_limits<underlying>::min();
constexpr int max_val = std::numeric_limits<underlying>::max();
constexpr int el_count = vfloat::size();
const int el_count = vfloat::size();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
CACHE_ALIGN float unit_float_vec[el_count];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
@ -1566,7 +1571,7 @@ namespace {
using vec = TypeParam;
using VT = ValueType<TypeParam>;
constexpr auto R = 2LL; // residual
constexpr auto N = vec::size() + R;
const auto N = vec::size() + R;
CACHE_ALIGN VT x1[N];
CACHE_ALIGN VT x2[N];
CACHE_ALIGN VT x3[N];
@ -2130,7 +2135,7 @@ namespace {
ASSERT_TRUE(vec_pinf.has_inf_nan()) << "Test failed for positive Infinity\n";
ASSERT_TRUE(vec_ninf.has_inf_nan()) << "Test failed for negative Infinity\n";
}
#if !defined(CPU_CAPABILITY_SVE)
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
template <typename vec, typename dst_t>
void test_convert_to(const char* dst_t_name) {
using src_t = ValueType<vec>;
@ -2213,13 +2218,13 @@ namespace {
TYPED_TEST(VecMaskTests, MaskedLoad) {
using vec = TypeParam;
using src_t = ValueType<TypeParam>;
constexpr auto size = vec::size();
const auto size = vec::size();
#define TEST_MASK_LOAD(dst_t, mask_t, mask_n) \
do { \
constexpr int dst_size = at::vec::Vectorized<dst_t>::size(); \
constexpr int dst_n = mask_n * size / dst_size; \
if constexpr(dst_n * dst_size >= mask_n * size) { \
int dst_size = at::vec::Vectorized<dst_t>::size(); \
int dst_n = mask_n * size / dst_size; \
if (dst_n * dst_size >= mask_n * size) { \
CACHE_ALIGN dst_t x[mask_n * size]; \
CACHE_ALIGN dst_t y[mask_n * size]; \
CACHE_ALIGN dst_t ref[mask_n * size]; \
@ -2230,9 +2235,47 @@ namespace {
x[i] = generator.get(); \
} \
auto vec_mask = generate_vec_mask<mask_t, mask_n>(seed); \
constexpr int rnd_n = (mask_n * size + dst_size - 1) / dst_size;\
auto x_vec = vec_mask.template loadu<dst_t, rnd_n>(x); \
x_vec.store(y); \
int rnd_n = (mask_n * size + dst_size - 1) / dst_size;\
switch (rnd_n) { \
case 1: \
{ \
auto x_vec = vec_mask.template loadu<dst_t, 1>(x); \
x_vec.store(y); \
break; \
} \
case 2: \
{ \
auto x_vec = vec_mask.template loadu<dst_t, 2>(x); \
x_vec.store(y); \
break; \
} \
case 3: \
{ \
auto x_vec = vec_mask.template loadu<dst_t, 3>(x); \
x_vec.store(y); \
break; \
} \
case 4: \
{ \
auto x_vec = vec_mask.template loadu<dst_t, 4>(x); \
x_vec.store(y); \
break; \
} \
case 8: \
{ \
auto x_vec = vec_mask.template loadu<dst_t, 8>(x); \
x_vec.store(y); \
break; \
} \
case 16: \
{ \
auto x_vec = vec_mask.template loadu<dst_t, 16>(x); \
x_vec.store(y); \
break; \
} \
default: \
throw std::out_of_range("Unexpected rnd_n call to vec_mask"); \
} \
for (const auto i : c10::irange(mask_n * size)) { \
if (vec_mask.is_masked(i)) { \
ref[i] = x[i]; \
@ -2269,7 +2312,7 @@ namespace {
#undef TEST_MASK_LOAD
#undef TEST_MASK_LOAD_N
}
#if !defined(CPU_CAPABILITY_SVE)
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
TYPED_TEST(VecMaskTests, MaskedCheck) {
using VT = ValueType<TypeParam>;
using vec = TypeParam;
@ -2294,7 +2337,7 @@ namespace {
#undef TEST_MASK_CHECK_N
}
#endif
#if !defined(CPU_CAPABILITY_SVE)
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
TYPED_TEST(VecMaskTests, ToFrom) {
using vec = TypeParam;
using VT = ValueType<TypeParam>;
@ -2321,7 +2364,7 @@ namespace {
}
}
#endif
#if !defined(CPU_CAPABILITY_SVE)
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
TYPED_TEST(VecMaskTests, Cast) {
using vec = TypeParam;
using src_t = ValueType<TypeParam>;

View File

@ -56,7 +56,7 @@ CACHE_ALIGN #define
defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__))
#undef CHECK_DEQUANT_WITH_LOW_PRECISION
#define CHECK_WITH_FMA 1
#elif defined(CPU_CAPABILITY_SVE)
#elif defined(CPU_CAPABILITY_SVE256)
#define CHECK_DEQUANT_WITH_LOW_PRECISION 1
#define CHECK_WITH_FMA 1
#elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2)
@ -136,7 +136,7 @@ template<typename T>
struct VecTypeHelper {
using holdType = typename T::value_type;
using memStorageType = typename T::value_type;
static constexpr int holdCount = T::size();
static inline int holdCount = T::size();
static constexpr int unitStorageCount = 1;
};
@ -399,9 +399,9 @@ T clamp_min(const T& a, const T& min) {
return a < min ? min : a;
}
template <class VT, size_t N>
void copy_interleave(VT(&vals)[N], VT(&interleaved)[N]) {
static_assert(N % 2 == 0, "should be even");
template <class VT>
void copy_interleave(VT * vals, VT * interleaved, size_t N) {
assert(N % 2 == 0);
auto ptr1 = vals;
auto ptr2 = vals + N / 2;
for (size_t i = 0; i < N; i += 2) {
@ -871,10 +871,10 @@ public:
using UVT = UvalueType<T>;
using BVT = BitType<UVT>;
UVT absErr = correctEpsilon(toleranceEps);
constexpr int sizeX = VecTypeHelper<T>::holdCount * VecTypeHelper<T>::unitStorageCount;
const int sizeX = VecTypeHelper<T>::holdCount * VecTypeHelper<T>::unitStorageCount;
constexpr int unitStorageCount = VecTypeHelper<T>::unitStorageCount;
CACHE_ALIGN UVT expArr[sizeX];
CACHE_ALIGN UVT actArr[sizeX];
UVT expArr[sizeX];
UVT actArr[sizeX];
exp.store(expArr);
act.store(actArr);
if (bitwise)
@ -942,7 +942,7 @@ void test_unary(
using vec_type = T;
using VT = ValueType<T>;
using UVT = UvalueType<T>;
constexpr int el_count = vec_type::size();
const int el_count = vec_type::size();
CACHE_ALIGN VT vals[el_count];
CACHE_ALIGN VT expected[el_count];
bool bitwise = testCase.isBitwise();
@ -1000,7 +1000,7 @@ void test_binary(
using vec_type = T;
using VT = ValueType<T>;
using UVT = UvalueType<T>;
constexpr int el_count = vec_type::size();
const int el_count = vec_type::size();
CACHE_ALIGN VT vals0[el_count];
CACHE_ALIGN VT vals1[el_count];
CACHE_ALIGN VT expected[el_count];
@ -1163,7 +1163,7 @@ void test_ternary(
using vec_type = T;
using VT = ValueType<T>;
using UVT = UvalueType<T>;
constexpr int el_count = vec_type::size();
const int el_count = vec_type::size();
CACHE_ALIGN VT vals0[el_count];
CACHE_ALIGN VT vals1[el_count];
CACHE_ALIGN VT vals2[el_count];
@ -1203,12 +1203,15 @@ void test_ternary(
auto input1 = vec_type::loadu(vals1);
auto input2 = vec_type::loadu(vals2);
auto actual = actualFunction(input0, input1, input2);
CACHE_ALIGN VT actual_[vec_type::size()];
actual.store(actual_);
auto vec_expected = vec_type::loadu(expected);
AssertVectorized<vec_type> vecAssert(
testNameInfo, seed, vec_expected, actual, input0, input1, input2);
if (vecAssert.check(
bitwise, dmn.CheckWithTolerance, dmn.ToleranceError))
return;
return;
} // trial
changeSeedBy += 1;
}
@ -1573,19 +1576,19 @@ double getDefaultTolerance() {
template<typename T, int N = 1>
at::vec::VecMask<T, N> create_vec_mask(uint64_t bitmask) {
constexpr auto size = at::vec::Vectorized<T>::size();
std::array<int, N * size> mask;
const auto size = at::vec::Vectorized<T>::size();
int mask[N * size];
for (int n = 0; n < N; n++) {
for (int i = 0; i < size; i++) {
mask[n * size + i] = (bitmask >> i) & 1;
}
}
return at::vec::VecMask<T, N>::from(mask.data());
return at::vec::VecMask<T, N>::from(mask);
}
template<typename T, int N = 1>
at::vec::VecMask<T, N> generate_vec_mask(int seed) {
constexpr auto size = at::vec::Vectorized<T>::size();
const auto size = at::vec::Vectorized<T>::size();
ValueGen<uint64_t> generator(0, (1ULL << size) - 1, seed);
auto bitmask = generator.get();
return create_vec_mask<T, N>(bitmask);

View File

@ -393,15 +393,21 @@ if(INTERN_BUILD_ATEN_OPS)
LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}")
endif(CXX_ZVECTOR_FOUND)
if(CXX_SVE_FOUND AND CXX_SVE256_FOUND AND CXX_ARM_BF16_FOUND)
if(CXX_SVE_FOUND AND CXX_ARM_BF16_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION")
list(APPEND CPU_CAPABILITY_NAMES "SVE256")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION")
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8.2-a+sve+bf16 -msve-vector-bits=256 -D__ARM_FEATURE_BF16")
else()
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8.2-a+sve+bf16 -msve-vector-bits=256 -D__ARM_FEATURE_BF16")
endif()
endif()
list(APPEND CPU_CAPABILITY_NAMES "SVE")
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8.2-a+sve+bf16 -msve-vector-bits=128 -D__ARM_FEATURE_BF16")
else()
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8.2-a+sve+bf16 -msve-vector-bits=128 -D__ARM_FEATURE_BF16")
endif()
endif(CXX_SVE_FOUND)
list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES)
math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1")

View File

@ -152,19 +152,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
ENDMACRO()
# Check for SVE256 vector length
CHECK_COMPILES(CXX "SVE256" "-march=armv8.2-a+sve -msve-vector-bits=256" "${SVE_CODE}")
CHECK_COMPILES(CXX "ARM_BF16" "-march=armv8.2-a+sve+bf16 -msve-vector-bits=256" "${ARM_BF16_CODE}")
CHECK_COMPILES(CXX "SVE" "-march=armv8.2-a+sve" "${SVE_CODE}")
CHECK_COMPILES(CXX "ARM_BF16" "-march=armv8.2-a+sve+bf16" "${ARM_BF16_CODE}")
# If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user
if(NOT CXX_SVE256_FOUND)
set(CXX_SVE_FOUND FALSE CACHE BOOL "SVE not available on host")
message(STATUS "No SVE processor on this machine.")
else()
# If SVE256 support is found, set CXX_SVE_FOUND to TRUE and notify the user
set(CXX_SVE_FOUND TRUE CACHE BOOL "SVE available on host")
message(STATUS "SVE support detected.")
endif()
# Mark the SVE support variable as advanced
mark_as_advanced(CXX_SVE_FOUND)
ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux")

View File

@ -34,6 +34,7 @@ These backends include:
```{eval-rst}
.. autofunction:: torch.backends.cpu.get_cpu_capability
.. autofunction:: torch.backends.cpu.get_sve_len
```
## torch.backends.cuda

View File

@ -1166,6 +1166,7 @@ def _show_config() -> str: ... # THPModule_showConfig
def _cxx_flags() -> str: ... # THPModule_cxxFlags
def _parallel_info() -> str: ... # THPModule_parallelInfo
def _get_cpu_capability() -> str: ... # THPModule_getCpuCapability
def _get_sve_len() -> _int: ... # THPModule_getCpuSveLen
def _set_backcompat_broadcast_warn(
arg: _bool,
) -> None: ... # THPModule_setBackcompatBroadcastWarn

View File

@ -656,6 +656,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._get_constant_bool_symnode",
"torch._C._get_cpp_backtrace",
"torch._C._get_cpu_capability",
"torch._C._get_sve_len",
"torch._C._get_cublas_allow_bf16_reduced_precision_reduction",
"torch._C._get_cublas_allow_fp16_reduced_precision_reduction",
"torch._C._get_cublas_allow_tf32",

View File

@ -173,7 +173,7 @@ class VecSVE256(VecISA):
# this function can be repurposed for SVE with variable vec length
_bit_width = 256
_macro = [
"CPU_CAPABILITY_SVE",
"HAVE_SVE_CPU_DEFINITION",
"CPU_CAPABILITY_SVE256",
"AT_BUILD_ARM_VEC256_WITH_SLEEF",
"__ARM_FEATURE_BF16",
@ -189,6 +189,25 @@ class VecSVE256(VecISA):
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
@dataclasses.dataclass
class VecSVE(VecISA):
# this function can be repurposed for SVE with variable vec length
# _bit_width = torch.backends.cpu.get_sve_len() # disable for now as it is not working
_bit_width = 128
_macro = [
"HAVE_SVE_CPU_DEFINITION",
"CPU_CAPABILITY_SVE",
"AT_BUILD_ARM_VEC256_WITH_SLEEF",
"__ARM_FEATURE_BF16",
]
_arch_flags = "-march=armv8-a+sve+bf16 -msve-vector-bits=128"
_dtype_nelements = {torch.float: int(_bit_width / 32), torch.bfloat16: int(_bit_width / 16), torch.float16: int(_bit_width / 16)}
def __str__(self) -> str:
return "asimd"
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
@dataclasses.dataclass
class VecAVX512(VecISA):
@ -404,13 +423,7 @@ def x86_isa_checker() -> list[str]:
invalid_vec_isa = InvalidVecISA()
supported_vec_isa_list = [
VecAMX(),
VecAVX512(),
VecAVX2(),
VecNEON(),
VecSVE256(),
]
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE256(), VecSVE()]
def get_isa_from_cpu_capability(
@ -473,6 +486,8 @@ def valid_vec_isa_list() -> list[VecISA]:
elif arch == "aarch64":
if torch.backends.cpu.get_cpu_capability() == "SVE256":
isa_list.append(VecSVE256())
elif torch.backends.cpu.get_cpu_capability() == "SVE":
isa_list.append(VecSVE())
else:
isa_list.append(VecNEON())

View File

@ -3,6 +3,7 @@ import torch
__all__ = [
"get_cpu_capability",
"get_sve_len"
]
@ -16,6 +17,13 @@ def get_cpu_capability() -> str:
- "NO AVX"
- "AVX2"
- "AVX512"
- "SVE"
- "SVE256"
"""
return torch._C._get_cpu_capability()
def get_sve_len() -> str:
r"""Return the maximum supported SVE length in bits.
"""
return torch._C._get_sve_len()

View File

@ -585,6 +585,14 @@ static PyObject* THPModule_getCpuCapability(
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_getSveLen(
PyObject* module,
PyObject* noargs) {
HANDLE_TH_ERRORS
return THPUtils_packInt32(at::get_sve_len());
END_HANDLE_TH_ERRORS
}
namespace {
template <class T>
@ -1584,6 +1592,7 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
{"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr},
{"_parallel_info", THPModule_parallelInfo, METH_NOARGS, nullptr},
{"_get_cpu_capability", THPModule_getCpuCapability, METH_NOARGS, nullptr},
{"_get_sve_len", THPModule_getSveLen, METH_NOARGS, nullptr},
{"_set_backcompat_broadcast_warn",
THPModule_setBackcompatBroadcastWarn,
METH_O,

View File

@ -32,9 +32,7 @@
#include <c10/util/irange.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || \
defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || \
defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE256)
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE256)|| defined(CPU_CAPABILITY_SVE)
#define INDUCTOR_USE_VECTOR_TYPES() 1
#else
#define INDUCTOR_USE_VECTOR_TYPES() 0

View File

@ -462,6 +462,10 @@ RegisterOperators reg({
"aten::_get_cpu_capability() -> str",
[](Stack& stack) { push(stack, at::get_cpu_capability()); },
aliasAnalysisConservative()),
Operator(
"aten::_get_sve_len() -> int",
[](Stack& stack) { push(stack, at::get_sve_len()); },
aliasAnalysisConservative()),
});
} // namespace
} // namespace torch::jit

View File

@ -113,6 +113,7 @@ _builtin_ops = [
(torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
(torch._C._get_tracing_state, "aten::_get_tracing_state"),
(torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
(torch._C._get_sve_len, "aten::_get_sve_len"),
(warnings.warn, "aten::warn"),
(torch._VF.stft, "aten::stft"), # type: ignore[attr-defined]
(torch._VF.istft, "aten::istft"), # type: ignore[attr-defined]