mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Change loop unrolling strategy. Previously, the script only unrolls the inner loop over block_size when block size is multiple of vector length. This version instead unrolls the outer loop which reduces the number of load/store for accumulation into the output array and improves performance for cases when block size is not multiple of vector length. Benchmarking script: ```python # SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com> # SPDX-License-Identifier: BSD-3-Clause import torch import torch.nn as nn import numpy as np import time import sys np.random.seed(0) torch.manual_seed(0) num_embeddings = 400000 embedding_dim = int(sys.argv[1]) multi_hot = 100 batch_size = 400 nrun = 1000 class SimpleEmbeddingBagModel(nn.Module): def __init__(self, num_embeddings, embedding_dim): super(SimpleEmbeddingBagModel, self).__init__() weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)).to(torch.float16) # Defining the EmbeddingBag layer self.embedding_bag = torch.nn.EmbeddingBag(num_embeddings, embedding_dim, _weight=weights, mode='sum', include_last_offset=True, dtype=torch.float32) def forward(self, input, offsets): # Forward pass through the EmbeddingBag layer result32 = self.embedding_bag(input, offsets, per_sample_weights=None) return result32 # Instantiate the model model = SimpleEmbeddingBagModel(num_embeddings=num_embeddings, embedding_dim=embedding_dim) model.eval() # Example input input_tensor = torch.randint(0, num_embeddings, (batch_size * multi_hot,), dtype=torch.long) offsets = torch.tensor(range(0, batch_size * multi_hot + 1, multi_hot)) with torch.no_grad(): # warm up output32 = model(input_tensor, offsets) ti = time.time_ns() for i in range(nrun): _ = model(input_tensor, offsets) tf = time.time_ns() print("{:3d} {:.3E}".format(embedding_dim, (tf-ti)/nrun/1.e6)) ``` Speedup on NEOVERSEV1 with 1 thread  Pull Request resolved: https://github.com/pytorch/pytorch/pull/150176 Approved by: https://github.com/digantdesai, https://github.com/malfet
4586 lines
183 KiB
C++
4586 lines
183 KiB
C++
//// --------------------------
|
|
//// ATTENTION:
|
|
//// THIS CODE IS AUTOGENERATED
|
|
//// BY sve_emblookup_codegen.py
|
|
//// DO NOT MODIFY!!!
|
|
//// --------------------------
|
|
|
|
#include <arm_sve.h>
|
|
#include <c10/util/BFloat16.h>
|
|
#include <c10/util/Half.h>
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
namespace caffe2 {
|
|
|
|
template <bool IS_WEIGHT_POSITIONAL>
|
|
static bool EmbeddingLookupIdx_int32_t_float_float__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const float* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
const svbool_t svAll = svptrue_b32();
|
|
const auto vLen = static_cast<int64_t>(svcntw());
|
|
int64_t pos = 0;
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
float* const op = &out[i * block_size];
|
|
memset(op, 0, sizeof(float) * block_size);
|
|
if (pos != offsets[i] - offsets[0]) {
|
|
return false;
|
|
}
|
|
int64_t start_offset = offsets[i];
|
|
int64_t end_offset = offsets[i + 1];
|
|
int64_t j = start_offset;
|
|
// unrolling 16 times
|
|
while (j + 15 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
const auto idx8 = indices[pos + 8];
|
|
const auto idx9 = indices[pos + 9];
|
|
const auto idx10 = indices[pos + 10];
|
|
const auto idx11 = indices[pos + 11];
|
|
const auto idx12 = indices[pos + 12];
|
|
const auto idx13 = indices[pos + 13];
|
|
const auto idx14 = indices[pos + 14];
|
|
const auto idx15 = indices[pos + 15];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx8 < 0 || idx8 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx9 < 0 || idx9 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx10 < 0 || idx10 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx11 < 0 || idx11 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx12 < 0 || idx12 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx13 < 0 || idx13 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx14 < 0 || idx14 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx15 < 0 || idx15 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float wgt8 = 1.f;
|
|
float wgt9 = 1.f;
|
|
float wgt10 = 1.f;
|
|
float wgt11 = 1.f;
|
|
float wgt12 = 1.f;
|
|
float wgt13 = 1.f;
|
|
float wgt14 = 1.f;
|
|
float wgt15 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
|
|
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
|
|
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
|
|
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
|
|
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
|
|
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
|
|
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
|
|
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
const float* const ip1 = &input[idx1 * block_size];
|
|
const float* const ip2 = &input[idx2 * block_size];
|
|
const float* const ip3 = &input[idx3 * block_size];
|
|
const float* const ip4 = &input[idx4 * block_size];
|
|
const float* const ip5 = &input[idx5 * block_size];
|
|
const float* const ip6 = &input[idx6 * block_size];
|
|
const float* const ip7 = &input[idx7 * block_size];
|
|
const float* const ip8 = &input[idx8 * block_size];
|
|
const float* const ip9 = &input[idx9 * block_size];
|
|
const float* const ip10 = &input[idx10 * block_size];
|
|
const float* const ip11 = &input[idx11 * block_size];
|
|
const float* const ip12 = &input[idx12 * block_size];
|
|
const float* const ip13 = &input[idx13 * block_size];
|
|
const float* const ip14 = &input[idx14 * block_size];
|
|
const float* const ip15 = &input[idx15 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 16;
|
|
pos += 16;
|
|
}
|
|
// unrolling 8 times
|
|
while (j + 7 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
const float* const ip1 = &input[idx1 * block_size];
|
|
const float* const ip2 = &input[idx2 * block_size];
|
|
const float* const ip3 = &input[idx3 * block_size];
|
|
const float* const ip4 = &input[idx4 * block_size];
|
|
const float* const ip5 = &input[idx5 * block_size];
|
|
const float* const ip6 = &input[idx6 * block_size];
|
|
const float* const ip7 = &input[idx7 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 8;
|
|
pos += 8;
|
|
}
|
|
// unrolling 4 times
|
|
while (j + 3 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
const float* const ip1 = &input[idx1 * block_size];
|
|
const float* const ip2 = &input[idx2 * block_size];
|
|
const float* const ip3 = &input[idx3 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 4;
|
|
pos += 4;
|
|
}
|
|
// unrolling 2 times
|
|
while (j + 1 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
const float* const ip1 = &input[idx1 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 2;
|
|
pos += 2;
|
|
}
|
|
// tail loop
|
|
if (j < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
pos ++;
|
|
}
|
|
const int64_t length = end_offset - start_offset;
|
|
|
|
if (normalize_by_lengths && length != 0) {
|
|
const float len_inv = 1.0f / length;
|
|
svbool_t pg;
|
|
int64_t j = 0;
|
|
while (j + vLen - 1 < block_size) {
|
|
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
|
|
j += vLen;
|
|
}
|
|
if (j < block_size) {
|
|
pg = svwhilelt_b32_s64(j, block_size);
|
|
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
|
|
}
|
|
}
|
|
}
|
|
return pos == index_size;
|
|
}
|
|
bool EmbeddingLookupIdx_int32_t_float_float_false__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const float* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int32_t_float_float__sve<false>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
bool EmbeddingLookupIdx_int32_t_float_float_true__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const float* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int32_t_float_float__sve<true>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
|
|
template <bool IS_WEIGHT_POSITIONAL>
|
|
static bool EmbeddingLookupIdx_int64_t_float_float__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const float* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
const svbool_t svAll = svptrue_b32();
|
|
const auto vLen = static_cast<int64_t>(svcntw());
|
|
int64_t pos = 0;
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
float* const op = &out[i * block_size];
|
|
memset(op, 0, sizeof(float) * block_size);
|
|
if (pos != offsets[i] - offsets[0]) {
|
|
return false;
|
|
}
|
|
int64_t start_offset = offsets[i];
|
|
int64_t end_offset = offsets[i + 1];
|
|
int64_t j = start_offset;
|
|
// unrolling 16 times
|
|
while (j + 15 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
const auto idx8 = indices[pos + 8];
|
|
const auto idx9 = indices[pos + 9];
|
|
const auto idx10 = indices[pos + 10];
|
|
const auto idx11 = indices[pos + 11];
|
|
const auto idx12 = indices[pos + 12];
|
|
const auto idx13 = indices[pos + 13];
|
|
const auto idx14 = indices[pos + 14];
|
|
const auto idx15 = indices[pos + 15];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx8 < 0 || idx8 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx9 < 0 || idx9 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx10 < 0 || idx10 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx11 < 0 || idx11 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx12 < 0 || idx12 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx13 < 0 || idx13 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx14 < 0 || idx14 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx15 < 0 || idx15 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float wgt8 = 1.f;
|
|
float wgt9 = 1.f;
|
|
float wgt10 = 1.f;
|
|
float wgt11 = 1.f;
|
|
float wgt12 = 1.f;
|
|
float wgt13 = 1.f;
|
|
float wgt14 = 1.f;
|
|
float wgt15 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
|
|
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
|
|
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
|
|
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
|
|
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
|
|
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
|
|
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
|
|
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
const float* const ip1 = &input[idx1 * block_size];
|
|
const float* const ip2 = &input[idx2 * block_size];
|
|
const float* const ip3 = &input[idx3 * block_size];
|
|
const float* const ip4 = &input[idx4 * block_size];
|
|
const float* const ip5 = &input[idx5 * block_size];
|
|
const float* const ip6 = &input[idx6 * block_size];
|
|
const float* const ip7 = &input[idx7 * block_size];
|
|
const float* const ip8 = &input[idx8 * block_size];
|
|
const float* const ip9 = &input[idx9 * block_size];
|
|
const float* const ip10 = &input[idx10 * block_size];
|
|
const float* const ip11 = &input[idx11 * block_size];
|
|
const float* const ip12 = &input[idx12 * block_size];
|
|
const float* const ip13 = &input[idx13 * block_size];
|
|
const float* const ip14 = &input[idx14 * block_size];
|
|
const float* const ip15 = &input[idx15 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 16;
|
|
pos += 16;
|
|
}
|
|
// unrolling 8 times
|
|
while (j + 7 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
const float* const ip1 = &input[idx1 * block_size];
|
|
const float* const ip2 = &input[idx2 * block_size];
|
|
const float* const ip3 = &input[idx3 * block_size];
|
|
const float* const ip4 = &input[idx4 * block_size];
|
|
const float* const ip5 = &input[idx5 * block_size];
|
|
const float* const ip6 = &input[idx6 * block_size];
|
|
const float* const ip7 = &input[idx7 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 8;
|
|
pos += 8;
|
|
}
|
|
// unrolling 4 times
|
|
while (j + 3 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
const float* const ip1 = &input[idx1 * block_size];
|
|
const float* const ip2 = &input[idx2 * block_size];
|
|
const float* const ip3 = &input[idx3 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 4;
|
|
pos += 4;
|
|
}
|
|
// unrolling 2 times
|
|
while (j + 1 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
const float* const ip1 = &input[idx1 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 2;
|
|
pos += 2;
|
|
}
|
|
// tail loop
|
|
if (j < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
}
|
|
const float* const ip0 = &input[idx0 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
pos ++;
|
|
}
|
|
const int64_t length = end_offset - start_offset;
|
|
|
|
if (normalize_by_lengths && length != 0) {
|
|
const float len_inv = 1.0f / length;
|
|
svbool_t pg;
|
|
int64_t j = 0;
|
|
while (j + vLen - 1 < block_size) {
|
|
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
|
|
j += vLen;
|
|
}
|
|
if (j < block_size) {
|
|
pg = svwhilelt_b32_s64(j, block_size);
|
|
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
|
|
}
|
|
}
|
|
}
|
|
return pos == index_size;
|
|
}
|
|
bool EmbeddingLookupIdx_int64_t_float_float_false__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const float* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int64_t_float_float__sve<false>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
bool EmbeddingLookupIdx_int64_t_float_float_true__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const float* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int64_t_float_float__sve<true>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
|
|
template <bool IS_WEIGHT_POSITIONAL>
|
|
static bool EmbeddingLookupIdx_int32_t_half_float__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::Half* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
const svbool_t svAll = svptrue_b32();
|
|
const auto vLen = static_cast<int64_t>(svcntw());
|
|
int64_t pos = 0;
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
float* const op = &out[i * block_size];
|
|
memset(op, 0, sizeof(float) * block_size);
|
|
if (pos != offsets[i] - offsets[0]) {
|
|
return false;
|
|
}
|
|
int64_t start_offset = offsets[i];
|
|
int64_t end_offset = offsets[i + 1];
|
|
int64_t j = start_offset;
|
|
// unrolling 16 times
|
|
while (j + 15 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
const auto idx8 = indices[pos + 8];
|
|
const auto idx9 = indices[pos + 9];
|
|
const auto idx10 = indices[pos + 10];
|
|
const auto idx11 = indices[pos + 11];
|
|
const auto idx12 = indices[pos + 12];
|
|
const auto idx13 = indices[pos + 13];
|
|
const auto idx14 = indices[pos + 14];
|
|
const auto idx15 = indices[pos + 15];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx8 < 0 || idx8 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx9 < 0 || idx9 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx10 < 0 || idx10 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx11 < 0 || idx11 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx12 < 0 || idx12 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx13 < 0 || idx13 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx14 < 0 || idx14 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx15 < 0 || idx15 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float wgt8 = 1.f;
|
|
float wgt9 = 1.f;
|
|
float wgt10 = 1.f;
|
|
float wgt11 = 1.f;
|
|
float wgt12 = 1.f;
|
|
float wgt13 = 1.f;
|
|
float wgt14 = 1.f;
|
|
float wgt15 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
|
|
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
|
|
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
|
|
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
|
|
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
|
|
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
|
|
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
|
|
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
const at::Half* const ip1 = &input[idx1 * block_size];
|
|
const at::Half* const ip2 = &input[idx2 * block_size];
|
|
const at::Half* const ip3 = &input[idx3 * block_size];
|
|
const at::Half* const ip4 = &input[idx4 * block_size];
|
|
const at::Half* const ip5 = &input[idx5 * block_size];
|
|
const at::Half* const ip6 = &input[idx6 * block_size];
|
|
const at::Half* const ip7 = &input[idx7 * block_size];
|
|
const at::Half* const ip8 = &input[idx8 * block_size];
|
|
const at::Half* const ip9 = &input[idx9 * block_size];
|
|
const at::Half* const ip10 = &input[idx10 * block_size];
|
|
const at::Half* const ip11 = &input[idx11 * block_size];
|
|
const at::Half* const ip12 = &input[idx12 * block_size];
|
|
const at::Half* const ip13 = &input[idx13 * block_size];
|
|
const at::Half* const ip14 = &input[idx14 * block_size];
|
|
const at::Half* const ip15 = &input[idx15 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
|
|
auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
|
|
auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
|
|
auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
|
|
auto input8 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k]))));
|
|
auto input9 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k]))));
|
|
auto input10 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k]))));
|
|
auto input11 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k]))));
|
|
auto input12 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k]))));
|
|
auto input13 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k]))));
|
|
auto input14 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k]))));
|
|
auto input15 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
output = svmla_x(svAll, output, input8, wgt8);
|
|
output = svmla_x(svAll, output, input9, wgt9);
|
|
output = svmla_x(svAll, output, input10, wgt10);
|
|
output = svmla_x(svAll, output, input11, wgt11);
|
|
output = svmla_x(svAll, output, input12, wgt12);
|
|
output = svmla_x(svAll, output, input13, wgt13);
|
|
output = svmla_x(svAll, output, input14, wgt14);
|
|
output = svmla_x(svAll, output, input15, wgt15);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
|
|
auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
|
|
auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
|
|
auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
|
|
auto input8 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k]))));
|
|
auto input9 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k]))));
|
|
auto input10 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k]))));
|
|
auto input11 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k]))));
|
|
auto input12 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k]))));
|
|
auto input13 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k]))));
|
|
auto input14 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k]))));
|
|
auto input15 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
output = svmla_x(pg, output, input8, wgt8);
|
|
output = svmla_x(pg, output, input9, wgt9);
|
|
output = svmla_x(pg, output, input10, wgt10);
|
|
output = svmla_x(pg, output, input11, wgt11);
|
|
output = svmla_x(pg, output, input12, wgt12);
|
|
output = svmla_x(pg, output, input13, wgt13);
|
|
output = svmla_x(pg, output, input14, wgt14);
|
|
output = svmla_x(pg, output, input15, wgt15);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 16;
|
|
pos += 16;
|
|
}
|
|
// unrolling 8 times
|
|
while (j + 7 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
const at::Half* const ip1 = &input[idx1 * block_size];
|
|
const at::Half* const ip2 = &input[idx2 * block_size];
|
|
const at::Half* const ip3 = &input[idx3 * block_size];
|
|
const at::Half* const ip4 = &input[idx4 * block_size];
|
|
const at::Half* const ip5 = &input[idx5 * block_size];
|
|
const at::Half* const ip6 = &input[idx6 * block_size];
|
|
const at::Half* const ip7 = &input[idx7 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
|
|
auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
|
|
auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
|
|
auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
|
|
auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
|
|
auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
|
|
auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 8;
|
|
pos += 8;
|
|
}
|
|
// unrolling 4 times
|
|
while (j + 3 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
const at::Half* const ip1 = &input[idx1 * block_size];
|
|
const at::Half* const ip2 = &input[idx2 * block_size];
|
|
const at::Half* const ip3 = &input[idx3 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 4;
|
|
pos += 4;
|
|
}
|
|
// unrolling 2 times
|
|
while (j + 1 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
const at::Half* const ip1 = &input[idx1 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 2;
|
|
pos += 2;
|
|
}
|
|
// tail loop
|
|
if (j < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
pos ++;
|
|
}
|
|
const int64_t length = end_offset - start_offset;
|
|
|
|
if (normalize_by_lengths && length != 0) {
|
|
const float len_inv = 1.0f / length;
|
|
svbool_t pg;
|
|
int64_t j = 0;
|
|
while (j + vLen - 1 < block_size) {
|
|
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
|
|
j += vLen;
|
|
}
|
|
if (j < block_size) {
|
|
pg = svwhilelt_b32_s64(j, block_size);
|
|
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
|
|
}
|
|
}
|
|
}
|
|
return pos == index_size;
|
|
}
|
|
bool EmbeddingLookupIdx_int32_t_half_float_false__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::Half* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int32_t_half_float__sve<false>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
bool EmbeddingLookupIdx_int32_t_half_float_true__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::Half* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int32_t_half_float__sve<true>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
|
|
template <bool IS_WEIGHT_POSITIONAL>
|
|
static bool EmbeddingLookupIdx_int64_t_half_float__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::Half* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
const svbool_t svAll = svptrue_b32();
|
|
const auto vLen = static_cast<int64_t>(svcntw());
|
|
int64_t pos = 0;
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
float* const op = &out[i * block_size];
|
|
memset(op, 0, sizeof(float) * block_size);
|
|
if (pos != offsets[i] - offsets[0]) {
|
|
return false;
|
|
}
|
|
int64_t start_offset = offsets[i];
|
|
int64_t end_offset = offsets[i + 1];
|
|
int64_t j = start_offset;
|
|
// unrolling 16 times
|
|
while (j + 15 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
const auto idx8 = indices[pos + 8];
|
|
const auto idx9 = indices[pos + 9];
|
|
const auto idx10 = indices[pos + 10];
|
|
const auto idx11 = indices[pos + 11];
|
|
const auto idx12 = indices[pos + 12];
|
|
const auto idx13 = indices[pos + 13];
|
|
const auto idx14 = indices[pos + 14];
|
|
const auto idx15 = indices[pos + 15];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx8 < 0 || idx8 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx9 < 0 || idx9 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx10 < 0 || idx10 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx11 < 0 || idx11 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx12 < 0 || idx12 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx13 < 0 || idx13 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx14 < 0 || idx14 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx15 < 0 || idx15 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float wgt8 = 1.f;
|
|
float wgt9 = 1.f;
|
|
float wgt10 = 1.f;
|
|
float wgt11 = 1.f;
|
|
float wgt12 = 1.f;
|
|
float wgt13 = 1.f;
|
|
float wgt14 = 1.f;
|
|
float wgt15 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
|
|
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
|
|
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
|
|
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
|
|
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
|
|
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
|
|
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
|
|
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
const at::Half* const ip1 = &input[idx1 * block_size];
|
|
const at::Half* const ip2 = &input[idx2 * block_size];
|
|
const at::Half* const ip3 = &input[idx3 * block_size];
|
|
const at::Half* const ip4 = &input[idx4 * block_size];
|
|
const at::Half* const ip5 = &input[idx5 * block_size];
|
|
const at::Half* const ip6 = &input[idx6 * block_size];
|
|
const at::Half* const ip7 = &input[idx7 * block_size];
|
|
const at::Half* const ip8 = &input[idx8 * block_size];
|
|
const at::Half* const ip9 = &input[idx9 * block_size];
|
|
const at::Half* const ip10 = &input[idx10 * block_size];
|
|
const at::Half* const ip11 = &input[idx11 * block_size];
|
|
const at::Half* const ip12 = &input[idx12 * block_size];
|
|
const at::Half* const ip13 = &input[idx13 * block_size];
|
|
const at::Half* const ip14 = &input[idx14 * block_size];
|
|
const at::Half* const ip15 = &input[idx15 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
|
|
auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
|
|
auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
|
|
auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
|
|
auto input8 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k]))));
|
|
auto input9 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k]))));
|
|
auto input10 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k]))));
|
|
auto input11 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k]))));
|
|
auto input12 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k]))));
|
|
auto input13 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k]))));
|
|
auto input14 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k]))));
|
|
auto input15 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
output = svmla_x(svAll, output, input8, wgt8);
|
|
output = svmla_x(svAll, output, input9, wgt9);
|
|
output = svmla_x(svAll, output, input10, wgt10);
|
|
output = svmla_x(svAll, output, input11, wgt11);
|
|
output = svmla_x(svAll, output, input12, wgt12);
|
|
output = svmla_x(svAll, output, input13, wgt13);
|
|
output = svmla_x(svAll, output, input14, wgt14);
|
|
output = svmla_x(svAll, output, input15, wgt15);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
|
|
auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
|
|
auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
|
|
auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
|
|
auto input8 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k]))));
|
|
auto input9 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k]))));
|
|
auto input10 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k]))));
|
|
auto input11 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k]))));
|
|
auto input12 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k]))));
|
|
auto input13 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k]))));
|
|
auto input14 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k]))));
|
|
auto input15 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
output = svmla_x(pg, output, input8, wgt8);
|
|
output = svmla_x(pg, output, input9, wgt9);
|
|
output = svmla_x(pg, output, input10, wgt10);
|
|
output = svmla_x(pg, output, input11, wgt11);
|
|
output = svmla_x(pg, output, input12, wgt12);
|
|
output = svmla_x(pg, output, input13, wgt13);
|
|
output = svmla_x(pg, output, input14, wgt14);
|
|
output = svmla_x(pg, output, input15, wgt15);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 16;
|
|
pos += 16;
|
|
}
|
|
// unrolling 8 times
|
|
while (j + 7 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
const at::Half* const ip1 = &input[idx1 * block_size];
|
|
const at::Half* const ip2 = &input[idx2 * block_size];
|
|
const at::Half* const ip3 = &input[idx3 * block_size];
|
|
const at::Half* const ip4 = &input[idx4 * block_size];
|
|
const at::Half* const ip5 = &input[idx5 * block_size];
|
|
const at::Half* const ip6 = &input[idx6 * block_size];
|
|
const at::Half* const ip7 = &input[idx7 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
|
|
auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
|
|
auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
|
|
auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
|
|
auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
|
|
auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
|
|
auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 8;
|
|
pos += 8;
|
|
}
|
|
// unrolling 4 times
|
|
while (j + 3 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
const at::Half* const ip1 = &input[idx1 * block_size];
|
|
const at::Half* const ip2 = &input[idx2 * block_size];
|
|
const at::Half* const ip3 = &input[idx3 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
|
|
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 4;
|
|
pos += 4;
|
|
}
|
|
// unrolling 2 times
|
|
while (j + 1 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
const at::Half* const ip1 = &input[idx1 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 2;
|
|
pos += 2;
|
|
}
|
|
// tail loop
|
|
if (j < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
}
|
|
const at::Half* const ip0 = &input[idx0 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
pos ++;
|
|
}
|
|
const int64_t length = end_offset - start_offset;
|
|
|
|
if (normalize_by_lengths && length != 0) {
|
|
const float len_inv = 1.0f / length;
|
|
svbool_t pg;
|
|
int64_t j = 0;
|
|
while (j + vLen - 1 < block_size) {
|
|
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
|
|
j += vLen;
|
|
}
|
|
if (j < block_size) {
|
|
pg = svwhilelt_b32_s64(j, block_size);
|
|
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
|
|
}
|
|
}
|
|
}
|
|
return pos == index_size;
|
|
}
|
|
bool EmbeddingLookupIdx_int64_t_half_float_false__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::Half* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int64_t_half_float__sve<false>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
bool EmbeddingLookupIdx_int64_t_half_float_true__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::Half* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int64_t_half_float__sve<true>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
|
|
template <bool IS_WEIGHT_POSITIONAL>
|
|
static bool EmbeddingLookupIdx_int32_t_bfloat16_float__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::BFloat16* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
const svbool_t svAll = svptrue_b32();
|
|
const auto vLen = static_cast<int64_t>(svcntw());
|
|
int64_t pos = 0;
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
float* const op = &out[i * block_size];
|
|
memset(op, 0, sizeof(float) * block_size);
|
|
if (pos != offsets[i] - offsets[0]) {
|
|
return false;
|
|
}
|
|
int64_t start_offset = offsets[i];
|
|
int64_t end_offset = offsets[i + 1];
|
|
int64_t j = start_offset;
|
|
// unrolling 16 times
|
|
while (j + 15 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
const auto idx8 = indices[pos + 8];
|
|
const auto idx9 = indices[pos + 9];
|
|
const auto idx10 = indices[pos + 10];
|
|
const auto idx11 = indices[pos + 11];
|
|
const auto idx12 = indices[pos + 12];
|
|
const auto idx13 = indices[pos + 13];
|
|
const auto idx14 = indices[pos + 14];
|
|
const auto idx15 = indices[pos + 15];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx8 < 0 || idx8 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx9 < 0 || idx9 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx10 < 0 || idx10 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx11 < 0 || idx11 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx12 < 0 || idx12 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx13 < 0 || idx13 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx14 < 0 || idx14 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx15 < 0 || idx15 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float wgt8 = 1.f;
|
|
float wgt9 = 1.f;
|
|
float wgt10 = 1.f;
|
|
float wgt11 = 1.f;
|
|
float wgt12 = 1.f;
|
|
float wgt13 = 1.f;
|
|
float wgt14 = 1.f;
|
|
float wgt15 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
|
|
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
|
|
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
|
|
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
|
|
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
|
|
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
|
|
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
|
|
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
const at::BFloat16* const ip1 = &input[idx1 * block_size];
|
|
const at::BFloat16* const ip2 = &input[idx2 * block_size];
|
|
const at::BFloat16* const ip3 = &input[idx3 * block_size];
|
|
const at::BFloat16* const ip4 = &input[idx4 * block_size];
|
|
const at::BFloat16* const ip5 = &input[idx5 * block_size];
|
|
const at::BFloat16* const ip6 = &input[idx6 * block_size];
|
|
const at::BFloat16* const ip7 = &input[idx7 * block_size];
|
|
const at::BFloat16* const ip8 = &input[idx8 * block_size];
|
|
const at::BFloat16* const ip9 = &input[idx9 * block_size];
|
|
const at::BFloat16* const ip10 = &input[idx10 * block_size];
|
|
const at::BFloat16* const ip11 = &input[idx11 * block_size];
|
|
const at::BFloat16* const ip12 = &input[idx12 * block_size];
|
|
const at::BFloat16* const ip13 = &input[idx13 * block_size];
|
|
const at::BFloat16* const ip14 = &input[idx14 * block_size];
|
|
const at::BFloat16* const ip15 = &input[idx15 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
auto input4 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
|
|
auto input5 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
|
|
auto input6 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
|
|
auto input7 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
|
|
auto input8 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
|
|
auto input9 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
|
|
auto input10 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
|
|
auto input11 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
|
|
auto input12 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
|
|
auto input13 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
|
|
auto input14 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
|
|
auto input15 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
output = svmla_x(svAll, output, input8, wgt8);
|
|
output = svmla_x(svAll, output, input9, wgt9);
|
|
output = svmla_x(svAll, output, input10, wgt10);
|
|
output = svmla_x(svAll, output, input11, wgt11);
|
|
output = svmla_x(svAll, output, input12, wgt12);
|
|
output = svmla_x(svAll, output, input13, wgt13);
|
|
output = svmla_x(svAll, output, input14, wgt14);
|
|
output = svmla_x(svAll, output, input15, wgt15);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
auto input4 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
|
|
auto input5 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
|
|
auto input6 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
|
|
auto input7 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
|
|
auto input8 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
|
|
auto input9 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
|
|
auto input10 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
|
|
auto input11 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
|
|
auto input12 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
|
|
auto input13 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
|
|
auto input14 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
|
|
auto input15 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
output = svmla_x(pg, output, input8, wgt8);
|
|
output = svmla_x(pg, output, input9, wgt9);
|
|
output = svmla_x(pg, output, input10, wgt10);
|
|
output = svmla_x(pg, output, input11, wgt11);
|
|
output = svmla_x(pg, output, input12, wgt12);
|
|
output = svmla_x(pg, output, input13, wgt13);
|
|
output = svmla_x(pg, output, input14, wgt14);
|
|
output = svmla_x(pg, output, input15, wgt15);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 16;
|
|
pos += 16;
|
|
}
|
|
// unrolling 8 times
|
|
while (j + 7 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
const at::BFloat16* const ip1 = &input[idx1 * block_size];
|
|
const at::BFloat16* const ip2 = &input[idx2 * block_size];
|
|
const at::BFloat16* const ip3 = &input[idx3 * block_size];
|
|
const at::BFloat16* const ip4 = &input[idx4 * block_size];
|
|
const at::BFloat16* const ip5 = &input[idx5 * block_size];
|
|
const at::BFloat16* const ip6 = &input[idx6 * block_size];
|
|
const at::BFloat16* const ip7 = &input[idx7 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
auto input4 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
|
|
auto input5 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
|
|
auto input6 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
|
|
auto input7 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
auto input4 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
|
|
auto input5 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
|
|
auto input6 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
|
|
auto input7 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 8;
|
|
pos += 8;
|
|
}
|
|
// unrolling 4 times
|
|
while (j + 3 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
const at::BFloat16* const ip1 = &input[idx1 * block_size];
|
|
const at::BFloat16* const ip2 = &input[idx2 * block_size];
|
|
const at::BFloat16* const ip3 = &input[idx3 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 4;
|
|
pos += 4;
|
|
}
|
|
// unrolling 2 times
|
|
while (j + 1 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
const at::BFloat16* const ip1 = &input[idx1 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 2;
|
|
pos += 2;
|
|
}
|
|
// tail loop
|
|
if (j < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
pos ++;
|
|
}
|
|
const int64_t length = end_offset - start_offset;
|
|
|
|
if (normalize_by_lengths && length != 0) {
|
|
const float len_inv = 1.0f / length;
|
|
svbool_t pg;
|
|
int64_t j = 0;
|
|
while (j + vLen - 1 < block_size) {
|
|
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
|
|
j += vLen;
|
|
}
|
|
if (j < block_size) {
|
|
pg = svwhilelt_b32_s64(j, block_size);
|
|
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
|
|
}
|
|
}
|
|
}
|
|
return pos == index_size;
|
|
}
|
|
bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::BFloat16* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int32_t_bfloat16_float__sve<false>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::BFloat16* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int32_t_bfloat16_float__sve<true>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
|
|
template <bool IS_WEIGHT_POSITIONAL>
|
|
static bool EmbeddingLookupIdx_int64_t_bfloat16_float__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::BFloat16* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
const svbool_t svAll = svptrue_b32();
|
|
const auto vLen = static_cast<int64_t>(svcntw());
|
|
int64_t pos = 0;
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
float* const op = &out[i * block_size];
|
|
memset(op, 0, sizeof(float) * block_size);
|
|
if (pos != offsets[i] - offsets[0]) {
|
|
return false;
|
|
}
|
|
int64_t start_offset = offsets[i];
|
|
int64_t end_offset = offsets[i + 1];
|
|
int64_t j = start_offset;
|
|
// unrolling 16 times
|
|
while (j + 15 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
const auto idx8 = indices[pos + 8];
|
|
const auto idx9 = indices[pos + 9];
|
|
const auto idx10 = indices[pos + 10];
|
|
const auto idx11 = indices[pos + 11];
|
|
const auto idx12 = indices[pos + 12];
|
|
const auto idx13 = indices[pos + 13];
|
|
const auto idx14 = indices[pos + 14];
|
|
const auto idx15 = indices[pos + 15];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx8 < 0 || idx8 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx9 < 0 || idx9 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx10 < 0 || idx10 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx11 < 0 || idx11 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx12 < 0 || idx12 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx13 < 0 || idx13 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx14 < 0 || idx14 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx15 < 0 || idx15 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float wgt8 = 1.f;
|
|
float wgt9 = 1.f;
|
|
float wgt10 = 1.f;
|
|
float wgt11 = 1.f;
|
|
float wgt12 = 1.f;
|
|
float wgt13 = 1.f;
|
|
float wgt14 = 1.f;
|
|
float wgt15 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
|
|
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
|
|
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
|
|
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
|
|
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
|
|
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
|
|
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
|
|
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
const at::BFloat16* const ip1 = &input[idx1 * block_size];
|
|
const at::BFloat16* const ip2 = &input[idx2 * block_size];
|
|
const at::BFloat16* const ip3 = &input[idx3 * block_size];
|
|
const at::BFloat16* const ip4 = &input[idx4 * block_size];
|
|
const at::BFloat16* const ip5 = &input[idx5 * block_size];
|
|
const at::BFloat16* const ip6 = &input[idx6 * block_size];
|
|
const at::BFloat16* const ip7 = &input[idx7 * block_size];
|
|
const at::BFloat16* const ip8 = &input[idx8 * block_size];
|
|
const at::BFloat16* const ip9 = &input[idx9 * block_size];
|
|
const at::BFloat16* const ip10 = &input[idx10 * block_size];
|
|
const at::BFloat16* const ip11 = &input[idx11 * block_size];
|
|
const at::BFloat16* const ip12 = &input[idx12 * block_size];
|
|
const at::BFloat16* const ip13 = &input[idx13 * block_size];
|
|
const at::BFloat16* const ip14 = &input[idx14 * block_size];
|
|
const at::BFloat16* const ip15 = &input[idx15 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
auto input4 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
|
|
auto input5 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
|
|
auto input6 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
|
|
auto input7 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
|
|
auto input8 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
|
|
auto input9 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
|
|
auto input10 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
|
|
auto input11 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
|
|
auto input12 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
|
|
auto input13 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
|
|
auto input14 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
|
|
auto input15 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
output = svmla_x(svAll, output, input8, wgt8);
|
|
output = svmla_x(svAll, output, input9, wgt9);
|
|
output = svmla_x(svAll, output, input10, wgt10);
|
|
output = svmla_x(svAll, output, input11, wgt11);
|
|
output = svmla_x(svAll, output, input12, wgt12);
|
|
output = svmla_x(svAll, output, input13, wgt13);
|
|
output = svmla_x(svAll, output, input14, wgt14);
|
|
output = svmla_x(svAll, output, input15, wgt15);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
auto input4 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
|
|
auto input5 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
|
|
auto input6 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
|
|
auto input7 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
|
|
auto input8 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
|
|
auto input9 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
|
|
auto input10 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
|
|
auto input11 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
|
|
auto input12 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
|
|
auto input13 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
|
|
auto input14 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
|
|
auto input15 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
output = svmla_x(pg, output, input8, wgt8);
|
|
output = svmla_x(pg, output, input9, wgt9);
|
|
output = svmla_x(pg, output, input10, wgt10);
|
|
output = svmla_x(pg, output, input11, wgt11);
|
|
output = svmla_x(pg, output, input12, wgt12);
|
|
output = svmla_x(pg, output, input13, wgt13);
|
|
output = svmla_x(pg, output, input14, wgt14);
|
|
output = svmla_x(pg, output, input15, wgt15);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 16;
|
|
pos += 16;
|
|
}
|
|
// unrolling 8 times
|
|
while (j + 7 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
const at::BFloat16* const ip1 = &input[idx1 * block_size];
|
|
const at::BFloat16* const ip2 = &input[idx2 * block_size];
|
|
const at::BFloat16* const ip3 = &input[idx3 * block_size];
|
|
const at::BFloat16* const ip4 = &input[idx4 * block_size];
|
|
const at::BFloat16* const ip5 = &input[idx5 * block_size];
|
|
const at::BFloat16* const ip6 = &input[idx6 * block_size];
|
|
const at::BFloat16* const ip7 = &input[idx7 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
auto input4 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
|
|
auto input5 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
|
|
auto input6 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
|
|
auto input7 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
auto input4 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
|
|
auto input5 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
|
|
auto input6 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
|
|
auto input7 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 8;
|
|
pos += 8;
|
|
}
|
|
// unrolling 4 times
|
|
while (j + 3 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
const at::BFloat16* const ip1 = &input[idx1 * block_size];
|
|
const at::BFloat16* const ip2 = &input[idx2 * block_size];
|
|
const at::BFloat16* const ip3 = &input[idx3 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
auto input2 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
|
|
auto input3 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 4;
|
|
pos += 4;
|
|
}
|
|
// unrolling 2 times
|
|
while (j + 1 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
const at::BFloat16* const ip1 = &input[idx1 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
auto input1 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 2;
|
|
pos += 2;
|
|
}
|
|
// tail loop
|
|
if (j < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
}
|
|
const at::BFloat16* const ip0 = &input[idx0 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(svAll,
|
|
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
auto input0 = svreinterpret_f32(svlsl_x(pg,
|
|
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
pos ++;
|
|
}
|
|
const int64_t length = end_offset - start_offset;
|
|
|
|
if (normalize_by_lengths && length != 0) {
|
|
const float len_inv = 1.0f / length;
|
|
svbool_t pg;
|
|
int64_t j = 0;
|
|
while (j + vLen - 1 < block_size) {
|
|
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
|
|
j += vLen;
|
|
}
|
|
if (j < block_size) {
|
|
pg = svwhilelt_b32_s64(j, block_size);
|
|
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
|
|
}
|
|
}
|
|
}
|
|
return pos == index_size;
|
|
}
|
|
bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::BFloat16* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int64_t_bfloat16_float__sve<false>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const at::BFloat16* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int64_t_bfloat16_float__sve<true>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
|
|
template <bool IS_WEIGHT_POSITIONAL>
|
|
static bool EmbeddingLookupIdx_int32_t_uint8_t_float__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
const svbool_t svAll = svptrue_b32();
|
|
const auto vLen = static_cast<int64_t>(svcntw());
|
|
int64_t pos = 0;
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
float* const op = &out[i * block_size];
|
|
memset(op, 0, sizeof(float) * block_size);
|
|
if (pos != offsets[i] - offsets[0]) {
|
|
return false;
|
|
}
|
|
int64_t start_offset = offsets[i];
|
|
int64_t end_offset = offsets[i + 1];
|
|
int64_t j = start_offset;
|
|
// unrolling 16 times
|
|
while (j + 15 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
const auto idx8 = indices[pos + 8];
|
|
const auto idx9 = indices[pos + 9];
|
|
const auto idx10 = indices[pos + 10];
|
|
const auto idx11 = indices[pos + 11];
|
|
const auto idx12 = indices[pos + 12];
|
|
const auto idx13 = indices[pos + 13];
|
|
const auto idx14 = indices[pos + 14];
|
|
const auto idx15 = indices[pos + 15];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx8 < 0 || idx8 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx9 < 0 || idx9 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx10 < 0 || idx10 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx11 < 0 || idx11 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx12 < 0 || idx12 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx13 < 0 || idx13 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx14 < 0 || idx14 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx15 < 0 || idx15 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float wgt8 = 1.f;
|
|
float wgt9 = 1.f;
|
|
float wgt10 = 1.f;
|
|
float wgt11 = 1.f;
|
|
float wgt12 = 1.f;
|
|
float wgt13 = 1.f;
|
|
float wgt14 = 1.f;
|
|
float wgt15 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
|
|
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
|
|
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
|
|
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
|
|
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
|
|
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
|
|
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
|
|
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
bio += wgt1 * scale_bias[2 * idx1 + 1];
|
|
wgt1 = wgt1 * scale_bias[2 * idx1];
|
|
bio += wgt2 * scale_bias[2 * idx2 + 1];
|
|
wgt2 = wgt2 * scale_bias[2 * idx2];
|
|
bio += wgt3 * scale_bias[2 * idx3 + 1];
|
|
wgt3 = wgt3 * scale_bias[2 * idx3];
|
|
bio += wgt4 * scale_bias[2 * idx4 + 1];
|
|
wgt4 = wgt4 * scale_bias[2 * idx4];
|
|
bio += wgt5 * scale_bias[2 * idx5 + 1];
|
|
wgt5 = wgt5 * scale_bias[2 * idx5];
|
|
bio += wgt6 * scale_bias[2 * idx6 + 1];
|
|
wgt6 = wgt6 * scale_bias[2 * idx6];
|
|
bio += wgt7 * scale_bias[2 * idx7 + 1];
|
|
wgt7 = wgt7 * scale_bias[2 * idx7];
|
|
bio += wgt8 * scale_bias[2 * idx8 + 1];
|
|
wgt8 = wgt8 * scale_bias[2 * idx8];
|
|
bio += wgt9 * scale_bias[2 * idx9 + 1];
|
|
wgt9 = wgt9 * scale_bias[2 * idx9];
|
|
bio += wgt10 * scale_bias[2 * idx10 + 1];
|
|
wgt10 = wgt10 * scale_bias[2 * idx10];
|
|
bio += wgt11 * scale_bias[2 * idx11 + 1];
|
|
wgt11 = wgt11 * scale_bias[2 * idx11];
|
|
bio += wgt12 * scale_bias[2 * idx12 + 1];
|
|
wgt12 = wgt12 * scale_bias[2 * idx12];
|
|
bio += wgt13 * scale_bias[2 * idx13 + 1];
|
|
wgt13 = wgt13 * scale_bias[2 * idx13];
|
|
bio += wgt14 * scale_bias[2 * idx14 + 1];
|
|
wgt14 = wgt14 * scale_bias[2 * idx14];
|
|
bio += wgt15 * scale_bias[2 * idx15 + 1];
|
|
wgt15 = wgt15 * scale_bias[2 * idx15];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
const uint8_t* const ip1 = &input[idx1 * block_size];
|
|
const uint8_t* const ip2 = &input[idx2 * block_size];
|
|
const uint8_t* const ip3 = &input[idx3 * block_size];
|
|
const uint8_t* const ip4 = &input[idx4 * block_size];
|
|
const uint8_t* const ip5 = &input[idx5 * block_size];
|
|
const uint8_t* const ip6 = &input[idx6 * block_size];
|
|
const uint8_t* const ip7 = &input[idx7 * block_size];
|
|
const uint8_t* const ip8 = &input[idx8 * block_size];
|
|
const uint8_t* const ip9 = &input[idx9 * block_size];
|
|
const uint8_t* const ip10 = &input[idx10 * block_size];
|
|
const uint8_t* const ip11 = &input[idx11 * block_size];
|
|
const uint8_t* const ip12 = &input[idx12 * block_size];
|
|
const uint8_t* const ip13 = &input[idx13 * block_size];
|
|
const uint8_t* const ip14 = &input[idx14 * block_size];
|
|
const uint8_t* const ip15 = &input[idx15 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
|
|
auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
|
|
auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
|
|
auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
|
|
auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
|
|
auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k]));
|
|
auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k]));
|
|
auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k]));
|
|
auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k]));
|
|
auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k]));
|
|
auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k]));
|
|
auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k]));
|
|
auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
output = svmla_x(svAll, output, input8, wgt8);
|
|
output = svmla_x(svAll, output, input9, wgt9);
|
|
output = svmla_x(svAll, output, input10, wgt10);
|
|
output = svmla_x(svAll, output, input11, wgt11);
|
|
output = svmla_x(svAll, output, input12, wgt12);
|
|
output = svmla_x(svAll, output, input13, wgt13);
|
|
output = svmla_x(svAll, output, input14, wgt14);
|
|
output = svmla_x(svAll, output, input15, wgt15);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
|
|
auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
|
|
auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
|
|
auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
|
|
auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
|
|
auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k]));
|
|
auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k]));
|
|
auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k]));
|
|
auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k]));
|
|
auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k]));
|
|
auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k]));
|
|
auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k]));
|
|
auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
output = svmla_x(pg, output, input8, wgt8);
|
|
output = svmla_x(pg, output, input9, wgt9);
|
|
output = svmla_x(pg, output, input10, wgt10);
|
|
output = svmla_x(pg, output, input11, wgt11);
|
|
output = svmla_x(pg, output, input12, wgt12);
|
|
output = svmla_x(pg, output, input13, wgt13);
|
|
output = svmla_x(pg, output, input14, wgt14);
|
|
output = svmla_x(pg, output, input15, wgt15);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 16;
|
|
pos += 16;
|
|
}
|
|
// unrolling 8 times
|
|
while (j + 7 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
bio += wgt1 * scale_bias[2 * idx1 + 1];
|
|
wgt1 = wgt1 * scale_bias[2 * idx1];
|
|
bio += wgt2 * scale_bias[2 * idx2 + 1];
|
|
wgt2 = wgt2 * scale_bias[2 * idx2];
|
|
bio += wgt3 * scale_bias[2 * idx3 + 1];
|
|
wgt3 = wgt3 * scale_bias[2 * idx3];
|
|
bio += wgt4 * scale_bias[2 * idx4 + 1];
|
|
wgt4 = wgt4 * scale_bias[2 * idx4];
|
|
bio += wgt5 * scale_bias[2 * idx5 + 1];
|
|
wgt5 = wgt5 * scale_bias[2 * idx5];
|
|
bio += wgt6 * scale_bias[2 * idx6 + 1];
|
|
wgt6 = wgt6 * scale_bias[2 * idx6];
|
|
bio += wgt7 * scale_bias[2 * idx7 + 1];
|
|
wgt7 = wgt7 * scale_bias[2 * idx7];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
const uint8_t* const ip1 = &input[idx1 * block_size];
|
|
const uint8_t* const ip2 = &input[idx2 * block_size];
|
|
const uint8_t* const ip3 = &input[idx3 * block_size];
|
|
const uint8_t* const ip4 = &input[idx4 * block_size];
|
|
const uint8_t* const ip5 = &input[idx5 * block_size];
|
|
const uint8_t* const ip6 = &input[idx6 * block_size];
|
|
const uint8_t* const ip7 = &input[idx7 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
|
|
auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
|
|
auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
|
|
auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
|
|
auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
|
|
auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
|
|
auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
|
|
auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
|
|
auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 8;
|
|
pos += 8;
|
|
}
|
|
// unrolling 4 times
|
|
while (j + 3 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
bio += wgt1 * scale_bias[2 * idx1 + 1];
|
|
wgt1 = wgt1 * scale_bias[2 * idx1];
|
|
bio += wgt2 * scale_bias[2 * idx2 + 1];
|
|
wgt2 = wgt2 * scale_bias[2 * idx2];
|
|
bio += wgt3 * scale_bias[2 * idx3 + 1];
|
|
wgt3 = wgt3 * scale_bias[2 * idx3];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
const uint8_t* const ip1 = &input[idx1 * block_size];
|
|
const uint8_t* const ip2 = &input[idx2 * block_size];
|
|
const uint8_t* const ip3 = &input[idx3 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 4;
|
|
pos += 4;
|
|
}
|
|
// unrolling 2 times
|
|
while (j + 1 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
bio += wgt1 * scale_bias[2 * idx1 + 1];
|
|
wgt1 = wgt1 * scale_bias[2 * idx1];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
const uint8_t* const ip1 = &input[idx1 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 2;
|
|
pos += 2;
|
|
}
|
|
// tail loop
|
|
if (j < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
pos ++;
|
|
}
|
|
const int64_t length = end_offset - start_offset;
|
|
|
|
if (normalize_by_lengths && length != 0) {
|
|
const float len_inv = 1.0f / length;
|
|
svbool_t pg;
|
|
int64_t j = 0;
|
|
while (j + vLen - 1 < block_size) {
|
|
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
|
|
j += vLen;
|
|
}
|
|
if (j < block_size) {
|
|
pg = svwhilelt_b32_s64(j, block_size);
|
|
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
|
|
}
|
|
}
|
|
}
|
|
return pos == index_size;
|
|
}
|
|
bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int32_t_uint8_t_float__sve<false>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const int32_t* indices,
|
|
const int32_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int32_t_uint8_t_float__sve<true>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
|
|
template <bool IS_WEIGHT_POSITIONAL>
|
|
static bool EmbeddingLookupIdx_int64_t_uint8_t_float__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
const svbool_t svAll = svptrue_b32();
|
|
const auto vLen = static_cast<int64_t>(svcntw());
|
|
int64_t pos = 0;
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
float* const op = &out[i * block_size];
|
|
memset(op, 0, sizeof(float) * block_size);
|
|
if (pos != offsets[i] - offsets[0]) {
|
|
return false;
|
|
}
|
|
int64_t start_offset = offsets[i];
|
|
int64_t end_offset = offsets[i + 1];
|
|
int64_t j = start_offset;
|
|
// unrolling 16 times
|
|
while (j + 15 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
const auto idx8 = indices[pos + 8];
|
|
const auto idx9 = indices[pos + 9];
|
|
const auto idx10 = indices[pos + 10];
|
|
const auto idx11 = indices[pos + 11];
|
|
const auto idx12 = indices[pos + 12];
|
|
const auto idx13 = indices[pos + 13];
|
|
const auto idx14 = indices[pos + 14];
|
|
const auto idx15 = indices[pos + 15];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx8 < 0 || idx8 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx9 < 0 || idx9 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx10 < 0 || idx10 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx11 < 0 || idx11 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx12 < 0 || idx12 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx13 < 0 || idx13 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx14 < 0 || idx14 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx15 < 0 || idx15 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float wgt8 = 1.f;
|
|
float wgt9 = 1.f;
|
|
float wgt10 = 1.f;
|
|
float wgt11 = 1.f;
|
|
float wgt12 = 1.f;
|
|
float wgt13 = 1.f;
|
|
float wgt14 = 1.f;
|
|
float wgt15 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
|
|
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
|
|
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
|
|
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
|
|
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
|
|
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
|
|
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
|
|
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
bio += wgt1 * scale_bias[2 * idx1 + 1];
|
|
wgt1 = wgt1 * scale_bias[2 * idx1];
|
|
bio += wgt2 * scale_bias[2 * idx2 + 1];
|
|
wgt2 = wgt2 * scale_bias[2 * idx2];
|
|
bio += wgt3 * scale_bias[2 * idx3 + 1];
|
|
wgt3 = wgt3 * scale_bias[2 * idx3];
|
|
bio += wgt4 * scale_bias[2 * idx4 + 1];
|
|
wgt4 = wgt4 * scale_bias[2 * idx4];
|
|
bio += wgt5 * scale_bias[2 * idx5 + 1];
|
|
wgt5 = wgt5 * scale_bias[2 * idx5];
|
|
bio += wgt6 * scale_bias[2 * idx6 + 1];
|
|
wgt6 = wgt6 * scale_bias[2 * idx6];
|
|
bio += wgt7 * scale_bias[2 * idx7 + 1];
|
|
wgt7 = wgt7 * scale_bias[2 * idx7];
|
|
bio += wgt8 * scale_bias[2 * idx8 + 1];
|
|
wgt8 = wgt8 * scale_bias[2 * idx8];
|
|
bio += wgt9 * scale_bias[2 * idx9 + 1];
|
|
wgt9 = wgt9 * scale_bias[2 * idx9];
|
|
bio += wgt10 * scale_bias[2 * idx10 + 1];
|
|
wgt10 = wgt10 * scale_bias[2 * idx10];
|
|
bio += wgt11 * scale_bias[2 * idx11 + 1];
|
|
wgt11 = wgt11 * scale_bias[2 * idx11];
|
|
bio += wgt12 * scale_bias[2 * idx12 + 1];
|
|
wgt12 = wgt12 * scale_bias[2 * idx12];
|
|
bio += wgt13 * scale_bias[2 * idx13 + 1];
|
|
wgt13 = wgt13 * scale_bias[2 * idx13];
|
|
bio += wgt14 * scale_bias[2 * idx14 + 1];
|
|
wgt14 = wgt14 * scale_bias[2 * idx14];
|
|
bio += wgt15 * scale_bias[2 * idx15 + 1];
|
|
wgt15 = wgt15 * scale_bias[2 * idx15];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
const uint8_t* const ip1 = &input[idx1 * block_size];
|
|
const uint8_t* const ip2 = &input[idx2 * block_size];
|
|
const uint8_t* const ip3 = &input[idx3 * block_size];
|
|
const uint8_t* const ip4 = &input[idx4 * block_size];
|
|
const uint8_t* const ip5 = &input[idx5 * block_size];
|
|
const uint8_t* const ip6 = &input[idx6 * block_size];
|
|
const uint8_t* const ip7 = &input[idx7 * block_size];
|
|
const uint8_t* const ip8 = &input[idx8 * block_size];
|
|
const uint8_t* const ip9 = &input[idx9 * block_size];
|
|
const uint8_t* const ip10 = &input[idx10 * block_size];
|
|
const uint8_t* const ip11 = &input[idx11 * block_size];
|
|
const uint8_t* const ip12 = &input[idx12 * block_size];
|
|
const uint8_t* const ip13 = &input[idx13 * block_size];
|
|
const uint8_t* const ip14 = &input[idx14 * block_size];
|
|
const uint8_t* const ip15 = &input[idx15 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
|
|
auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
|
|
auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
|
|
auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
|
|
auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
|
|
auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k]));
|
|
auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k]));
|
|
auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k]));
|
|
auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k]));
|
|
auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k]));
|
|
auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k]));
|
|
auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k]));
|
|
auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
output = svmla_x(svAll, output, input8, wgt8);
|
|
output = svmla_x(svAll, output, input9, wgt9);
|
|
output = svmla_x(svAll, output, input10, wgt10);
|
|
output = svmla_x(svAll, output, input11, wgt11);
|
|
output = svmla_x(svAll, output, input12, wgt12);
|
|
output = svmla_x(svAll, output, input13, wgt13);
|
|
output = svmla_x(svAll, output, input14, wgt14);
|
|
output = svmla_x(svAll, output, input15, wgt15);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
|
|
auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
|
|
auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
|
|
auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
|
|
auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
|
|
auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k]));
|
|
auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k]));
|
|
auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k]));
|
|
auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k]));
|
|
auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k]));
|
|
auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k]));
|
|
auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k]));
|
|
auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
output = svmla_x(pg, output, input8, wgt8);
|
|
output = svmla_x(pg, output, input9, wgt9);
|
|
output = svmla_x(pg, output, input10, wgt10);
|
|
output = svmla_x(pg, output, input11, wgt11);
|
|
output = svmla_x(pg, output, input12, wgt12);
|
|
output = svmla_x(pg, output, input13, wgt13);
|
|
output = svmla_x(pg, output, input14, wgt14);
|
|
output = svmla_x(pg, output, input15, wgt15);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 16;
|
|
pos += 16;
|
|
}
|
|
// unrolling 8 times
|
|
while (j + 7 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
const auto idx4 = indices[pos + 4];
|
|
const auto idx5 = indices[pos + 5];
|
|
const auto idx6 = indices[pos + 6];
|
|
const auto idx7 = indices[pos + 7];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx4 < 0 || idx4 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx5 < 0 || idx5 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx6 < 0 || idx6 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx7 < 0 || idx7 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float wgt4 = 1.f;
|
|
float wgt5 = 1.f;
|
|
float wgt6 = 1.f;
|
|
float wgt7 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
|
|
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
|
|
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
|
|
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
bio += wgt1 * scale_bias[2 * idx1 + 1];
|
|
wgt1 = wgt1 * scale_bias[2 * idx1];
|
|
bio += wgt2 * scale_bias[2 * idx2 + 1];
|
|
wgt2 = wgt2 * scale_bias[2 * idx2];
|
|
bio += wgt3 * scale_bias[2 * idx3 + 1];
|
|
wgt3 = wgt3 * scale_bias[2 * idx3];
|
|
bio += wgt4 * scale_bias[2 * idx4 + 1];
|
|
wgt4 = wgt4 * scale_bias[2 * idx4];
|
|
bio += wgt5 * scale_bias[2 * idx5 + 1];
|
|
wgt5 = wgt5 * scale_bias[2 * idx5];
|
|
bio += wgt6 * scale_bias[2 * idx6 + 1];
|
|
wgt6 = wgt6 * scale_bias[2 * idx6];
|
|
bio += wgt7 * scale_bias[2 * idx7 + 1];
|
|
wgt7 = wgt7 * scale_bias[2 * idx7];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
const uint8_t* const ip1 = &input[idx1 * block_size];
|
|
const uint8_t* const ip2 = &input[idx2 * block_size];
|
|
const uint8_t* const ip3 = &input[idx3 * block_size];
|
|
const uint8_t* const ip4 = &input[idx4 * block_size];
|
|
const uint8_t* const ip5 = &input[idx5 * block_size];
|
|
const uint8_t* const ip6 = &input[idx6 * block_size];
|
|
const uint8_t* const ip7 = &input[idx7 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
|
|
auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
|
|
auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
|
|
auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
|
|
auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
output = svmla_x(svAll, output, input4, wgt4);
|
|
output = svmla_x(svAll, output, input5, wgt5);
|
|
output = svmla_x(svAll, output, input6, wgt6);
|
|
output = svmla_x(svAll, output, input7, wgt7);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
|
|
auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
|
|
auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
|
|
auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
|
|
auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
output = svmla_x(pg, output, input4, wgt4);
|
|
output = svmla_x(pg, output, input5, wgt5);
|
|
output = svmla_x(pg, output, input6, wgt6);
|
|
output = svmla_x(pg, output, input7, wgt7);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 8;
|
|
pos += 8;
|
|
}
|
|
// unrolling 4 times
|
|
while (j + 3 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
const auto idx2 = indices[pos + 2];
|
|
const auto idx3 = indices[pos + 3];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx2 < 0 || idx2 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx3 < 0 || idx3 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float wgt2 = 1.f;
|
|
float wgt3 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
|
|
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
bio += wgt1 * scale_bias[2 * idx1 + 1];
|
|
wgt1 = wgt1 * scale_bias[2 * idx1];
|
|
bio += wgt2 * scale_bias[2 * idx2 + 1];
|
|
wgt2 = wgt2 * scale_bias[2 * idx2];
|
|
bio += wgt3 * scale_bias[2 * idx3 + 1];
|
|
wgt3 = wgt3 * scale_bias[2 * idx3];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
const uint8_t* const ip1 = &input[idx1 * block_size];
|
|
const uint8_t* const ip2 = &input[idx2 * block_size];
|
|
const uint8_t* const ip3 = &input[idx3 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
output = svmla_x(svAll, output, input2, wgt2);
|
|
output = svmla_x(svAll, output, input3, wgt3);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
|
|
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
|
|
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
output = svmla_x(pg, output, input2, wgt2);
|
|
output = svmla_x(pg, output, input3, wgt3);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 4;
|
|
pos += 4;
|
|
}
|
|
// unrolling 2 times
|
|
while (j + 1 < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
const auto idx1 = indices[pos + 1];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
if (idx1 < 0 || idx1 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float wgt1 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
bio += wgt1 * scale_bias[2 * idx1 + 1];
|
|
wgt1 = wgt1 * scale_bias[2 * idx1];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
const uint8_t* const ip1 = &input[idx1 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
output = svmla_x(svAll, output, input1, wgt1);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
output = svmla_x(pg, output, input1, wgt1);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
j += 2;
|
|
pos += 2;
|
|
}
|
|
// tail loop
|
|
if (j < end_offset) {
|
|
const auto idx0 = indices[pos + 0];
|
|
if (idx0 < 0 || idx0 >= data_size) {
|
|
return false;
|
|
}
|
|
float wgt0 = 1.f;
|
|
float bio = 0.f;
|
|
if (weights) {
|
|
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
|
|
}
|
|
if (scale_bias) {
|
|
bio += wgt0 * scale_bias[2 * idx0 + 1];
|
|
wgt0 = wgt0 * scale_bias[2 * idx0];
|
|
}
|
|
const uint8_t* const ip0 = &input[idx0 * block_size];
|
|
svbool_t pg;
|
|
int64_t k = 0;
|
|
while (k + vLen - 1 < block_size) {
|
|
auto output = svld1(svAll, &op[k]);
|
|
output = svadd_x(svAll, output, bio);
|
|
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
|
|
output = svmla_x(svAll, output, input0, wgt0);
|
|
svst1(svAll, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
if (k < block_size) {
|
|
pg = svwhilelt_b32_s64(k, block_size);
|
|
auto output = svld1(pg, &op[k]);
|
|
output = svadd_x(pg, output, bio);
|
|
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
|
|
output = svmla_x(pg, output, input0, wgt0);
|
|
svst1(pg, &op[k], output);
|
|
k += vLen;
|
|
}
|
|
pos ++;
|
|
}
|
|
const int64_t length = end_offset - start_offset;
|
|
|
|
if (normalize_by_lengths && length != 0) {
|
|
const float len_inv = 1.0f / length;
|
|
svbool_t pg;
|
|
int64_t j = 0;
|
|
while (j + vLen - 1 < block_size) {
|
|
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
|
|
j += vLen;
|
|
}
|
|
if (j < block_size) {
|
|
pg = svwhilelt_b32_s64(j, block_size);
|
|
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
|
|
}
|
|
}
|
|
}
|
|
return pos == index_size;
|
|
}
|
|
bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int64_t_uint8_t_float__sve<false>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__sve(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const int64_t* indices,
|
|
const int64_t* offsets,
|
|
const float* weights,
|
|
const float* scale_bias,
|
|
bool normalize_by_lengths,
|
|
float* out) {
|
|
return EmbeddingLookupIdx_int64_t_uint8_t_float__sve<true>(
|
|
block_size,
|
|
output_size,
|
|
index_size,
|
|
data_size,
|
|
input,
|
|
indices,
|
|
offsets,
|
|
weights,
|
|
scale_bias,
|
|
normalize_by_lengths,
|
|
out);
|
|
}
|
|
|
|
} // namespace caffe2
|