Extend vec backend with BF16 SVE intrinsics (#143666)

- Following the work in https://github.com/pytorch/pytorch/pull/119571, BF16 SVE intrinsics are added to the Vectorized class, providing ~1.7x speedup on `silu` and `softmax`.
- Added bf16 detection in CMake
- Added a guard for native NEON code to prevent compilation errors

@aditew01 @maajidkhann please have a look

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143666
Approved by: https://github.com/malfet, https://github.com/aditew01, https://github.com/nikhil-arm

Co-authored-by: Aditya Tewari <aditya.tewari@arm.com>
This commit is contained in:
Ryo Suzuki
2025-04-28 18:25:41 +00:00
committed by PyTorch MergeBot
parent 0c52ee1b35
commit fcbbb03d48
15 changed files with 703 additions and 47 deletions

View File

@ -105,7 +105,7 @@ std::string get_cpu_capability() {
return "DEFAULT";
case native::CPUCapability::ZVECTOR:
return "Z VECTOR";
#elif defined(HAVE_SVE_CPU_DEFINITION)
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
case native::CPUCapability::DEFAULT:
return "DEFAULT";
case native::CPUCapability::SVE256:

View File

@ -17,6 +17,7 @@ typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH
typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svbfloat16_t vls_bfloat16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
@ -41,6 +42,7 @@ typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDT
#define ONE_U32 svdup_n_u32(1)
#define ONE_U64 svdup_n_u64(1)
#define ONE_F16 svdup_n_f16(1.f)
#define ONE_BF16 svdup_n_bf16(1.f)
#define ONE_F32 svdup_n_f32(1.f)
#define ONE_F64 svdup_n_f64(1.0)
#define ALL_S8_TRUE_MASK svdup_n_s8(0xff)
@ -55,6 +57,8 @@ typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDT
#define ALL_U8_FALSE_MASK svdup_n_u8(0x00)
#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK)
#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK)
#define ALL_BF16_TRUE_MASK svreinterpret_bf16_s16(ALL_S16_TRUE_MASK)
#define ALL_BF16_FALSE_MASK svreinterpret_bf16_s16(ALL_S16_FALSE_MASK)
#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK)
#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK)
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK)

View File

@ -0,0 +1,524 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/sve/sve_helper.h>
#include <ATen/cpu/vec/sve/vec_common_sve.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/sve/vec_float.h>
#include <cmath>
namespace at {
namespace vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
template <>
class Vectorized<BFloat16> {
private:
vls_bfloat16_t values;
public:
using value_type = BFloat16;
using size_type = int;
static constexpr size_type size() {
return VECTOR_WIDTH / sizeof(BFloat16);
}
Vectorized() {}
Vectorized(svbfloat16_t v) : values(v) {}
Vectorized(int val);
Vectorized(BFloat16 val);
template <
typename... Args,
typename = std::enable_if_t<(sizeof...(Args) == size())>>
Vectorized(Args... vals) {
__at_align__ BFloat16 buffer[size()] = {vals...};
values = svld1_bf16(ptrue, reinterpret_cast<const bfloat16_t*>(buffer));
}
operator svbfloat16_t() const {
return values;
}
static Vectorized<BFloat16> blendv(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& b, const Vectorized<BFloat16>& mask_) {
svbool_t mask = svcmpeq_s16(ptrue, svreinterpret_s16_bf16(mask_),
ALL_S16_TRUE_MASK);
return svsel_bf16(mask, b, a);
}
template<typename step_t>
static Vectorized<BFloat16> arange(BFloat16 base = 0.f, step_t step =
static_cast<step_t>(1)) {
__at_align__ BFloat16 buffer[size()];
for (int64_t i = 0; i < size(); i++) {
buffer[i] = base + i * step;
}
return svld1_bf16(ptrue, reinterpret_cast<bfloat16_t *>(buffer));
}
static Vectorized<BFloat16> set(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& b, int64_t count = size()) {
if (count == 0) {
return a;
} else if (count < size()) {
return svsel_bf16(svwhilelt_b16(0ull, count), b, a);
}
return b;
}
static Vectorized<BFloat16> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return svld1_bf16(ptrue, reinterpret_cast<const bfloat16_t*>(ptr));
svbool_t pg = svwhilelt_b16(0ull, count);
return svld1_bf16(pg, reinterpret_cast<const bfloat16_t*>(ptr));
}
void store(void* ptr, int64_t count = size()) const {
__at_align__ bfloat16_t tmp[size()];
std::memset(tmp, 0, sizeof(tmp));
if (count == size()) {
svst1_bf16(ptrue, reinterpret_cast<bfloat16_t*>(tmp), values);
} else {
svbool_t pg = svwhilelt_b16(0ull, count);
svst1_bf16(pg, reinterpret_cast<bfloat16_t*>(tmp), values);
}
std::memcpy(
reinterpret_cast<bfloat16_t*>(ptr),
reinterpret_cast<const bfloat16_t*>(tmp),
count * sizeof(bfloat16_t));
}
const BFloat16& operator[](int idx) const = delete;
BFloat16& operator[](int idx) = delete;
int64_t zero_mask() const {
int64_t mask = 0;
// returns an integer mask where all zero elements are translated to
// 1-bit and others are translated to 0-bit int64_t mask = 0;
__at_align__ int16_t mask_array[size()];
svbool_t svbool_mask = svcmpeq_f16(ptrue, svreinterpret_f16_bf16(values), ZERO_F16);
svst1_s16(ptrue, mask_array, svsel_s16(svbool_mask,
ALL_S16_TRUE_MASK,
ALL_S16_FALSE_MASK));
for (int64_t i = 0; i < size(); ++i) {
if (mask_array[i]) mask |= (1ull << i);
}
return mask;
}
Vectorized<BFloat16> isnan() const;
bool has_inf_nan() const;
Vectorized<BFloat16> map(BFloat16 (*f)(BFloat16)) const {
__at_align__ BFloat16 tmp[size()];
store(tmp);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<BFloat16> abs() const {
auto mask = svdup_n_u16(0x7FFF);
auto vals = svreinterpret_u16_bf16(values);
vals = svand_u16_x(ptrue, vals, mask);
return svreinterpret_bf16_u16(vals);
}
Vectorized<BFloat16> angle() const;
Vectorized<BFloat16> real() const {
return values;
}
Vectorized<BFloat16> imag() const {
return Vectorized<BFloat16>(0.f);
}
Vectorized<BFloat16> conj() const {
return values;
}
Vectorized<BFloat16> acos() const;
Vectorized<BFloat16> acosh() const;
Vectorized<BFloat16> asin() const;
Vectorized<BFloat16> atan() const;
Vectorized<BFloat16> atanh() const;
Vectorized<BFloat16> atan2(const Vectorized<BFloat16> &b) const;
Vectorized<BFloat16> copysign(const Vectorized<BFloat16> &sign) const;
Vectorized<BFloat16> erf() const;
Vectorized<BFloat16> erfc() const;
Vectorized<BFloat16> erfinv() const;
Vectorized<BFloat16> exp() const;
Vectorized<BFloat16> exp2() const;
Vectorized<BFloat16> expm1() const;
Vectorized<BFloat16> exp_u20() const {
return exp();
}
Vectorized<BFloat16> fmod(const Vectorized<BFloat16>& q) const;
Vectorized<BFloat16> hypot(const Vectorized<BFloat16> &b) const;
Vectorized<BFloat16> i0() const;
Vectorized<BFloat16> i0e() const;
Vectorized<BFloat16> digamma() const;
Vectorized<BFloat16> igamma(const Vectorized<BFloat16> &x) const;
Vectorized<BFloat16> igammac(const Vectorized<BFloat16> &x) const;
Vectorized<BFloat16> nextafter(const Vectorized<BFloat16> &b) const;
Vectorized<BFloat16> log() const;
Vectorized<BFloat16> log2() const;
Vectorized<BFloat16> log10() const;
Vectorized<BFloat16> log1p() const;
Vectorized<BFloat16> frac() const;
Vectorized<BFloat16> sin() const;
Vectorized<BFloat16> sinh() const;
Vectorized<BFloat16> cos() const;
Vectorized<BFloat16> cosh() const;
Vectorized<BFloat16> ceil() const;
Vectorized<BFloat16> floor() const;
Vectorized<BFloat16> neg() const {
auto mask = svdup_n_u16(0x8000);
auto vals = svreinterpret_u16_bf16(values);
vals = sveor_u16_x(ptrue, vals, mask);
return svreinterpret_bf16_u16(vals);
};
Vectorized<BFloat16> round() const;
Vectorized<BFloat16> tan() const;
Vectorized<BFloat16> tanh() const;
Vectorized<BFloat16> trunc() const;
Vectorized<BFloat16> lgamma() const;
Vectorized<BFloat16> sqrt() const;
Vectorized<BFloat16> reciprocal() const;
Vectorized<BFloat16> rsqrt() const;
Vectorized<BFloat16> pow(const Vectorized<BFloat16> &b) const;
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<BFloat16> operator==(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> operator!=(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> operator<(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> operator<=(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> operator>(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> operator>=(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const;
Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const;
};
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(
const Vectorized<c10::BFloat16>& a) {
static_assert(
Vectorized<c10::BFloat16>::size() == 2 * Vectorized<float>::size());
auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f));
auto bf16_vec1 = svzip1_bf16(zero, a);
auto bf16_vec2 = svzip2_bf16(zero, a);
auto x1 = svreinterpret_f32_bf16(bf16_vec1);
auto x2 = svreinterpret_f32_bf16(bf16_vec2);
return {Vectorized<float>(x1), Vectorized<float>(x2)};
}
inline Vectorized<c10::BFloat16> convert_float_bfloat16(
const Vectorized<float>& a,
const Vectorized<float>& b) {
static_assert(
Vectorized<c10::BFloat16>::size() == 2 * Vectorized<float>::size());
svbfloat16_t x1 = svcvt_bf16_f32_z(ptrue, a);
svbfloat16_t x2 = svcvt_bf16_f32_z(ptrue, b);
return Vectorized<c10::BFloat16>(svuzp1_bf16(x1, x2));
}
inline void load_fp32_from_bf16(const BFloat16* data, Vectorized<float>& out) {
__at_align__ float values[Vectorized<float>::size()];
for (const auto k : c10::irange(Vectorized<float>::size())) {
values[k] = data[k];
}
out = Vectorized<float>::loadu(values);
}
inline void load_fp32_from_bf16(
const BFloat16* data,
Vectorized<float>& out1,
Vectorized<float>& out2) {
Vectorized<BFloat16> bf16_vec = Vectorized<BFloat16>::loadu(data);
auto floats = convert_bfloat16_float(bf16_vec);
out1 = std::get<0>(floats);
out2 = std::get<1>(floats);
}
template <typename Op>
Vectorized<c10::BFloat16> binary_operator_via_float(
Op op,
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
return convert_float_bfloat16(
op(a_float_low, b_float_low), op(a_float_high, b_float_high));
}
template <>
Vectorized<c10::BFloat16> inline operator+(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
}
template <>
Vectorized<c10::BFloat16> inline operator-(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
}
template <>
Vectorized<c10::BFloat16> inline operator*(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
}
template <>
Vectorized<c10::BFloat16> inline operator/(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
}
inline Vectorized<BFloat16>::Vectorized(int val) {
auto vals_f = svdup_n_f32(val);
values = convert_float_bfloat16(vals_f, vals_f);
}
inline Vectorized<BFloat16>::Vectorized(BFloat16 val) {
auto vals_f = svdup_n_f32((float) val);
values = convert_float_bfloat16(vals_f, vals_f);
}
bool inline Vectorized<c10::BFloat16>::has_inf_nan() const {
auto [v1, v2] = convert_bfloat16_float(values);
return v1.has_inf_nan() || v2.has_inf_nan();
}
// frac. Implement this here so we can use subtraction
Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
return *this - this->trunc();
}
#define DEFINE_BF16_FUNC_VIA_FLOAT(func_name) \
Vectorized<BFloat16> inline Vectorized<BFloat16>::func_name() const { \
auto [v1, v2] = convert_bfloat16_float(*this); \
v1 = v1.func_name(); \
v2 = v2.func_name(); \
return convert_float_bfloat16(v1, v2); \
}
#define DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(func_name) \
Vectorized<BFloat16> inline Vectorized<BFloat16>::func_name(const Vectorized<BFloat16> &a) const { \
auto [v1, v2] = convert_bfloat16_float(*this); \
auto [v3, v4] = convert_bfloat16_float(a); \
v1 = v1.func_name(v3); \
v2 = v2.func_name(v4); \
return convert_float_bfloat16(v1, v2); \
}
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
DEFINE_BF16_FUNC_VIA_FLOAT(log);
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
DEFINE_BF16_FUNC_VIA_FLOAT(round);
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(const Vectorized<BFloat16>& other) const {
auto [f1, f2] = convert_bfloat16_float(values);
auto [f3, f4] = convert_bfloat16_float(other);
svbool_t mask1 = svcmpeq_f32(ptrue, f1, f3);
svbool_t mask2 = svcmpeq_f32(ptrue, f2, f4);
auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
auto bf16_1 = svreinterpret_bf16_f32(res1);
auto bf16_2 = svreinterpret_bf16_f32(res2);
return svuzp1_bf16(bf16_1, bf16_2);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator!=(const Vectorized<BFloat16>& other) const {
auto [f1, f2] = convert_bfloat16_float(values);
auto [f3, f4] = convert_bfloat16_float(other);
svbool_t mask1 = svcmpne_f32(ptrue, f1, f3);
svbool_t mask2 = svcmpne_f32(ptrue, f2, f4);
auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK);
auto bf16_1 = svreinterpret_bf16_f32(res1);
auto bf16_2 = svreinterpret_bf16_f32(res2);
return svuzp1_bf16(bf16_1, bf16_2);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>(const Vectorized<BFloat16>& other) const {
auto [v1, v2] = convert_bfloat16_float(*this);
auto [v3, v4] = convert_bfloat16_float(other);
return convert_float_bfloat16(v1 > v3, v2 > v4);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>=(const Vectorized<BFloat16>& other) const {
auto [v1, v2] = convert_bfloat16_float(*this);
auto [v3, v4] = convert_bfloat16_float(other);
return convert_float_bfloat16(v1 >= v3, v2 >= v4);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<(const Vectorized<BFloat16>& other) const {
auto [v1, v2] = convert_bfloat16_float(*this);
auto [v3, v4] = convert_bfloat16_float(other);
return convert_float_bfloat16(v1 < v3, v2 < v4);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<=(const Vectorized<BFloat16>& other) const {
auto [v1, v2] = convert_bfloat16_float(*this);
auto [v3, v4] = convert_bfloat16_float(other);
return convert_float_bfloat16(v1 <= v3, v2 <= v4);
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<BFloat16> inline maximum(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& b) {
return binary_operator_via_float(static_cast<Vectorized<float>(*)(const Vectorized<float>&, const Vectorized<float>&)>(&maximum), a, b);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<BFloat16> inline minimum(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& b) {
return binary_operator_via_float(static_cast<Vectorized<float>(*)(const Vectorized<float>&, const Vectorized<float>&)>(&minimum), a, b);
}
template <>
Vectorized<BFloat16> inline clamp_max(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& max) {
return binary_operator_via_float(static_cast<Vectorized<float>(*)(const Vectorized<float>&, const Vectorized<float>&)>(&clamp_max), a, max);
}
template <>
Vectorized<BFloat16> inline clamp_min(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& min) {
return binary_operator_via_float(static_cast<Vectorized<float>(*)(const Vectorized<float>&, const Vectorized<float>&)>(&clamp_min), a, min);
}
template <>
Vectorized<BFloat16> inline clamp(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& min, const Vectorized<BFloat16>& max) {
return clamp_min(clamp_max(a, max), min);
}
template <>
Vectorized<BFloat16> inline operator&(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& b) {
return svreinterpret_bf16_u16(svand_u16_x(ptrue, svreinterpret_u16_bf16(a),
svreinterpret_u16_bf16(b)));
}
template <>
Vectorized<BFloat16> inline operator|(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& b) {
return svreinterpret_bf16_u16(svorr_u16_x(ptrue, svreinterpret_u16_bf16(a),
svreinterpret_u16_bf16(b)));
}
template <>
Vectorized<BFloat16> inline operator^(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& b) {
return svreinterpret_bf16_u16(sveor_u16_x(ptrue, svreinterpret_u16_bf16(a),
svreinterpret_u16_bf16(b)));
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::eq(const Vectorized<BFloat16>&
other) const {
return (*this == other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::ne(const Vectorized<BFloat16>&
other) const {
return (*this != other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::gt(const Vectorized<BFloat16>&
other) const {
return (*this > other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::ge(const Vectorized<BFloat16>&
other) const {
return (*this >= other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::lt(const Vectorized<BFloat16>&
other) const {
return (*this < other) & Vectorized<BFloat16>(1.0f);
}
Vectorized<BFloat16> inline Vectorized<BFloat16>::le(const Vectorized<BFloat16>&
other) const {
return (*this <= other) & Vectorized<BFloat16>(1.0f);
}
template <>
inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
const int64_t fraction = n % Vectorized<BFloat16>::size();
#pragma unroll
for (int64_t i = 0; i < n - fraction; i += Vectorized<BFloat16>::size()) {
svst1_bf16(ptrue, const_cast<bfloat16_t*>(reinterpret_cast<const bfloat16_t*>(dst)) + i, svldnt1_bf16(ptrue, const_cast<bfloat16_t*>(reinterpret_cast<const bfloat16_t*>(src)) + i));
}
#pragma unroll
for (int64_t i = n - fraction; i < n; i += Vectorized<BFloat16>::size()) {
svbool_t pg = svwhilelt_b16(i, n);
svst1_bf16(pg, const_cast<bfloat16_t*>(reinterpret_cast<const bfloat16_t*>(dst)) + i, svldnt1_bf16(pg, const_cast<bfloat16_t*>(reinterpret_cast<const bfloat16_t*>(src)) + i));
}
}
template <>
Vectorized<BFloat16> inline fmadd(const Vectorized<BFloat16>& a, const
Vectorized<BFloat16>& b, const Vectorized<BFloat16>& c) {
return a * b + c;
}
#endif // defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16)
} // namespace CPU_CAPABILITY
} // namespace vec
} // namespace at

View File

@ -13,6 +13,7 @@
#include <ATen/cpu/vec/sve/vec_double.h>
#include <ATen/cpu/vec/sve/vec_int.h>
#include <ATen/cpu/vec/sve/vec_qint.h>
#include <ATen/cpu/vec/sve/vec_bfloat16.h>
#endif
@ -30,33 +31,29 @@ inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_SVE)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
return svreinterpret_f32_f64(src);
}
template<>
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
return svreinterpret_f64_f32(src);
}
#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \
#define DEFINE_SVE_CAST(t1_t, t1_prefix, t2_t, t2_prefix) \
template<> \
inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \
return svreinterpret_s##int_bit##_f##float_bit(src); \
inline Vectorized<t1_t> cast<t1_t, t2_t>(const Vectorized<t2_t>& src) { \
return svreinterpret_##t1_prefix##_##t2_prefix(src); \
} \
template<> \
inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \
return svreinterpret_f##float_bit##_s##int_bit(src); \
inline Vectorized<t2_t> cast<t2_t, t1_t>(const Vectorized<t1_t>& src) { \
return svreinterpret_##t2_prefix##_##t1_prefix(src); \
}
DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64)
DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64)
DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64)
DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32)
DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32)
DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32)
DEFINE_SVE_CAST(int64_t, s64, double, f64)
DEFINE_SVE_CAST(int32_t, s32, double, f64)
DEFINE_SVE_CAST(int16_t, s16, double, f64)
DEFINE_SVE_CAST(int64_t, s64, float, f32)
DEFINE_SVE_CAST(int32_t, s32, float, f32)
DEFINE_SVE_CAST(int16_t, s16, float, f32)
DEFINE_SVE_CAST(float, f32, double, f64)
#ifdef __ARM_FEATURE_BF16
DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16)
DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16)
DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16)
#endif // __ARM_FEATURE_BF16
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -143,6 +140,21 @@ inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b
Vectorized<float>(svzip2_f32(a, b)));
}
#ifdef __ARM_FEATURE_BF16
template <>
std::pair<Vectorized<c10::BFloat16>, Vectorized<c10::BFloat16>>
inline interleave2<c10::BFloat16>(const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& b) {
// inputs:
// a = {a0, a1, a2, a3, a4, a5, a6, a7}
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
// group cols crossing lanes:
// return {a0, b0, a1, b1, a2, b2, a3, b3}
// {a4, b4, a5, b5, a6, b6, a7, b7}
return std::make_pair(Vectorized<c10::BFloat16>(svzip1_bf16(a, b)),
Vectorized<c10::BFloat16>(svzip2_bf16(a, b)));
}
#endif // __ARM_FEATURE_BF16
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>
@ -171,6 +183,21 @@ inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>&
Vectorized<float>(svuzp2_f32(a, b)));
}
#ifdef __ARM_FEATURE_BF16
template <>
std::pair<Vectorized<c10::BFloat16>, Vectorized<c10::BFloat16>>
inline deinterleave2<c10::BFloat16>(const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& b) {
// inputs:
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
// swap lanes:
// return {a0, a1, a2, a3, a4, a5, a6, a7}
// {b0, b1, b2, b3, b4, b5, b6, b7}
return std::make_pair(Vectorized<c10::BFloat16>(svuzp1_bf16((svbfloat16_t) a, (svbfloat16_t) b)),
Vectorized<c10::BFloat16>(svuzp2_bf16((svbfloat16_t) a, (svbfloat16_t) b)));
}
#endif // __ARM_FEATURE_BF16
#endif // defined(CPU_CAPABILITY_SVE)
}}

View File

@ -9,13 +9,16 @@
#if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR))
#if defined(CPU_CAPABILITY_SVE256)
#include <ATen/cpu/vec/sve/vec_common_sve.h>
#endif
#else
#include <ATen/cpu/vec/vec256/vec256_float.h>
#include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
#include <ATen/cpu/vec/vec256/vec256_half.h>
#include <ATen/cpu/vec/vec256/vec256_double.h>
#include <ATen/cpu/vec/vec256/vec256_int.h>
#include <ATen/cpu/vec/vec256/vec256_qint.h>
#endif
#if !defined(CPU_CAPABILITY_SVE256) || !defined(__ARM_FEATURE_BF16)
#include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
#endif
#include <ATen/cpu/vec/vec256/vec256_half.h>
#include <ATen/cpu/vec/vec256/vec256_complex_float.h>
#include <ATen/cpu/vec/vec256/vec256_complex_double.h>
#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)

View File

@ -299,6 +299,46 @@ struct VecConvert<
};
#endif
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
template <>
struct VecConvert<float, 1, BFloat16, 1> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<BFloat16, 1>& src) {
VectorizedN<float, 1> res;
// Load 16-bit unsigned integers from src into an SVE vector
svuint16_t u16x4 = svld1_u16(svptrue_b16(), reinterpret_cast<const uint16_t*>(&src[0]));
// Zero-extend to 32-bit SVE does not have direct vmovl_u16 equivalent.
vls_uint32_t u32x4 = svreinterpret_u32_u16(svzip1_u16(svdup_n_u16(0), u16x4));
// Reinterpret as float32
vls_float32_t f32x4 = svreinterpret_f32_u32(u32x4);
res[0] = Vectorized<float>(f32x4);
return res;
}
};
template <>
struct VecConvert<float, 2, BFloat16, 1> {
static inline VectorizedN<float, 2> apply(
const VectorizedN<BFloat16, 1>& src) {
VectorizedN<float, 2> res;
std::tie(res[0], res[1]) = convert_bfloat16_float(src[0]);
return res;
}
};
template <>
struct VecConvert<BFloat16, 1, float, 2> {
static inline VectorizedN<BFloat16, 1> apply(
const VectorizedN<float, 2>& src) {
VectorizedN<BFloat16, 1> res;
res[0] = convert_float_bfloat16(src[0], src[1]);
return res;
}
};
#endif // defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
template <typename src_t>
struct VecConvert<
float,

View File

@ -1323,7 +1323,7 @@ inline Vectorized<IntType> convert_to_int_of_same_size(
static_assert(sizeof(T) == sizeof(IntType));
static constexpr int size = Vectorized<T>::size();
std::array<T, size> src_arr;
std::array<T, size> src_arr = {};
src.store(static_cast<void*>(src_arr.data()));
std::array<IntType, size> buffer;
std::transform(

View File

@ -41,7 +41,11 @@ static CPUCapability compute_cpu_capability() {
#ifdef HAVE_SVE256_CPU_DEFINITION
if (strcmp(envar, "sve256") == 0) {
if (sve_vl == 256) {
return CPUCapability::SVE256;
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
if (cpuinfo_has_arm_bf16()) {
return CPUCapability::SVE256;
}
#endif
}
TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT");
return CPUCapability::DEFAULT;
@ -102,7 +106,10 @@ static CPUCapability compute_cpu_capability() {
}
#ifdef HAVE_SVE256_CPU_DEFINITION
if (sve_vl == 256) { // Check for SVE256
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
if (cpuinfo_has_arm_bf16())
return CPUCapability::SVE256;
#endif
}
#endif
// Return the default CPU capability.

View File

@ -64,7 +64,7 @@ enum class CPUCapability {
VSX = 1,
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
ZVECTOR = 1,
#elif defined(HAVE_SVE_CPU_DEFINITION)
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
SVE256 = 1,
#else
AVX2 = 1,

View File

@ -274,6 +274,26 @@ inline Vectorized<scalar_t> div_floor_floating_vec(
return floordiv;
}
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
// Since sve lacks sufficient bf16 intrinsics, do the calculations in f32 to
// avoid rounding errors. This should not cause performance issues as
// most of the used instructions would be cast to f32 vectors anyway.
template<>
inline Vectorized<c10::BFloat16> div_floor_floating_vec(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
auto [a1, a2] = convert_bfloat16_float(a);
auto [b1, b2] = convert_bfloat16_float(b);
auto res1 = div_floor_floating_vec(a1, b1);
auto res2 = div_floor_floating_vec(a2, b2);
return convert_float_bfloat16(res1, res2);
}
#endif
void div_floor_kernel(TensorIteratorBase& iter) {
const auto dtype = iter.common_dtype();
if (dtype == kByte) {

View File

@ -237,7 +237,7 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
// Return a + b_low * c_low + b_high * c_high
vec::Vectorized<float> fmadd(vec::Vectorized<float> a, vec::Vectorized<Half> b, vec::Vectorized<Half> c) {
#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML)
#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)
// NOTE: this instruction is an optional instruction in ARM v8.2 and
// v8.3, but mandatory in v8.4 per
// https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en

View File

@ -582,6 +582,19 @@ namespace {
}
}
}
#if defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16)
TEST(NanBfloat16, IsNan) {
for (unsigned int ii = 0; ii < 0xFFFF; ++ii) {
c10::BFloat16 val(ii, c10::BFloat16::from_bits());
bool expected = std::isnan(val);
CACHE_ALIGN c10::BFloat16 actual_vals[at::vec::SVE256::Vectorized<c10::BFloat16>::size()];
at::vec::SVE256::Vectorized<c10::BFloat16>(val).isnan().store(actual_vals);
for (int jj = 0; jj < at::vec::SVE256::Vectorized<c10::BFloat16>::size(); ++jj) {
EXPECT_EQ(expected, c10::bit_cast<uint16_t>(actual_vals[jj]) != 0) << "bf16 isnan failure for bit pattern " << std::hex << ii << std::dec;
}
}
}
#endif
TYPED_TEST(LGamma, LGamma) {
using vec = TypeParam;
using UVT = UvalueType<vec>;

View File

@ -390,17 +390,15 @@ if(INTERN_BUILD_ATEN_OPS)
LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}")
endif(CXX_ZVECTOR_FOUND)
if(CXX_SVE_FOUND)
if(CXX_SVE256_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION")
list(APPEND CPU_CAPABILITY_NAMES "SVE256")
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
else()
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
endif()
endif(CXX_SVE256_FOUND)
endif(CXX_SVE_FOUND)
if(CXX_SVE_FOUND AND CXX_SVE256_FOUND AND CXX_ARM_BF16_FOUND)
list(APPEND CPU_CAPABILITY_NAMES "SVE256")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION")
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
else()
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
endif()
endif(CXX_SVE_FOUND AND CXX_SVE256_FOUND)
list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES)
math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1")

View File

@ -106,8 +106,18 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
}
")
SET(ARM_BF16_CODE "
#include <arm_neon.h>
int main()
{
float32x4_t a = vdupq_n_f32(0);
bfloat16x8_t b = vreinterpretq_bf16_f32(a);
return 0;
}
")
# Macro to check for SVE instruction support
MACRO(CHECK_SVE lang type flags)
MACRO(CHECK_COMPILES lang type flags code)
# Save the current state of required flags
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
@ -116,9 +126,9 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
# Check if the source code compiles with the given flags for the specified language (C or C++)
IF(lang STREQUAL "CXX")
CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type})
CHECK_CXX_SOURCE_COMPILES("${code}" ${lang}_HAS_${type})
ELSE()
CHECK_C_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type})
CHECK_C_SOURCE_COMPILES("${code}" ${lang}_HAS_${type})
ENDIF()
# If the compilation test is successful, set appropriate variables indicating support
@ -142,7 +152,8 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
ENDMACRO()
# Check for SVE256 vector length
CHECK_SVE(CXX "SVE256" "-march=armv8-a+sve -msve-vector-bits=256")
CHECK_COMPILES(CXX "SVE256" "-march=armv8.2-a+sve -msve-vector-bits=256" "${SVE_CODE}")
CHECK_COMPILES(CXX "ARM_BF16" "-march=armv8.2-a+sve+bf16 -msve-vector-bits=256" "${ARM_BF16_CODE}")
# If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user
if(NOT CXX_SVE256_FOUND)

View File

@ -175,8 +175,10 @@ class VecSVE256(VecISA):
"CPU_CAPABILITY_SVE",
"CPU_CAPABILITY_SVE256",
"AT_BUILD_ARM_VEC256_WITH_SLEEF",
"__ARM_FEATURE_BF16",
]
_arch_flags = "-march=armv8-a+sve -msve-vector-bits=256"
_arch_flags = "-march=armv8-a+sve+bf16 -msve-vector-bits=256"
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
def __str__(self) -> str:
@ -332,7 +334,13 @@ def x86_isa_checker() -> list[str]:
invalid_vec_isa = InvalidVecISA()
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE256()]
supported_vec_isa_list = [
VecAMX(),
VecAVX512(),
VecAVX2(),
VecNEON(),
VecSVE256(),
]
def get_isa_from_cpu_capability(
@ -397,6 +405,7 @@ def valid_vec_isa_list() -> list[VecISA]:
isa_list.append(VecSVE256())
else:
isa_list.append(VecNEON())
elif arch in ["x86_64", "AMD64"]:
"""
arch value is x86_64 on Linux, and the value is AMD64 on Windows.