From e1e6417d4cd0737c29f4fe39a9ca38a74f3fc24e Mon Sep 17 00:00:00 2001 From: Siddhartha Menon Date: Tue, 15 Oct 2024 18:52:44 +0000 Subject: [PATCH] Add SVE implementation of embedding_lookup_idx (#133995) Adds an accelerated version of the embedding_lookup_idx perfkernels. This is done via a python codegen file similarly to `caffe2/perfkernels/hp_emblookup_codegen.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133995 Approved by: https://github.com/malfet, https://github.com/huydhn --- caffe2/perfkernels/CMakeLists.txt | 22 +- caffe2/perfkernels/common.h | 17 +- caffe2/perfkernels/common_sve.cc | 22 + caffe2/perfkernels/embedding_lookup_idx.cc | 22 +- .../perfkernels/embedding_lookup_idx_sve.cc | 6769 +++++++++++++++++ caffe2/perfkernels/sve_emblookup_codegen.py | 408 + cmake/MiscCheck.cmake | 10 + cmake/public/ComputeLibrary.cmake | 4 +- 8 files changed, 7265 insertions(+), 9 deletions(-) create mode 100644 caffe2/perfkernels/common_sve.cc create mode 100644 caffe2/perfkernels/embedding_lookup_idx_sve.cc create mode 100644 caffe2/perfkernels/sve_emblookup_codegen.py diff --git a/caffe2/perfkernels/CMakeLists.txt b/caffe2/perfkernels/CMakeLists.txt index 83e4a5f915d1..1b46916cb927 100644 --- a/caffe2/perfkernels/CMakeLists.txt +++ b/caffe2/perfkernels/CMakeLists.txt @@ -10,9 +10,13 @@ endif() file(GLOB common_srcs *.cc) file(GLOB avx_srcs *_avx.cc) file(GLOB avx2_srcs *_avx2.cc) -# exclude avx and avx2 srcs from common_srcs +file(GLOB avx512_srcs *_avx512.cc) +file(GLOB sve_srcs *_sve.cc) +# exclude avx, avx2, avx512, and sve srcs from common_srcs exclude(common_srcs "${common_srcs}" ${avx_srcs}) exclude(common_srcs "${common_srcs}" ${avx2_srcs}) +exclude(common_srcs "${common_srcs}" ${avx512_srcs}) +exclude(common_srcs "${common_srcs}" ${sve_srcs}) # We will always build common srcs. set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs}) @@ -42,6 +46,22 @@ if(CXX_AVX2_FOUND) "Caffe2_perfkernels_avx2_interface") endif() +# We will only build the SVE perfkernel files if the compiler supports SVE +# extensions. +if(CXX_SVE_FOUND) + add_library(Caffe2_perfkernels_sve STATIC ${sve_srcs}) + target_link_libraries(Caffe2_perfkernels_sve PRIVATE c10) + install(TARGETS Caffe2_perfkernels_sve + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") + + target_compile_options(Caffe2_perfkernels_sve PRIVATE "-march=armv8-a+sve") + + caffe2_interface_library( + Caffe2_perfkernels_sve Caffe2_perfkernels_sve_interface) + list(APPEND + Caffe2_DEPENDENCY_WHOLE_LINK_LIBS "Caffe2_perfkernels_sve_interface") +endif() + # TODO(jiayq): currently, we only implement the very base files for the # perfkernels. This is because to implement avx and avx2 files, we actually # need to set up different compilation units and this is a bit more involving diff --git a/caffe2/perfkernels/common.h b/caffe2/perfkernels/common.h index 6fed9e1d6d06..6e069861b28d 100644 --- a/caffe2/perfkernels/common.h +++ b/caffe2/perfkernels/common.h @@ -61,9 +61,8 @@ In foo.cc, do: // we use cpuinfo to identify cpu support and run the proper functions. #pragma once - -#if defined(CAFFE2_PERF_WITH_AVX512) || defined(CAFFE2_PERF_WITH_AVX2) \ - || defined(CAFFE2_PERF_WITH_AVX) +#if defined(CAFFE2_PERF_WITH_SVE) || defined(CAFFE2_PERF_WITH_AVX512) || \ + defined(CAFFE2_PERF_WITH_AVX2) || defined(CAFFE2_PERF_WITH_AVX) #include #endif @@ -72,6 +71,18 @@ In foo.cc, do: #define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__); +#ifdef CAFFE2_PERF_WITH_SVE +#define SVE_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_arm_sve(); \ + if (isDo) { \ + return funcname##__sve(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_SVE +#define SVE_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_SVE + #ifdef CAFFE2_PERF_WITH_AVX512 #define AVX512_DO(funcname, ...) \ { \ diff --git a/caffe2/perfkernels/common_sve.cc b/caffe2/perfkernels/common_sve.cc new file mode 100644 index 000000000000..03b0bf983c80 --- /dev/null +++ b/caffe2/perfkernels/common_sve.cc @@ -0,0 +1,22 @@ +// This file is here merely to check that the flags are not mixed up: for +// example, if your compiler did not specify -march=armv8-a+sve, you should not +// provide the CAFFE2_PERF_WITH_SVE macro. + +#include "caffe2/core/common.h" + +#ifdef CAFFE2_PERF_WITH_SVE +#ifndef __ARM_FEATURE_SVE +#error( \ + "You found a build system error: CAFFE2_PERF_WITH_SVE is defined" \ + "but __ARM_FEATURE_SVE is not defined (via e.g. -march=armv8-a+sve)."); +#endif // __ARM_FEATURE_SVE +#endif // CAFFE2_PERF_WITH_SVE + +#ifdef __ARM_FEATURE_SVE +#ifndef CAFFE2_PERF_WITH_SVE +#error( \ + "You found a build system error: __SVE__ is defined \ + (via e.g. -march=armv8-a+sve) " \ + "but CAFFE2_PERF_WITH_SVE is not defined."); +#endif // CAFFE2_PERF_WITH_SVE +#endif diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index 5fcf71016aea..7c62d9e883fd 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -88,7 +88,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const IndexType* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -113,6 +113,9 @@ static bool EmbeddingLookupGenericSlowIdx( decltype( \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \ + decltype( \ + EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \ + EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__sve; \ bool \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \ const int64_t block_size, \ @@ -121,7 +124,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const IndexType* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -131,6 +134,19 @@ static bool EmbeddingLookupGenericSlowIdx( } else { \ CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \ } \ + SVE_DO( \ + EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + offsets, \ + weights, \ + scale_bias, \ + normalize_by_lengths, \ + out); \ AVX2_FMA_DO( \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ block_size, \ @@ -166,7 +182,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const IndexType* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ diff --git a/caffe2/perfkernels/embedding_lookup_idx_sve.cc b/caffe2/perfkernels/embedding_lookup_idx_sve.cc new file mode 100644 index 000000000000..873823536b55 --- /dev/null +++ b/caffe2/perfkernels/embedding_lookup_idx_sve.cc @@ -0,0 +1,6769 @@ +//// -------------------------- +//// ATTENTION: +//// THIS CODE IS AUTOGENERATED +//// BY sve_emblookup_codegen.py +//// DO NOT MODIFY!!! +//// -------------------------- + +#include +#include +#include +#include +#include +namespace caffe2 { + +template +static bool EmbeddingLookupIdx_int32_t_float_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + vsum8 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); + vsum9 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); + vsum10 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); + vsum11 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); + vsum12 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); + vsum13 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); + vsum14 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); + vsum15 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); + vsum16 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); + vsum17 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); + vsum18 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); + vsum19 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); + vsum20 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); + vsum21 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); + vsum22 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); + vsum23 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); + vsum24 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); + vsum25 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); + vsum26 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); + vsum27 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); + vsum28 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); + vsum29 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); + vsum30 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); + vsum31 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + vsum8 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); + vsum9 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); + vsum10 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); + vsum11 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); + vsum12 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); + vsum13 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); + vsum14 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); + vsum15 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int32_t_float_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_float_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int32_t_float_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_float_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int64_t_float_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + vsum8 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); + vsum9 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); + vsum10 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); + vsum11 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); + vsum12 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); + vsum13 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); + vsum14 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); + vsum15 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); + vsum16 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); + vsum17 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); + vsum18 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); + vsum19 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); + vsum20 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); + vsum21 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); + vsum22 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); + vsum23 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); + vsum24 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); + vsum25 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); + vsum26 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); + vsum27 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); + vsum28 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); + vsum29 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); + vsum30 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); + vsum31 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + vsum8 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); + vsum9 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); + vsum10 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); + vsum11 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); + vsum12 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); + vsum13 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); + vsum14 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); + vsum15 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int64_t_float_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_float_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int64_t_float_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_float_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int32_t_half_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])))), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])))), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])))), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])))), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])))), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])))), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])))), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])))), + vsum15); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[16 * vLen])))), + vsum16); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[17 * vLen])))), + vsum17); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[18 * vLen])))), + vsum18); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[19 * vLen])))), + vsum19); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[20 * vLen])))), + vsum20); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[21 * vLen])))), + vsum21); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[22 * vLen])))), + vsum22); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[23 * vLen])))), + vsum23); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[24 * vLen])))), + vsum24); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[25 * vLen])))), + vsum25); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[26 * vLen])))), + vsum26); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[27 * vLen])))), + vsum27); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[28 * vLen])))), + vsum28); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[29 * vLen])))), + vsum29); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[30 * vLen])))), + vsum30); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[31 * vLen])))), + vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])))), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])))), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])))), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])))), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])))), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])))), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])))), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])))), + vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svcvt_f32_f16_x( + pg, + svreinterpret_f16_u32(svld1uh_u32( + pg, reinterpret_cast(&ip[k])))), + svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int32_t_half_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_half_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int32_t_half_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_half_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int64_t_half_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])))), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])))), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])))), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])))), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])))), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])))), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])))), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])))), + vsum15); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[16 * vLen])))), + vsum16); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[17 * vLen])))), + vsum17); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[18 * vLen])))), + vsum18); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[19 * vLen])))), + vsum19); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[20 * vLen])))), + vsum20); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[21 * vLen])))), + vsum21); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[22 * vLen])))), + vsum22); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[23 * vLen])))), + vsum23); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[24 * vLen])))), + vsum24); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[25 * vLen])))), + vsum25); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[26 * vLen])))), + vsum26); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[27 * vLen])))), + vsum27); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[28 * vLen])))), + vsum28); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[29 * vLen])))), + vsum29); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[30 * vLen])))), + vsum30); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[31 * vLen])))), + vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])))), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])))), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])))), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])))), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])))), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])))), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])))), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])))), + vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svcvt_f32_f16_x( + pg, + svreinterpret_f16_u32(svld1uh_u32( + pg, reinterpret_cast(&ip[k])))), + svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int64_t_half_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_half_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int64_t_half_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_half_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int32_t_bfloat16_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])), + 16)), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])), + 16)), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])), + 16)), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])), + 16)), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])), + 16)), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])), + 16)), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])), + 16)), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])), + 16)), + vsum15); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[16 * vLen])), + 16)), + vsum16); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[17 * vLen])), + 16)), + vsum17); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[18 * vLen])), + 16)), + vsum18); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[19 * vLen])), + 16)), + vsum19); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[20 * vLen])), + 16)), + vsum20); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[21 * vLen])), + 16)), + vsum21); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[22 * vLen])), + 16)), + vsum22); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[23 * vLen])), + 16)), + vsum23); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[24 * vLen])), + 16)), + vsum24); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[25 * vLen])), + 16)), + vsum25); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[26 * vLen])), + 16)), + vsum26); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[27 * vLen])), + 16)), + vsum27); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[28 * vLen])), + 16)), + vsum28); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[29 * vLen])), + 16)), + vsum29); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[30 * vLen])), + 16)), + vsum30); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[31 * vLen])), + 16)), + vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])), + 16)), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])), + 16)), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])), + 16)), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])), + 16)), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])), + 16)), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])), + 16)), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])), + 16)), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])), + 16)), + vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + pg, + svld1uh_u32( + pg, reinterpret_cast(&ip[k])), + 16)), + svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_bfloat16_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_bfloat16_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int64_t_bfloat16_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])), + 16)), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])), + 16)), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])), + 16)), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])), + 16)), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])), + 16)), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])), + 16)), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])), + 16)), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])), + 16)), + vsum15); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[16 * vLen])), + 16)), + vsum16); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[17 * vLen])), + 16)), + vsum17); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[18 * vLen])), + 16)), + vsum18); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[19 * vLen])), + 16)), + vsum19); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[20 * vLen])), + 16)), + vsum20); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[21 * vLen])), + 16)), + vsum21); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[22 * vLen])), + 16)), + vsum22); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[23 * vLen])), + 16)), + vsum23); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[24 * vLen])), + 16)), + vsum24); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[25 * vLen])), + 16)), + vsum25); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[26 * vLen])), + 16)), + vsum26); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[27 * vLen])), + 16)), + vsum27); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[28 * vLen])), + 16)), + vsum28); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[29 * vLen])), + 16)), + vsum29); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[30 * vLen])), + 16)), + vsum30); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[31 * vLen])), + 16)), + vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])), + 16)), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])), + 16)), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])), + 16)), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])), + 16)), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])), + 16)), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])), + 16)), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])), + 16)), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])), + 16)), + vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + pg, + svld1uh_u32( + pg, reinterpret_cast(&ip[k])), + 16)), + svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_bfloat16_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_bfloat16_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int32_t_uint8_t_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), + svadd_f32_x(svAll, vsum8, vbio)); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), + svadd_f32_x(svAll, vsum9, vbio)); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), + svadd_f32_x(svAll, vsum10, vbio)); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), + svadd_f32_x(svAll, vsum11, vbio)); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), + svadd_f32_x(svAll, vsum12, vbio)); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), + svadd_f32_x(svAll, vsum13, vbio)); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), + svadd_f32_x(svAll, vsum14, vbio)); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), + svadd_f32_x(svAll, vsum15, vbio)); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), + svadd_f32_x(svAll, vsum16, vbio)); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), + svadd_f32_x(svAll, vsum17, vbio)); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), + svadd_f32_x(svAll, vsum18, vbio)); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), + svadd_f32_x(svAll, vsum19, vbio)); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), + svadd_f32_x(svAll, vsum20, vbio)); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), + svadd_f32_x(svAll, vsum21, vbio)); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), + svadd_f32_x(svAll, vsum22, vbio)); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), + svadd_f32_x(svAll, vsum23, vbio)); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), + svadd_f32_x(svAll, vsum24, vbio)); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), + svadd_f32_x(svAll, vsum25, vbio)); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), + svadd_f32_x(svAll, vsum26, vbio)); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), + svadd_f32_x(svAll, vsum27, vbio)); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), + svadd_f32_x(svAll, vsum28, vbio)); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), + svadd_f32_x(svAll, vsum29, vbio)); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), + svadd_f32_x(svAll, vsum30, vbio)); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), + svadd_f32_x(svAll, vsum31, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), + svadd_f32_x(svAll, vsum8, vbio)); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), + svadd_f32_x(svAll, vsum9, vbio)); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), + svadd_f32_x(svAll, vsum10, vbio)); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), + svadd_f32_x(svAll, vsum11, vbio)); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), + svadd_f32_x(svAll, vsum12, vbio)); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), + svadd_f32_x(svAll, vsum13, vbio)); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), + svadd_f32_x(svAll, vsum14, vbio)); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), + svadd_f32_x(svAll, vsum15, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + // unimplemented + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), + svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_uint8_t_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_uint8_t_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int64_t_uint8_t_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), + svadd_f32_x(svAll, vsum8, vbio)); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), + svadd_f32_x(svAll, vsum9, vbio)); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), + svadd_f32_x(svAll, vsum10, vbio)); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), + svadd_f32_x(svAll, vsum11, vbio)); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), + svadd_f32_x(svAll, vsum12, vbio)); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), + svadd_f32_x(svAll, vsum13, vbio)); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), + svadd_f32_x(svAll, vsum14, vbio)); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), + svadd_f32_x(svAll, vsum15, vbio)); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), + svadd_f32_x(svAll, vsum16, vbio)); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), + svadd_f32_x(svAll, vsum17, vbio)); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), + svadd_f32_x(svAll, vsum18, vbio)); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), + svadd_f32_x(svAll, vsum19, vbio)); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), + svadd_f32_x(svAll, vsum20, vbio)); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), + svadd_f32_x(svAll, vsum21, vbio)); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), + svadd_f32_x(svAll, vsum22, vbio)); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), + svadd_f32_x(svAll, vsum23, vbio)); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), + svadd_f32_x(svAll, vsum24, vbio)); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), + svadd_f32_x(svAll, vsum25, vbio)); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), + svadd_f32_x(svAll, vsum26, vbio)); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), + svadd_f32_x(svAll, vsum27, vbio)); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), + svadd_f32_x(svAll, vsum28, vbio)); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), + svadd_f32_x(svAll, vsum29, vbio)); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), + svadd_f32_x(svAll, vsum30, vbio)); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), + svadd_f32_x(svAll, vsum31, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), + svadd_f32_x(svAll, vsum8, vbio)); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), + svadd_f32_x(svAll, vsum9, vbio)); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), + svadd_f32_x(svAll, vsum10, vbio)); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), + svadd_f32_x(svAll, vsum11, vbio)); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), + svadd_f32_x(svAll, vsum12, vbio)); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), + svadd_f32_x(svAll, vsum13, vbio)); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), + svadd_f32_x(svAll, vsum14, vbio)); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), + svadd_f32_x(svAll, vsum15, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + // unimplemented + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), + svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_uint8_t_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_uint8_t_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +} // namespace caffe2 diff --git a/caffe2/perfkernels/sve_emblookup_codegen.py b/caffe2/perfkernels/sve_emblookup_codegen.py new file mode 100644 index 000000000000..02f010ccc250 --- /dev/null +++ b/caffe2/perfkernels/sve_emblookup_codegen.py @@ -0,0 +1,408 @@ +# mypy: allow-untyped-defs +import argparse +import sys + +# Unroll loops when block_size is a multiple of vector length. +def unroll(num_unrolls, IndexType, InType, OutType, use_weights): + def compute(regid, InType, use_weights): + code = [] + + if InType == "float": + code.append( + f" vsum{regid} =\n" + " svmad_f32_x(" + f"svAll, vwgt, svld1_f32(svAll, &ip[{regid} * vLen])," + f" vsum{regid});" + ) + elif InType == "at::Half": + code.append( + f" vsum{regid} = svmad_f32_x(\n" + " svAll,\n" + " vwgt,\n" + " svcvt_f32_f16_x(\n" + " svAll,\n" + " svreinterpret_f16_u32(svld1uh_u32(\n" + " svAll, reinterpret_cast(" + f"&ip[{regid} * vLen])))),\n" # noqa + f" vsum{regid});" + ) + elif InType == "at::BFloat16": + code.append( + f" vsum{regid} = svmad_f32_x(\n" + " svAll,\n" + " vwgt,\n" + " svreinterpret_f32_u32(svlsl_n_u32_x(\n" + " svAll,\n" + " svld1uh_u32(\n" + " svAll, reinterpret_cast(" + f"&ip[{regid} * vLen])),\n" + " 16)),\n" # noqa + f" vsum{regid});" + ) + elif InType == "uint8_t": + code.append( + f" vsum{regid} = svmad_f32_x(\n" + " svAll,\n" + " vwgt,\n" + " svcvt_f32_u32_x(svAll," + f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" # noqa + f" svadd_f32_x(svAll, vsum{regid}, vbio));" + ) + else: + raise ValueError(f"Unknown datatype \"{InType}\"") + + return code + + code = [] + code.append(f" // unrolling {num_unrolls} times") + + code.append(" for (int64_t i = 0; i < output_size; ++i) {") + + code.append(" " + OutType + "* const op = &out[i * block_size];") + code.append( + " if (pos != offsets[i] - offsets[0]) {\n" + + " return false;\n" + + " }" + ) + + # Initialise vector sum registers + for i in range(num_unrolls): + code.append(f" svfloat32_t vsum{i} = svdup_n_f32(0);") + + # inner loop + code.append("""\ + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1];""") + code.append( + " for (" + + "int64_t" + + " j = start_offset; j < end_offset; ++j) {" # noqa + ) + + code.append(" const auto idx = indices[pos];") + code.append( + " if (idx < 0 || idx >= data_size) {\n" + + " return false;\n" + + " }" + ) + + if InType == "uint8_t": + code.append(" " + OutType + " wgt = 1.f;") + code.append(" " + OutType + " bio{};") + code.append(" if (weights) {") + code.append( + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + ) + code.append(" }") + code.append(" if (scale_bias) {") + code.append(" bio = wgt * scale_bias[2 * idx + 1];") + code.append(" wgt = wgt * scale_bias[2 * idx];") + code.append(" }") + code.append(" svfloat32_t vbio = svdup_n_f32(bio);") + else: + code.append(" " + OutType + " wgt = 1.f;") + code.append(" if (weights) {") + code.append( + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + ) + code.append(" }") + + code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") + code.append(f" const {InType}* const ip = &input[idx * block_size];") + code.append(" // weight * input + out") + + for i in range(num_unrolls): + code.extend(compute(i, InType, use_weights)) + + code.append(" ++pos;") + code.append(" }") + + code.append(" // Normalisation") + code.append(" const int64_t length = end_offset - start_offset;") + code.append(" if (normalize_by_lengths && length != 0) {") + code.append(" const float len_inv = 1.0f / length;") + code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);") + + for i in range(num_unrolls): + code.append(f" svst1_f32(svAll, &op[{i} * vLen]," + + f" svmul_f32_x(svAll, vsum{i}, vlen_inv));") + + code.append(" } else {") + # inv of length + for i in range(num_unrolls): + code.append(f" svst1_f32(svAll, &op[{i} * vLen], vsum{i});") + + code.append(" }") + code.append(" }") + return code + + +# Handle the case where block_size is not a multiple of vector length. +def generic(IndexType, InType, OutType, use_weights): + def compute(InType, use_weights): + code = [] + if InType == "float": + code.append( + " svst1_f32(\n" + " pg,\n" + " &op[k],\n" + " svmad_f32_x(\n" + " pg, vwgt, svld1_f32(pg, &ip[k])," + " svld1_f32(pg, &op[k])));" + ) + elif InType == "at::Half": + code.append( + " svst1_f32(\n" + " pg,\n" + " &op[k],\n" + " svmad_f32_x(\n" + " pg,\n" + " vwgt,\n" + " svcvt_f32_f16_x(\n" + " pg,\n" + " svreinterpret_f16_u32(svld1uh_u32(\n" + " pg," + " reinterpret_cast(&ip[k])))),\n" + " svld1_f32(pg, &op[k])));" + ) + elif InType == "at::BFloat16": + code.append( + " svst1_f32(\n" + " pg,\n" + " &op[k],\n" + " svmad_f32_x(\n" + " pg,\n" + " vwgt,\n" + " svreinterpret_f32_u32(svlsl_n_u32_x(\n" + " pg,\n" + " svld1uh_u32(\n" + " pg," + " reinterpret_cast(&ip[k])),\n" + " 16)),\n" + " svld1_f32(pg, &op[k])));" + ) + elif InType == "uint8_t": + code.append( + " svst1_f32(\n" + " pg,\n" + " &op[k],\n" + " svmad_f32_x(\n" + " pg,\n" + " vwgt,\n" + " svcvt_f32_u32_x(pg," + " svld1ub_u32(pg, &ip[k])),\n" # noqa + " svadd_f32_x(pg," + " svld1_f32(pg, &op[k]), vbio)));" + ) + else: + raise ValueError(f"Unknown datatype \"{InType}\"") + + return code + + code = [] + + code.append( + " for (int64_t i = 0; i < output_size; ++i) {" + ) + + code.append(" " + OutType + "* const op = &out[i * block_size];") + + # initialize to 0 + code.append(" memset(op, 0, sizeof(float) * block_size);") + + # inner loop + code.append( + " if (pos != offsets[i] - offsets[0]) {\n" + + " return false;\n" + + " }" + ) + code.append( + " int64_t start_offset = offsets[i];\n" + + " int64_t end_offset = offsets[i + 1];" + ) + code.append( + " for (" + + "int64_t" + + " j = start_offset; j < end_offset; ++j) {" # noqa + ) + + code.append(" const auto idx = indices[pos];") + code.append( + " if (idx < 0 || idx >= data_size) {\n" + + " return false;\n" + + " }" + ) + + if InType == "uint8_t": + code.append(" // unimplemented") + code.append(" " + OutType + " wgt = 1.f;") + code.append(" " + OutType + " bio{};") + code.append(" if (weights) {") + code.append( + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + ) + code.append(" }") + code.append(" if (scale_bias) {") + code.append(" bio = wgt * scale_bias[2 * idx + 1];") + code.append(" wgt = wgt * scale_bias[2 * idx];") + code.append(" }") + code.append(" svfloat32_t vbio = svdup_n_f32(bio);") + else: + code.append(" " + OutType + " wgt = 1.f;") + code.append(" if (weights) {") + code.append( + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + ) + code.append(" }") + + code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") + code.append(f" const {InType}* ip = &input[idx * block_size];") + + # compute and store main loop + code.append(" svbool_t pg;") + code.append(" for (int64_t k = 0;") + code.append(" svptest_first(svAll, pg = svwhilelt_b32_s64(" + + "k, block_size));") + code.append(" k += vLen) {") + code.extend(compute(InType, use_weights)) + code.append(" }\n") + code.append(" ++pos;") + code.append(" }") + + code.append(" const int64_t length = end_offset - start_offset;\n") + code.append(" if (normalize_by_lengths && length != 0) {") + code.append(" const float len_inv = 1.0f / length;") + code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);") + code.append(" svbool_t pg;") + code.append(" for (int64_t j = 0;\n" + " svptest_first(svAll, pg = svwhilelt_b32_s64(" + "j, block_size));") + code.append(" j += vLen) {") + code.append( + " svst1_f32(\n" + " pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));" + ) + code.append(" }") + code.append(" }") + code.append(" }") + return code + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-f", "--filename", help="file name") + opts = parser.parse_args() + if opts.filename: + filename = opts.filename + else: + filename = "embedding_lookup_idx_sve.cc" + + options = [ + ["int32_t", "int32_t", "float", "float", "float", "float"], + ["int64_t", "int64_t", "float", "float", "float", "float"], + ["int32_t", "int32_t", "half", "at::Half", "float", "float"], + ["int64_t", "int64_t", "half", "at::Half", "float", "float"], + ["int32_t", "int32_t", "bfloat16", "at::BFloat16", "float", "float"], + ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"], + ["int32_t", "int32_t", "uint8_t", "uint8_t", "float", "float"], + ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"], + ] + + code = [] + # includes + code.append("//// --------------------------") + code.append("//// ATTENTION:") + code.append("//// THIS CODE IS AUTOGENERATED") + code.append(f"//// BY {' '.join(sys.argv)}") + code.append("//// DO NOT MODIFY!!!") + code.append("//// --------------------------\n") + + code.append("#include ") + code.append("#include ") + code.append("#include ") + code.append("#include ") + code.append("#include ") + + code.append("namespace caffe2 {\n") + for o in options: + [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o + + code.append("template ") + fn_base = f"EmbeddingLookupIdx_{IndexTypeName}_{InTypeName}_{OutTypeName}" + + suffix = "__sve" + fn = "static bool " + fn_base + suffix + code.append(fn + "(") + + args = [] + args.append(" const int64_t block_size,") + args.append(" const int64_t output_size,") + args.append(" const int64_t index_size,") + args.append(" const int64_t data_size,") + args.append(" const " + InType + "* input,") + args.append(" const " + IndexType + "* indices,") + args.append(" const " + IndexType + "* offsets,") + args.append(" const float* weights,") + args.append(" const float* scale_bias,") + args.append(" bool normalize_by_lengths,") + args.append(" " + OutType + "* out) {") + code += args + + code.append(" const svbool_t svAll = svptrue_b32();") + code.append(" const auto vLen = static_cast(svcntw());") + code.append(" int64_t pos = 0;") + + code.append(" if (block_size == 32 * vLen) {") + code += unroll(32, IndexType, InType, OutType, True) + code.append(" } else if (block_size == 16 * vLen) {") + code += unroll(16, IndexType, InType, OutType, True) + code.append(" } else if (block_size == 8 * vLen) {") + code += unroll(8, IndexType, InType, OutType, True) + code.append(" } else if (block_size == 4 * vLen) {") + code += unroll(4, IndexType, InType, OutType, True) + code.append(" } else if (block_size == 2 * vLen) {") + code += unroll(2, IndexType, InType, OutType, True) + code.append(" } else {") + code.append(" // generic code:") + code += generic(IndexType, InType, OutType, True) + code.append(" }") + code.append(" return pos == index_size;") + + code.append("}") + + for is_weight_positional in ["false", "true"]: + code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(") + code += args + + # Resolve the Lint warnings: Limit of 80 characters in one line. + extra_space = "\n " + ret_string = " return " + fn_base + suffix \ + + "<" + is_weight_positional + ">(" + if len(ret_string) <= 80: + code.append(ret_string) + else: + code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(") + + code.append(" block_size,") + code.append(" output_size,") + code.append(" index_size,") + code.append(" data_size,") + code.append(" input,") + code.append(" indices,") + code.append(" offsets,") + code.append(" weights,") + code.append(" scale_bias,") + code.append(" normalize_by_lengths,") + code.append(" out);") + code.append("}") + + code.append("") + + code.append("} // namespace caffe2") + + with open(filename, "w") as fout: + fout.write("\n".join(code) + "\n") + + print("Created " + filename) + +if __name__ == "__main__": + main() diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 10fa810b8fdf..74fc1487333a 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -101,6 +101,16 @@ endif() # Also, we will turn off deprecated-declarations # due to protobuf. +# ---[ Check if the compiler has SVE support. +find_package(ARM) # checks SVE +if(CXX_SVE_FOUND) + message(STATUS "Compiler supports SVE extension. Will build perfkernels.") + # Also see CMakeLists.txt under caffe2/perfkernels. + add_compile_definitions(CAFFE2_PERF_WITH_SVE=1) +else() + message(STATUS "Compiler does not support SVE extension. Will not build perfkernels.") +endif() + if(IOS AND (${IOS_ARCH} MATCHES "armv7*")) add_definitions("-mfpu=neon-fp16") add_definitions("-arch" ${IOS_ARCH}) diff --git a/cmake/public/ComputeLibrary.cmake b/cmake/public/ComputeLibrary.cmake index d0b3b56ff531..e18527ce65b0 100644 --- a/cmake/public/ComputeLibrary.cmake +++ b/cmake/public/ComputeLibrary.cmake @@ -21,10 +21,10 @@ if("${ACL_VERSION_FILE}" STREQUAL "") message(WARNING "Build may fail: Could not determine ACL version (minimum required is ${ACL_MINIMUM_VERSION})") else() file(READ ${ACL_VERSION_FILE} ACL_VERSION_STRING) - string(REGEX MATCH "v([0-9]+\\.[0-9]+)" ACL_VERSION ${ACL_VERSION_STRING}) + string(REGEX MATCH "v([0-9]+\\.[0-9]+)" ACL_VERSION "${ACL_VERSION_STRING}") set(ACL_VERSION "${CMAKE_MATCH_1}") - if(${ACL_VERSION} VERSION_EQUAL "0.0") + if("${ACL_VERSION}" VERSION_EQUAL "0.0") # Unreleased ACL versions come with version string "v0.0-unreleased", and may not be compatible with oneDNN. # It is recommended to use the latest release of ACL. message(WARNING "Build may fail: Using unreleased ACL version (minimum required is ${ACL_MINIMUM_VERSION})")