mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Compare commits
9 Commits
ciflow/tru
...
sve-poc
| Author | SHA1 | Date | |
|---|---|---|---|
| 8ac81dba21 | |||
| dae9a71d99 | |||
| bf4b0e8c41 | |||
| 0384f48daa | |||
| 3b92a1adfe | |||
| 6ca9dc026d | |||
| a499828924 | |||
| e84eabd4f9 | |||
| 5e53e458b9 |
@ -10,6 +10,10 @@
|
||||
#include <ideep.hpp>
|
||||
#endif
|
||||
|
||||
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||
#include <cpuinfo.h>
|
||||
#endif
|
||||
|
||||
#include <caffe2/core/common.h>
|
||||
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
@ -103,7 +107,9 @@ std::string get_cpu_capability() {
|
||||
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
|
||||
case native::CPUCapability::ZVECTOR:
|
||||
return "Z VECTOR";
|
||||
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
#elif defined(HAVE_SVE_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
case native::CPUCapability::SVE:
|
||||
return "SVE";
|
||||
case native::CPUCapability::SVE256:
|
||||
return "SVE256";
|
||||
#else
|
||||
@ -118,6 +124,12 @@ std::string get_cpu_capability() {
|
||||
return "";
|
||||
}
|
||||
|
||||
int get_sve_len() {
|
||||
// It is possible that we override the cpu_capability with
|
||||
// environment variable
|
||||
return cpuinfo_get_max_arm_sve_length();
|
||||
}
|
||||
|
||||
static std::string used_cpu_capability() {
|
||||
// It is possible that we override the cpu_capability with
|
||||
// environment variable
|
||||
|
||||
@ -15,4 +15,6 @@ TORCH_API std::string get_cxx_flags();
|
||||
|
||||
TORCH_API std::string get_cpu_capability();
|
||||
|
||||
TORCH_API int get_sve_len();
|
||||
|
||||
} // namespace at
|
||||
|
||||
@ -34,9 +34,9 @@ inline scalar_t vec_reduce_all(
|
||||
scalar_t acc_arr[Vec::size()];
|
||||
acc_vec.store(acc_arr);
|
||||
for (const auto i : c10::irange(1, size)) {
|
||||
std::array<scalar_t, Vec::size()> acc_arr_next = {0};
|
||||
scalar_t acc_arr_next[Vec::size()] = {0};
|
||||
acc_arr_next[0] = acc_arr[i];
|
||||
Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
|
||||
Vec acc_vec_next = Vec::loadu(acc_arr_next);
|
||||
acc_vec = vec_fun(acc_vec, acc_vec_next);
|
||||
}
|
||||
acc_vec.store(acc_arr);
|
||||
@ -102,8 +102,7 @@ struct VecReduceAllSIMD<float, Op> {
|
||||
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
|
||||
// !defined(C10_MOBILE)
|
||||
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||
!defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
|
||||
template <typename Op>
|
||||
struct VecReduceAllSIMD<float, Op> {
|
||||
static inline float apply(
|
||||
@ -143,8 +142,7 @@ struct VecReduceAllSIMD<float, std::plus<Vectorized<float>>> {
|
||||
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
|
||||
// && !defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||
defined(CPU_CAPABILITY_SVE256)
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && (defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE))
|
||||
template <typename Op>
|
||||
struct VecReduceAllSIMD<float, Op> {
|
||||
static inline float apply(
|
||||
@ -152,18 +150,28 @@ struct VecReduceAllSIMD<float, Op> {
|
||||
const Vectorized<float>& acc_vec) {
|
||||
using Vec = Vectorized<float>;
|
||||
Vec v = acc_vec;
|
||||
// 128-bit shuffle
|
||||
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
|
||||
Vec v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
// 64-bit shuffle
|
||||
ind = svdupq_n_u32(2, 3, 0, 1);
|
||||
v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
// 32-bit shuffle
|
||||
ind = svdupq_n_u32(1, 0, 2, 3);
|
||||
v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
if (Vec::size() == 8) {
|
||||
// 128-bit shuffle
|
||||
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
|
||||
Vec v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
// 64-bit shuffle
|
||||
ind = svdupq_n_u32(2, 3, 0, 1);
|
||||
v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
// 32-bit shuffle
|
||||
ind = svdupq_n_u32(1, 0, 2, 3);
|
||||
v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
} else {
|
||||
svuint32_t ind = svdupq_n_u32(2, 3, 0, 1); // 64-bit stride-2
|
||||
Vec v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
|
||||
ind = svdupq_n_u32(1, 0, 2, 3); // 32-bit stride-1
|
||||
v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
}
|
||||
return svlasta(svpfalse(), v);
|
||||
}
|
||||
};
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
// Define the data type of VLS(vector-length specific).
|
||||
typedef svbool_t vls_pred_t
|
||||
@ -77,4 +77,4 @@ typedef svfloat64_t vls_float64_t
|
||||
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK)
|
||||
#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK)
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_SVE)
|
||||
#endif // defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
@ -19,7 +19,7 @@ namespace vec {
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
|
||||
#if (defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)) && defined(__ARM_FEATURE_BF16)
|
||||
|
||||
template <>
|
||||
struct is_vec_specialized_for<BFloat16> : std::bool_constant<true> {};
|
||||
@ -230,8 +230,6 @@ __attribute__((optimize("no-tree-vectorize")))
|
||||
#endif
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>>
|
||||
convert_bfloat16_float(const Vectorized<c10::BFloat16>& a) {
|
||||
static_assert(
|
||||
Vectorized<c10::BFloat16>::size() == 2 * Vectorized<float>::size());
|
||||
auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f));
|
||||
auto bf16_vec1 = svzip1_bf16(zero, a);
|
||||
auto bf16_vec2 = svzip2_bf16(zero, a);
|
||||
@ -243,19 +241,18 @@ convert_bfloat16_float(const Vectorized<c10::BFloat16>& a) {
|
||||
inline Vectorized<c10::BFloat16> convert_float_bfloat16(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
static_assert(
|
||||
Vectorized<c10::BFloat16>::size() == 2 * Vectorized<float>::size());
|
||||
svbfloat16_t x1 = svcvt_bf16_f32_z(ptrue, a);
|
||||
svbfloat16_t x2 = svcvt_bf16_f32_z(ptrue, b);
|
||||
return Vectorized<c10::BFloat16>(svuzp1_bf16(x1, x2));
|
||||
}
|
||||
|
||||
inline void load_fp32_from_bf16(const BFloat16* data, Vectorized<float>& out) {
|
||||
__at_align__ float values[Vectorized<float>::size()];
|
||||
__at_align__ float * values = new float[Vectorized<float>::size()];
|
||||
for (const auto k : c10::irange(Vectorized<float>::size())) {
|
||||
values[k] = data[k];
|
||||
}
|
||||
out = Vectorized<float>::loadu(values);
|
||||
delete[] values;
|
||||
}
|
||||
|
||||
inline void load_fp32_from_bf16(
|
||||
@ -308,8 +305,8 @@ Vectorized<c10::BFloat16> inline operator/(
|
||||
}
|
||||
|
||||
inline Vectorized<BFloat16>::Vectorized() {
|
||||
const short zero = 0;
|
||||
values = svdup_n_bf16(c10::bit_cast<bfloat16_t>(zero));
|
||||
auto vals_f = svdup_n_f32(0);
|
||||
values = convert_float_bfloat16(vals_f, vals_f);
|
||||
}
|
||||
|
||||
inline Vectorized<BFloat16>::Vectorized(int val) {
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
#include <ATen/cpu/vec/sve/sve_helper.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE256)
|
||||
#include <ATen/cpu/vec/sve/vec_bfloat16.h>
|
||||
#include <ATen/cpu/vec/sve/vec_double.h>
|
||||
#include <ATen/cpu/vec/sve/vec_float.h>
|
||||
@ -27,7 +27,7 @@ namespace at::vec {
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#define DEFINE_SVE_CAST(t1_t, t1_prefix, t2_t, t2_prefix) \
|
||||
@ -231,6 +231,5 @@ std::pair<
|
||||
#endif // __ARM_FEATURE_BF16
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
}
|
||||
|
||||
@ -22,7 +22,7 @@ namespace at::vec {
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
template <>
|
||||
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
|
||||
@ -55,10 +55,11 @@ class Vectorized<double> {
|
||||
operator svfloat64_t() const {
|
||||
return values;
|
||||
}
|
||||
template <uint64_t mask>
|
||||
static Vectorized<double> blend(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
const Vectorized<double>& b,
|
||||
int64_t mask
|
||||
) {
|
||||
// Build an array of flags: each element is 1 if the corresponding bit in
|
||||
// 'mask' is set, 0 otherwise.
|
||||
__at_align__ int64_t flag_arr[size()];
|
||||
|
||||
@ -2,8 +2,10 @@
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/sve/sve_helper.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
|
||||
#include <sleef.h>
|
||||
#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
|
||||
@ -22,7 +24,7 @@ namespace at::vec {
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE256)
|
||||
|
||||
template <>
|
||||
struct is_vec_specialized_for<float> : std::bool_constant<true> {};
|
||||
@ -30,52 +32,77 @@ struct is_vec_specialized_for<float> : std::bool_constant<true> {};
|
||||
template <>
|
||||
class Vectorized<float> {
|
||||
private:
|
||||
vls_float32_t values;
|
||||
|
||||
__at_align__ float values[2048 / sizeof(float)];
|
||||
public:
|
||||
|
||||
using value_type = float;
|
||||
using size_type = int;
|
||||
static constexpr size_type size() {
|
||||
return VECTOR_WIDTH / sizeof(float);
|
||||
static inline size_type size() {
|
||||
return svcntw();
|
||||
}
|
||||
Vectorized() {
|
||||
values = svdup_n_f32(0);
|
||||
inline Vectorized() {svst1_f32(ptrue, values, svdup_n_f32(0));}
|
||||
inline Vectorized(const float val) {
|
||||
svst1_f32(ptrue, values, svdup_n_f32(val));
|
||||
}
|
||||
Vectorized(svfloat32_t v) : values(v) {}
|
||||
Vectorized(float val) {
|
||||
values = svdup_n_f32(val);
|
||||
inline Vectorized(const svfloat32_t val) {
|
||||
svst1_f32(ptrue, values, val);
|
||||
}
|
||||
template <
|
||||
typename... Args,
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>>
|
||||
Vectorized(Args... vals) {
|
||||
__at_align__ float buffer[size()] = {vals...};
|
||||
values = svld1_f32(ptrue, buffer);
|
||||
template<typename T,
|
||||
typename = std::enable_if_t<std::is_pointer_v<T>>>
|
||||
inline Vectorized(float * val) {
|
||||
svst1_f32(ptrue, values, svld1_f32(ptrue, val));
|
||||
}
|
||||
operator svfloat32_t() const {
|
||||
return values;
|
||||
template<typename... Args,
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>>
|
||||
inline Vectorized(Args... vals) {
|
||||
values = { vals... };
|
||||
}
|
||||
template <uint64_t mask>
|
||||
static Vectorized<float> blend(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
// Build an array of flags: each element is 1 if the corresponding bit in
|
||||
// 'mask' is set, 0 otherwise.
|
||||
__at_align__ int32_t flag_arr[size()];
|
||||
inline operator svfloat32_t() const {
|
||||
return svld1_f32(ptrue, values);
|
||||
}
|
||||
static inline Vectorized<float> from_ptr(const float * vs) {
|
||||
Vectorized<float> v;
|
||||
svst1_f32(ptrue, v.values, svld1_f32(ptrue, static_cast<const float *>(vs)));
|
||||
return v;
|
||||
}
|
||||
static inline Vectorized<float> from_ptr(const float * vs, int count) {
|
||||
Vectorized<float> v;
|
||||
svst1_f32(ptrue, v.values, svld1_f32(svwhilelt_b32_s32(0, count), static_cast<const float *>(vs)));
|
||||
return v;
|
||||
}
|
||||
inline void set_lane(int i, float value) {
|
||||
values[i] = value;
|
||||
}
|
||||
inline Vectorized<float> map(float (*fn)(float)) const {
|
||||
Vectorized<float> result;
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
result.set_lane(i, fn(values[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
inline Vectorized<float> map2(float (*fn)(float, float), const Vectorized<float> &b) const {
|
||||
Vectorized<float> result;
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
result.set_lane(i, fn(values[i], b.values[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b, const uint64_t mask) {
|
||||
// Build an array of flags: each element is 1 if the corresponding bit in 'mask' is set, 0 otherwise.
|
||||
__at_align__ int32_t * flag_arr = new int32_t[size()];
|
||||
for (int i = 0; i < size(); i++) {
|
||||
flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0;
|
||||
}
|
||||
// Load the flag array into an SVE int32 vector.
|
||||
svint32_t int_mask = svld1_s32(svptrue_b32(), flag_arr);
|
||||
// Compare each lane of int_mask to 0; returns an svbool_t predicate where
|
||||
// true indicates a nonzero flag.
|
||||
svbool_t blend_mask = svcmpne_n_s32(svptrue_b32(), int_mask, 0);
|
||||
// Use svsel to select elements from b where the predicate is true, else
|
||||
// from a.
|
||||
svfloat32_t result = svsel_f32(blend_mask, b.values, a.values);
|
||||
return Vectorized<float>(result);
|
||||
svint32_t int_mask = svld1_s32(ptrue, flag_arr);
|
||||
delete[] flag_arr;
|
||||
// Compare each lane of int_mask to 0; returns an svbool_t predicate where true indicates a nonzero flag.
|
||||
svbool_t blend_mask = svcmpne_n_s32(ptrue, int_mask, 0);
|
||||
// Use svsel to select elements from b where the predicate is true, else from a.
|
||||
return svsel_f32(blend_mask, b, a);
|
||||
}
|
||||
static Vectorized<float> blendv(
|
||||
static inline Vectorized<float> blendv(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b,
|
||||
const Vectorized<float>& mask_) {
|
||||
@ -84,16 +111,18 @@ class Vectorized<float> {
|
||||
return svsel_f32(mask, b, a);
|
||||
}
|
||||
template <typename step_t>
|
||||
static Vectorized<float> arange(
|
||||
static inline Vectorized<float> arange(
|
||||
float base = 0.f,
|
||||
step_t step = static_cast<step_t>(1)) {
|
||||
__at_align__ float buffer[size()];
|
||||
__at_align__ float * buffer = new float[size()];
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
buffer[i] = base + i * step;
|
||||
}
|
||||
return svld1_f32(ptrue, buffer);
|
||||
auto tmp = Vectorized<float>::from_ptr(buffer);
|
||||
delete[] buffer;
|
||||
return tmp;
|
||||
}
|
||||
static Vectorized<float> set(
|
||||
static inline Vectorized<float> set(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b,
|
||||
int64_t count = size()) {
|
||||
@ -169,271 +198,213 @@ class Vectorized<float> {
|
||||
poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
|
||||
return poly;
|
||||
}
|
||||
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
|
||||
if (count == size())
|
||||
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
|
||||
svbool_t pg = svwhilelt_b32(0ull, count);
|
||||
return svld1_f32(pg, reinterpret_cast<const float*>(ptr));
|
||||
static inline Vectorized<float> loadu(const void* ptr) {
|
||||
return Vectorized<float>::from_ptr(reinterpret_cast<const float *>(ptr));
|
||||
}
|
||||
void store(void* ptr, int64_t count = size()) const {
|
||||
if (count == size()) {
|
||||
svst1_f32(ptrue, reinterpret_cast<float*>(ptr), values);
|
||||
} else {
|
||||
svbool_t pg = svwhilelt_b32(0ull, count);
|
||||
svst1_f32(pg, reinterpret_cast<float*>(ptr), values);
|
||||
}
|
||||
static inline Vectorized<float> loadu(const void* ptr, int64_t count) {
|
||||
return Vectorized<float>::from_ptr(reinterpret_cast<const float *>(ptr), count);
|
||||
}
|
||||
const float& operator[](int idx) const = delete;
|
||||
float& operator[](int idx) = delete;
|
||||
int64_t zero_mask() const {
|
||||
// returns an integer mask where all zero elements are translated to 1-bit
|
||||
// and others are translated to 0-bit
|
||||
inline void store(void* ptr) const {
|
||||
svst1_f32(ptrue, static_cast<float *>(ptr), svld1_f32(ptrue, values));
|
||||
}
|
||||
inline void store(void* ptr, int count) const {
|
||||
svst1_f32(svwhilelt_b32_s32(0, count), static_cast<float *>(ptr), svld1_f32(ptrue, values));
|
||||
}
|
||||
inline const float& operator[](int idx) const {
|
||||
return values[idx];
|
||||
};
|
||||
inline float& operator[](int idx) {
|
||||
return values[idx];
|
||||
};
|
||||
inline int64_t zero_mask() const {
|
||||
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
|
||||
int64_t mask = 0;
|
||||
__at_align__ int32_t mask_array[size()];
|
||||
__at_align__ int32_t * mask_array = new int32_t[size()];
|
||||
|
||||
svbool_t svbool_mask = svcmpeq_f32(ptrue, values, ZERO_F32);
|
||||
svst1_s32(
|
||||
ptrue,
|
||||
mask_array,
|
||||
svsel_s32(svbool_mask, ALL_S32_TRUE_MASK, ALL_S32_FALSE_MASK));
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
if (mask_array[i])
|
||||
mask |= (1ull << i);
|
||||
svbool_t svbool_mask = svcmpeq_f32(ptrue, *this, ZERO_F32);
|
||||
svst1_s32(ptrue, mask_array, svsel_s32(svbool_mask,
|
||||
ALL_S32_TRUE_MASK,
|
||||
ALL_S32_FALSE_MASK));
|
||||
for (int64_t j = 0; j < size(); ++j) {
|
||||
if (mask_array[j]) mask |= (1ull << j);
|
||||
}
|
||||
delete[] mask_array;
|
||||
return mask;
|
||||
}
|
||||
Vectorized<float> isnan() const {
|
||||
inline Vectorized<float> isnan() const {
|
||||
// NaN check
|
||||
svbool_t mask = svcmpuo_f32(ptrue, values, ZERO_F32);
|
||||
auto mask = svcmpuo_f32(ptrue, *this, ZERO_F32);
|
||||
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
}
|
||||
bool has_inf_nan() const {
|
||||
return svptest_any(
|
||||
ptrue,
|
||||
svcmpuo_f32(ptrue, svsub_f32_x(ptrue, values, values), ZERO_F32));
|
||||
inline bool has_inf_nan() const {
|
||||
return svptest_any(ptrue, svcmpuo_f32(ptrue, svsub_f32_x(ptrue, *this, *this), ZERO_F32));
|
||||
}
|
||||
Vectorized<float> map(float (*f)(float)) const {
|
||||
__at_align__ float tmp[size()];
|
||||
store(tmp);
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
tmp[i] = f(tmp[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
|
||||
inline Vectorized<float> abs() const {
|
||||
return svabs_f32_x(ptrue, *this);
|
||||
}
|
||||
Vectorized<float> abs() const {
|
||||
return svabs_f32_x(ptrue, values);
|
||||
}
|
||||
Vectorized<float> angle() const {
|
||||
inline Vectorized<float> angle() const {
|
||||
const auto nan_vec = svdup_n_f32(NAN);
|
||||
const auto nan_mask = svcmpuo_f32(ptrue, values, ZERO_F32);
|
||||
const auto nan_mask = svcmpuo_f32(ptrue, *this, ZERO_F32);
|
||||
const auto pi = svdup_n_f32(c10::pi<float>);
|
||||
|
||||
const auto neg_mask = svcmplt_f32(ptrue, values, ZERO_F32);
|
||||
const auto neg_mask = svcmplt_f32(ptrue, *this, ZERO_F32);
|
||||
auto angle = svsel_f32(neg_mask, pi, ZERO_F32);
|
||||
angle = svsel_f32(nan_mask, nan_vec, angle);
|
||||
return angle;
|
||||
return svsel_f32(nan_mask, nan_vec, angle);
|
||||
}
|
||||
Vectorized<float> real() const {
|
||||
return values;
|
||||
inline Vectorized<float> real() const {
|
||||
return *this;
|
||||
}
|
||||
Vectorized<float> imag() const {
|
||||
inline Vectorized<float> imag() const {
|
||||
return Vectorized<float>(0.f);
|
||||
}
|
||||
Vectorized<float> conj() const {
|
||||
return values;
|
||||
inline Vectorized<float> conj() const {
|
||||
return *this;
|
||||
}
|
||||
Vectorized<float> acos() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_acosfx_u10sve(values)), map(std::acos));
|
||||
inline Vectorized<float> acos() const {
|
||||
return USE_SLEEF(Sleef_acosfx_u10sve(*this), map(std::acos));
|
||||
}
|
||||
Vectorized<float> acosh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_acoshfx_u10sve(values)), map(std::acosh));
|
||||
inline Vectorized<float> acosh() const {
|
||||
return USE_SLEEF(Sleef_acoshfx_u10sve(*this), map(std::acosh));
|
||||
}
|
||||
Vectorized<float> asin() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_asinfx_u10sve(values)), map(std::asin));
|
||||
inline Vectorized<float> asin() const {
|
||||
return USE_SLEEF(Sleef_asinfx_u10sve(*this), map(std::asin));
|
||||
}
|
||||
Vectorized<float> asinh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_asinhfx_u10sve(values)), map(std::asinh));
|
||||
inline Vectorized<float> asinh() const {
|
||||
return USE_SLEEF(Sleef_asinhfx_u10sve(*this), map(std::asinh));
|
||||
}
|
||||
Vectorized<float> atan() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_atanfx_u10sve(values)), map(std::atan));
|
||||
inline Vectorized<float> atan() const {
|
||||
return USE_SLEEF(Sleef_atanfx_u10sve(*this), map(std::atan));
|
||||
}
|
||||
Vectorized<float> atanh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_atanhfx_u10sve(values)), map(std::atanh));
|
||||
inline Vectorized<float> atanh() const {
|
||||
return USE_SLEEF(Sleef_atanhfx_u10sve(*this), map(std::atanh));
|
||||
}
|
||||
Vectorized<float> atan2(const Vectorized<float>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_atan2fx_u10sve(values, b)); },
|
||||
{
|
||||
__at_align__ float tmp[size()];
|
||||
__at_align__ float tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<float> copysign(const Vectorized<float>& sign) const {
|
||||
|
||||
USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_copysignfx_sve(values, sign)); },
|
||||
{
|
||||
__at_align__ float tmp[size()];
|
||||
__at_align__ float tmp_sign[size()];
|
||||
store(tmp);
|
||||
sign.store(tmp_sign);
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<float> erf() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_erffx_u10sve(values)), map(std::erf));
|
||||
inline Vectorized<float> atan2(const Vectorized<float> &b) const {
|
||||
return USE_SLEEF(Sleef_atan2fx_u10sve(*this, b), map2(std::atan2, b));
|
||||
}
|
||||
Vectorized<float> erfc() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_erfcfx_u15sve(values)), map(std::erfc));
|
||||
inline Vectorized<float> copysign(const Vectorized<float> &sign) const {
|
||||
return USE_SLEEF(Sleef_copysignfx_sve(*this, sign), map2(std::copysign, sign));
|
||||
}
|
||||
Vectorized<float> erfinv() const {
|
||||
inline Vectorized<float> erf() const {
|
||||
return USE_SLEEF(Sleef_erffx_u10sve(*this), map(std::erf));
|
||||
}
|
||||
inline Vectorized<float> erfc() const {
|
||||
return USE_SLEEF(Sleef_erfcfx_u15sve(*this), map(std::erfc));
|
||||
}
|
||||
inline Vectorized<float> erfinv() const {
|
||||
return map(calc_erfinv);
|
||||
}
|
||||
Vectorized<float> exp() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_expfx_u10sve(values)), map(std::exp));
|
||||
inline Vectorized<float> exp() const {
|
||||
return USE_SLEEF(Sleef_expfx_u10sve(*this), map(std::exp));
|
||||
}
|
||||
Vectorized<float> exp2() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_exp2fx_u10sve(values)), map(std::exp2));
|
||||
inline Vectorized<float> exp2() const {
|
||||
return USE_SLEEF(Sleef_exp2fx_u10sve(*this), map(std::exp2));
|
||||
}
|
||||
Vectorized<float> expm1() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
|
||||
inline Vectorized<float> expm1() const {
|
||||
return USE_SLEEF(Sleef_expm1fx_u10sve(*this), map(std::expm1));
|
||||
}
|
||||
// Implementation copied from Arm Optimized Routines:
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
return exp();
|
||||
// special case to handle special inputs that are too large or too small
|
||||
// i.e. where there's at least one element x, s.t. |x| >= 87.3...
|
||||
svbool_t is_special_case = svacgt (svptrue_b32(), *this, 0x1.5d5e2ap+6f);
|
||||
if (svptest_any (svptrue_b32(), is_special_case)) {
|
||||
return exp();
|
||||
}
|
||||
const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
|
||||
const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
|
||||
const svfloat32_t c1 = svdup_n_f32(0.5f);
|
||||
const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
|
||||
|
||||
const float shift = 0x1.803f8p17f;
|
||||
|
||||
/* n = round(x/(ln2/N)). */
|
||||
svfloat32_t z = svmad_x (svptrue_b32(), inv_ln2, *this, shift);
|
||||
svfloat32_t n = svsub_x (svptrue_b32(), z, shift);
|
||||
|
||||
/* r = x - n*ln2/N. */
|
||||
svfloat32_t r = *this;
|
||||
r = svmls_x(svptrue_b32(), r, n, ln2_hi);
|
||||
r = svmls_x(svptrue_b32(), r, n, ln2_lo);
|
||||
|
||||
/* scale = 2^(n/N). */
|
||||
svfloat32_t scale = svexpa (svreinterpret_u32 (z));
|
||||
|
||||
/* poly(r) = exp(r) - 1 ~= r + 0.5 r^2. */
|
||||
svfloat32_t r2 = svmul_x (svptrue_b32 (), r, r);
|
||||
svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
|
||||
return svmla_x (svptrue_b32(), scale, scale, poly);
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp();
|
||||
return exp_u20();
|
||||
}
|
||||
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
|
||||
{
|
||||
__at_align__ float tmp[size()];
|
||||
__at_align__ float tmp_q[size()];
|
||||
store(tmp);
|
||||
q.store(tmp_q);
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<float> hypot(const Vectorized<float>& b) const {
|
||||
USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_hypotfx_u05sve(values, b)); },
|
||||
{
|
||||
__at_align__ float tmp[size()];
|
||||
__at_align__ float tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<float> i0() const {
|
||||
inline Vectorized<float> fmod(const Vectorized<float>& q) const {
|
||||
return USE_SLEEF(Sleef_fmodfx_sve(*this, q), return map2(std::fmod, q));
|
||||
}
|
||||
inline Vectorized<float> hypot(const Vectorized<float> &b) const {
|
||||
return USE_SLEEF(Sleef_hypotfx_u05sve(*this, b), map2(std::hypot, b));
|
||||
}
|
||||
inline Vectorized<float> i0() const {
|
||||
return map(calc_i0);
|
||||
}
|
||||
Vectorized<float> i0e() const {
|
||||
return map(calc_i0e);
|
||||
inline Vectorized<float> i0e() const {
|
||||
return map(calc_i0e<float>);
|
||||
}
|
||||
Vectorized<float> digamma() const {
|
||||
inline Vectorized<float> digamma() const {
|
||||
return map(calc_digamma);
|
||||
}
|
||||
Vectorized<float> igamma(const Vectorized<float>& x) const {
|
||||
__at_align__ float tmp[size()];
|
||||
__at_align__ float tmp_x[size()];
|
||||
store(tmp);
|
||||
x.store(tmp_x);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
inline Vectorized<float> igamma(const Vectorized<float> &x) const {
|
||||
return map2(calc_igamma<float>, x);
|
||||
}
|
||||
Vectorized<float> igammac(const Vectorized<float>& x) const {
|
||||
__at_align__ float tmp[size()];
|
||||
__at_align__ float tmp_x[size()];
|
||||
store(tmp);
|
||||
x.store(tmp_x);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
inline Vectorized<float> igammac(const Vectorized<float> &x) const {
|
||||
return map2(calc_igammac<float>, x);
|
||||
}
|
||||
Vectorized<float> nextafter(const Vectorized<float>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_nextafterfx_sve(values, b)); },
|
||||
{
|
||||
__at_align__ float tmp[size()];
|
||||
__at_align__ float tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} Vectorized<float> log() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_logfx_u10sve(values)), map(std::log));
|
||||
inline Vectorized<float> nextafter(const Vectorized<float> &b) const {
|
||||
return USE_SLEEF(Sleef_nextafterfx_sve(*this, b), map2(std::nextafter, b));
|
||||
}
|
||||
Vectorized<float> log2() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_log2fx_u10sve(values)), map(std::log2));
|
||||
inline Vectorized<float> log() const {
|
||||
return USE_SLEEF(Sleef_logfx_u10sve(*this), map(std::log));
|
||||
}
|
||||
Vectorized<float> log10() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_log10fx_u10sve(values)), map(std::log10));
|
||||
inline Vectorized<float> log2() const {
|
||||
return USE_SLEEF(Sleef_log2fx_u10sve(*this), map(std::log2));
|
||||
}
|
||||
Vectorized<float> log1p() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_log1pfx_u10sve(values)), map(std::log1p));
|
||||
inline Vectorized<float> log10() const {
|
||||
return USE_SLEEF(Sleef_log10fx_u10sve(*this), map(std::log10));
|
||||
}
|
||||
Vectorized<float> frac() const;
|
||||
Vectorized<float> sin() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_sinfx_u10sve(values)), map(std::sin));
|
||||
inline Vectorized<float> log1p() const {
|
||||
return USE_SLEEF(Sleef_log1pfx_u10sve(*this), map(std::log1p));
|
||||
}
|
||||
Vectorized<float> sinh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_sinhfx_u10sve(values)), map(std::sinh));
|
||||
inline Vectorized<float> frac() const;
|
||||
inline Vectorized<float> sin() const {
|
||||
return USE_SLEEF(Sleef_sinfx_u10sve(*this), map(std::sin));
|
||||
}
|
||||
Vectorized<float> cos() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_cosfx_u10sve(values)), map(std::cos));
|
||||
inline Vectorized<float> sinh() const {
|
||||
return USE_SLEEF(Sleef_sinhfx_u10sve(*this), map(std::sinh));
|
||||
}
|
||||
Vectorized<float> cosh() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_coshfx_u10sve(values)), map(std::cosh));
|
||||
inline Vectorized<float> cos() const {
|
||||
return USE_SLEEF(Sleef_cosfx_u10sve(*this), map(std::cos));
|
||||
}
|
||||
Vectorized<float> ceil() const {
|
||||
return svrintp_f32_x(ptrue, values);
|
||||
inline Vectorized<float> cosh() const {
|
||||
return USE_SLEEF(Sleef_coshfx_u10sve(*this), map(std::cosh));
|
||||
}
|
||||
Vectorized<float> floor() const {
|
||||
return svrintm_f32_x(ptrue, values);
|
||||
inline Vectorized<float> ceil() const {
|
||||
return svrintp_f32_x(ptrue, *this);
|
||||
}
|
||||
Vectorized<float> neg() const {
|
||||
return svneg_f32_x(ptrue, values);
|
||||
inline Vectorized<float> floor() const {
|
||||
return svrintm_f32_x(ptrue, *this);
|
||||
}
|
||||
Vectorized<float> round() const {
|
||||
return svrinti_f32_x(ptrue, values);
|
||||
inline Vectorized<float> neg() const {
|
||||
return svneg_f32_x(ptrue, *this);
|
||||
}
|
||||
Vectorized<float> tan() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_tanfx_u10sve(values)), map(std::tan));
|
||||
inline Vectorized<float> round() const {
|
||||
return svrinti_f32_x(ptrue, *this);
|
||||
}
|
||||
inline Vectorized<float> tan() const {
|
||||
return USE_SLEEF(Sleef_tanfx_u10sve(*this), map(std::tan));
|
||||
}
|
||||
// Implementation is picked from
|
||||
// https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L179
|
||||
Vectorized<float> tanh() const {
|
||||
inline Vectorized<float> tanh() const {
|
||||
// Constants used for the tanh calculation.
|
||||
const svfloat32_t CONST_1 =
|
||||
svdup_n_f32(1.f); // Constant 1.0f for the tanh formula.
|
||||
@ -450,7 +421,7 @@ class Vectorized<float> {
|
||||
// instability. svmax_f32_z ensures values are greater than -10, and
|
||||
// svmin_f32_z ensures they are less than 10.
|
||||
svfloat32_t x = svmin_f32_z(
|
||||
ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
|
||||
ptrue, svmax_f32_z(ptrue, *this, CONST_MIN_TANH), CONST_MAX_TANH);
|
||||
|
||||
// Step 2: Calculate exp(2 * x), where x is the clamped value.
|
||||
// svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
|
||||
@ -472,104 +443,85 @@ class Vectorized<float> {
|
||||
// Return the calculated tanh values.
|
||||
return tanh;
|
||||
}
|
||||
Vectorized<float> trunc() const {
|
||||
return svrintz_f32_x(ptrue, values);
|
||||
inline Vectorized<float> trunc() const {
|
||||
return svrintz_f32_x(ptrue, *this);
|
||||
}
|
||||
Vectorized<float> lgamma() const {
|
||||
return USE_SLEEF(
|
||||
Vectorized<float>(Sleef_lgammafx_u10sve(values)), map(std::lgamma));
|
||||
inline Vectorized<float> lgamma() const {
|
||||
return USE_SLEEF(Sleef_lgammafx_u10sve(*this), map(std::lgamma));
|
||||
}
|
||||
Vectorized<float> sqrt() const {
|
||||
return svsqrt_f32_x(ptrue, values);
|
||||
inline Vectorized<float> sqrt() const {
|
||||
return svsqrt_f32_x(ptrue, *this);
|
||||
}
|
||||
Vectorized<float> reciprocal() const {
|
||||
return svdivr_f32_x(ptrue, values, ONE_F32);
|
||||
inline Vectorized<float> reciprocal() const {
|
||||
return svdivr_f32_x(ptrue, *this, svdup_n_f32(1.f));
|
||||
}
|
||||
Vectorized<float> rsqrt() const {
|
||||
return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, values), ONE_F32);
|
||||
inline Vectorized<float> rsqrt() const {
|
||||
return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, *this), ONE_F32);
|
||||
}
|
||||
Vectorized<float> pow(const Vectorized<float>& b) const {USE_SLEEF(
|
||||
{ return Vectorized<float>(Sleef_powfx_u10sve(values, b)); },
|
||||
{
|
||||
__at_align__ float tmp[size()];
|
||||
__at_align__ float tmp_b[size()];
|
||||
store(tmp);
|
||||
b.store(tmp_b);
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
tmp[i] = std::pow(tmp[i], tmp_b[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
})} // Comparison using the _CMP_**_OQ predicate.
|
||||
// `O`: get false if an operand is NaN
|
||||
// `Q`: do not raise if an operand is NaN
|
||||
Vectorized<float> operator==(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmpeq_f32(ptrue, values, other);
|
||||
inline Vectorized<float> pow(const Vectorized<float> &b) const {
|
||||
return USE_SLEEF(Sleef_powfx_u10sve(*this, b), map(std::pow, b));
|
||||
}
|
||||
// Comparison using the _CMP_**_OQ predicate.
|
||||
// `O`: get false if an operand is NaN
|
||||
// `Q`: do not raise if an operand is NaN
|
||||
inline Vectorized<float> operator==(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmpeq_f32(ptrue, *this, other);
|
||||
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
}
|
||||
inline Vectorized<float> operator!=(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmpne_f32(ptrue, *this, other);
|
||||
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
}
|
||||
inline Vectorized<float> operator<(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmplt_f32(ptrue, *this, other);
|
||||
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
}
|
||||
|
||||
Vectorized<float> operator!=(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmpne_f32(ptrue, values, other);
|
||||
inline Vectorized<float> operator<=(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmple_f32(ptrue, *this, other);
|
||||
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
}
|
||||
|
||||
Vectorized<float> operator<(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmplt_f32(ptrue, values, other);
|
||||
inline Vectorized<float> operator>(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmpgt_f32(ptrue, *this, other);
|
||||
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
}
|
||||
|
||||
Vectorized<float> operator<=(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmple_f32(ptrue, values, other);
|
||||
inline Vectorized<float> operator>=(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmpge_f32(ptrue, *this, other);
|
||||
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
}
|
||||
|
||||
Vectorized<float> operator>(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmpgt_f32(ptrue, values, other);
|
||||
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
}
|
||||
|
||||
Vectorized<float> operator>=(const Vectorized<float>& other) const {
|
||||
svbool_t mask = svcmpge_f32(ptrue, values, other);
|
||||
return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
}
|
||||
|
||||
Vectorized<float> eq(const Vectorized<float>& other) const;
|
||||
Vectorized<float> ne(const Vectorized<float>& other) const;
|
||||
Vectorized<float> gt(const Vectorized<float>& other) const;
|
||||
Vectorized<float> ge(const Vectorized<float>& other) const;
|
||||
Vectorized<float> lt(const Vectorized<float>& other) const;
|
||||
Vectorized<float> le(const Vectorized<float>& other) const;
|
||||
inline Vectorized<float> eq(const Vectorized<float>& other) const;
|
||||
inline Vectorized<float> ne(const Vectorized<float>& other) const;
|
||||
inline Vectorized<float> gt(const Vectorized<float>& other) const;
|
||||
inline Vectorized<float> ge(const Vectorized<float>& other) const;
|
||||
inline Vectorized<float> lt(const Vectorized<float>& other) const;
|
||||
inline Vectorized<float> le(const Vectorized<float>& other) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline operator+(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
inline Vectorized<float> operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return svadd_f32_x(ptrue, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline operator-(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
inline Vectorized<float> operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return svsub_f32_x(ptrue, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline operator*(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
inline Vectorized<float> operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return svmul_f32_x(ptrue, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline operator/(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
inline Vectorized<float> operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return svdiv_f32_x(ptrue, a, b);
|
||||
}
|
||||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
Vectorized<float> inline Vectorized<float>::frac() const {
|
||||
inline Vectorized<float> Vectorized<float>::frac() const {
|
||||
return *this - this->trunc();
|
||||
}
|
||||
|
||||
@ -585,115 +537,91 @@ Vectorized<float> inline maximum(
|
||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<float> inline minimum(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
inline Vectorized<float> minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return svmin_f32_x(ptrue, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline clamp(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& min,
|
||||
const Vectorized<float>& max) {
|
||||
inline Vectorized<float> clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
|
||||
return svmin_f32_x(ptrue, max, svmax_f32_x(ptrue, min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline clamp_max(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& max) {
|
||||
inline Vectorized<float> clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
|
||||
return svmin_f32_x(ptrue, max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline clamp_min(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& min) {
|
||||
inline Vectorized<float> clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
|
||||
return svmax_f32_x(ptrue, min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline operator&(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
return svreinterpret_f32_s32(
|
||||
svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
|
||||
inline Vectorized<float> operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return svreinterpret_f32_s32(svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline operator|(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
return svreinterpret_f32_s32(
|
||||
svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
|
||||
inline Vectorized<float> operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return svreinterpret_f32_s32(svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline operator^(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
return svreinterpret_f32_s32(
|
||||
sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
|
||||
inline Vectorized<float> operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return svreinterpret_f32_s32(sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b)));
|
||||
}
|
||||
|
||||
Vectorized<float> inline Vectorized<float>::eq(
|
||||
const Vectorized<float>& other) const {
|
||||
inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const {
|
||||
return (*this == other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<float> inline Vectorized<float>::ne(
|
||||
const Vectorized<float>& other) const {
|
||||
inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const {
|
||||
return (*this != other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<float> inline Vectorized<float>::gt(
|
||||
const Vectorized<float>& other) const {
|
||||
inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const {
|
||||
return (*this > other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<float> inline Vectorized<float>::ge(
|
||||
const Vectorized<float>& other) const {
|
||||
inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const {
|
||||
return (*this >= other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<float> inline Vectorized<float>::lt(
|
||||
const Vectorized<float>& other) const {
|
||||
inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const {
|
||||
return (*this < other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<float> inline Vectorized<float>::le(
|
||||
const Vectorized<float>& other) const {
|
||||
inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const {
|
||||
return (*this <= other) & Vectorized<float>(1.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const float* src, float* dst, int64_t n) {
|
||||
const int64_t fraction = n % Vectorized<float>::size();
|
||||
const int64_t fraction = n % svcntw();
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
|
||||
for (int64_t i = 0; i < n - fraction; i += svcntw()) {
|
||||
svst1_f32(ptrue, dst + i, svldnt1_f32(ptrue, src + i));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
|
||||
for (int64_t i = n - fraction; i < n; i += svcntw()) {
|
||||
svbool_t pg = svwhilelt_b32(i, n);
|
||||
svst1_f32(pg, dst + i, svldnt1_f32(pg, src + i));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const float* src, at::Half* dst, int64_t n) {
|
||||
const int64_t fraction = n % Vectorized<float>::size();
|
||||
svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized<float>::size());
|
||||
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<float>::size());
|
||||
inline void convert(const float *src, at::Half *dst, int64_t n) {
|
||||
const int64_t fraction = n % svcntw();
|
||||
svbool_t pg_16 = svwhilelt_b16(0ull, svcntw());
|
||||
svbool_t pg_32 = svwhilelt_b32(0ull, svcntw());
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
|
||||
svfloat16_t src_vec = svuzp1_f16(
|
||||
svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16);
|
||||
for (int64_t i = 0; i < n - fraction; i += svcntw()) {
|
||||
svfloat16_t src_vec = svuzp1_f16(svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)),
|
||||
ZERO_F16);
|
||||
svst1_f16(pg_16, reinterpret_cast<float16_t*>(dst) + i, src_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
|
||||
for (int64_t i = n - fraction; i < n; i += svcntw()) {
|
||||
pg_16 = svwhilelt_b16(i, n);
|
||||
pg_32 = svwhilelt_b32(i, n);
|
||||
svfloat16_t src_vec = svuzp1_f16(
|
||||
@ -703,19 +631,18 @@ inline void convert(const float* src, at::Half* dst, int64_t n) {
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const at::Half* src, float* dst, int64_t n) {
|
||||
const int64_t fraction = n % Vectorized<float>::size();
|
||||
svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized<float>::size());
|
||||
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<float>::size());
|
||||
inline void convert(const at::Half *src, float *dst, int64_t n) {
|
||||
const int64_t fraction = n % svcntw();
|
||||
svbool_t pg_16 = svwhilelt_b16(0ull, svcntw());
|
||||
svbool_t pg_32 = svwhilelt_b32(0ull, svcntw());
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
|
||||
svfloat16_t src_vec = svzip1_f16(
|
||||
svldnt1_f16(pg_16, reinterpret_cast<const float16_t*>(src) + i),
|
||||
ZERO_F16);
|
||||
for (int64_t i = 0; i < n - fraction; i += svcntw()) {
|
||||
svfloat16_t src_vec = svzip1_f16(svldnt1_f16(pg_16, reinterpret_cast<const float16_t*>(src) + i),
|
||||
ZERO_F16);
|
||||
svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
|
||||
for (int64_t i = n - fraction; i < n; i += svcntw()) {
|
||||
pg_16 = svwhilelt_b16(i, n);
|
||||
pg_32 = svwhilelt_b32(i, n);
|
||||
svfloat16_t src_vec = svzip1_f16(
|
||||
@ -726,20 +653,19 @@ inline void convert(const at::Half* src, float* dst, int64_t n) {
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const bool* src, float* dst, int64_t n) {
|
||||
const int64_t fraction = n % Vectorized<float>::size();
|
||||
svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized<float>::size());
|
||||
svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized<float>::size());
|
||||
inline void convert(const bool *src, float *dst, int64_t n) {
|
||||
const int64_t fraction = n % svcntw();
|
||||
svbool_t pg_8 = svwhilelt_b8(0ull, svcntw());
|
||||
svbool_t pg_32 = svwhilelt_b32(0ull, svcntw());
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < n - fraction; i += Vectorized<float>::size()) {
|
||||
svuint8_t src_vec_u8 =
|
||||
svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
|
||||
for (int64_t i = 0; i < n - fraction; i += svcntw()) {
|
||||
svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast<const uint8_t*>(src) + i);
|
||||
svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8));
|
||||
svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32);
|
||||
svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int64_t i = n - fraction; i < n; i += Vectorized<float>::size()) {
|
||||
for (int64_t i = n - fraction; i < n; i += svcntw()) {
|
||||
pg_8 = svwhilelt_b8(i, n);
|
||||
pg_32 = svwhilelt_b32(i, n);
|
||||
svuint8_t src_vec_u8 =
|
||||
@ -751,10 +677,7 @@ inline void convert(const bool* src, float* dst, int64_t n) {
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<float> inline fmadd(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b,
|
||||
const Vectorized<float>& c) {
|
||||
inline Vectorized<float> fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
|
||||
return svmad_f32_x(ptrue, a, b, c);
|
||||
}
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ namespace at::vec {
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
#define VEC_INT_SVE_TEMPLATE(vl, bit) \
|
||||
template <> \
|
||||
@ -49,10 +49,11 @@ inline namespace CPU_CAPABILITY {
|
||||
operator svint##bit##_t() const { \
|
||||
return values; \
|
||||
} \
|
||||
template <uint64_t mask> \
|
||||
static Vectorized<int##bit##_t> blend( \
|
||||
const Vectorized<int##bit##_t>& a, \
|
||||
const Vectorized<int##bit##_t>& b) { \
|
||||
const Vectorized<int##bit##_t>& b, \
|
||||
uint64_t mask \
|
||||
) { \
|
||||
__at_align__ int##bit##_t flag_arr[size()]; \
|
||||
for (int i = 0; i < size(); ++i) { \
|
||||
flag_arr[i] = (i < 64 && (mask & (1ULL << i))) ? 1 : 0; \
|
||||
@ -493,7 +494,7 @@ Vectorized<int8_t> inline operator>>(
|
||||
return svasr_s8_x(ptrue, a, svreinterpret_u8_s8(b));
|
||||
}
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_SVE)
|
||||
#endif // defined(CPU_CAPABILITY_SVE256)
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
|
||||
@ -46,7 +46,7 @@ namespace at::vec {
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
// NOTE: These are low-performance implementations that we fall back on
|
||||
// if we are not building with SVE. This may not be an issue, because
|
||||
@ -100,12 +100,12 @@ struct VectorizedQuantizedConverter {
|
||||
Vectorized<float> zero_point,
|
||||
Vectorized<float> scale_zp_premul) const {
|
||||
float_vec_return_type rv;
|
||||
float tmp_scale[Vectorized<float>::size()];
|
||||
float tmp_zero_point[Vectorized<float>::size()];
|
||||
float * tmp_scale = new float[Vectorized<float>::size()];
|
||||
float * tmp_zero_point = new float[Vectorized<float>::size()];
|
||||
scale.store(tmp_scale);
|
||||
zero_point.store(tmp_zero_point);
|
||||
for (int i = 0; i < float_num_vecs(); ++i) {
|
||||
float tmp_vals[Vectorized<float>::size()];
|
||||
float * tmp_vals = new float[Vectorized<float>::size()];
|
||||
for (int j = 0; j < Vectorized<float>::size(); ++j) {
|
||||
tmp_vals[j] = at::native::dequantize_val<T>(
|
||||
tmp_scale[j],
|
||||
@ -113,7 +113,11 @@ struct VectorizedQuantizedConverter {
|
||||
T(vals[Vectorized<float>::size() * i + j]));
|
||||
}
|
||||
rv[i] = Vectorized<float>::loadu(tmp_vals);
|
||||
|
||||
delete[] tmp_vals;
|
||||
}
|
||||
delete[] tmp_scale;
|
||||
delete[] tmp_zero_point;
|
||||
return rv;
|
||||
}
|
||||
|
||||
@ -121,12 +125,12 @@ struct VectorizedQuantizedConverter {
|
||||
Vectorized<float> scale,
|
||||
Vectorized<float> zero_point) const {
|
||||
float_vec_return_type rv;
|
||||
float tmp_scale[Vectorized<float>::size()];
|
||||
float tmp_zero_point[Vectorized<float>::size()];
|
||||
float * tmp_scale = new float[Vectorized<float>::size()];
|
||||
float * tmp_zero_point = new float[Vectorized<float>::size()];
|
||||
scale.store(tmp_scale);
|
||||
zero_point.store(tmp_zero_point);
|
||||
for (int i = 0; i < float_num_vecs(); ++i) {
|
||||
float tmp_vals[Vectorized<float>::size()];
|
||||
float * tmp_vals = new float[Vectorized<float>::size()];
|
||||
for (int j = 0; j < Vectorized<float>::size(); ++j) {
|
||||
tmp_vals[j] = at::native::dequantize_val<T>(
|
||||
tmp_scale[j],
|
||||
@ -134,7 +138,10 @@ struct VectorizedQuantizedConverter {
|
||||
T(vals[Vectorized<float>::size() * i + j]));
|
||||
}
|
||||
rv[i] = Vectorized<float>::loadu(tmp_vals);
|
||||
delete[] tmp_vals;
|
||||
}
|
||||
delete[] tmp_scale;
|
||||
delete[] tmp_zero_point;
|
||||
return rv;
|
||||
}
|
||||
|
||||
@ -205,7 +212,7 @@ struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
|
||||
int32_t zero_point,
|
||||
float inverse_scale) {
|
||||
std::array<value_type, size()> qvals;
|
||||
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
|
||||
float * float_vals = new float[float_num_vecs() * Vectorized<float>::size()];
|
||||
|
||||
for (int i = 0; i < float_num_vecs(); ++i) {
|
||||
rhs[i].store(
|
||||
@ -216,10 +223,11 @@ struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
|
||||
at::native::quantize_vec<c10::qint32, /*precision=*/32>(
|
||||
scale,
|
||||
zero_point,
|
||||
float_vals.data(),
|
||||
float_vals,
|
||||
(c10::qint32*)qvals.data(),
|
||||
Vectorized<float>::size() * float_num_vecs());
|
||||
|
||||
delete[] float_vals;
|
||||
return Vectorized<c10::qint32>::loadu(qvals.data());
|
||||
}
|
||||
|
||||
@ -359,7 +367,7 @@ struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
|
||||
int32_t zero_point,
|
||||
float inverse_scale) {
|
||||
std::array<value_type, size()> qvals;
|
||||
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
|
||||
float * float_vals = new float[float_num_vecs() * Vectorized<float>::size()];
|
||||
|
||||
for (int i = 0; i < float_num_vecs(); ++i) {
|
||||
rhs[i].store(
|
||||
@ -370,10 +378,11 @@ struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
|
||||
at::native::quantize_vec<c10::qint8>(
|
||||
scale,
|
||||
zero_point,
|
||||
float_vals.data(),
|
||||
float_vals,
|
||||
(c10::qint8*)qvals.data(),
|
||||
Vectorized<float>::size() * float_num_vecs());
|
||||
|
||||
delete[] float_vals;
|
||||
return Vectorized<c10::qint8>::loadu(qvals.data());
|
||||
}
|
||||
|
||||
@ -511,7 +520,7 @@ struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
|
||||
int32_t zero_point,
|
||||
float inverse_scale) {
|
||||
std::array<value_type, size()> qvals;
|
||||
std::array<float, float_num_vecs() * Vectorized<float>::size()> float_vals;
|
||||
float * float_vals = new float[float_num_vecs() * Vectorized<float>::size()];
|
||||
|
||||
for (int i = 0; i < float_num_vecs(); ++i) {
|
||||
rhs[i].store(
|
||||
@ -522,10 +531,11 @@ struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
|
||||
at::native::quantize_vec<c10::quint8>(
|
||||
scale,
|
||||
zero_point,
|
||||
float_vals.data(),
|
||||
float_vals,
|
||||
(c10::quint8*)qvals.data(),
|
||||
Vectorized<float>::size() * float_num_vecs());
|
||||
|
||||
delete[] float_vals;
|
||||
return Vectorized<c10::quint8>::loadu(qvals.data());
|
||||
}
|
||||
|
||||
@ -600,7 +610,7 @@ Vectorized<c10::quint8> inline maximum(
|
||||
return a.maximum(b);
|
||||
}
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_SVE)
|
||||
#endif // defined(CPU_CAPABILITY_SVE256)
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
|
||||
@ -4,7 +4,9 @@
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
|
||||
#ifdef __aarch64__
|
||||
#if !defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE256)
|
||||
#include <ATen/cpu/vec/sve/vec_common_sve.h>
|
||||
#else
|
||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
|
||||
@ -241,7 +241,7 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
Vectorized() = default;
|
||||
|
||||
Vectorized(c10::BFloat16 val)
|
||||
: Vectorized16(at_vdupq_n_bf16(c10::bit_cast<at_bfloat16_t>(val.x))) {}
|
||||
: Vectorized16(at_vdupq_n_bf16(val.x)) {}
|
||||
Vectorized(float val) : Vectorized(c10::BFloat16(val)) {}
|
||||
Vectorized(
|
||||
value_type val0,
|
||||
@ -253,14 +253,14 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
value_type val6,
|
||||
value_type val7)
|
||||
: Vectorized16(at_bfloat16x8_t{
|
||||
c10::bit_cast<at_bfloat16_t>(val0.x),
|
||||
c10::bit_cast<at_bfloat16_t>(val1.x),
|
||||
c10::bit_cast<at_bfloat16_t>(val2.x),
|
||||
c10::bit_cast<at_bfloat16_t>(val3.x),
|
||||
c10::bit_cast<at_bfloat16_t>(val4.x),
|
||||
c10::bit_cast<at_bfloat16_t>(val5.x),
|
||||
c10::bit_cast<at_bfloat16_t>(val6.x),
|
||||
c10::bit_cast<at_bfloat16_t>(val7.x)}) {}
|
||||
val0.x,
|
||||
val1.x,
|
||||
val2.x,
|
||||
val3.x,
|
||||
val4.x,
|
||||
val5.x,
|
||||
val6.x,
|
||||
val7.x}) {}
|
||||
|
||||
static Vectorized<c10::BFloat16> blendv(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE))
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
||||
@ -41,32 +41,16 @@ inline namespace CPU_CAPABILITY {
|
||||
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
|
||||
#endif
|
||||
|
||||
template <int index, bool mask_val>
|
||||
template <int index>
|
||||
struct BlendRegs {
|
||||
static float32x4_t impl(
|
||||
const float32x4_t& a,
|
||||
const float32x4_t& b,
|
||||
float32x4_t& res);
|
||||
};
|
||||
|
||||
template <int index>
|
||||
struct BlendRegs<index, true> {
|
||||
static float32x4_t impl(
|
||||
const float32x4_t& a,
|
||||
const float32x4_t& b,
|
||||
float32x4_t& res) {
|
||||
return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index);
|
||||
}
|
||||
};
|
||||
|
||||
template <int index>
|
||||
struct BlendRegs<index, false> {
|
||||
static float32x4_t impl(
|
||||
const float32x4_t& a,
|
||||
const float32x4_t& b,
|
||||
float32x4_t& res) {
|
||||
return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index);
|
||||
}
|
||||
float32x4_t& res,
|
||||
bool mask_val
|
||||
) {
|
||||
return vsetq_lane_f32(vgetq_lane_f32(mask_val ? b : a, index), res, index);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -94,19 +78,15 @@ class Vectorized<float> {
|
||||
operator float32x4_t() const {
|
||||
return values;
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<float> blend(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
const Vectorized<float>& b,
|
||||
int64_t mask) {
|
||||
Vectorized<float> vec;
|
||||
vec.values = BlendRegs < 0,
|
||||
(mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values);
|
||||
vec.values = BlendRegs < 1,
|
||||
(mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values);
|
||||
vec.values = BlendRegs < 2,
|
||||
(mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values);
|
||||
vec.values = BlendRegs < 3,
|
||||
(mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values);
|
||||
vec.values = BlendRegs <0>::impl(a.values, b.values, vec.values, (mask & 0x01) != 0);
|
||||
vec.values = BlendRegs <1> ::impl(a.values, b.values, vec.values, (mask & 0x02) != 0);
|
||||
vec.values = BlendRegs <2> ::impl(a.values, b.values, vec.values, (mask & 0x04) != 0);
|
||||
vec.values = BlendRegs <3> ::impl(a.values, b.values, vec.values, (mask & 0x08) != 0);
|
||||
return vec;
|
||||
}
|
||||
static Vectorized<float> blendv(
|
||||
@ -307,11 +287,48 @@ class Vectorized<float> {
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
||||
// Implementation copied from Arm Optimized Routine https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
return exp();
|
||||
// bail out to sleef if it's a special case:
|
||||
// i.e. there's an input s.t. |input| > 87.3....
|
||||
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
|
||||
uint32x4_t cmp = vcagtq_f32 (values, special_bound);
|
||||
if (vpaddd_u64 (vreinterpretq_u64_u32 (cmp)) != 0) {
|
||||
return exp();
|
||||
}
|
||||
|
||||
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
|
||||
const float ln2_hi = 0x1.62e4p-1f;
|
||||
const float ln2_lo = 0x1.7f7d1cp-20f;
|
||||
const float c0 = 0x1.0e4020p-7f;
|
||||
const float c2 = 0x1.555e66p-3f;
|
||||
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
|
||||
|
||||
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
|
||||
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
|
||||
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
|
||||
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
|
||||
|
||||
/* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
|
||||
x = ln2*n + r, with r in [-ln2/2, ln2/2]. */
|
||||
|
||||
float32x4_t n = vrndaq_f32 (vmulq_f32 (values, inv_ln2));
|
||||
float32x4_t r = vfmsq_laneq_f32 (values, n, ln2_c02, 0);
|
||||
r = vfmsq_laneq_f32 (r, n, ln2_c02, 1);
|
||||
uint32x4_t e = vshlq_n_u32 (vreinterpretq_u32_s32 (vcvtq_s32_f32 (n)), 23);
|
||||
float32x4_t scale = vreinterpretq_f32_u32 (vaddq_u32 (e, exponent_bias));
|
||||
|
||||
float32x4_t r2 = vmulq_f32 (r, r);
|
||||
float32x4_t p = vfmaq_laneq_f32 (c1, r, ln2_c02, 2);
|
||||
float32x4_t q = vfmaq_laneq_f32 (c3, r, ln2_c02, 3);
|
||||
q = vfmaq_f32 (q, p, r2);
|
||||
p = vmulq_f32 (c4, r);
|
||||
float32x4_t poly = vfmaq_f32 (p, q, r2);
|
||||
|
||||
return vfmaq_f32 (scale, poly, scale);
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp();
|
||||
return exp_u20();
|
||||
}
|
||||
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
|
||||
fmod,
|
||||
|
||||
@ -813,11 +813,12 @@ static inline Vectorized<T> binary_op_as_fp32(
|
||||
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
|
||||
inline void load_fp32_from_##name( \
|
||||
const type* data, Vectorized<float>& out) { \
|
||||
__at_align__ float values[Vectorized<float>::size()]; \
|
||||
__at_align__ float * values = new float[Vectorized<float>::size()]; \
|
||||
for (const auto k : c10::irange(Vectorized<float>::size())) { \
|
||||
values[k] = data[k]; \
|
||||
} \
|
||||
out = Vectorized<float>::loadu(values); \
|
||||
delete[] values; \
|
||||
} \
|
||||
\
|
||||
inline void load_fp32_from_##name( \
|
||||
|
||||
@ -269,12 +269,13 @@ LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16)
|
||||
#else // defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
#if !( \
|
||||
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||
!defined(CPU_CAPABILITY_SVE256))
|
||||
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__))
|
||||
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
|
||||
#endif
|
||||
|
||||
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
|
||||
LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16)
|
||||
#endif
|
||||
#endif // defined(CPU_CAPABILITY_AVX2)
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
|
||||
@ -294,7 +294,7 @@ struct VecConvert<
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
|
||||
#if (defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)) && defined(__ARM_FEATURE_BF16)
|
||||
|
||||
template <>
|
||||
struct VecConvert<float, 1, BFloat16, 1> {
|
||||
|
||||
@ -270,7 +270,7 @@ LOAD_FP32_VECTORIZED_INIT(Half, fp16)
|
||||
|
||||
#if !( \
|
||||
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||
!defined(CPU_CAPABILITY_SVE256))
|
||||
!defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE))
|
||||
CONVERT_NON_VECTORIZED_INIT(Half, half)
|
||||
#endif
|
||||
|
||||
|
||||
@ -915,7 +915,7 @@ Vectorized<c10::quint8> inline maximum(
|
||||
return a.maximum(b);
|
||||
}
|
||||
|
||||
#elif !defined(CPU_CAPABILITY_SVE256)
|
||||
#elif !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
// NOTE: These are low-performance implementations that we fall back on
|
||||
// if we are not building with AVX2. This may not be an issue, because
|
||||
@ -1374,11 +1374,11 @@ Vectorized<c10::quint8> inline maximum(
|
||||
|
||||
#endif // if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
at::vec::Vectorized<int8_t> src) {
|
||||
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||
auto s16x8 = vmovl_s8(s8x8);
|
||||
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
|
||||
std::pair<Vectorized<float>, Vectorized<float>>
|
||||
inline convert_int8_to_float(at::vec::Vectorized<int8_t> src) {
|
||||
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||
auto s16x8 = vmovl_s8(s8x8);
|
||||
|
||||
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
|
||||
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
|
||||
|
||||
@ -292,8 +292,7 @@ class Vectorized16 {
|
||||
_mm512_mask_storeu_epi16(ptr, mask, values);
|
||||
}
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b, int64_t mask) {
|
||||
return _mm512_mask_blend_epi16(mask, a.values, b.values);
|
||||
}
|
||||
static Vectorized<T> blendv(
|
||||
|
||||
@ -69,10 +69,10 @@ class Vectorized<c10::complex<double>> {
|
||||
operator __m512d() const {
|
||||
return values;
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<c10::complex<double>> blend(
|
||||
const Vectorized<c10::complex<double>>& a,
|
||||
const Vectorized<c10::complex<double>>& b) {
|
||||
const Vectorized<c10::complex<double>>& b,
|
||||
int64_t mask) {
|
||||
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
|
||||
// NOLINTNEXTLINE(clang-diagnostic-warning)
|
||||
switch (mask) {
|
||||
|
||||
@ -89,10 +89,10 @@ class Vectorized<c10::complex<float>> {
|
||||
operator __m512() const {
|
||||
return values;
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<c10::complex<float>> blend(
|
||||
const Vectorized<c10::complex<float>>& a,
|
||||
const Vectorized<c10::complex<float>>& b) {
|
||||
const Vectorized<c10::complex<float>>& b,
|
||||
int64_t mask) {
|
||||
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
|
||||
static_assert(mask > -1 && mask < 256, "Unexpected mask value");
|
||||
// The compiler would hopefully convert this switch condition
|
||||
|
||||
@ -55,10 +55,10 @@ class Vectorized<double> {
|
||||
operator __m512d() const {
|
||||
return values;
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<double> blend(
|
||||
const Vectorized<double>& a,
|
||||
const Vectorized<double>& b) {
|
||||
const Vectorized<double>& b,
|
||||
int64_t mask) {
|
||||
return _mm512_mask_blend_pd(mask, a.values, b.values);
|
||||
}
|
||||
static Vectorized<double> blendv(
|
||||
|
||||
@ -95,10 +95,10 @@ class Vectorized<float> {
|
||||
operator __m512() const {
|
||||
return values;
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<float> blend(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
const Vectorized<float>& b,
|
||||
int64_t mask) {
|
||||
return _mm512_mask_blend_ps(mask, a.values, b.values);
|
||||
}
|
||||
static Vectorized<float> blendv(
|
||||
|
||||
@ -528,10 +528,10 @@ class Vectorized<int16_t> : public Vectorizedi {
|
||||
val2,
|
||||
val1);
|
||||
}
|
||||
template <int64_t mask>
|
||||
static Vectorized<int16_t> blend(
|
||||
Vectorized<int16_t> a,
|
||||
Vectorized<int16_t> b) {
|
||||
Vectorized<int16_t> b,
|
||||
int64_t mask) {
|
||||
return _mm512_mask_blend_epi16(mask, a.values, b.values);
|
||||
}
|
||||
static Vectorized<int16_t> blendv(
|
||||
|
||||
@ -68,7 +68,7 @@ Windows llvm will not have this definition.
|
||||
#define VECTOR_WIDTH 64
|
||||
#define int_vector __m512i
|
||||
#elif defined(__aarch64__) && \
|
||||
!defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512
|
||||
!defined(CPU_CAPABILITY_SVE) && !defined(CPU_CAPABILITY_SVE256) // CPU_CAPABILITY_AVX512
|
||||
// SVE code expects 256-vectors; leave that set for SVE?
|
||||
#if defined(__GNUC__)
|
||||
#define __at_align__ __attribute__((aligned(16)))
|
||||
@ -79,6 +79,18 @@ Windows llvm will not have this definition.
|
||||
#endif
|
||||
#define VECTOR_WIDTH 16
|
||||
#else // CPU_CAPABILITY_AVX512
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(__GNUC__)
|
||||
#define __at_align__ __attribute__((aligned(16)))
|
||||
#elif defined(_WIN32)
|
||||
#define __at_align__ __declspec(align(16))
|
||||
#else
|
||||
#define __at_align__
|
||||
#endif
|
||||
#define VECTOR_WIDTH 16
|
||||
#define int_vector __m256i
|
||||
#else // CPU_CAPABILITY_SVE256 || CPU_CAPABILITY_SVE
|
||||
#if defined(CPU_CAPABILITY_SVE256)
|
||||
#if defined(__GNUC__)
|
||||
#define __at_align__ __attribute__((aligned(32)))
|
||||
#elif defined(_WIN32)
|
||||
@ -88,6 +100,18 @@ Windows llvm will not have this definition.
|
||||
#endif
|
||||
#define VECTOR_WIDTH 32
|
||||
#define int_vector __m256i
|
||||
#else // CPU_CAPABILITY_SVE
|
||||
#if defined(__GNUC__)
|
||||
#define __at_align__ __attribute__((aligned(16)))
|
||||
#elif defined(_WIN32)
|
||||
#define __at_align__ __declspec(align(16))
|
||||
#else
|
||||
#define __at_align__
|
||||
#endif
|
||||
#define VECTOR_WIDTH 16
|
||||
#define int_vector __m256i
|
||||
#endif // CPU_CAPABILITY_SVE256
|
||||
#endif // CPU_CAPABILITY_SVE256 || CPU_CAPABILITY_SVE
|
||||
#endif // CPU_CAPABILITY_AVX512
|
||||
|
||||
namespace at::vec {
|
||||
@ -210,8 +234,7 @@ struct Vectorized {
|
||||
auto as_bytes() const -> const char* {
|
||||
return reinterpret_cast<const char*>(values);
|
||||
}
|
||||
template <int64_t mask_>
|
||||
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b, const int64_t mask_) {
|
||||
int64_t mask = mask_;
|
||||
Vectorized vector;
|
||||
for (const auto i : c10::irange(size())) {
|
||||
@ -1312,7 +1335,7 @@ std::
|
||||
T const* base_addr,
|
||||
const Vectorized<int_same_size_t<T>>& vindex,
|
||||
Vectorized<T>& mask) {
|
||||
static constexpr int size = Vectorized<T>::size();
|
||||
static const int size = Vectorized<T>::size();
|
||||
T src_arr[size];
|
||||
int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
|
||||
int_same_size_t<T> index_arr[size];
|
||||
@ -1405,7 +1428,7 @@ inline Vectorized<T> convert_to_fp_of_same_size(
|
||||
// clang-format on
|
||||
template <typename T>
|
||||
inline std::enable_if_t<
|
||||
Vectorized<T>::size() % 2 == 0,
|
||||
true,
|
||||
std::pair<Vectorized<T>, Vectorized<T>>>
|
||||
deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||
static constexpr int size = Vectorized<T>::size();
|
||||
@ -1444,7 +1467,7 @@ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2)
|
||||
// clang-format on
|
||||
template <typename T>
|
||||
inline std::enable_if_t<
|
||||
Vectorized<T>::size() % 2 == 0,
|
||||
true,
|
||||
std::pair<Vectorized<T>, Vectorized<T>>>
|
||||
interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
|
||||
static constexpr int size = Vectorized<T>::size();
|
||||
@ -1486,7 +1509,7 @@ inline void convert(const src_T* src, dst_T* dst, int64_t n) {
|
||||
|
||||
template <typename T>
|
||||
inline Vectorized<T> flip(const Vectorized<T>& data) {
|
||||
static constexpr int size = Vectorized<T>::size();
|
||||
static const int size = Vectorized<T>::size();
|
||||
T output[size];
|
||||
T buffer[size];
|
||||
data.store(static_cast<void*>(buffer));
|
||||
|
||||
@ -15,7 +15,7 @@ template <
|
||||
struct VecConvert {
|
||||
static inline VectorizedN<dst_t, dst_n> apply(
|
||||
const VectorizedN<src_t, src_n>& src) {
|
||||
constexpr int count = std::min(
|
||||
const int count = std::min(
|
||||
VectorizedN<src_t, src_n>::size(), VectorizedN<dst_t, dst_n>::size());
|
||||
__at_align__ src_t src_buf[VectorizedN<src_t, src_n>::size()];
|
||||
src.store(src_buf);
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <ATen/cpu/vec/vec_n.h>
|
||||
|
||||
#include <cassert>
|
||||
namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
@ -38,9 +40,9 @@ struct VecMaskLoad {
|
||||
static inline VectorizedN<data_t, data_n> apply(
|
||||
const data_t* ptr,
|
||||
const VecMask<mask_t, mask_n>& vec_mask) {
|
||||
constexpr typename VecMask<mask_t, mask_n>::size_type size =
|
||||
const typename VecMask<mask_t, mask_n>::size_type size =
|
||||
VecMask<mask_t, mask_n>::size();
|
||||
static_assert(VectorizedN<data_t, data_n>::size() >= size);
|
||||
assert((VectorizedN<data_t, data_n>::size() >= size));
|
||||
__at_align__ data_t data[size];
|
||||
__at_align__ mask_t mask[size];
|
||||
auto mask_ = VectorizedN<mask_t, mask_n>(vec_mask);
|
||||
@ -134,7 +136,7 @@ class VecMask {
|
||||
template <typename U, int L>
|
||||
static VecMask<T, N> from(const VectorizedN<U, L>& b_vec) {
|
||||
__at_align__ U b_buf[size()];
|
||||
if constexpr (size() >= VectorizedN<U, L>::size()) {
|
||||
if (size() >= VectorizedN<U, L>::size()) {
|
||||
b_vec.store(b_buf);
|
||||
for (int i = VectorizedN<U, L>::size(); i < size(); i++) {
|
||||
b_buf[i] = static_cast<U>(0);
|
||||
@ -235,16 +237,18 @@ class VecMask {
|
||||
template <
|
||||
typename U,
|
||||
int L,
|
||||
std::enable_if_t<L >= 2 && VectorizedN<U, L>::size() >= size(), int> = 0>
|
||||
std::enable_if_t<L >= 2, int> = 0>
|
||||
VectorizedN<U, L> loadu(const U* ptr) const {
|
||||
assert((VectorizedN<U, L>::size() >= size()));
|
||||
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
|
||||
}
|
||||
|
||||
template <
|
||||
typename U,
|
||||
int L,
|
||||
std::enable_if_t<L == 1 && Vectorized<U>::size() >= size(), int> = 0>
|
||||
std::enable_if_t<L == 1, int> = 0>
|
||||
Vectorized<U> loadu(const U* ptr) const {
|
||||
assert((Vectorized<U>::size() >= size()));
|
||||
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
|
||||
}
|
||||
};
|
||||
|
||||
@ -28,7 +28,7 @@ class VectorizedN {
|
||||
using size_type = int;
|
||||
|
||||
static constexpr size_type size_T = sizeof(T);
|
||||
static constexpr size_type size() {
|
||||
static size_type size() {
|
||||
return Vectorized<T>::size() * N;
|
||||
}
|
||||
|
||||
|
||||
@ -1157,6 +1157,7 @@ REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_SVE_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_SVE256_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl)
|
||||
@ -1164,6 +1165,7 @@ REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_SVE_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_SVE256_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel)
|
||||
@ -1171,6 +1173,7 @@ REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_SVE_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_SVE256_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel)
|
||||
@ -1178,6 +1181,7 @@ REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_SVE_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_SVE256_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel)
|
||||
@ -1185,6 +1189,7 @@ REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_SVE_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_SVE256_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl)
|
||||
@ -1192,6 +1197,7 @@ REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_SVE_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_SVE256_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel)
|
||||
@ -1199,6 +1205,7 @@ REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_SVE_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_SVE256_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel)
|
||||
@ -1206,6 +1213,7 @@ REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_SVE_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_SVE256_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel)
|
||||
@ -1213,6 +1221,7 @@ REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_SVE_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_SVE256_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel)
|
||||
@ -1220,6 +1229,7 @@ REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_SVE_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_SVE256_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel)
|
||||
@ -1227,6 +1237,7 @@ REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_SVE_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_SVE256_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel)
|
||||
@ -1234,6 +1245,7 @@ REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_SVE_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_SVE256_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel)
|
||||
@ -1241,6 +1253,7 @@ REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_SVE_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_SVE256_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel)
|
||||
@ -1248,6 +1261,7 @@ REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_SVE_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_SVE256_DISPATCH(svd_stub, &svd_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel)
|
||||
@ -1255,5 +1269,6 @@ REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
} // namespace at::native
|
||||
|
||||
@ -38,17 +38,27 @@ static CPUCapability compute_cpu_capability() {
|
||||
return CPUCapability::ZVECTOR;
|
||||
}
|
||||
#elif defined(HAVE_SVE_CPU_DEFINITION)
|
||||
int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW.
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
int sve_vl = cpuinfo_get_max_arm_sve_length(); // Returns maximum SVE VL supported by your HW.
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
if (envar == "sve256") {
|
||||
if (sve_vl == 256) {
|
||||
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
if (sve_vl == 256) {
|
||||
return CPUCapability::SVE256;
|
||||
} else if (sve_vl > 0) {
|
||||
return CPUCapability::SVE;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT");
|
||||
#endif
|
||||
TORCH_WARN("SVE capability not available on hardware. Falling back to DEFAULT");
|
||||
return CPUCapability::DEFAULT;
|
||||
} else if (envar == "sve") {
|
||||
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
|
||||
if (cpuinfo_has_arm_bf16() && sve_vl > 0) {
|
||||
return CPUCapability::SVE;
|
||||
}
|
||||
#endif
|
||||
TORCH_WARN("SVE capability not available on hardware. Falling back to DEFAULT");
|
||||
return CPUCapability::DEFAULT;
|
||||
}
|
||||
#endif
|
||||
@ -100,19 +110,15 @@ static CPUCapability compute_cpu_capability() {
|
||||
#if defined(__linux__) && defined(HAVE_SVE_CPU_DEFINITION)
|
||||
if (cpuinfo_initialize() && cpuinfo_has_arm_sve()) {
|
||||
int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW.
|
||||
if (sve_vl <= 0) {
|
||||
// SVE is not supported on this system.
|
||||
// Return the default CPU capability.
|
||||
return CPUCapability::DEFAULT;
|
||||
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
if (sve_vl == 256) { // Check for SVE256
|
||||
return CPUCapability::SVE256;
|
||||
} else if (sve_vl > 0) {
|
||||
return CPUCapability::SVE;
|
||||
}
|
||||
}
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
if (sve_vl == 256) { // Check for SVE256
|
||||
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
|
||||
if (cpuinfo_has_arm_bf16())
|
||||
return CPUCapability::SVE256;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
// Return the default CPU capability.
|
||||
return CPUCapability::DEFAULT;
|
||||
}
|
||||
@ -144,7 +150,8 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, void *ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, void *SVE
|
||||
, void *SVE256
|
||||
#endif
|
||||
) {
|
||||
@ -182,7 +189,8 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, SVE
|
||||
, SVE256
|
||||
#endif
|
||||
);
|
||||
@ -239,7 +247,8 @@ void* DispatchStubImpl::get_call_ptr(
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, void *ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, void *SVE
|
||||
, void *SVE256
|
||||
#endif
|
||||
) {
|
||||
@ -263,7 +272,9 @@ void* DispatchStubImpl::get_call_ptr(
|
||||
,
|
||||
ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
,
|
||||
SVE
|
||||
,
|
||||
SVE256
|
||||
#endif
|
||||
@ -298,7 +309,8 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, void *ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, void *SVE
|
||||
, void *SVE256
|
||||
#endif
|
||||
){
|
||||
@ -333,7 +345,7 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
|
||||
return ZVECTOR != nullptr ? DispatchResult(ZVECTOR) : ErrorType::MissingDeviceKernel;
|
||||
}
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
if (capability >= static_cast<int>(CPUCapability::SVE256)) {
|
||||
if (C10_UNLIKELY(!SVE256)) {
|
||||
// dispatch to DEFAULT, since the SVE kernel is missing
|
||||
@ -342,6 +354,14 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
|
||||
return DispatchResult(SVE256);
|
||||
}
|
||||
}
|
||||
if (capability >= static_cast<int>(CPUCapability::SVE)) {
|
||||
if (C10_UNLIKELY(!SVE)) {
|
||||
// dispatch to DEFAULT, since the SVE kernel is missing
|
||||
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
|
||||
} else {
|
||||
return DispatchResult(SVE);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
|
||||
}
|
||||
@ -360,7 +380,8 @@ void* DispatchStubImpl::choose_cpu_impl(
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, void *ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, void *SVE
|
||||
, void *SVE256
|
||||
#endif
|
||||
) {
|
||||
@ -398,7 +419,7 @@ void* DispatchStubImpl::choose_cpu_impl(
|
||||
return ZVECTOR;
|
||||
}
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
if (capability >= static_cast<int>(CPUCapability::SVE256)) {
|
||||
if (C10_UNLIKELY(!SVE256)) {
|
||||
// dispatch to DEFAULT, since the SVE kernel is missing
|
||||
@ -408,6 +429,15 @@ void* DispatchStubImpl::choose_cpu_impl(
|
||||
return SVE256;
|
||||
}
|
||||
}
|
||||
if (capability >= static_cast<int>(CPUCapability::SVE)) {
|
||||
if (C10_UNLIKELY(!SVE)) {
|
||||
// dispatch to DEFAULT, since the SVE kernel is missing
|
||||
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
|
||||
return DEFAULT;
|
||||
} else {
|
||||
return SVE;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
|
||||
return DEFAULT;
|
||||
|
||||
@ -64,8 +64,9 @@ enum class CPUCapability {
|
||||
VSX = 1,
|
||||
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
|
||||
ZVECTOR = 1,
|
||||
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
SVE256 = 1,
|
||||
#elif defined(HAVE_SVE_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
SVE=1,
|
||||
SVE256 = 2,
|
||||
#else
|
||||
AVX2 = 1,
|
||||
AVX512 = 2,
|
||||
@ -115,7 +116,8 @@ struct TORCH_API DispatchStubImpl {
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, void *ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, void *SVE
|
||||
, void *SVE256
|
||||
#endif
|
||||
);
|
||||
@ -136,7 +138,8 @@ struct TORCH_API DispatchStubImpl {
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, void *ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, void *SVE
|
||||
, void *SVE256
|
||||
#endif
|
||||
);
|
||||
@ -157,7 +160,8 @@ struct TORCH_API DispatchStubImpl {
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, void *ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, void *SVE
|
||||
, void *SVE256
|
||||
#endif
|
||||
);
|
||||
@ -181,7 +185,8 @@ struct TORCH_API DispatchStubImpl {
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, void *ZVECTOR
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, void *SVE
|
||||
, void *SVE256
|
||||
#endif
|
||||
);
|
||||
@ -238,7 +243,8 @@ private:
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, reinterpret_cast<void*>(ZVECTOR)
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, reinterpret_cast<void*>(SVE)
|
||||
, reinterpret_cast<void*>(SVE256)
|
||||
#endif
|
||||
)
|
||||
@ -299,7 +305,8 @@ public:
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
, reinterpret_cast<void*>(ZVECTOR)
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
, reinterpret_cast<void*>(SVE)
|
||||
, reinterpret_cast<void*>(SVE256)
|
||||
#endif
|
||||
);
|
||||
@ -322,7 +329,8 @@ public:
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
static TORCH_API FnPtr ZVECTOR;
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
static TORCH_API FnPtr SVE;
|
||||
static TORCH_API FnPtr SVE256;
|
||||
#endif
|
||||
private:
|
||||
@ -426,9 +434,11 @@ struct RegisterPRIVATEUSE1Dispatch {
|
||||
#define REGISTER_ZVECTOR_DISPATCH(name, fn)
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
#ifdef HAVE_SVE_CPU_DEFINITION
|
||||
#define REGISTER_SVE_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE, fn)
|
||||
#define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn)
|
||||
#else
|
||||
#define REGISTER_SVE_DISPATCH(name, fn)
|
||||
#define REGISTER_SVE256_DISPATCH(name, fn)
|
||||
#endif
|
||||
|
||||
@ -440,6 +450,7 @@ struct RegisterPRIVATEUSE1Dispatch {
|
||||
REGISTER_AVX2_DISPATCH(name, fn) \
|
||||
REGISTER_VSX_DISPATCH(name, fn) \
|
||||
REGISTER_ZVECTOR_DISPATCH(name, fn) \
|
||||
REGISTER_SVE_DISPATCH(name, fn) \
|
||||
REGISTER_SVE256_DISPATCH(name, fn)
|
||||
|
||||
#define REGISTER_NO_CPU_DISPATCH(name) \
|
||||
@ -488,6 +499,7 @@ struct RegisterPRIVATEUSE1Dispatch {
|
||||
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
||||
#endif
|
||||
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
||||
#define ALSO_REGISTER_SVE_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
||||
#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
||||
#endif
|
||||
} // namespace at::native
|
||||
|
||||
@ -466,6 +466,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cp
|
||||
REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
|
||||
// offsets dispatches
|
||||
@ -477,6 +478,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cp
|
||||
REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
|
||||
// Currently some computation is being duplicated across forward and backward.
|
||||
@ -548,6 +550,9 @@ REGISTER_VSX_DISPATCH(
|
||||
REGISTER_ZVECTOR_DISPATCH(
|
||||
_segment_reduce_lengths_backward_stub,
|
||||
&_segment_reduce_cpu_lengths_backward_kernel)
|
||||
REGISTER_SVE_DISPATCH(
|
||||
_segment_reduce_lengths_backward_stub,
|
||||
&_segment_reduce_cpu_lengths_backward_kernel)
|
||||
REGISTER_SVE256_DISPATCH(
|
||||
_segment_reduce_lengths_backward_stub,
|
||||
&_segment_reduce_cpu_lengths_backward_kernel)
|
||||
@ -568,6 +573,9 @@ REGISTER_VSX_DISPATCH(
|
||||
REGISTER_ZVECTOR_DISPATCH(
|
||||
_segment_reduce_offsets_backward_stub,
|
||||
&_segment_reduce_cpu_offsets_backward_kernel)
|
||||
REGISTER_SVE_DISPATCH(
|
||||
_segment_reduce_offsets_backward_stub,
|
||||
&_segment_reduce_cpu_offsets_backward_kernel)
|
||||
REGISTER_SVE256_DISPATCH(
|
||||
_segment_reduce_offsets_backward_stub,
|
||||
&_segment_reduce_cpu_offsets_backward_kernel)
|
||||
|
||||
@ -274,7 +274,7 @@ inline Vectorized<scalar_t> div_floor_floating_vec(
|
||||
return floordiv;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
|
||||
#if (defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE)) && defined(__ARM_FEATURE_BF16)
|
||||
|
||||
// Since sve lacks sufficient bf16 intrinsics, do the calculations in f32 to
|
||||
// avoid rounding errors. This should not cause performance issues as
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <variant>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -44,13 +45,23 @@ inline void _scale_attn_mask_fusion_kernel(
|
||||
#endif
|
||||
const auto vec_size1 = at::vec::Vectorized<T1>::size();
|
||||
const auto vec_size2 = at::vec::Vectorized<T2>::size();
|
||||
constexpr int64_t T1_n =
|
||||
const int64_t T1_n =
|
||||
(vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1;
|
||||
constexpr int64_t T2_n = 1;
|
||||
auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
|
||||
std::variant<at::vec::VectorizedN<T1, 2>, at::vec::VectorizedN<T1, 1>> vec_scale;
|
||||
if (T1_n == 2)
|
||||
vec_scale = at::vec::VectorizedN<T1, 2>(val);
|
||||
else if (T1_n == 1)
|
||||
vec_scale = at::vec::VectorizedN<T1, 1>(val);
|
||||
|
||||
int64_t i = 0;
|
||||
for (; i < size - (size % vec_size2); i += vec_size2) {
|
||||
auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
|
||||
std::variant<at::vec::VectorizedN<T1, 2>, at::vec::VectorizedN<T1, 1>> a_n;
|
||||
if (T1_n == 2)
|
||||
a_n = at::vec::VectorizedN<T1, 2>::loadu(a + i);
|
||||
else if (T1_n == 1)
|
||||
a_n = at::vec::VectorizedN<T1, 1>::loadu(a + i);
|
||||
|
||||
at::vec::VectorizedN<T2, T2_n> b_n;
|
||||
#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE)
|
||||
if (is_b_stride_zero) {
|
||||
@ -61,9 +72,16 @@ inline void _scale_attn_mask_fusion_kernel(
|
||||
} else {
|
||||
b_n = at::vec::VectorizedN<T2, T2_n>::loadu(b + i);
|
||||
}
|
||||
auto b_n_convert = at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
|
||||
auto res = a_n * vec_scale + b_n_convert;
|
||||
res.store(out + i);
|
||||
std::variant<at::vec::VectorizedN<T1, 2>, at::vec::VectorizedN<T1, 1>> b_n_convert;
|
||||
if (T1_n == 2) {
|
||||
auto b_n_convert = at::vec::convert<T1, 2, T2, T2_n, true>(b_n);
|
||||
auto res = std::get<at::vec::VectorizedN<T1, 2>>(a_n) * std::get<at::vec::VectorizedN<T1, 2>>(vec_scale) + b_n_convert;
|
||||
res.store(out + i);
|
||||
} else if(T1_n == 1) {
|
||||
auto b_n_convert = at::vec::convert<T1, 1, T2, T2_n, true>(b_n);
|
||||
auto res = std::get<at::vec::VectorizedN<T1, 1>>(a_n) * std::get<at::vec::VectorizedN<T1, 1>>(vec_scale) + b_n_convert;
|
||||
res.store(out + i);
|
||||
}
|
||||
}
|
||||
for (; i < size; i++) {
|
||||
auto tmp0 = a[i];
|
||||
|
||||
@ -694,7 +694,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bilinear,
|
||||
gx = gx * gx_mult;
|
||||
gy = gy * gy_mult;
|
||||
|
||||
constexpr int64_t step = Vec::size();
|
||||
const int64_t step = Vec::size();
|
||||
auto interleaved_gGrid = interleave2(gx, gy);
|
||||
auto gGrid_ptr = gGrid_slice.data() + offset * 2;
|
||||
std::get<0>(interleaved_gGrid).store(gGrid_ptr,
|
||||
@ -1010,7 +1010,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bicubic,
|
||||
gx = gx * gx_mult;
|
||||
gy = gy * gy_mult;
|
||||
|
||||
constexpr int64_t step = Vec::size();
|
||||
const int64_t step = Vec::size();
|
||||
auto interleaved_gGrid = interleave2(gx, gy);
|
||||
auto gGrid_ptr = gGrid_slice.data() + offset * 2;
|
||||
std::get<0>(interleaved_gGrid).store(gGrid_ptr,
|
||||
@ -1041,7 +1041,7 @@ static inline void grid_sample_2d_grid_slice_iterator(
|
||||
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
using iVec = Vectorized<int_same_size_t<scalar_t>>;
|
||||
constexpr int64_t step = Vec::size();
|
||||
const int64_t step = Vec::size();
|
||||
|
||||
// Loop over each output pixel in grid.
|
||||
// We consider the following three cases (after slicing out the batch
|
||||
|
||||
@ -19,7 +19,7 @@ Vectorized<scalar_t> is_lerp_weight_small(Vectorized<scalar_t> weight) {
|
||||
// is_lerp_weight_small doesn't work for complex because z.abs() returns a
|
||||
// complex vector which can't be compared. Either implement it with z.abs_2_(),
|
||||
// or fallback to the scalar function.
|
||||
#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) || defined(CPU_CAPABILITY_SVE))
|
||||
#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) || defined(CPU_CAPABILITY_SVE256) || defined(CPU_CAPABILITY_SVE))
|
||||
template <typename value_t>
|
||||
Vectorized<c10::complex<value_t>> is_lerp_weight_small(Vectorized<c10::complex<value_t>> weight) {
|
||||
using vec_reg_t = decltype(weight.abs_2_());
|
||||
|
||||
@ -210,13 +210,22 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve
|
||||
|
||||
Vec opt_scalar = Vec(S > 0 ? c10::load((scalar_t*)data[S]) : scalar_t(0));
|
||||
int64_t i = 0;
|
||||
for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
|
||||
int size = Vec::size();
|
||||
#if !defined(CPU_CAPABILITY_SVE) && !defined(CPU_CAPABILITY_SVE256)
|
||||
// Loop unrolling prevents compiler from optimizing the SVE classes
|
||||
for (; i <= n - 2 * size; i += 2 * size) {
|
||||
auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
|
||||
auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
|
||||
auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + size);
|
||||
auto out1 = std::apply(vop, std::move(args1));
|
||||
auto out2 = std::apply(vop, std::move(args2));
|
||||
out1.store(data[0] + i * sizeof(scalar_t));
|
||||
out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
|
||||
out2.store(data[0] + (i + size) * sizeof(scalar_t));
|
||||
}
|
||||
#endif
|
||||
for (; i <= n - size; i += size) {
|
||||
auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
|
||||
auto out1 = c10::guts::apply(vop, std::move(args1));
|
||||
out1.store(data[0] + i * sizeof(scalar_t));
|
||||
}
|
||||
if (i < n) {
|
||||
int64_t strides[ntensors];
|
||||
|
||||
@ -80,7 +80,7 @@ inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n,
|
||||
template <typename func_t, typename vec_func_t>
|
||||
inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
|
||||
VEC_LOOP_HEADER(func_t, data)
|
||||
constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
|
||||
const int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
|
||||
int64_t count = n / (4 * Vec::size());
|
||||
if (count > 0) {
|
||||
vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true);
|
||||
@ -96,7 +96,7 @@ inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_
|
||||
VEC_LOOP_HEADER(func_t, data)
|
||||
|
||||
// reduce down each column of 4 * Vec::size() elements.
|
||||
constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
|
||||
const int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
|
||||
int64_t outer_stride[2] = { vector_stride, vector_stride };
|
||||
UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
|
||||
vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false);
|
||||
|
||||
@ -154,8 +154,8 @@ inline void map_acc(
|
||||
using Vec = vec::Vectorized<scalar_t>;
|
||||
using aVec = vec::Vectorized<accumut>;
|
||||
int64_t d = 0;
|
||||
constexpr int64_t kVecSize = Vec::size();
|
||||
constexpr int64_t kaVecSize = aVec::size();
|
||||
const int64_t kVecSize = Vec::size();
|
||||
const int64_t kaVecSize = aVec::size();
|
||||
for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
|
||||
Vec data2_vec = Vec::loadu(input_data2 + d);
|
||||
auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
|
||||
|
||||
@ -22,8 +22,8 @@ inline namespace CPU_CAPABILITY {
|
||||
|
||||
constexpr auto kF32RegisterPairsPerIteration = 4;
|
||||
constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
|
||||
constexpr auto kF32ElementsPerRegister = vec::Vectorized<float>::size();
|
||||
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;
|
||||
const auto kF32ElementsPerRegister = vec::Vectorized<float>::size();
|
||||
const auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
@ -150,16 +150,16 @@ float reduce(vec::VectorizedN<float, kF32RegistersPerIteration>& x) {
|
||||
// BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path
|
||||
// working.
|
||||
#if __ARM_FEATURE_BF16_VECTOR_ARITHMETIC
|
||||
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
|
||||
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(CPU_CAPABILITY_SVE256) && defined(__clang__) && __clang_major__ > 15
|
||||
// https://godbolt.org/z/z8P4Yncra
|
||||
#define COMPILER_SUPPORTS_BF16_TARGET 1
|
||||
#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
|
||||
#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
|
||||
// https://gcc.gnu.org/gcc-10/changes.html
|
||||
// https://godbolt.org/z/cdGG7vn8o
|
||||
#define COMPILER_SUPPORTS_BF16_TARGET 1
|
||||
#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
|
||||
#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
|
||||
#define COMPILER_SUPPORTS_BF16_TARGET 0
|
||||
#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
|
||||
#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
|
||||
#else // __ARM_FEATURE_BF16_VECTOR_ARITHMETIC
|
||||
#define COMPILER_SUPPORTS_BF16_TARGET 0
|
||||
#endif // __ARM_FEATURE_BF16_VECTOR_ARITHMETIC
|
||||
@ -212,7 +212,7 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
|
||||
const vec::Vectorized<c10::Half>& b,
|
||||
const vec::Vectorized<float>& acc_low,
|
||||
const vec::Vectorized<float>& acc_high) {
|
||||
#if defined(__ARM_FEATURE_FP16_FML) && !defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(__ARM_FEATURE_FP16_FML) && !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
|
||||
return std::make_pair(vfmlalq_low_f16(acc_low, a, b), vfmlalq_high_f16(acc_high, a, b));
|
||||
#else
|
||||
const auto [a_float_low, a_float_high] = convert_half_float(a);
|
||||
|
||||
@ -28,8 +28,8 @@ inline void _update(at::opmath_type<scalar_t>* out_ptr, int64_t e, int64_t c, co
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
using Vec = vec::Vectorized<scalar_t>;
|
||||
using aVec = VecType<scalar_t>;
|
||||
constexpr int64_t kVecSize = Vec::size();
|
||||
constexpr int64_t kVLEN = kVecSize * 4;
|
||||
const int64_t kVecSize = Vec::size();
|
||||
const int64_t kVLEN = kVecSize * 4;
|
||||
|
||||
int64_t k = 0;
|
||||
aVec val_vec = aVec((opmath_t)val);
|
||||
|
||||
@ -21,11 +21,11 @@ Vectorized<acc_t> load_reduce_vec(const scalar_t* data, F reduce, acc_t ident) {
|
||||
using vacc_t = Vectorized<acc_t>;
|
||||
static_assert(vacc_t::size() <= vec_t::size());
|
||||
const auto val = vec_t::loadu(data);
|
||||
alignas(64) std::array<scalar_t, vec_t::size()> values;
|
||||
val.store(values.data());
|
||||
alignas(64) scalar_t values[vec_t::size()];
|
||||
val.store(values);
|
||||
|
||||
constexpr int vstride = vec_t::size() / vacc_t::size();
|
||||
alignas(64) std::array<acc_t, vacc_t::size()> acc;
|
||||
alignas(64) acc_t acc[vacc_t::size()];
|
||||
acc.fill(ident);
|
||||
for (const auto k : c10::irange(vstride)) {
|
||||
for (const auto i : c10::irange(vacc_t::size())) {
|
||||
@ -33,7 +33,7 @@ Vectorized<acc_t> load_reduce_vec(const scalar_t* data, F reduce, acc_t ident) {
|
||||
}
|
||||
}
|
||||
|
||||
return vacc_t::loadu(acc.data());
|
||||
return vacc_t::loadu(acc);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -138,7 +138,7 @@ struct OuterSumCastLoadPolicy <vec_t, vacc_t,
|
||||
using scalar_t = vechold_type<vec_t>;
|
||||
using acc_t = vechold_type<vacc_t>;
|
||||
|
||||
static constexpr int64_t memsize() {
|
||||
static int64_t memsize() {
|
||||
return sizeof(scalar_t) * vacc_t::size();
|
||||
}
|
||||
|
||||
@ -161,7 +161,7 @@ template <typename vec_t, typename vacc_t>
|
||||
struct OuterSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_floating_point_v<vechold_type<vec_t>>>> {
|
||||
using scalar_t = vechold_type<vec_t>;
|
||||
|
||||
static constexpr int64_t memsize() {
|
||||
static int64_t memsize() {
|
||||
return sizeof(scalar_t) * vacc_t::size();
|
||||
}
|
||||
|
||||
@ -198,7 +198,7 @@ template <typename scalar_t>
|
||||
struct NanSumLoadPolicy<Vectorized<scalar_t>> {
|
||||
using vec_t = Vectorized<scalar_t>;
|
||||
|
||||
static constexpr int64_t memsize() {
|
||||
static int64_t memsize() {
|
||||
return LoadPolicy<vec_t>::memsize();
|
||||
}
|
||||
|
||||
@ -267,7 +267,7 @@ struct InnerNanSumCastLoadPolicy <vec_t, vacc_t, std::enable_if_t<is_reduced_flo
|
||||
|
||||
template <typename vec_t, typename vacc_t>
|
||||
struct OuterNanSumCastLoadPolicy {
|
||||
static constexpr int64_t memsize() {
|
||||
static int64_t memsize() {
|
||||
return OuterSumCastLoadPolicy<vec_t, vacc_t>::memsize();
|
||||
}
|
||||
|
||||
@ -300,13 +300,23 @@ static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename StorePolicy, typename scalar_t>
|
||||
static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
|
||||
const scalar_t *values, size_t numel) {
|
||||
auto *base_ptr = data + stride * index;
|
||||
for (const auto k : c10::irange(numel)) {
|
||||
auto val = values[k];
|
||||
StorePolicy::store(base_ptr, stride, k, val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename StorePolicy, typename scalar_t>
|
||||
static void store(char * C10_RESTRICT data, int64_t stride, int64_t index,
|
||||
const Vectorized<scalar_t> &values) {
|
||||
using vec_t = Vectorized<scalar_t>;
|
||||
alignas(64) std::array<scalar_t, vec_t::size()> array_values{};
|
||||
values.store(array_values.data());
|
||||
store<StorePolicy>(data, stride, index, array_values);
|
||||
alignas(64) scalar_t array_values[vec_t::size()] = {};
|
||||
values.store(array_values);
|
||||
store<StorePolicy, scalar_t>(data, stride, index, array_values, vec_t::size());
|
||||
}
|
||||
|
||||
/** Simultaneously sum over n rows at once
|
||||
@ -436,9 +446,9 @@ void vectorized_inner_sum(
|
||||
char * C10_RESTRICT data[2], int64_t outer_stride, int64_t out_stride,
|
||||
int64_t size0, int64_t size1) {
|
||||
using vacc_t = Vectorized<acc_t>;
|
||||
constexpr int64_t vec_stride = VecLoadPolicy::memsize();
|
||||
constexpr int64_t scalar_stride = ScalarLoadPolicy::memsize();
|
||||
constexpr int64_t vec_numel = vec_stride / scalar_stride;
|
||||
const int64_t vec_stride = VecLoadPolicy::memsize();
|
||||
const int64_t scalar_stride = ScalarLoadPolicy::memsize();
|
||||
const int64_t vec_numel = vec_stride / scalar_stride;
|
||||
const int64_t vec_size = size0 / vec_numel;
|
||||
|
||||
// Input is contiguous over the first (reduced) dimension
|
||||
@ -451,9 +461,9 @@ void vectorized_inner_sum(
|
||||
final_acc += ScalarLoadPolicy::load(row_in, scalar_stride, k);
|
||||
}
|
||||
|
||||
alignas(64) std::array<acc_t, vacc_t::size()> partials{};
|
||||
vec_acc.store(partials.data());
|
||||
for (const auto k : c10::irange(partials.size())) {
|
||||
alignas(64) acc_t partials[vacc_t::size()] = {};
|
||||
vec_acc.store(partials);
|
||||
for (const auto k : c10::irange(vacc_t::size())) {
|
||||
final_acc += partials[k];
|
||||
}
|
||||
store<StorePolicy>(data[0], out_stride, j, final_acc);
|
||||
@ -479,7 +489,7 @@ void vectorized_outer_sum(
|
||||
int64_t size0, int64_t size1) {
|
||||
using vacc_t = Vectorized<acc_t>;
|
||||
constexpr int64_t scalar_stride = ScalarLoadPolicy::memsize();
|
||||
constexpr int64_t vec_stride = VecLoadPolicy::memsize();
|
||||
const int64_t vec_stride = VecLoadPolicy::memsize();
|
||||
constexpr int64_t nrows = 4;
|
||||
|
||||
// Input is contiguous over the second (non-reduced) dimension
|
||||
|
||||
@ -93,7 +93,7 @@ ColumnwiseMoments(
|
||||
int64_t C,
|
||||
int64_t D) {
|
||||
using Vec = vec::Vectorized<T>;
|
||||
constexpr int64_t K = Vec::size();
|
||||
const int64_t K = Vec::size();
|
||||
const int64_t inner_size = D / K * K;
|
||||
Vec acc0_vec{0}, acc1_vec{0};
|
||||
for (const auto m : c10::irange(HxW)) {
|
||||
@ -668,20 +668,20 @@ void GroupNormInputBackward(
|
||||
const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
|
||||
const bool gamma_null = (gamma == nullptr);
|
||||
at::parallel_for(0, N * G, 1, [=](int64_t start, int64_t end) {
|
||||
constexpr int64_t K = vec::Vectorized<PT>::size();
|
||||
const int64_t K = vec::Vectorized<PT>::size();
|
||||
const int64_t d = D / K * K;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> ds_arr;
|
||||
opmath_t ds_arr[at::vec::Vectorized<opmath_t>::size()];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> db_arr;
|
||||
opmath_t db_arr[at::vec::Vectorized<opmath_t>::size()];
|
||||
for (const auto i : c10::irange(start, end)) {
|
||||
const int64_t g = i % G;
|
||||
const opmath_t* ds_ptr = ds + i * D;
|
||||
const opmath_t* db_ptr = db + i * D;
|
||||
const PT* gamma_ptr = gamma_null ? nullptr : (gamma + g * D);
|
||||
CalcDsDb(ds_ptr, db_ptr, gamma_ptr, d, K, ds_arr.data(), db_arr.data());
|
||||
opmath_t ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), opmath_t(0));
|
||||
opmath_t db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), opmath_t(0));
|
||||
CalcDsDb(ds_ptr, db_ptr, gamma_ptr, d, K, ds_arr, db_arr);
|
||||
opmath_t ds_val = std::accumulate(&ds_arr[0], &ds_arr[at::vec::Vectorized<opmath_t>::size()], opmath_t(0));
|
||||
opmath_t db_val = std::accumulate(&db_arr[0], &db_arr[at::vec::Vectorized<opmath_t>::size()], opmath_t(0));
|
||||
for (const auto j : c10::irange(d, D)) {
|
||||
const opmath_t gamma_v = gamma_null ? opmath_t(1) : opmath_t(gamma[g * D + j]);
|
||||
ds_val += ds_ptr[j] * gamma_v;
|
||||
@ -718,7 +718,7 @@ GammaBackward(
|
||||
PT* dgamma) {
|
||||
const int64_t G = group;
|
||||
const int64_t D = C / G;
|
||||
constexpr int64_t K = at::vec::Vectorized<PT>::size();
|
||||
const int64_t K = at::vec::Vectorized<PT>::size();
|
||||
using Vec = at::vec::Vectorized<PT>;
|
||||
const int64_t inner_size = D / K * K;
|
||||
for (const auto g : c10::irange(G)) {
|
||||
@ -818,7 +818,7 @@ template <typename PT, typename opmath_t>
|
||||
std::enable_if_t<std::is_same_v<PT, opmath_t>, void>
|
||||
BetaBackward(int64_t N, int64_t C, const opmath_t* db, PT* dbeta) {
|
||||
using Vec = at::vec::Vectorized<PT>;
|
||||
constexpr int64_t K = Vec::size();
|
||||
const int64_t K = Vec::size();
|
||||
Vec acc_vec{0}, zero{0};
|
||||
const int64_t inner_size = C / K * K;
|
||||
int64_t i = 0;
|
||||
@ -943,7 +943,7 @@ DsDbRowwiseMomentsChannelsLast(
|
||||
opmath_t* db_ptr,
|
||||
int64_t C) {
|
||||
using Vec = vec::Vectorized<T>;
|
||||
constexpr int64_t K = vec::Vectorized<T>::size();
|
||||
const int64_t K = vec::Vectorized<T>::size();
|
||||
const int64_t inner_size = C / K * K;
|
||||
int64_t d = 0;
|
||||
for (; d < inner_size; d += K) {
|
||||
@ -1247,7 +1247,7 @@ inline typename std::
|
||||
int64_t D) {
|
||||
using Vec = vec::Vectorized<T>;
|
||||
const bool gamma_null = (gamma_ptr == nullptr);
|
||||
constexpr int64_t K = Vec::size();
|
||||
const int64_t K = Vec::size();
|
||||
const int64_t inner_size = D / K * K;
|
||||
int64_t d = 0;
|
||||
opmath_t ds_gamma{0}, db_gamma{0};
|
||||
|
||||
@ -625,7 +625,7 @@ void weight_to_int4pack_kernel(
|
||||
int K = weight.size(1);
|
||||
|
||||
// 64 for avx512 and 32 for avx2/non-vectorized
|
||||
constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
|
||||
const int BLOCK_N = vec::Vectorized<float>::size() * 4;
|
||||
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
|
||||
// parallel on NB blocks
|
||||
@ -713,7 +713,7 @@ void int4pack_mm_kernel_(
|
||||
|
||||
constexpr int BLOCK_M = 4;
|
||||
// 64 for avx512 and 32 for avx2/non-vectorized
|
||||
constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
|
||||
const int BLOCK_N = vec::Vectorized<float>::size() * 4;
|
||||
// 32, 64, 128, 256
|
||||
const int BLOCK_K = qGroupSize;
|
||||
|
||||
|
||||
@ -109,8 +109,8 @@ template <typename T, int64_t kMaxDepth>
|
||||
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
|
||||
using math_t = opmath_t<T>;
|
||||
|
||||
constexpr int64_t kVecSize = vec::Vectorized<T>::size();
|
||||
constexpr int64_t kAccVecSize = vec::Vectorized<math_t>::size();
|
||||
const int64_t kVecSize = vec::Vectorized<T>::size();
|
||||
const int64_t kAccVecSize = vec::Vectorized<math_t>::size();
|
||||
const int64_t n = N / kVecSize;
|
||||
const int64_t m = divup(n, kChunkSize);
|
||||
const int64_t depth = utils::CeilLog2(m);
|
||||
@ -155,10 +155,10 @@ std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, in
|
||||
m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
|
||||
}
|
||||
|
||||
std::array<math_t, kAccVecSize> m1_arr{};
|
||||
std::array<math_t, kAccVecSize> m2_arr{};
|
||||
m1_stk[0].store(m1_arr.data());
|
||||
m2_stk[0].store(m2_arr.data());
|
||||
math_t m1_arr[kAccVecSize] = {};
|
||||
math_t m2_arr[kAccVecSize] = {};
|
||||
m1_stk[0].store(m1_arr);
|
||||
m2_stk[0].store(m2_arr);
|
||||
|
||||
int64_t m0 = 0;
|
||||
math_t m1 = 0;
|
||||
@ -182,7 +182,7 @@ std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, in
|
||||
template <typename T>
|
||||
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
|
||||
using Vec = vec::Vectorized<T>;
|
||||
constexpr int64_t kVecSize = Vec::size();
|
||||
const int64_t kVecSize = Vec::size();
|
||||
const int64_t n = N / kVecSize;
|
||||
const int64_t m = divup(n, kChunkSize);
|
||||
const int64_t depth = utils::CeilLog2(m);
|
||||
|
||||
@ -165,6 +165,7 @@ REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_co
|
||||
REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
REGISTER_SVE_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
REGISTER_SVE256_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
|
||||
// _out variants can be shared between PocketFFT and MKL
|
||||
|
||||
@ -142,7 +142,7 @@ Tensor qcat_nhwc_kernel(
|
||||
continue;
|
||||
}
|
||||
|
||||
constexpr auto VLEN = Vec::size();
|
||||
const auto VLEN = Vec::size();
|
||||
int64_t c = 0;
|
||||
|
||||
// Vectorized loop
|
||||
@ -170,16 +170,16 @@ Tensor qcat_nhwc_kernel(
|
||||
}
|
||||
|
||||
// Vectorized loop for channel between 8 and 32 (avx2)
|
||||
constexpr auto kVLEN = Vectorized<float>::size();
|
||||
const auto kVLEN = Vectorized<float>::size();
|
||||
int64_t elem_size = curr_C - c;
|
||||
if ((VLEN == 4 * kVLEN) && elem_size >= kVLEN) {
|
||||
auto curr_scale_vec = Vectorized<float>(curr_scale);
|
||||
auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
|
||||
auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
|
||||
int64_t vec_num = elem_size / kVLEN;
|
||||
std::array<typename scalar_t::underlying, VLEN> buf_in{};
|
||||
memcpy(buf_in.data(), iptr + c, vec_num * kVLEN);
|
||||
auto inp_vec = Vec::loadu(buf_in.data());
|
||||
typename scalar_t::underlying buf_in[VLEN] = {};
|
||||
memcpy(buf_in, iptr + c, vec_num * kVLEN);
|
||||
auto inp_vec = Vec::loadu(buf_in);
|
||||
auto float_values = inp_vec.dequantize(
|
||||
curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
|
||||
Vec::float_vec_return_type retvals;
|
||||
@ -1487,7 +1487,7 @@ void _qmaxpool_2d_nhwc_kernel(
|
||||
int64_t c = 0;
|
||||
|
||||
// Interleaved vector loop 4x
|
||||
constexpr auto vec_width = Vectorized<scalar_t>::size();
|
||||
const auto vec_width = Vectorized<scalar_t>::size();
|
||||
for (; c + 4 * vec_width <= iC; c += 4 * vec_width) {
|
||||
Vectorized<scalar_t> acc{
|
||||
scalar_t(std::numeric_limits<scalar_t_underlying>::lowest())};
|
||||
@ -1623,7 +1623,7 @@ void qmaxpool_3d_nthwc_kernel(
|
||||
w_start += dW;
|
||||
|
||||
int64_t c = 0;
|
||||
constexpr auto vec_width = Vectorized<scalar_t>::size();
|
||||
const auto vec_width = Vectorized<scalar_t>::size();
|
||||
// Vector loop
|
||||
for (; c + vec_width <= iC; c += vec_width) {
|
||||
Vectorized<scalar_t> acc{
|
||||
@ -2449,7 +2449,7 @@ void q_batch_norm_kernel(
|
||||
reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
|
||||
scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());
|
||||
|
||||
constexpr int kVLen = Vectorized<float>::size();
|
||||
const int kVLen = Vectorized<float>::size();
|
||||
const int64_t outer_size = N * HxW;
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
// Hoisted variables
|
||||
@ -2975,7 +2975,7 @@ void quantized_normalize_kernel(
|
||||
float y_scale = Y->q_scale();
|
||||
float y_inv_scale = 1.0f / y_scale;
|
||||
|
||||
constexpr int kFloatVLen = fVec::size();
|
||||
const int kFloatVLen = fVec::size();
|
||||
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
|
||||
int64_t kNumIntVecInLayer = N / kIntVLen;
|
||||
int64_t kNonVecRemInLayer = N % kIntVLen;
|
||||
@ -3263,7 +3263,7 @@ void quantized_groupnorm_nhwc_kernel(
|
||||
float y_scale = Y->q_scale();
|
||||
float y_inv_scale = 1.0f / y_scale;
|
||||
|
||||
constexpr int kFloatVLen = fVec::size();
|
||||
const int kFloatVLen = fVec::size();
|
||||
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
|
||||
int64_t channels_per_group = C / G;
|
||||
int64_t HxW = N / channels_per_group;
|
||||
|
||||
@ -27,6 +27,7 @@ REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -161,6 +161,7 @@ REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_
|
||||
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel)
|
||||
@ -168,6 +169,7 @@ REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_interse
|
||||
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel)
|
||||
@ -175,5 +177,6 @@ REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projectio
|
||||
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
}
|
||||
|
||||
@ -448,6 +448,7 @@ REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_SVE_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta)
|
||||
|
||||
|
||||
@ -134,7 +134,7 @@ namespace {
|
||||
TYPED_TEST(Memory, UnAlignedLoadStore) {
|
||||
using vec = TypeParam;
|
||||
using VT = ValueType<TypeParam>;
|
||||
constexpr size_t b_size = vec::size() * sizeof(VT);
|
||||
const size_t b_size = vec::size() * sizeof(VT);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
CACHE_ALIGN unsigned char ref_storage[128 * b_size];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
@ -164,7 +164,7 @@ namespace {
|
||||
for (size_t offset = 0; offset < b_size; offset += 1) {
|
||||
unsigned char* p1 = ref_storage + offset;
|
||||
unsigned char* p2 = storage + offset;
|
||||
for (; p1 + b_size <= std::end(ref_storage); p1 += b_size, p2 += b_size) {
|
||||
for (; p1 + b_size <= &ref_storage[128 * b_size]; p1 += b_size, p2 += b_size) {
|
||||
vec v = vec::loadu(p1);
|
||||
v.store(p2);
|
||||
}
|
||||
@ -381,7 +381,7 @@ namespace {
|
||||
TYPED_TEST(Hyperbolic, Tanh) {
|
||||
using vec = TypeParam;
|
||||
// NOTE: Because SVE uses ACL logic, the precision changes, hence the adjusted tolerance.
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#if defined(CPU_CAPABILITY_SVE) || defined(CPU_CAPABILITY_SVE256)
|
||||
using UVT = UvalueType<vec>;
|
||||
UVT tolerance = getDefaultTolerance<UVT>();
|
||||
test_unary<vec>(
|
||||
@ -586,7 +586,7 @@ namespace {
|
||||
}
|
||||
}
|
||||
}
|
||||
#if defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16)
|
||||
#if (defined(CPU_CAPABILITY_SVE256)) && defined(__ARM_FEATURE_BF16)
|
||||
TEST(NanBfloat16, IsNan) {
|
||||
for (unsigned int ii = 0; ii < 0xFFFF; ++ii) {
|
||||
c10::BFloat16 val(ii, c10::BFloat16::from_bits());
|
||||
@ -598,6 +598,19 @@ namespace {
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if (defined(CPU_CAPABILITY_SVE)) && defined(__ARM_FEATURE_BF16)
|
||||
TEST(NanBfloat16, IsNan) {
|
||||
for (unsigned int ii = 0; ii < 0xFFFF; ++ii) {
|
||||
c10::BFloat16 val(ii, c10::BFloat16::from_bits());
|
||||
bool expected = std::isnan(val);
|
||||
CACHE_ALIGN c10::BFloat16 actual_vals[at::vec::SVE::Vectorized<c10::BFloat16>::size()];
|
||||
at::vec::SVE::Vectorized<c10::BFloat16>(val).isnan().store(actual_vals);
|
||||
for (int jj = 0; jj < at::vec::SVE::Vectorized<c10::BFloat16>::size(); ++jj) {
|
||||
EXPECT_EQ(expected, c10::bit_cast<uint16_t>(actual_vals[jj]) != 0) << "bf16 isnan failure for bit pattern " << std::hex << ii << std::dec;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
TYPED_TEST(LGamma, LGamma) {
|
||||
using vec = TypeParam;
|
||||
@ -653,7 +666,7 @@ namespace {
|
||||
TYPED_TEST(Interleave, Interleave) {
|
||||
using vec = TypeParam;
|
||||
using VT = ValueType<TypeParam>;
|
||||
constexpr auto N = vec::size() * 2LL;
|
||||
const auto N = vec::size() * 2LL;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
CACHE_ALIGN VT vals[N];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
@ -663,7 +676,7 @@ namespace {
|
||||
for (VT& v : vals) {
|
||||
v = generator.get();
|
||||
}
|
||||
copy_interleave(vals, interleaved);
|
||||
copy_interleave<VT>(vals, interleaved, N);
|
||||
auto a = vec::loadu(vals);
|
||||
auto b = vec::loadu(vals + vec::size());
|
||||
auto cc = interleave2(a, b);
|
||||
@ -673,7 +686,7 @@ namespace {
|
||||
TYPED_TEST(Interleave, DeInterleave) {
|
||||
using vec = TypeParam;
|
||||
using VT = ValueType<TypeParam>;
|
||||
constexpr auto N = vec::size() * 2LL;
|
||||
const auto N = vec::size() * 2LL;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
CACHE_ALIGN VT vals[N];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
@ -683,7 +696,7 @@ namespace {
|
||||
for (VT& v : vals) {
|
||||
v = generator.get();
|
||||
}
|
||||
copy_interleave(vals, interleaved);
|
||||
copy_interleave<VT>(vals, interleaved, N);
|
||||
// test interleaved with vals this time
|
||||
auto a = vec::loadu(interleaved);
|
||||
auto b = vec::loadu(interleaved + vec::size());
|
||||
@ -1017,78 +1030,70 @@ namespace {
|
||||
RESOLVE_OVERLOAD(filter_fmadd));
|
||||
}
|
||||
#endif
|
||||
template<typename vec, typename VT, int64_t mask>
|
||||
typename std::enable_if_t<(mask < 0 || mask> 255), void>
|
||||
template<typename vec, typename VT>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()])
|
||||
{
|
||||
void test_blend(VT * expected_val, VT * a, VT * b, int64_t mask) {
|
||||
if (mask >= 0 && mask <= 255) {
|
||||
// generate expected_val
|
||||
int64_t m = mask;
|
||||
for (int64_t i = 0; i < vec::size(); i++) {
|
||||
expected_val[i] = (m & 0x01) ? b[i] : a[i];
|
||||
m = m >> 1;
|
||||
}
|
||||
// test with blend
|
||||
auto vec_a = vec::loadu(a);
|
||||
auto vec_b = vec::loadu(b);
|
||||
auto expected = vec::loadu(expected_val);
|
||||
auto actual = vec::blend(vec_a, vec_b, mask);
|
||||
auto mask_str = std::string("\nblend mask: ") + std::to_string(mask);
|
||||
if (AssertVectorized<vec>(std::string(NAME_INFO(test_blend)) + mask_str, expected, actual).check()) return;
|
||||
test_blend<vec, VT>(expected_val, a, b, mask - 1);
|
||||
}
|
||||
}
|
||||
template<typename vec, typename VT, int64_t mask>
|
||||
typename std::enable_if_t<(mask >= 0 && mask <= 255), void>
|
||||
template<typename vec, typename VT>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
test_blend(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()]) {
|
||||
// generate expected_val
|
||||
int64_t m = mask;
|
||||
for (int64_t i = 0; i < vec::size(); i++) {
|
||||
expected_val[i] = (m & 0x01) ? b[i] : a[i];
|
||||
m = m >> 1;
|
||||
}
|
||||
// test with blend
|
||||
auto vec_a = vec::loadu(a);
|
||||
auto vec_b = vec::loadu(b);
|
||||
auto expected = vec::loadu(expected_val);
|
||||
auto actual = vec::template blend<mask>(vec_a, vec_b);
|
||||
auto mask_str = std::string("\nblend mask: ") + std::to_string(mask);
|
||||
if (AssertVectorized<vec>(std::string(NAME_INFO(test_blend)) + mask_str, expected, actual).check()) return;
|
||||
test_blend<vec, VT, mask - 1>(expected_val, a, b);
|
||||
bool test_blendv(VT * expected_val, VT * a, VT * b, VT * mask, int64_t idx, size_t N) {
|
||||
if ((size_t) idx == N) {
|
||||
using bit_rep = BitType<VT>;
|
||||
// generate expected_val
|
||||
for (int64_t i = 0; i < vec::size(); i++) {
|
||||
bit_rep hex_mask = 0;
|
||||
hex_mask=c10::bit_cast<bit_rep>(mask[i]);
|
||||
expected_val[i] = (hex_mask & 0x01) ? b[i] : a[i];
|
||||
}
|
||||
// test with blendv
|
||||
auto vec_a = vec::loadu(a);
|
||||
auto vec_b = vec::loadu(b);
|
||||
auto vec_m = vec::loadu(mask);
|
||||
auto expected = vec::loadu(expected_val);
|
||||
auto actual = vec::blendv(vec_a, vec_b, vec_m);
|
||||
auto mask_str = std::string("\nblendv mask: ");
|
||||
for (int64_t i = 0; i < vec::size(); i++) {
|
||||
mask_str += std::to_string(mask[i]) + " ";
|
||||
}
|
||||
if (AssertVectorized<vec>(std::string(NAME_INFO(test_blendv)) + mask_str, expected, actual).check()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
// shuffle mask and do blendv test
|
||||
VT m = mask[idx];
|
||||
if (!test_blendv<vec, VT>(expected_val, a, b, mask, idx+1, N)) return false;
|
||||
if (m != (VT)0) {
|
||||
mask[idx] = (VT)0;
|
||||
}
|
||||
else {
|
||||
uint64_t hex_mask = 0xFFFFFFFFFFFFFFFF;
|
||||
std::memcpy(&mask[idx], &hex_mask, sizeof(VT));
|
||||
}
|
||||
if (!test_blendv<vec, VT>(expected_val, a, b, mask, idx+1, N)) return false;
|
||||
mask[idx] = m;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
template<typename vec, typename VT, int64_t idx, int64_t N>
|
||||
std::enable_if_t<(!is_complex<VT>::value && idx == N), bool>
|
||||
template<typename T>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
test_blendv(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], VT mask[vec::size()]) {
|
||||
using bit_rep = BitType<VT>;
|
||||
// generate expected_val
|
||||
for (int64_t i = 0; i < vec::size(); i++) {
|
||||
bit_rep hex_mask = 0;
|
||||
hex_mask=c10::bit_cast<bit_rep>(mask[i]);
|
||||
expected_val[i] = (hex_mask & 0x01) ? b[i] : a[i];
|
||||
}
|
||||
// test with blendv
|
||||
auto vec_a = vec::loadu(a);
|
||||
auto vec_b = vec::loadu(b);
|
||||
auto vec_m = vec::loadu(mask);
|
||||
auto expected = vec::loadu(expected_val);
|
||||
auto actual = vec::blendv(vec_a, vec_b, vec_m);
|
||||
auto mask_str = std::string("\nblendv mask: ");
|
||||
for (int64_t i = 0; i < vec::size(); i++) {
|
||||
mask_str += std::to_string(mask[i]) + " ";
|
||||
}
|
||||
if (AssertVectorized<vec>(std::string(NAME_INFO(test_blendv)) + mask_str, expected, actual).check()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
template<typename vec, typename VT, int64_t idx, int64_t N>
|
||||
std::enable_if_t<(!is_complex<VT>::value && idx != N), bool>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
test_blendv(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], VT mask[vec::size()]) {
|
||||
// shuffle mask and do blendv test
|
||||
VT m = mask[idx];
|
||||
if (!test_blendv<vec, VT, idx+1, N>(expected_val, a, b, mask)) return false;
|
||||
if (m != (VT)0) {
|
||||
mask[idx] = (VT)0;
|
||||
}
|
||||
else {
|
||||
uint64_t hex_mask = 0xFFFFFFFFFFFFFFFF;
|
||||
std::memcpy(&mask[idx], &hex_mask, sizeof(VT));
|
||||
}
|
||||
if (!test_blendv<vec, VT, idx+1, N>(expected_val, a, b, mask)) return false;
|
||||
mask[idx] = m;
|
||||
return true;
|
||||
}
|
||||
template<typename T, int N>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
void blend_init(T(&a)[N], T(&b)[N]) {
|
||||
void blend_init(T * a, T * b, int N) {
|
||||
a[0] = (T)1.0;
|
||||
b[0] = a[0] + (T)N;
|
||||
for (const auto i : c10::irange(1, N)) {
|
||||
@ -1107,8 +1112,8 @@ namespace {
|
||||
CACHE_ALIGN VT mask[vec::size()] = {0};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
CACHE_ALIGN VT expected_val[vec::size()];
|
||||
blend_init(a, b);
|
||||
test_blendv<vec, VT, 0, vec::size()>(expected_val, a, b, mask);
|
||||
blend_init(a, b, vec::size());
|
||||
test_blendv<vec, VT>(expected_val, a, b, mask, 0, vec::size());
|
||||
}
|
||||
TYPED_TEST(BitwiseFloatsAdditional2, Blend) {
|
||||
using vec = TypeParam;
|
||||
@ -1119,9 +1124,9 @@ namespace {
|
||||
CACHE_ALIGN VT b[vec::size()];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
CACHE_ALIGN VT expected_val[vec::size()];
|
||||
blend_init(a, b);
|
||||
constexpr int64_t power_sets = 1LL << (vec::size());
|
||||
test_blend<vec, VT, power_sets - 1>(expected_val, a, b);
|
||||
blend_init(a, b, vec::size());
|
||||
const int64_t power_sets = 1LL << (vec::size());
|
||||
test_blend<vec, VT>(expected_val, a, b, power_sets - 1);
|
||||
}
|
||||
template<typename vec, typename VT>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
@ -1152,7 +1157,7 @@ namespace {
|
||||
CACHE_ALIGN VT b[vec::size()];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
CACHE_ALIGN VT expected_val[vec::size()];
|
||||
blend_init(a, b);
|
||||
blend_init(a, b, vec::size());
|
||||
test_set<vec, VT>(expected_val, a, b, vec::size());
|
||||
}
|
||||
template<typename T>
|
||||
@ -1218,7 +1223,7 @@ namespace {
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
constexpr int min_val = std::numeric_limits<underlying>::min();
|
||||
constexpr int max_val = std::numeric_limits<underlying>::max();
|
||||
constexpr int el_count = vfloat::size();
|
||||
const int el_count = vfloat::size();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
CACHE_ALIGN float unit_float_vec[el_count];
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
@ -1566,7 +1571,7 @@ namespace {
|
||||
using vec = TypeParam;
|
||||
using VT = ValueType<TypeParam>;
|
||||
constexpr auto R = 2LL; // residual
|
||||
constexpr auto N = vec::size() + R;
|
||||
const auto N = vec::size() + R;
|
||||
CACHE_ALIGN VT x1[N];
|
||||
CACHE_ALIGN VT x2[N];
|
||||
CACHE_ALIGN VT x3[N];
|
||||
@ -2130,7 +2135,7 @@ namespace {
|
||||
ASSERT_TRUE(vec_pinf.has_inf_nan()) << "Test failed for positive Infinity\n";
|
||||
ASSERT_TRUE(vec_ninf.has_inf_nan()) << "Test failed for negative Infinity\n";
|
||||
}
|
||||
#if !defined(CPU_CAPABILITY_SVE)
|
||||
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
|
||||
template <typename vec, typename dst_t>
|
||||
void test_convert_to(const char* dst_t_name) {
|
||||
using src_t = ValueType<vec>;
|
||||
@ -2213,13 +2218,13 @@ namespace {
|
||||
TYPED_TEST(VecMaskTests, MaskedLoad) {
|
||||
using vec = TypeParam;
|
||||
using src_t = ValueType<TypeParam>;
|
||||
constexpr auto size = vec::size();
|
||||
const auto size = vec::size();
|
||||
|
||||
#define TEST_MASK_LOAD(dst_t, mask_t, mask_n) \
|
||||
do { \
|
||||
constexpr int dst_size = at::vec::Vectorized<dst_t>::size(); \
|
||||
constexpr int dst_n = mask_n * size / dst_size; \
|
||||
if constexpr(dst_n * dst_size >= mask_n * size) { \
|
||||
int dst_size = at::vec::Vectorized<dst_t>::size(); \
|
||||
int dst_n = mask_n * size / dst_size; \
|
||||
if (dst_n * dst_size >= mask_n * size) { \
|
||||
CACHE_ALIGN dst_t x[mask_n * size]; \
|
||||
CACHE_ALIGN dst_t y[mask_n * size]; \
|
||||
CACHE_ALIGN dst_t ref[mask_n * size]; \
|
||||
@ -2230,9 +2235,47 @@ namespace {
|
||||
x[i] = generator.get(); \
|
||||
} \
|
||||
auto vec_mask = generate_vec_mask<mask_t, mask_n>(seed); \
|
||||
constexpr int rnd_n = (mask_n * size + dst_size - 1) / dst_size;\
|
||||
auto x_vec = vec_mask.template loadu<dst_t, rnd_n>(x); \
|
||||
x_vec.store(y); \
|
||||
int rnd_n = (mask_n * size + dst_size - 1) / dst_size;\
|
||||
switch (rnd_n) { \
|
||||
case 1: \
|
||||
{ \
|
||||
auto x_vec = vec_mask.template loadu<dst_t, 1>(x); \
|
||||
x_vec.store(y); \
|
||||
break; \
|
||||
} \
|
||||
case 2: \
|
||||
{ \
|
||||
auto x_vec = vec_mask.template loadu<dst_t, 2>(x); \
|
||||
x_vec.store(y); \
|
||||
break; \
|
||||
} \
|
||||
case 3: \
|
||||
{ \
|
||||
auto x_vec = vec_mask.template loadu<dst_t, 3>(x); \
|
||||
x_vec.store(y); \
|
||||
break; \
|
||||
} \
|
||||
case 4: \
|
||||
{ \
|
||||
auto x_vec = vec_mask.template loadu<dst_t, 4>(x); \
|
||||
x_vec.store(y); \
|
||||
break; \
|
||||
} \
|
||||
case 8: \
|
||||
{ \
|
||||
auto x_vec = vec_mask.template loadu<dst_t, 8>(x); \
|
||||
x_vec.store(y); \
|
||||
break; \
|
||||
} \
|
||||
case 16: \
|
||||
{ \
|
||||
auto x_vec = vec_mask.template loadu<dst_t, 16>(x); \
|
||||
x_vec.store(y); \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
throw std::out_of_range("Unexpected rnd_n call to vec_mask"); \
|
||||
} \
|
||||
for (const auto i : c10::irange(mask_n * size)) { \
|
||||
if (vec_mask.is_masked(i)) { \
|
||||
ref[i] = x[i]; \
|
||||
@ -2269,7 +2312,7 @@ namespace {
|
||||
#undef TEST_MASK_LOAD
|
||||
#undef TEST_MASK_LOAD_N
|
||||
}
|
||||
#if !defined(CPU_CAPABILITY_SVE)
|
||||
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
|
||||
TYPED_TEST(VecMaskTests, MaskedCheck) {
|
||||
using VT = ValueType<TypeParam>;
|
||||
using vec = TypeParam;
|
||||
@ -2294,7 +2337,7 @@ namespace {
|
||||
#undef TEST_MASK_CHECK_N
|
||||
}
|
||||
#endif
|
||||
#if !defined(CPU_CAPABILITY_SVE)
|
||||
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
|
||||
TYPED_TEST(VecMaskTests, ToFrom) {
|
||||
using vec = TypeParam;
|
||||
using VT = ValueType<TypeParam>;
|
||||
@ -2321,7 +2364,7 @@ namespace {
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if !defined(CPU_CAPABILITY_SVE)
|
||||
#if !defined(CPU_CAPABILITY_SVE256) && !defined(CPU_CAPABILITY_SVE)
|
||||
TYPED_TEST(VecMaskTests, Cast) {
|
||||
using vec = TypeParam;
|
||||
using src_t = ValueType<TypeParam>;
|
||||
|
||||
@ -56,7 +56,7 @@ CACHE_ALIGN #define
|
||||
defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__))
|
||||
#undef CHECK_DEQUANT_WITH_LOW_PRECISION
|
||||
#define CHECK_WITH_FMA 1
|
||||
#elif defined(CPU_CAPABILITY_SVE)
|
||||
#elif defined(CPU_CAPABILITY_SVE256)
|
||||
#define CHECK_DEQUANT_WITH_LOW_PRECISION 1
|
||||
#define CHECK_WITH_FMA 1
|
||||
#elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2)
|
||||
@ -136,7 +136,7 @@ template<typename T>
|
||||
struct VecTypeHelper {
|
||||
using holdType = typename T::value_type;
|
||||
using memStorageType = typename T::value_type;
|
||||
static constexpr int holdCount = T::size();
|
||||
static inline int holdCount = T::size();
|
||||
static constexpr int unitStorageCount = 1;
|
||||
};
|
||||
|
||||
@ -399,9 +399,9 @@ T clamp_min(const T& a, const T& min) {
|
||||
return a < min ? min : a;
|
||||
}
|
||||
|
||||
template <class VT, size_t N>
|
||||
void copy_interleave(VT(&vals)[N], VT(&interleaved)[N]) {
|
||||
static_assert(N % 2 == 0, "should be even");
|
||||
template <class VT>
|
||||
void copy_interleave(VT * vals, VT * interleaved, size_t N) {
|
||||
assert(N % 2 == 0);
|
||||
auto ptr1 = vals;
|
||||
auto ptr2 = vals + N / 2;
|
||||
for (size_t i = 0; i < N; i += 2) {
|
||||
@ -871,10 +871,10 @@ public:
|
||||
using UVT = UvalueType<T>;
|
||||
using BVT = BitType<UVT>;
|
||||
UVT absErr = correctEpsilon(toleranceEps);
|
||||
constexpr int sizeX = VecTypeHelper<T>::holdCount * VecTypeHelper<T>::unitStorageCount;
|
||||
const int sizeX = VecTypeHelper<T>::holdCount * VecTypeHelper<T>::unitStorageCount;
|
||||
constexpr int unitStorageCount = VecTypeHelper<T>::unitStorageCount;
|
||||
CACHE_ALIGN UVT expArr[sizeX];
|
||||
CACHE_ALIGN UVT actArr[sizeX];
|
||||
UVT expArr[sizeX];
|
||||
UVT actArr[sizeX];
|
||||
exp.store(expArr);
|
||||
act.store(actArr);
|
||||
if (bitwise)
|
||||
@ -942,7 +942,7 @@ void test_unary(
|
||||
using vec_type = T;
|
||||
using VT = ValueType<T>;
|
||||
using UVT = UvalueType<T>;
|
||||
constexpr int el_count = vec_type::size();
|
||||
const int el_count = vec_type::size();
|
||||
CACHE_ALIGN VT vals[el_count];
|
||||
CACHE_ALIGN VT expected[el_count];
|
||||
bool bitwise = testCase.isBitwise();
|
||||
@ -1000,7 +1000,7 @@ void test_binary(
|
||||
using vec_type = T;
|
||||
using VT = ValueType<T>;
|
||||
using UVT = UvalueType<T>;
|
||||
constexpr int el_count = vec_type::size();
|
||||
const int el_count = vec_type::size();
|
||||
CACHE_ALIGN VT vals0[el_count];
|
||||
CACHE_ALIGN VT vals1[el_count];
|
||||
CACHE_ALIGN VT expected[el_count];
|
||||
@ -1163,7 +1163,7 @@ void test_ternary(
|
||||
using vec_type = T;
|
||||
using VT = ValueType<T>;
|
||||
using UVT = UvalueType<T>;
|
||||
constexpr int el_count = vec_type::size();
|
||||
const int el_count = vec_type::size();
|
||||
CACHE_ALIGN VT vals0[el_count];
|
||||
CACHE_ALIGN VT vals1[el_count];
|
||||
CACHE_ALIGN VT vals2[el_count];
|
||||
@ -1203,12 +1203,15 @@ void test_ternary(
|
||||
auto input1 = vec_type::loadu(vals1);
|
||||
auto input2 = vec_type::loadu(vals2);
|
||||
auto actual = actualFunction(input0, input1, input2);
|
||||
CACHE_ALIGN VT actual_[vec_type::size()];
|
||||
actual.store(actual_);
|
||||
auto vec_expected = vec_type::loadu(expected);
|
||||
|
||||
AssertVectorized<vec_type> vecAssert(
|
||||
testNameInfo, seed, vec_expected, actual, input0, input1, input2);
|
||||
if (vecAssert.check(
|
||||
bitwise, dmn.CheckWithTolerance, dmn.ToleranceError))
|
||||
return;
|
||||
return;
|
||||
} // trial
|
||||
changeSeedBy += 1;
|
||||
}
|
||||
@ -1573,19 +1576,19 @@ double getDefaultTolerance() {
|
||||
|
||||
template<typename T, int N = 1>
|
||||
at::vec::VecMask<T, N> create_vec_mask(uint64_t bitmask) {
|
||||
constexpr auto size = at::vec::Vectorized<T>::size();
|
||||
std::array<int, N * size> mask;
|
||||
const auto size = at::vec::Vectorized<T>::size();
|
||||
int mask[N * size];
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
mask[n * size + i] = (bitmask >> i) & 1;
|
||||
}
|
||||
}
|
||||
return at::vec::VecMask<T, N>::from(mask.data());
|
||||
return at::vec::VecMask<T, N>::from(mask);
|
||||
}
|
||||
|
||||
template<typename T, int N = 1>
|
||||
at::vec::VecMask<T, N> generate_vec_mask(int seed) {
|
||||
constexpr auto size = at::vec::Vectorized<T>::size();
|
||||
const auto size = at::vec::Vectorized<T>::size();
|
||||
ValueGen<uint64_t> generator(0, (1ULL << size) - 1, seed);
|
||||
auto bitmask = generator.get();
|
||||
return create_vec_mask<T, N>(bitmask);
|
||||
|
||||
@ -393,15 +393,21 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}")
|
||||
endif(CXX_ZVECTOR_FOUND)
|
||||
|
||||
if(CXX_SVE_FOUND AND CXX_SVE256_FOUND AND CXX_ARM_BF16_FOUND)
|
||||
if(CXX_SVE_FOUND AND CXX_ARM_BF16_FOUND)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION")
|
||||
list(APPEND CPU_CAPABILITY_NAMES "SVE256")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION")
|
||||
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8.2-a+sve+bf16 -msve-vector-bits=256 -D__ARM_FEATURE_BF16")
|
||||
else()
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8.2-a+sve+bf16 -msve-vector-bits=256 -D__ARM_FEATURE_BF16")
|
||||
endif()
|
||||
endif()
|
||||
list(APPEND CPU_CAPABILITY_NAMES "SVE")
|
||||
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8.2-a+sve+bf16 -msve-vector-bits=128 -D__ARM_FEATURE_BF16")
|
||||
else()
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8.2-a+sve+bf16 -msve-vector-bits=128 -D__ARM_FEATURE_BF16")
|
||||
endif()
|
||||
endif(CXX_SVE_FOUND)
|
||||
|
||||
list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES)
|
||||
math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1")
|
||||
|
||||
@ -152,19 +152,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
ENDMACRO()
|
||||
|
||||
# Check for SVE256 vector length
|
||||
CHECK_COMPILES(CXX "SVE256" "-march=armv8.2-a+sve -msve-vector-bits=256" "${SVE_CODE}")
|
||||
CHECK_COMPILES(CXX "ARM_BF16" "-march=armv8.2-a+sve+bf16 -msve-vector-bits=256" "${ARM_BF16_CODE}")
|
||||
CHECK_COMPILES(CXX "SVE" "-march=armv8.2-a+sve" "${SVE_CODE}")
|
||||
CHECK_COMPILES(CXX "ARM_BF16" "-march=armv8.2-a+sve+bf16" "${ARM_BF16_CODE}")
|
||||
|
||||
# If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user
|
||||
if(NOT CXX_SVE256_FOUND)
|
||||
set(CXX_SVE_FOUND FALSE CACHE BOOL "SVE not available on host")
|
||||
message(STATUS "No SVE processor on this machine.")
|
||||
else()
|
||||
# If SVE256 support is found, set CXX_SVE_FOUND to TRUE and notify the user
|
||||
set(CXX_SVE_FOUND TRUE CACHE BOOL "SVE available on host")
|
||||
message(STATUS "SVE support detected.")
|
||||
endif()
|
||||
|
||||
# Mark the SVE support variable as advanced
|
||||
mark_as_advanced(CXX_SVE_FOUND)
|
||||
ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
|
||||
@ -34,6 +34,7 @@ These backends include:
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.backends.cpu.get_cpu_capability
|
||||
.. autofunction:: torch.backends.cpu.get_sve_len
|
||||
```
|
||||
|
||||
## torch.backends.cuda
|
||||
|
||||
@ -1166,6 +1166,7 @@ def _show_config() -> str: ... # THPModule_showConfig
|
||||
def _cxx_flags() -> str: ... # THPModule_cxxFlags
|
||||
def _parallel_info() -> str: ... # THPModule_parallelInfo
|
||||
def _get_cpu_capability() -> str: ... # THPModule_getCpuCapability
|
||||
def _get_sve_len() -> _int: ... # THPModule_getCpuSveLen
|
||||
def _set_backcompat_broadcast_warn(
|
||||
arg: _bool,
|
||||
) -> None: ... # THPModule_setBackcompatBroadcastWarn
|
||||
|
||||
@ -656,6 +656,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._get_constant_bool_symnode",
|
||||
"torch._C._get_cpp_backtrace",
|
||||
"torch._C._get_cpu_capability",
|
||||
"torch._C._get_sve_len",
|
||||
"torch._C._get_cublas_allow_bf16_reduced_precision_reduction",
|
||||
"torch._C._get_cublas_allow_fp16_reduced_precision_reduction",
|
||||
"torch._C._get_cublas_allow_tf32",
|
||||
|
||||
@ -173,7 +173,7 @@ class VecSVE256(VecISA):
|
||||
# this function can be repurposed for SVE with variable vec length
|
||||
_bit_width = 256
|
||||
_macro = [
|
||||
"CPU_CAPABILITY_SVE",
|
||||
"HAVE_SVE_CPU_DEFINITION",
|
||||
"CPU_CAPABILITY_SVE256",
|
||||
"AT_BUILD_ARM_VEC256_WITH_SLEEF",
|
||||
"__ARM_FEATURE_BF16",
|
||||
@ -189,6 +189,25 @@ class VecSVE256(VecISA):
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecSVE(VecISA):
|
||||
# this function can be repurposed for SVE with variable vec length
|
||||
# _bit_width = torch.backends.cpu.get_sve_len() # disable for now as it is not working
|
||||
_bit_width = 128
|
||||
_macro = [
|
||||
"HAVE_SVE_CPU_DEFINITION",
|
||||
"CPU_CAPABILITY_SVE",
|
||||
"AT_BUILD_ARM_VEC256_WITH_SLEEF",
|
||||
"__ARM_FEATURE_BF16",
|
||||
]
|
||||
_arch_flags = "-march=armv8-a+sve+bf16 -msve-vector-bits=128"
|
||||
_dtype_nelements = {torch.float: int(_bit_width / 32), torch.bfloat16: int(_bit_width / 16), torch.float16: int(_bit_width / 16)}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "asimd"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAVX512(VecISA):
|
||||
@ -404,13 +423,7 @@ def x86_isa_checker() -> list[str]:
|
||||
|
||||
|
||||
invalid_vec_isa = InvalidVecISA()
|
||||
supported_vec_isa_list = [
|
||||
VecAMX(),
|
||||
VecAVX512(),
|
||||
VecAVX2(),
|
||||
VecNEON(),
|
||||
VecSVE256(),
|
||||
]
|
||||
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE256(), VecSVE()]
|
||||
|
||||
|
||||
def get_isa_from_cpu_capability(
|
||||
@ -473,6 +486,8 @@ def valid_vec_isa_list() -> list[VecISA]:
|
||||
elif arch == "aarch64":
|
||||
if torch.backends.cpu.get_cpu_capability() == "SVE256":
|
||||
isa_list.append(VecSVE256())
|
||||
elif torch.backends.cpu.get_cpu_capability() == "SVE":
|
||||
isa_list.append(VecSVE())
|
||||
else:
|
||||
isa_list.append(VecNEON())
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
|
||||
__all__ = [
|
||||
"get_cpu_capability",
|
||||
"get_sve_len"
|
||||
]
|
||||
|
||||
|
||||
@ -16,6 +17,13 @@ def get_cpu_capability() -> str:
|
||||
- "NO AVX"
|
||||
- "AVX2"
|
||||
- "AVX512"
|
||||
- "SVE"
|
||||
- "SVE256"
|
||||
"""
|
||||
return torch._C._get_cpu_capability()
|
||||
|
||||
|
||||
def get_sve_len() -> str:
|
||||
r"""Return the maximum supported SVE length in bits.
|
||||
"""
|
||||
return torch._C._get_sve_len()
|
||||
|
||||
@ -585,6 +585,14 @@ static PyObject* THPModule_getCpuCapability(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPModule_getSveLen(
|
||||
PyObject* module,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
return THPUtils_packInt32(at::get_sve_len());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <class T>
|
||||
@ -1584,6 +1592,7 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
{"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr},
|
||||
{"_parallel_info", THPModule_parallelInfo, METH_NOARGS, nullptr},
|
||||
{"_get_cpu_capability", THPModule_getCpuCapability, METH_NOARGS, nullptr},
|
||||
{"_get_sve_len", THPModule_getSveLen, METH_NOARGS, nullptr},
|
||||
{"_set_backcompat_broadcast_warn",
|
||||
THPModule_setBackcompatBroadcastWarn,
|
||||
METH_O,
|
||||
|
||||
@ -32,9 +32,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || \
|
||||
defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || \
|
||||
defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE256)
|
||||
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE256)|| defined(CPU_CAPABILITY_SVE)
|
||||
#define INDUCTOR_USE_VECTOR_TYPES() 1
|
||||
#else
|
||||
#define INDUCTOR_USE_VECTOR_TYPES() 0
|
||||
|
||||
@ -462,6 +462,10 @@ RegisterOperators reg({
|
||||
"aten::_get_cpu_capability() -> str",
|
||||
[](Stack& stack) { push(stack, at::get_cpu_capability()); },
|
||||
aliasAnalysisConservative()),
|
||||
Operator(
|
||||
"aten::_get_sve_len() -> int",
|
||||
[](Stack& stack) { push(stack, at::get_sve_len()); },
|
||||
aliasAnalysisConservative()),
|
||||
});
|
||||
} // namespace
|
||||
} // namespace torch::jit
|
||||
|
||||
@ -113,6 +113,7 @@ _builtin_ops = [
|
||||
(torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
|
||||
(torch._C._get_tracing_state, "aten::_get_tracing_state"),
|
||||
(torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
|
||||
(torch._C._get_sve_len, "aten::_get_sve_len"),
|
||||
(warnings.warn, "aten::warn"),
|
||||
(torch._VF.stft, "aten::stft"), # type: ignore[attr-defined]
|
||||
(torch._VF.istft, "aten::istft"), # type: ignore[attr-defined]
|
||||
|
||||
Reference in New Issue
Block a user