Reintroduce s390x SIMD support (#99057)

Reintroduce s390x SIMD support

Use vectorized FMA to fix test precision failures

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99057
Approved by: https://github.com/malfet
This commit is contained in:
Aleksei Nikiforov
2023-04-15 00:24:40 +00:00
committed by PyTorch MergeBot
parent 7cb581d42f
commit c130b8a716
3 changed files with 63 additions and 14 deletions

View File

@ -16,7 +16,8 @@
namespace at {
namespace vec {
namespace {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
template <typename T>
constexpr bool is_zarch_implemented() {
@ -290,6 +291,8 @@ constexpr int64_t allbitset(int16_t x) {
return (onex << x) - onex;
}
namespace { /* unnamed namespace */
ZSimdVect<float> vec_mergee(ZSimdVect<float> x, ZSimdVect<float> y) {
constexpr ZSimdVectBinary<uint8_t> mergee_mask{
0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27};
@ -310,6 +313,8 @@ ZSimdVect<double> vec_mergeo(ZSimdVect<double> x, ZSimdVect<double> y) {
return vec_mergel(x, y);
}
} /* unnamed namespace */
//
template <typename T>
constexpr auto GetBpermZeroMask() {
@ -696,8 +701,27 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
return set_inner<1, size()>(a, b, count);
}
const T& operator[](int idx) const = delete;
T& operator[](int idx) = delete;
const ElementType& operator[](int idx) const {
if (idx < size() / 2)
{
return _vec0[idx];
}
else
{
return _vec1[idx - (size() / 2)];
}
}
ElementType& operator[](int idx) {
if (idx < size() / 2)
{
return _vec0[idx];
}
else
{
return _vec1[idx - (size() / 2)];
}
}
Vectorized<T> C10_ALWAYS_INLINE operator+(const Vectorized<T>& other) const {
return Vectorized<T>{_vec0 + other._vec0, _vec1 + other._vec1};
@ -1247,6 +1271,8 @@ ZSimdVect<int> vec_flt_int(const ZSimdVect<float> x) {
#define vec_flt_int vec_signed
#endif
namespace { /* unnamed namespace */
Vectorized<float> convert_to_float(const Vectorized<int32_t>& x) {
return {vec_int_flt(x.vec0()), vec_int_flt(x.vec1())};
}
@ -1263,6 +1289,8 @@ Vectorized<int64_t> convert_to_int(const Vectorized<double>& x) {
return {vec_signed(x.vec0()), vec_signed(x.vec1())};
}
} /* unnamed namespace */
template <typename T, typename V>
Vectorized<V> cast_zvector(const Vectorized<T>& x) {
using cast_type = typename Vectorized<V>::vtype;
@ -1275,7 +1303,8 @@ Vectorized<float> C10_ALWAYS_INLINE fmadd(
const Vectorized<float>& b,
const Vectorized<float>& c) {
return Vectorized<float>{
a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
__builtin_s390_vfmasb(a.vec0(), b.vec0(), c.vec0()),
__builtin_s390_vfmasb(a.vec1(), b.vec1(), c.vec1())};
}
template <>
Vectorized<double> C10_ALWAYS_INLINE fmadd(
@ -1283,7 +1312,8 @@ Vectorized<double> C10_ALWAYS_INLINE fmadd(
const Vectorized<double>& b,
const Vectorized<double>& c) {
return Vectorized<double>{
a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
__builtin_s390_vfmadb(a.vec0(), b.vec0(), c.vec0()),
__builtin_s390_vfmadb(a.vec1(), b.vec1(), c.vec1())};
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE fmadd(
@ -1408,6 +1438,8 @@ struct pack_type<int32_t> {
using type = int16_t;
};
namespace { /* unnamed namespace */
template <typename T, typename V = typename unpack_type<T>::type>
std::pair<Vectorized<V>, Vectorized<V>> unpack(const Vectorized<T>& x) {
auto vec0 = vec_unpackh(x.vec0());
@ -1451,6 +1483,8 @@ Vectorized<uint8_t> pack(
return Vectorized<uint8_t>{vec0, vec1};
}
} /* unnamed namespace */
//////////////////////////////////QUANT///////////////////////////////////////////
template <typename T>
struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
@ -1522,7 +1556,7 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
Vectorized<float> zero_point,
Vectorized<float> scale_zp_premul) const {
auto float_val = convert_to_float(_vec);
return {(scale * float_val + scale_zp_premul)};
return {fmadd(scale, float_val, scale_zp_premul)};
}
template <
@ -1593,10 +1627,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
auto vecf_2 = convert_to_float(ret32_1.first);
auto vecf_3 = convert_to_float(ret32_1.second);
return {
Vectorized<float>{scale * vecf_0 + scale_zp_premul},
Vectorized<float>{scale * vecf_1 + scale_zp_premul},
Vectorized<float>{scale * vecf_2 + scale_zp_premul},
Vectorized<float>{scale * vecf_3 + scale_zp_premul}};
fmadd(scale, vecf_0, scale_zp_premul),
fmadd(scale, vecf_1, scale_zp_premul),
fmadd(scale, vecf_2, scale_zp_premul),
fmadd(scale, vecf_3, scale_zp_premul)};
}
template <
@ -1869,7 +1903,7 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
public:
Vectorized() {}
explicit C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {}
C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {}
template <typename U = T, std::enable_if_t<(sizeof(U) == 16), int> = 0>
C10_ALWAYS_INLINE Vectorized(T s1, T s2)
@ -1893,6 +1927,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
template <typename U = T, std::enable_if_t<(sizeof(U) == 8), int> = 0>
C10_ALWAYS_INLINE Vectorized(T s) : Vectorized<T>(s, s, s, s) {}
C10_ALWAYS_INLINE operator vinner_type() const {
return _vec;
}
C10_ALWAYS_INLINE const vinner_type& vec() const {
return _vec;
}
@ -2071,7 +2109,7 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
vi = vi ^ rsign_mask<underline_type>();
vinner_type ret = _vec * vr;
vinner_type vx_swapped = _vec.swapped();
ret = vx_swapped * vi + ret;
ret = fmadd(vx_swapped, vi, ret);
#else
vinner_type ac_bd = _vec * b;
vinner_type d_c = bv.swapped();
@ -2094,7 +2132,7 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
vi = vi ^ isign_mask<underline_type>();
vinner_type ret = _vec * vr;
vinner_type vx_swapped = _vec.swapped();
ret = vx_swapped * vi + ret;
ret = fmadd(vx_swapped, vi, ret);
ret = ret / abs_b;
#else
// Vectorized x86 simulation
@ -2260,6 +2298,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
return mapOrdinary(std::exp);
}
Vectorized<T> exp2() const {
return mapOrdinary(exp2_impl);
}
Vectorized<T> log() const {
return mapOrdinary(std::log);
}
@ -2275,6 +2317,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
return Vectorized<T>{ret._vec * vinner_type(log10e_inv<underline_type>())};
}
Vectorized<T> log1p() const {
return mapOrdinary(std::log1p);
}
Vectorized<T> sgn() const {
return mapOrdinary(at::native::sgn_impl);
}

View File

@ -1805,6 +1805,8 @@ if(NOT INTERN_BUILD_MOBILE)
else(NOT C_HAS_THREAD)
add_compile_options(-DTH_HAVE_THREAD)
endif(NOT C_HAS_THREAD)
find_package(ZVECTOR) # s390x simd support
endif()
#

View File

@ -7,13 +7,14 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
if(bintrue MATCHES "AT_PLATFORM:[ \\t\\n\\r]*([a-zA-Z0-9_]+)[ \\t\\n\\r]*")
if(CMAKE_MATCH_COUNT GREATER 0)
string(TOLOWER ${CMAKE_MATCH_1} platform)
if(${platform} MATCHES "^z(14|15)")
if(${platform} MATCHES "^z(14|15|16)")
message("-- Z ARCH Platform: ${platform}")
list( APPEND Z_ARCH_LIST "${platform}" )
endif()
endif()
endif()
#adds other archs in descending order. as its cached nothing will be checked twice
list( APPEND Z_ARCH_LIST "z16" )
list( APPEND Z_ARCH_LIST "z15" )
list( APPEND Z_ARCH_LIST "z14" )