mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reintroduce s390x SIMD support (#99057)
Reintroduce s390x SIMD support Use vectorized FMA to fix test precision failures Pull Request resolved: https://github.com/pytorch/pytorch/pull/99057 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
7cb581d42f
commit
c130b8a716
@ -16,7 +16,8 @@
|
||||
namespace at {
|
||||
namespace vec {
|
||||
|
||||
namespace {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_zarch_implemented() {
|
||||
@ -290,6 +291,8 @@ constexpr int64_t allbitset(int16_t x) {
|
||||
return (onex << x) - onex;
|
||||
}
|
||||
|
||||
namespace { /* unnamed namespace */
|
||||
|
||||
ZSimdVect<float> vec_mergee(ZSimdVect<float> x, ZSimdVect<float> y) {
|
||||
constexpr ZSimdVectBinary<uint8_t> mergee_mask{
|
||||
0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27};
|
||||
@ -310,6 +313,8 @@ ZSimdVect<double> vec_mergeo(ZSimdVect<double> x, ZSimdVect<double> y) {
|
||||
return vec_mergel(x, y);
|
||||
}
|
||||
|
||||
} /* unnamed namespace */
|
||||
|
||||
//
|
||||
template <typename T>
|
||||
constexpr auto GetBpermZeroMask() {
|
||||
@ -696,8 +701,27 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
|
||||
return set_inner<1, size()>(a, b, count);
|
||||
}
|
||||
|
||||
const T& operator[](int idx) const = delete;
|
||||
T& operator[](int idx) = delete;
|
||||
const ElementType& operator[](int idx) const {
|
||||
if (idx < size() / 2)
|
||||
{
|
||||
return _vec0[idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
return _vec1[idx - (size() / 2)];
|
||||
}
|
||||
}
|
||||
|
||||
ElementType& operator[](int idx) {
|
||||
if (idx < size() / 2)
|
||||
{
|
||||
return _vec0[idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
return _vec1[idx - (size() / 2)];
|
||||
}
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator+(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec0 + other._vec0, _vec1 + other._vec1};
|
||||
@ -1247,6 +1271,8 @@ ZSimdVect<int> vec_flt_int(const ZSimdVect<float> x) {
|
||||
#define vec_flt_int vec_signed
|
||||
#endif
|
||||
|
||||
namespace { /* unnamed namespace */
|
||||
|
||||
Vectorized<float> convert_to_float(const Vectorized<int32_t>& x) {
|
||||
return {vec_int_flt(x.vec0()), vec_int_flt(x.vec1())};
|
||||
}
|
||||
@ -1263,6 +1289,8 @@ Vectorized<int64_t> convert_to_int(const Vectorized<double>& x) {
|
||||
return {vec_signed(x.vec0()), vec_signed(x.vec1())};
|
||||
}
|
||||
|
||||
} /* unnamed namespace */
|
||||
|
||||
template <typename T, typename V>
|
||||
Vectorized<V> cast_zvector(const Vectorized<T>& x) {
|
||||
using cast_type = typename Vectorized<V>::vtype;
|
||||
@ -1275,7 +1303,8 @@ Vectorized<float> C10_ALWAYS_INLINE fmadd(
|
||||
const Vectorized<float>& b,
|
||||
const Vectorized<float>& c) {
|
||||
return Vectorized<float>{
|
||||
a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
|
||||
__builtin_s390_vfmasb(a.vec0(), b.vec0(), c.vec0()),
|
||||
__builtin_s390_vfmasb(a.vec1(), b.vec1(), c.vec1())};
|
||||
}
|
||||
template <>
|
||||
Vectorized<double> C10_ALWAYS_INLINE fmadd(
|
||||
@ -1283,7 +1312,8 @@ Vectorized<double> C10_ALWAYS_INLINE fmadd(
|
||||
const Vectorized<double>& b,
|
||||
const Vectorized<double>& c) {
|
||||
return Vectorized<double>{
|
||||
a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
|
||||
__builtin_s390_vfmadb(a.vec0(), b.vec0(), c.vec0()),
|
||||
__builtin_s390_vfmadb(a.vec1(), b.vec1(), c.vec1())};
|
||||
}
|
||||
template <>
|
||||
Vectorized<int16_t> C10_ALWAYS_INLINE fmadd(
|
||||
@ -1408,6 +1438,8 @@ struct pack_type<int32_t> {
|
||||
using type = int16_t;
|
||||
};
|
||||
|
||||
namespace { /* unnamed namespace */
|
||||
|
||||
template <typename T, typename V = typename unpack_type<T>::type>
|
||||
std::pair<Vectorized<V>, Vectorized<V>> unpack(const Vectorized<T>& x) {
|
||||
auto vec0 = vec_unpackh(x.vec0());
|
||||
@ -1451,6 +1483,8 @@ Vectorized<uint8_t> pack(
|
||||
return Vectorized<uint8_t>{vec0, vec1};
|
||||
}
|
||||
|
||||
} /* unnamed namespace */
|
||||
|
||||
//////////////////////////////////QUANT///////////////////////////////////////////
|
||||
template <typename T>
|
||||
struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
|
||||
@ -1522,7 +1556,7 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
|
||||
Vectorized<float> zero_point,
|
||||
Vectorized<float> scale_zp_premul) const {
|
||||
auto float_val = convert_to_float(_vec);
|
||||
return {(scale * float_val + scale_zp_premul)};
|
||||
return {fmadd(scale, float_val, scale_zp_premul)};
|
||||
}
|
||||
|
||||
template <
|
||||
@ -1593,10 +1627,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
|
||||
auto vecf_2 = convert_to_float(ret32_1.first);
|
||||
auto vecf_3 = convert_to_float(ret32_1.second);
|
||||
return {
|
||||
Vectorized<float>{scale * vecf_0 + scale_zp_premul},
|
||||
Vectorized<float>{scale * vecf_1 + scale_zp_premul},
|
||||
Vectorized<float>{scale * vecf_2 + scale_zp_premul},
|
||||
Vectorized<float>{scale * vecf_3 + scale_zp_premul}};
|
||||
fmadd(scale, vecf_0, scale_zp_premul),
|
||||
fmadd(scale, vecf_1, scale_zp_premul),
|
||||
fmadd(scale, vecf_2, scale_zp_premul),
|
||||
fmadd(scale, vecf_3, scale_zp_premul)};
|
||||
}
|
||||
|
||||
template <
|
||||
@ -1869,7 +1903,7 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
public:
|
||||
Vectorized() {}
|
||||
|
||||
explicit C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {}
|
||||
C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {}
|
||||
|
||||
template <typename U = T, std::enable_if_t<(sizeof(U) == 16), int> = 0>
|
||||
C10_ALWAYS_INLINE Vectorized(T s1, T s2)
|
||||
@ -1893,6 +1927,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
template <typename U = T, std::enable_if_t<(sizeof(U) == 8), int> = 0>
|
||||
C10_ALWAYS_INLINE Vectorized(T s) : Vectorized<T>(s, s, s, s) {}
|
||||
|
||||
C10_ALWAYS_INLINE operator vinner_type() const {
|
||||
return _vec;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE const vinner_type& vec() const {
|
||||
return _vec;
|
||||
}
|
||||
@ -2071,7 +2109,7 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
vi = vi ^ rsign_mask<underline_type>();
|
||||
vinner_type ret = _vec * vr;
|
||||
vinner_type vx_swapped = _vec.swapped();
|
||||
ret = vx_swapped * vi + ret;
|
||||
ret = fmadd(vx_swapped, vi, ret);
|
||||
#else
|
||||
vinner_type ac_bd = _vec * b;
|
||||
vinner_type d_c = bv.swapped();
|
||||
@ -2094,7 +2132,7 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
vi = vi ^ isign_mask<underline_type>();
|
||||
vinner_type ret = _vec * vr;
|
||||
vinner_type vx_swapped = _vec.swapped();
|
||||
ret = vx_swapped * vi + ret;
|
||||
ret = fmadd(vx_swapped, vi, ret);
|
||||
ret = ret / abs_b;
|
||||
#else
|
||||
// Vectorized x86 simulation
|
||||
@ -2260,6 +2298,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
return mapOrdinary(std::exp);
|
||||
}
|
||||
|
||||
Vectorized<T> exp2() const {
|
||||
return mapOrdinary(exp2_impl);
|
||||
}
|
||||
|
||||
Vectorized<T> log() const {
|
||||
return mapOrdinary(std::log);
|
||||
}
|
||||
@ -2275,6 +2317,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
return Vectorized<T>{ret._vec * vinner_type(log10e_inv<underline_type>())};
|
||||
}
|
||||
|
||||
Vectorized<T> log1p() const {
|
||||
return mapOrdinary(std::log1p);
|
||||
}
|
||||
|
||||
Vectorized<T> sgn() const {
|
||||
return mapOrdinary(at::native::sgn_impl);
|
||||
}
|
||||
|
@ -1805,6 +1805,8 @@ if(NOT INTERN_BUILD_MOBILE)
|
||||
else(NOT C_HAS_THREAD)
|
||||
add_compile_options(-DTH_HAVE_THREAD)
|
||||
endif(NOT C_HAS_THREAD)
|
||||
|
||||
find_package(ZVECTOR) # s390x simd support
|
||||
endif()
|
||||
|
||||
#
|
||||
|
@ -7,13 +7,14 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
if(bintrue MATCHES "AT_PLATFORM:[ \\t\\n\\r]*([a-zA-Z0-9_]+)[ \\t\\n\\r]*")
|
||||
if(CMAKE_MATCH_COUNT GREATER 0)
|
||||
string(TOLOWER ${CMAKE_MATCH_1} platform)
|
||||
if(${platform} MATCHES "^z(14|15)")
|
||||
if(${platform} MATCHES "^z(14|15|16)")
|
||||
message("-- Z ARCH Platform: ${platform}")
|
||||
list( APPEND Z_ARCH_LIST "${platform}" )
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
#adds other archs in descending order. as its cached nothing will be checked twice
|
||||
list( APPEND Z_ARCH_LIST "z16" )
|
||||
list( APPEND Z_ARCH_LIST "z15" )
|
||||
list( APPEND Z_ARCH_LIST "z14" )
|
||||
|
||||
|
Reference in New Issue
Block a user