mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Extend vec backend with BF16 SVE intrinsics (#143666)
- Following the work in https://github.com/pytorch/pytorch/pull/119571, BF16 SVE intrinsics are added to the Vectorized class, providing ~1.7x speedup on `silu` and `softmax`. - Added bf16 detection in CMake - Added a guard for native NEON code to prevent compilation errors @aditew01 @maajidkhann please have a look Pull Request resolved: https://github.com/pytorch/pytorch/pull/143666 Approved by: https://github.com/malfet, https://github.com/aditew01, https://github.com/nikhil-arm Co-authored-by: Aditya Tewari <aditya.tewari@arm.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
0c52ee1b35
commit
fcbbb03d48
@ -105,7 +105,7 @@ std::string get_cpu_capability() {
|
||||
return "DEFAULT";
|
||||
case native::CPUCapability::ZVECTOR:
|
||||
return "Z VECTOR";
|
||||
#elif defined(HAVE_SVE_CPU_DEFINITION)
|
||||
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
case native::CPUCapability::DEFAULT:
|
||||
return "DEFAULT";
|
||||
case native::CPUCapability::SVE256:
|
||||
|
@ -17,6 +17,7 @@ typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH
|
||||
typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
|
||||
typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
|
||||
typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
|
||||
typedef svbfloat16_t vls_bfloat16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
|
||||
typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
|
||||
typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
|
||||
|
||||
@ -41,6 +42,7 @@ typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDT
|
||||
#define ONE_U32 svdup_n_u32(1)
|
||||
#define ONE_U64 svdup_n_u64(1)
|
||||
#define ONE_F16 svdup_n_f16(1.f)
|
||||
#define ONE_BF16 svdup_n_bf16(1.f)
|
||||
#define ONE_F32 svdup_n_f32(1.f)
|
||||
#define ONE_F64 svdup_n_f64(1.0)
|
||||
#define ALL_S8_TRUE_MASK svdup_n_s8(0xff)
|
||||
@ -55,6 +57,8 @@ typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDT
|
||||
#define ALL_U8_FALSE_MASK svdup_n_u8(0x00)
|
||||
#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK)
|
||||
#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK)
|
||||
#define ALL_BF16_TRUE_MASK svreinterpret_bf16_s16(ALL_S16_TRUE_MASK)
|
||||
#define ALL_BF16_FALSE_MASK svreinterpret_bf16_s16(ALL_S16_FALSE_MASK)
|
||||
#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK)
|
||||
#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK)
|
||||
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK)
|
||||
|
524
aten/src/ATen/cpu/vec/sve/vec_bfloat16.h
Normal file
524
aten/src/ATen/cpu/vec/sve/vec_bfloat16.h
Normal file
@ -0,0 +1,524 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/sve/sve_helper.h>
|
||||
#include <ATen/cpu/vec/sve/vec_common_sve.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <ATen/cpu/vec/sve/vec_float.h>
|
||||
#include <cmath>
|
||||
namespace at {
|
||||
namespace vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
|
||||
|
||||
template <>
|
||||
class Vectorized<BFloat16> {
|
||||
private:
|
||||
vls_bfloat16_t values;
|
||||
|
||||
public:
|
||||
using value_type = BFloat16;
|
||||
using size_type = int;
|
||||
|
||||
static constexpr size_type size() {
|
||||
return VECTOR_WIDTH / sizeof(BFloat16);
|
||||
}
|
||||
|
||||
Vectorized() {}
|
||||
Vectorized(svbfloat16_t v) : values(v) {}
|
||||
Vectorized(int val);
|
||||
Vectorized(BFloat16 val);
|
||||
|
||||
template <
|
||||
typename... Args,
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>>
|
||||
Vectorized(Args... vals) {
|
||||
__at_align__ BFloat16 buffer[size()] = {vals...};
|
||||
values = svld1_bf16(ptrue, reinterpret_cast<const bfloat16_t*>(buffer));
|
||||
}
|
||||
|
||||
operator svbfloat16_t() const {
|
||||
return values;
|
||||
}
|
||||
static Vectorized<BFloat16> blendv(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& b, const Vectorized<BFloat16>& mask_) {
|
||||
svbool_t mask = svcmpeq_s16(ptrue, svreinterpret_s16_bf16(mask_),
|
||||
ALL_S16_TRUE_MASK);
|
||||
return svsel_bf16(mask, b, a);
|
||||
}
|
||||
template<typename step_t>
|
||||
static Vectorized<BFloat16> arange(BFloat16 base = 0.f, step_t step =
|
||||
static_cast<step_t>(1)) {
|
||||
__at_align__ BFloat16 buffer[size()];
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
buffer[i] = base + i * step;
|
||||
}
|
||||
return svld1_bf16(ptrue, reinterpret_cast<bfloat16_t *>(buffer));
|
||||
}
|
||||
static Vectorized<BFloat16> set(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& b, int64_t count = size()) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count < size()) {
|
||||
return svsel_bf16(svwhilelt_b16(0ull, count), b, a);
|
||||
}
|
||||
return b;
|
||||
}
|
||||
static Vectorized<BFloat16> loadu(const void* ptr, int64_t count = size()) {
|
||||
if (count == size())
|
||||
return svld1_bf16(ptrue, reinterpret_cast<const bfloat16_t*>(ptr));
|
||||
svbool_t pg = svwhilelt_b16(0ull, count);
|
||||
return svld1_bf16(pg, reinterpret_cast<const bfloat16_t*>(ptr));
|
||||
}
|
||||
void store(void* ptr, int64_t count = size()) const {
|
||||
__at_align__ bfloat16_t tmp[size()];
|
||||
std::memset(tmp, 0, sizeof(tmp));
|
||||
if (count == size()) {
|
||||
svst1_bf16(ptrue, reinterpret_cast<bfloat16_t*>(tmp), values);
|
||||
} else {
|
||||
svbool_t pg = svwhilelt_b16(0ull, count);
|
||||
svst1_bf16(pg, reinterpret_cast<bfloat16_t*>(tmp), values);
|
||||
}
|
||||
std::memcpy(
|
||||
reinterpret_cast<bfloat16_t*>(ptr),
|
||||
reinterpret_cast<const bfloat16_t*>(tmp),
|
||||
count * sizeof(bfloat16_t));
|
||||
}
|
||||
const BFloat16& operator[](int idx) const = delete;
|
||||
BFloat16& operator[](int idx) = delete;
|
||||
int64_t zero_mask() const {
|
||||
int64_t mask = 0;
|
||||
// returns an integer mask where all zero elements are translated to
|
||||
// 1-bit and others are translated to 0-bit int64_t mask = 0;
|
||||
__at_align__ int16_t mask_array[size()];
|
||||
|
||||
svbool_t svbool_mask = svcmpeq_f16(ptrue, svreinterpret_f16_bf16(values), ZERO_F16);
|
||||
svst1_s16(ptrue, mask_array, svsel_s16(svbool_mask,
|
||||
ALL_S16_TRUE_MASK,
|
||||
ALL_S16_FALSE_MASK));
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
if (mask_array[i]) mask |= (1ull << i);
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
Vectorized<BFloat16> isnan() const;
|
||||
bool has_inf_nan() const;
|
||||
Vectorized<BFloat16> map(BFloat16 (*f)(BFloat16)) const {
|
||||
__at_align__ BFloat16 tmp[size()];
|
||||
store(tmp);
|
||||
for (int64_t i = 0; i < size(); ++i) {
|
||||
tmp[i] = f(tmp[i]);
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
Vectorized<BFloat16> abs() const {
|
||||
auto mask = svdup_n_u16(0x7FFF);
|
||||
auto vals = svreinterpret_u16_bf16(values);
|
||||
vals = svand_u16_x(ptrue, vals, mask);
|
||||
return svreinterpret_bf16_u16(vals);
|
||||
}
|
||||
Vectorized<BFloat16> angle() const;
|
||||
Vectorized<BFloat16> real() const {
|
||||
return values;
|
||||
}
|
||||
Vectorized<BFloat16> imag() const {
|
||||
return Vectorized<BFloat16>(0.f);
|
||||
}
|
||||
Vectorized<BFloat16> conj() const {
|
||||
return values;
|
||||
}
|
||||
Vectorized<BFloat16> acos() const;
|
||||
Vectorized<BFloat16> acosh() const;
|
||||
Vectorized<BFloat16> asin() const;
|
||||
Vectorized<BFloat16> atan() const;
|
||||
Vectorized<BFloat16> atanh() const;
|
||||
Vectorized<BFloat16> atan2(const Vectorized<BFloat16> &b) const;
|
||||
Vectorized<BFloat16> copysign(const Vectorized<BFloat16> &sign) const;
|
||||
Vectorized<BFloat16> erf() const;
|
||||
Vectorized<BFloat16> erfc() const;
|
||||
Vectorized<BFloat16> erfinv() const;
|
||||
Vectorized<BFloat16> exp() const;
|
||||
Vectorized<BFloat16> exp2() const;
|
||||
Vectorized<BFloat16> expm1() const;
|
||||
Vectorized<BFloat16> exp_u20() const {
|
||||
return exp();
|
||||
}
|
||||
Vectorized<BFloat16> fmod(const Vectorized<BFloat16>& q) const;
|
||||
Vectorized<BFloat16> hypot(const Vectorized<BFloat16> &b) const;
|
||||
Vectorized<BFloat16> i0() const;
|
||||
Vectorized<BFloat16> i0e() const;
|
||||
Vectorized<BFloat16> digamma() const;
|
||||
Vectorized<BFloat16> igamma(const Vectorized<BFloat16> &x) const;
|
||||
Vectorized<BFloat16> igammac(const Vectorized<BFloat16> &x) const;
|
||||
Vectorized<BFloat16> nextafter(const Vectorized<BFloat16> &b) const;
|
||||
Vectorized<BFloat16> log() const;
|
||||
Vectorized<BFloat16> log2() const;
|
||||
Vectorized<BFloat16> log10() const;
|
||||
Vectorized<BFloat16> log1p() const;
|
||||
Vectorized<BFloat16> frac() const;
|
||||
Vectorized<BFloat16> sin() const;
|
||||
Vectorized<BFloat16> sinh() const;
|
||||
Vectorized<BFloat16> cos() const;
|
||||
Vectorized<BFloat16> cosh() const;
|
||||
Vectorized<BFloat16> ceil() const;
|
||||
Vectorized<BFloat16> floor() const;
|
||||
Vectorized<BFloat16> neg() const {
|
||||
auto mask = svdup_n_u16(0x8000);
|
||||
auto vals = svreinterpret_u16_bf16(values);
|
||||
vals = sveor_u16_x(ptrue, vals, mask);
|
||||
return svreinterpret_bf16_u16(vals);
|
||||
};
|
||||
Vectorized<BFloat16> round() const;
|
||||
Vectorized<BFloat16> tan() const;
|
||||
Vectorized<BFloat16> tanh() const;
|
||||
Vectorized<BFloat16> trunc() const;
|
||||
Vectorized<BFloat16> lgamma() const;
|
||||
Vectorized<BFloat16> sqrt() const;
|
||||
Vectorized<BFloat16> reciprocal() const;
|
||||
Vectorized<BFloat16> rsqrt() const;
|
||||
Vectorized<BFloat16> pow(const Vectorized<BFloat16> &b) const;
|
||||
// Comparison using the _CMP_**_OQ predicate.
|
||||
// `O`: get false if an operand is NaN
|
||||
// `Q`: do not raise if an operand is NaN
|
||||
Vectorized<BFloat16> operator==(const Vectorized<BFloat16>& other) const;
|
||||
|
||||
Vectorized<BFloat16> operator!=(const Vectorized<BFloat16>& other) const;
|
||||
|
||||
Vectorized<BFloat16> operator<(const Vectorized<BFloat16>& other) const;
|
||||
|
||||
Vectorized<BFloat16> operator<=(const Vectorized<BFloat16>& other) const;
|
||||
|
||||
Vectorized<BFloat16> operator>(const Vectorized<BFloat16>& other) const;
|
||||
|
||||
Vectorized<BFloat16> operator>=(const Vectorized<BFloat16>& other) const;
|
||||
|
||||
Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const;
|
||||
Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const;
|
||||
Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const;
|
||||
Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const;
|
||||
Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const;
|
||||
Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const;
|
||||
};
|
||||
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(
|
||||
const Vectorized<c10::BFloat16>& a) {
|
||||
static_assert(
|
||||
Vectorized<c10::BFloat16>::size() == 2 * Vectorized<float>::size());
|
||||
auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f));
|
||||
auto bf16_vec1 = svzip1_bf16(zero, a);
|
||||
auto bf16_vec2 = svzip2_bf16(zero, a);
|
||||
auto x1 = svreinterpret_f32_bf16(bf16_vec1);
|
||||
auto x2 = svreinterpret_f32_bf16(bf16_vec2);
|
||||
return {Vectorized<float>(x1), Vectorized<float>(x2)};
|
||||
}
|
||||
|
||||
inline Vectorized<c10::BFloat16> convert_float_bfloat16(
|
||||
const Vectorized<float>& a,
|
||||
const Vectorized<float>& b) {
|
||||
static_assert(
|
||||
Vectorized<c10::BFloat16>::size() == 2 * Vectorized<float>::size());
|
||||
svbfloat16_t x1 = svcvt_bf16_f32_z(ptrue, a);
|
||||
svbfloat16_t x2 = svcvt_bf16_f32_z(ptrue, b);
|
||||
return Vectorized<c10::BFloat16>(svuzp1_bf16(x1, x2));
|
||||
}
|
||||
|
||||
inline void load_fp32_from_bf16(const BFloat16* data, Vectorized<float>& out) {
|
||||
__at_align__ float values[Vectorized<float>::size()];
|
||||
for (const auto k : c10::irange(Vectorized<float>::size())) {
|
||||
values[k] = data[k];
|
||||
}
|
||||
out = Vectorized<float>::loadu(values);
|
||||
}
|
||||
|
||||
inline void load_fp32_from_bf16(
|
||||
const BFloat16* data,
|
||||
Vectorized<float>& out1,
|
||||
Vectorized<float>& out2) {
|
||||
Vectorized<BFloat16> bf16_vec = Vectorized<BFloat16>::loadu(data);
|
||||
auto floats = convert_bfloat16_float(bf16_vec);
|
||||
out1 = std::get<0>(floats);
|
||||
out2 = std::get<1>(floats);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
Vectorized<c10::BFloat16> binary_operator_via_float(
|
||||
Op op,
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
|
||||
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
|
||||
return convert_float_bfloat16(
|
||||
op(a_float_low, b_float_low), op(a_float_high, b_float_high));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator+(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator-(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator*(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::BFloat16> inline operator/(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
||||
}
|
||||
|
||||
inline Vectorized<BFloat16>::Vectorized(int val) {
|
||||
auto vals_f = svdup_n_f32(val);
|
||||
values = convert_float_bfloat16(vals_f, vals_f);
|
||||
}
|
||||
|
||||
|
||||
inline Vectorized<BFloat16>::Vectorized(BFloat16 val) {
|
||||
auto vals_f = svdup_n_f32((float) val);
|
||||
values = convert_float_bfloat16(vals_f, vals_f);
|
||||
}
|
||||
|
||||
bool inline Vectorized<c10::BFloat16>::has_inf_nan() const {
|
||||
auto [v1, v2] = convert_bfloat16_float(values);
|
||||
return v1.has_inf_nan() || v2.has_inf_nan();
|
||||
}
|
||||
// frac. Implement this here so we can use subtraction
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
|
||||
return *this - this->trunc();
|
||||
}
|
||||
|
||||
#define DEFINE_BF16_FUNC_VIA_FLOAT(func_name) \
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::func_name() const { \
|
||||
auto [v1, v2] = convert_bfloat16_float(*this); \
|
||||
v1 = v1.func_name(); \
|
||||
v2 = v2.func_name(); \
|
||||
return convert_float_bfloat16(v1, v2); \
|
||||
}
|
||||
|
||||
#define DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(func_name) \
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::func_name(const Vectorized<BFloat16> &a) const { \
|
||||
auto [v1, v2] = convert_bfloat16_float(*this); \
|
||||
auto [v3, v4] = convert_bfloat16_float(a); \
|
||||
v1 = v1.func_name(v3); \
|
||||
v2 = v2.func_name(v4); \
|
||||
return convert_float_bfloat16(v1, v2); \
|
||||
}
|
||||
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(round);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(const Vectorized<BFloat16>& other) const {
|
||||
auto [f1, f2] = convert_bfloat16_float(values);
|
||||
auto [f3, f4] = convert_bfloat16_float(other);
|
||||
svbool_t mask1 = svcmpeq_f32(ptrue, f1, f3);
|
||||
svbool_t mask2 = svcmpeq_f32(ptrue, f2, f4);
|
||||
auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
|
||||
auto bf16_1 = svreinterpret_bf16_f32(res1);
|
||||
auto bf16_2 = svreinterpret_bf16_f32(res2);
|
||||
return svuzp1_bf16(bf16_1, bf16_2);
|
||||
}
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator!=(const Vectorized<BFloat16>& other) const {
|
||||
auto [f1, f2] = convert_bfloat16_float(values);
|
||||
auto [f3, f4] = convert_bfloat16_float(other);
|
||||
svbool_t mask1 = svcmpne_f32(ptrue, f1, f3);
|
||||
svbool_t mask2 = svcmpne_f32(ptrue, f2, f4);
|
||||
auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
|
||||
|
||||
auto bf16_1 = svreinterpret_bf16_f32(res1);
|
||||
auto bf16_2 = svreinterpret_bf16_f32(res2);
|
||||
return svuzp1_bf16(bf16_1, bf16_2);
|
||||
}
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>(const Vectorized<BFloat16>& other) const {
|
||||
auto [v1, v2] = convert_bfloat16_float(*this);
|
||||
auto [v3, v4] = convert_bfloat16_float(other);
|
||||
return convert_float_bfloat16(v1 > v3, v2 > v4);
|
||||
}
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>=(const Vectorized<BFloat16>& other) const {
|
||||
auto [v1, v2] = convert_bfloat16_float(*this);
|
||||
auto [v3, v4] = convert_bfloat16_float(other);
|
||||
return convert_float_bfloat16(v1 >= v3, v2 >= v4);
|
||||
}
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<(const Vectorized<BFloat16>& other) const {
|
||||
auto [v1, v2] = convert_bfloat16_float(*this);
|
||||
auto [v3, v4] = convert_bfloat16_float(other);
|
||||
return convert_float_bfloat16(v1 < v3, v2 < v4);
|
||||
}
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<=(const Vectorized<BFloat16>& other) const {
|
||||
auto [v1, v2] = convert_bfloat16_float(*this);
|
||||
auto [v3, v4] = convert_bfloat16_float(other);
|
||||
return convert_float_bfloat16(v1 <= v3, v2 <= v4);
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<BFloat16> inline maximum(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& b) {
|
||||
return binary_operator_via_float(static_cast<Vectorized<float>(*)(const Vectorized<float>&, const Vectorized<float>&)>(&maximum), a, b);
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<BFloat16> inline minimum(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& b) {
|
||||
return binary_operator_via_float(static_cast<Vectorized<float>(*)(const Vectorized<float>&, const Vectorized<float>&)>(&minimum), a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<BFloat16> inline clamp_max(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& max) {
|
||||
return binary_operator_via_float(static_cast<Vectorized<float>(*)(const Vectorized<float>&, const Vectorized<float>&)>(&clamp_max), a, max);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<BFloat16> inline clamp_min(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& min) {
|
||||
return binary_operator_via_float(static_cast<Vectorized<float>(*)(const Vectorized<float>&, const Vectorized<float>&)>(&clamp_min), a, min);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<BFloat16> inline clamp(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& min, const Vectorized<BFloat16>& max) {
|
||||
return clamp_min(clamp_max(a, max), min);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<BFloat16> inline operator&(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& b) {
|
||||
return svreinterpret_bf16_u16(svand_u16_x(ptrue, svreinterpret_u16_bf16(a),
|
||||
svreinterpret_u16_bf16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<BFloat16> inline operator|(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& b) {
|
||||
return svreinterpret_bf16_u16(svorr_u16_x(ptrue, svreinterpret_u16_bf16(a),
|
||||
svreinterpret_u16_bf16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<BFloat16> inline operator^(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& b) {
|
||||
return svreinterpret_bf16_u16(sveor_u16_x(ptrue, svreinterpret_u16_bf16(a),
|
||||
svreinterpret_u16_bf16(b)));
|
||||
}
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::eq(const Vectorized<BFloat16>&
|
||||
other) const {
|
||||
return (*this == other) & Vectorized<BFloat16>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::ne(const Vectorized<BFloat16>&
|
||||
other) const {
|
||||
return (*this != other) & Vectorized<BFloat16>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::gt(const Vectorized<BFloat16>&
|
||||
other) const {
|
||||
return (*this > other) & Vectorized<BFloat16>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::ge(const Vectorized<BFloat16>&
|
||||
other) const {
|
||||
return (*this >= other) & Vectorized<BFloat16>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::lt(const Vectorized<BFloat16>&
|
||||
other) const {
|
||||
return (*this < other) & Vectorized<BFloat16>(1.0f);
|
||||
}
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::le(const Vectorized<BFloat16>&
|
||||
other) const {
|
||||
return (*this <= other) & Vectorized<BFloat16>(1.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
|
||||
const int64_t fraction = n % Vectorized<BFloat16>::size();
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < n - fraction; i += Vectorized<BFloat16>::size()) {
|
||||
svst1_bf16(ptrue, const_cast<bfloat16_t*>(reinterpret_cast<const bfloat16_t*>(dst)) + i, svldnt1_bf16(ptrue, const_cast<bfloat16_t*>(reinterpret_cast<const bfloat16_t*>(src)) + i));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int64_t i = n - fraction; i < n; i += Vectorized<BFloat16>::size()) {
|
||||
svbool_t pg = svwhilelt_b16(i, n);
|
||||
svst1_bf16(pg, const_cast<bfloat16_t*>(reinterpret_cast<const bfloat16_t*>(dst)) + i, svldnt1_bf16(pg, const_cast<bfloat16_t*>(reinterpret_cast<const bfloat16_t*>(src)) + i));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<BFloat16> inline fmadd(const Vectorized<BFloat16>& a, const
|
||||
Vectorized<BFloat16>& b, const Vectorized<BFloat16>& c) {
|
||||
return a * b + c;
|
||||
}
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16)
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace vec
|
||||
} // namespace at
|
@ -13,6 +13,7 @@
|
||||
#include <ATen/cpu/vec/sve/vec_double.h>
|
||||
#include <ATen/cpu/vec/sve/vec_int.h>
|
||||
#include <ATen/cpu/vec/sve/vec_qint.h>
|
||||
#include <ATen/cpu/vec/sve/vec_bfloat16.h>
|
||||
#endif
|
||||
|
||||
|
||||
@ -30,33 +31,29 @@ inline namespace CPU_CAPABILITY {
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
template<>
|
||||
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
|
||||
return svreinterpret_f32_f64(src);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
|
||||
return svreinterpret_f64_f32(src);
|
||||
}
|
||||
|
||||
#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \
|
||||
#define DEFINE_SVE_CAST(t1_t, t1_prefix, t2_t, t2_prefix) \
|
||||
template<> \
|
||||
inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \
|
||||
return svreinterpret_s##int_bit##_f##float_bit(src); \
|
||||
inline Vectorized<t1_t> cast<t1_t, t2_t>(const Vectorized<t2_t>& src) { \
|
||||
return svreinterpret_##t1_prefix##_##t2_prefix(src); \
|
||||
} \
|
||||
template<> \
|
||||
inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \
|
||||
return svreinterpret_f##float_bit##_s##int_bit(src); \
|
||||
inline Vectorized<t2_t> cast<t2_t, t1_t>(const Vectorized<t1_t>& src) { \
|
||||
return svreinterpret_##t2_prefix##_##t1_prefix(src); \
|
||||
}
|
||||
|
||||
DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64)
|
||||
DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64)
|
||||
DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64)
|
||||
DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32)
|
||||
DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32)
|
||||
DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32)
|
||||
DEFINE_SVE_CAST(int64_t, s64, double, f64)
|
||||
DEFINE_SVE_CAST(int32_t, s32, double, f64)
|
||||
DEFINE_SVE_CAST(int16_t, s16, double, f64)
|
||||
DEFINE_SVE_CAST(int64_t, s64, float, f32)
|
||||
DEFINE_SVE_CAST(int32_t, s32, float, f32)
|
||||
DEFINE_SVE_CAST(int16_t, s16, float, f32)
|
||||
DEFINE_SVE_CAST(float, f32, double, f64)
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16)
|
||||
DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16)
|
||||
DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16)
|
||||
#endif // __ARM_FEATURE_BF16
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -143,6 +140,21 @@ inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b
|
||||
Vectorized<float>(svzip2_f32(a, b)));
|
||||
}
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
template <>
|
||||
std::pair<Vectorized<c10::BFloat16>, Vectorized<c10::BFloat16>>
|
||||
inline interleave2<c10::BFloat16>(const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& b) {
|
||||
// inputs:
|
||||
// a = {a0, a1, a2, a3, a4, a5, a6, a7}
|
||||
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
|
||||
// group cols crossing lanes:
|
||||
// return {a0, b0, a1, b1, a2, b2, a3, b3}
|
||||
// {a4, b4, a5, b5, a6, b6, a7, b7}
|
||||
return std::make_pair(Vectorized<c10::BFloat16>(svzip1_bf16(a, b)),
|
||||
Vectorized<c10::BFloat16>(svzip2_bf16(a, b)));
|
||||
}
|
||||
#endif // __ARM_FEATURE_BF16
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
template <>
|
||||
@ -171,6 +183,21 @@ inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>&
|
||||
Vectorized<float>(svuzp2_f32(a, b)));
|
||||
}
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
template <>
|
||||
std::pair<Vectorized<c10::BFloat16>, Vectorized<c10::BFloat16>>
|
||||
inline deinterleave2<c10::BFloat16>(const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& b) {
|
||||
// inputs:
|
||||
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
|
||||
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
|
||||
// swap lanes:
|
||||
// return {a0, a1, a2, a3, a4, a5, a6, a7}
|
||||
// {b0, b1, b2, b3, b4, b5, b6, b7}
|
||||
return std::make_pair(Vectorized<c10::BFloat16>(svuzp1_bf16((svbfloat16_t) a, (svbfloat16_t) b)),
|
||||
Vectorized<c10::BFloat16>(svuzp2_bf16((svbfloat16_t) a, (svbfloat16_t) b)));
|
||||
}
|
||||
#endif // __ARM_FEATURE_BF16
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
}}
|
||||
|
@ -9,13 +9,16 @@
|
||||
#if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR))
|
||||
#if defined(CPU_CAPABILITY_SVE256)
|
||||
#include <ATen/cpu/vec/sve/vec_common_sve.h>
|
||||
#endif
|
||||
#else
|
||||
#include <ATen/cpu/vec/vec256/vec256_float.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_half.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_double.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_int.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_qint.h>
|
||||
#endif
|
||||
#if !defined(CPU_CAPABILITY_SVE256) || !defined(__ARM_FEATURE_BF16)
|
||||
#include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
|
||||
#endif
|
||||
#include <ATen/cpu/vec/vec256/vec256_half.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_complex_float.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_complex_double.h>
|
||||
#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
|
||||
|
@ -299,6 +299,46 @@ struct VecConvert<
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
|
||||
|
||||
template <>
|
||||
struct VecConvert<float, 1, BFloat16, 1> {
|
||||
static inline VectorizedN<float, 1> apply(
|
||||
const VectorizedN<BFloat16, 1>& src) {
|
||||
VectorizedN<float, 1> res;
|
||||
// Load 16-bit unsigned integers from src into an SVE vector
|
||||
svuint16_t u16x4 = svld1_u16(svptrue_b16(), reinterpret_cast<const uint16_t*>(&src[0]));
|
||||
// Zero-extend to 32-bit SVE does not have direct vmovl_u16 equivalent.
|
||||
vls_uint32_t u32x4 = svreinterpret_u32_u16(svzip1_u16(svdup_n_u16(0), u16x4));
|
||||
// Reinterpret as float32
|
||||
vls_float32_t f32x4 = svreinterpret_f32_u32(u32x4);
|
||||
res[0] = Vectorized<float>(f32x4);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<float, 2, BFloat16, 1> {
|
||||
static inline VectorizedN<float, 2> apply(
|
||||
const VectorizedN<BFloat16, 1>& src) {
|
||||
VectorizedN<float, 2> res;
|
||||
std::tie(res[0], res[1]) = convert_bfloat16_float(src[0]);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<BFloat16, 1, float, 2> {
|
||||
static inline VectorizedN<BFloat16, 1> apply(
|
||||
const VectorizedN<float, 2>& src) {
|
||||
VectorizedN<BFloat16, 1> res;
|
||||
res[0] = convert_float_bfloat16(src[0], src[1]);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
@ -1323,7 +1323,7 @@ inline Vectorized<IntType> convert_to_int_of_same_size(
|
||||
static_assert(sizeof(T) == sizeof(IntType));
|
||||
static constexpr int size = Vectorized<T>::size();
|
||||
|
||||
std::array<T, size> src_arr;
|
||||
std::array<T, size> src_arr = {};
|
||||
src.store(static_cast<void*>(src_arr.data()));
|
||||
std::array<IntType, size> buffer;
|
||||
std::transform(
|
||||
|
@ -41,7 +41,11 @@ static CPUCapability compute_cpu_capability() {
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
if (strcmp(envar, "sve256") == 0) {
|
||||
if (sve_vl == 256) {
|
||||
return CPUCapability::SVE256;
|
||||
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
return CPUCapability::SVE256;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT");
|
||||
return CPUCapability::DEFAULT;
|
||||
@ -102,7 +106,10 @@ static CPUCapability compute_cpu_capability() {
|
||||
}
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
if (sve_vl == 256) { // Check for SVE256
|
||||
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
|
||||
if (cpuinfo_has_arm_bf16())
|
||||
return CPUCapability::SVE256;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
// Return the default CPU capability.
|
||||
|
@ -64,7 +64,7 @@ enum class CPUCapability {
|
||||
VSX = 1,
|
||||
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
|
||||
ZVECTOR = 1,
|
||||
#elif defined(HAVE_SVE_CPU_DEFINITION)
|
||||
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
SVE256 = 1,
|
||||
#else
|
||||
AVX2 = 1,
|
||||
|
@ -274,6 +274,26 @@ inline Vectorized<scalar_t> div_floor_floating_vec(
|
||||
return floordiv;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
|
||||
|
||||
// Since sve lacks sufficient bf16 intrinsics, do the calculations in f32 to
|
||||
// avoid rounding errors. This should not cause performance issues as
|
||||
// most of the used instructions would be cast to f32 vectors anyway.
|
||||
template<>
|
||||
inline Vectorized<c10::BFloat16> div_floor_floating_vec(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b) {
|
||||
auto [a1, a2] = convert_bfloat16_float(a);
|
||||
auto [b1, b2] = convert_bfloat16_float(b);
|
||||
|
||||
auto res1 = div_floor_floating_vec(a1, b1);
|
||||
auto res2 = div_floor_floating_vec(a2, b2);
|
||||
|
||||
return convert_float_bfloat16(res1, res2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void div_floor_kernel(TensorIteratorBase& iter) {
|
||||
const auto dtype = iter.common_dtype();
|
||||
if (dtype == kByte) {
|
||||
|
@ -237,7 +237,7 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
|
||||
|
||||
// Return a + b_low * c_low + b_high * c_high
|
||||
vec::Vectorized<float> fmadd(vec::Vectorized<float> a, vec::Vectorized<Half> b, vec::Vectorized<Half> c) {
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML)
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)
|
||||
// NOTE: this instruction is an optional instruction in ARM v8.2 and
|
||||
// v8.3, but mandatory in v8.4 per
|
||||
// https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en
|
||||
|
@ -582,6 +582,19 @@ namespace {
|
||||
}
|
||||
}
|
||||
}
|
||||
#if defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16)
|
||||
TEST(NanBfloat16, IsNan) {
|
||||
for (unsigned int ii = 0; ii < 0xFFFF; ++ii) {
|
||||
c10::BFloat16 val(ii, c10::BFloat16::from_bits());
|
||||
bool expected = std::isnan(val);
|
||||
CACHE_ALIGN c10::BFloat16 actual_vals[at::vec::SVE256::Vectorized<c10::BFloat16>::size()];
|
||||
at::vec::SVE256::Vectorized<c10::BFloat16>(val).isnan().store(actual_vals);
|
||||
for (int jj = 0; jj < at::vec::SVE256::Vectorized<c10::BFloat16>::size(); ++jj) {
|
||||
EXPECT_EQ(expected, c10::bit_cast<uint16_t>(actual_vals[jj]) != 0) << "bf16 isnan failure for bit pattern " << std::hex << ii << std::dec;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
TYPED_TEST(LGamma, LGamma) {
|
||||
using vec = TypeParam;
|
||||
using UVT = UvalueType<vec>;
|
||||
|
@ -390,17 +390,15 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}")
|
||||
endif(CXX_ZVECTOR_FOUND)
|
||||
|
||||
if(CXX_SVE_FOUND)
|
||||
if(CXX_SVE256_FOUND)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION")
|
||||
list(APPEND CPU_CAPABILITY_NAMES "SVE256")
|
||||
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
|
||||
else()
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
|
||||
endif()
|
||||
endif(CXX_SVE256_FOUND)
|
||||
endif(CXX_SVE_FOUND)
|
||||
if(CXX_SVE_FOUND AND CXX_SVE256_FOUND AND CXX_ARM_BF16_FOUND)
|
||||
list(APPEND CPU_CAPABILITY_NAMES "SVE256")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION")
|
||||
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
|
||||
else()
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
|
||||
endif()
|
||||
endif(CXX_SVE_FOUND AND CXX_SVE256_FOUND)
|
||||
|
||||
list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES)
|
||||
math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1")
|
||||
|
@ -106,8 +106,18 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
}
|
||||
")
|
||||
|
||||
SET(ARM_BF16_CODE "
|
||||
#include <arm_neon.h>
|
||||
int main()
|
||||
{
|
||||
float32x4_t a = vdupq_n_f32(0);
|
||||
bfloat16x8_t b = vreinterpretq_bf16_f32(a);
|
||||
return 0;
|
||||
}
|
||||
")
|
||||
|
||||
# Macro to check for SVE instruction support
|
||||
MACRO(CHECK_SVE lang type flags)
|
||||
MACRO(CHECK_COMPILES lang type flags code)
|
||||
# Save the current state of required flags
|
||||
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
|
||||
@ -116,9 +126,9 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
|
||||
# Check if the source code compiles with the given flags for the specified language (C or C++)
|
||||
IF(lang STREQUAL "CXX")
|
||||
CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type})
|
||||
CHECK_CXX_SOURCE_COMPILES("${code}" ${lang}_HAS_${type})
|
||||
ELSE()
|
||||
CHECK_C_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type})
|
||||
CHECK_C_SOURCE_COMPILES("${code}" ${lang}_HAS_${type})
|
||||
ENDIF()
|
||||
|
||||
# If the compilation test is successful, set appropriate variables indicating support
|
||||
@ -142,7 +152,8 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
ENDMACRO()
|
||||
|
||||
# Check for SVE256 vector length
|
||||
CHECK_SVE(CXX "SVE256" "-march=armv8-a+sve -msve-vector-bits=256")
|
||||
CHECK_COMPILES(CXX "SVE256" "-march=armv8.2-a+sve -msve-vector-bits=256" "${SVE_CODE}")
|
||||
CHECK_COMPILES(CXX "ARM_BF16" "-march=armv8.2-a+sve+bf16 -msve-vector-bits=256" "${ARM_BF16_CODE}")
|
||||
|
||||
# If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user
|
||||
if(NOT CXX_SVE256_FOUND)
|
||||
|
@ -175,8 +175,10 @@ class VecSVE256(VecISA):
|
||||
"CPU_CAPABILITY_SVE",
|
||||
"CPU_CAPABILITY_SVE256",
|
||||
"AT_BUILD_ARM_VEC256_WITH_SLEEF",
|
||||
"__ARM_FEATURE_BF16",
|
||||
]
|
||||
_arch_flags = "-march=armv8-a+sve -msve-vector-bits=256"
|
||||
_arch_flags = "-march=armv8-a+sve+bf16 -msve-vector-bits=256"
|
||||
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -332,7 +334,13 @@ def x86_isa_checker() -> list[str]:
|
||||
|
||||
|
||||
invalid_vec_isa = InvalidVecISA()
|
||||
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE256()]
|
||||
supported_vec_isa_list = [
|
||||
VecAMX(),
|
||||
VecAVX512(),
|
||||
VecAVX2(),
|
||||
VecNEON(),
|
||||
VecSVE256(),
|
||||
]
|
||||
|
||||
|
||||
def get_isa_from_cpu_capability(
|
||||
@ -397,6 +405,7 @@ def valid_vec_isa_list() -> list[VecISA]:
|
||||
isa_list.append(VecSVE256())
|
||||
else:
|
||||
isa_list.append(VecNEON())
|
||||
|
||||
elif arch in ["x86_64", "AMD64"]:
|
||||
"""
|
||||
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
||||
|
Reference in New Issue
Block a user