Files
pytorch/caffe2/perfkernels/embedding_lookup_idx_sve.cc
Annop Wongwathanarat 6fcffd8cd1 Optimize SVE embedding performance (#150176)
Change loop unrolling strategy. Previously, the script only unrolls the inner loop over block_size when block size is multiple of vector length. This version instead unrolls the outer loop which reduces the number of load/store for accumulation into the output array and improves performance for cases when block size is not multiple of vector length.

Benchmarking script:
```python
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
import torch.nn as nn
import numpy as np
import time
import sys

np.random.seed(0)
torch.manual_seed(0)

num_embeddings = 400000
embedding_dim = int(sys.argv[1])
multi_hot = 100
batch_size = 400
nrun = 1000

class SimpleEmbeddingBagModel(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(SimpleEmbeddingBagModel, self).__init__()

        weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)).to(torch.float16)

        # Defining the EmbeddingBag layer
        self.embedding_bag = torch.nn.EmbeddingBag(num_embeddings, embedding_dim, _weight=weights,
                                                   mode='sum', include_last_offset=True, dtype=torch.float32)

    def forward(self, input, offsets):
        # Forward pass through the EmbeddingBag layer
        result32 = self.embedding_bag(input, offsets, per_sample_weights=None)
        return result32

# Instantiate the model
model = SimpleEmbeddingBagModel(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
model.eval()

# Example input
input_tensor = torch.randint(0, num_embeddings, (batch_size * multi_hot,), dtype=torch.long)

offsets = torch.tensor(range(0, batch_size * multi_hot + 1, multi_hot))

with torch.no_grad():
    # warm up
    output32 = model(input_tensor, offsets)

    ti = time.time_ns()
    for i in range(nrun):
        _ = model(input_tensor, offsets)
    tf = time.time_ns()
    print("{:3d} {:.3E}".format(embedding_dim, (tf-ti)/nrun/1.e6))
```
Speedup on NEOVERSEV1 with 1 thread
![embedding](https://github.com/user-attachments/assets/16e567ed-b9a5-4db3-90b8-dec66d5414a7)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150176
Approved by: https://github.com/digantdesai, https://github.com/malfet
2025-04-07 18:01:54 +00:00

4586 lines
183 KiB
C++

//// --------------------------
//// ATTENTION:
//// THIS CODE IS AUTOGENERATED
//// BY sve_emblookup_codegen.py
//// DO NOT MODIFY!!!
//// --------------------------
#include <arm_sve.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <cstdint>
#include <cstring>
namespace caffe2 {
template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_float_float__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const float* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
const svbool_t svAll = svptrue_b32();
const auto vLen = static_cast<int64_t>(svcntw());
int64_t pos = 0;
for (int64_t i = 0; i < output_size; ++i) {
float* const op = &out[i * block_size];
memset(op, 0, sizeof(float) * block_size);
if (pos != offsets[i] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[i];
int64_t end_offset = offsets[i + 1];
int64_t j = start_offset;
// unrolling 16 times
while (j + 15 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
const auto idx8 = indices[pos + 8];
const auto idx9 = indices[pos + 9];
const auto idx10 = indices[pos + 10];
const auto idx11 = indices[pos + 11];
const auto idx12 = indices[pos + 12];
const auto idx13 = indices[pos + 13];
const auto idx14 = indices[pos + 14];
const auto idx15 = indices[pos + 15];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
if (idx8 < 0 || idx8 >= data_size) {
return false;
}
if (idx9 < 0 || idx9 >= data_size) {
return false;
}
if (idx10 < 0 || idx10 >= data_size) {
return false;
}
if (idx11 < 0 || idx11 >= data_size) {
return false;
}
if (idx12 < 0 || idx12 >= data_size) {
return false;
}
if (idx13 < 0 || idx13 >= data_size) {
return false;
}
if (idx14 < 0 || idx14 >= data_size) {
return false;
}
if (idx15 < 0 || idx15 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float wgt8 = 1.f;
float wgt9 = 1.f;
float wgt10 = 1.f;
float wgt11 = 1.f;
float wgt12 = 1.f;
float wgt13 = 1.f;
float wgt14 = 1.f;
float wgt15 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
}
const float* const ip0 = &input[idx0 * block_size];
const float* const ip1 = &input[idx1 * block_size];
const float* const ip2 = &input[idx2 * block_size];
const float* const ip3 = &input[idx3 * block_size];
const float* const ip4 = &input[idx4 * block_size];
const float* const ip5 = &input[idx5 * block_size];
const float* const ip6 = &input[idx6 * block_size];
const float* const ip7 = &input[idx7 * block_size];
const float* const ip8 = &input[idx8 * block_size];
const float* const ip9 = &input[idx9 * block_size];
const float* const ip10 = &input[idx10 * block_size];
const float* const ip11 = &input[idx11 * block_size];
const float* const ip12 = &input[idx12 * block_size];
const float* const ip13 = &input[idx13 * block_size];
const float* const ip14 = &input[idx14 * block_size];
const float* const ip15 = &input[idx15 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8);
output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9);
output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10);
output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11);
output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12);
output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13);
output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14);
output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8);
output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9);
output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10);
output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11);
output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12);
output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13);
output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14);
output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15);
svst1(pg, &op[k], output);
k += vLen;
}
j += 16;
pos += 16;
}
// unrolling 8 times
while (j + 7 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
}
const float* const ip0 = &input[idx0 * block_size];
const float* const ip1 = &input[idx1 * block_size];
const float* const ip2 = &input[idx2 * block_size];
const float* const ip3 = &input[idx3 * block_size];
const float* const ip4 = &input[idx4 * block_size];
const float* const ip5 = &input[idx5 * block_size];
const float* const ip6 = &input[idx6 * block_size];
const float* const ip7 = &input[idx7 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
svst1(pg, &op[k], output);
k += vLen;
}
j += 8;
pos += 8;
}
// unrolling 4 times
while (j + 3 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
}
const float* const ip0 = &input[idx0 * block_size];
const float* const ip1 = &input[idx1 * block_size];
const float* const ip2 = &input[idx2 * block_size];
const float* const ip3 = &input[idx3 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
svst1(pg, &op[k], output);
k += vLen;
}
j += 4;
pos += 4;
}
// unrolling 2 times
while (j + 1 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
}
const float* const ip0 = &input[idx0 * block_size];
const float* const ip1 = &input[idx1 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
svst1(pg, &op[k], output);
k += vLen;
}
j += 2;
pos += 2;
}
// tail loop
if (j < end_offset) {
const auto idx0 = indices[pos + 0];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
float wgt0 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
}
const float* const ip0 = &input[idx0 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
svst1(pg, &op[k], output);
k += vLen;
}
pos ++;
}
const int64_t length = end_offset - start_offset;
if (normalize_by_lengths && length != 0) {
const float len_inv = 1.0f / length;
svbool_t pg;
int64_t j = 0;
while (j + vLen - 1 < block_size) {
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
j += vLen;
}
if (j < block_size) {
pg = svwhilelt_b32_s64(j, block_size);
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
}
}
}
return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_float_float_false__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const float* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int32_t_float_float__sve<false>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
bool EmbeddingLookupIdx_int32_t_float_float_true__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const float* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int32_t_float_float__sve<true>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_float_float__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const float* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
const svbool_t svAll = svptrue_b32();
const auto vLen = static_cast<int64_t>(svcntw());
int64_t pos = 0;
for (int64_t i = 0; i < output_size; ++i) {
float* const op = &out[i * block_size];
memset(op, 0, sizeof(float) * block_size);
if (pos != offsets[i] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[i];
int64_t end_offset = offsets[i + 1];
int64_t j = start_offset;
// unrolling 16 times
while (j + 15 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
const auto idx8 = indices[pos + 8];
const auto idx9 = indices[pos + 9];
const auto idx10 = indices[pos + 10];
const auto idx11 = indices[pos + 11];
const auto idx12 = indices[pos + 12];
const auto idx13 = indices[pos + 13];
const auto idx14 = indices[pos + 14];
const auto idx15 = indices[pos + 15];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
if (idx8 < 0 || idx8 >= data_size) {
return false;
}
if (idx9 < 0 || idx9 >= data_size) {
return false;
}
if (idx10 < 0 || idx10 >= data_size) {
return false;
}
if (idx11 < 0 || idx11 >= data_size) {
return false;
}
if (idx12 < 0 || idx12 >= data_size) {
return false;
}
if (idx13 < 0 || idx13 >= data_size) {
return false;
}
if (idx14 < 0 || idx14 >= data_size) {
return false;
}
if (idx15 < 0 || idx15 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float wgt8 = 1.f;
float wgt9 = 1.f;
float wgt10 = 1.f;
float wgt11 = 1.f;
float wgt12 = 1.f;
float wgt13 = 1.f;
float wgt14 = 1.f;
float wgt15 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
}
const float* const ip0 = &input[idx0 * block_size];
const float* const ip1 = &input[idx1 * block_size];
const float* const ip2 = &input[idx2 * block_size];
const float* const ip3 = &input[idx3 * block_size];
const float* const ip4 = &input[idx4 * block_size];
const float* const ip5 = &input[idx5 * block_size];
const float* const ip6 = &input[idx6 * block_size];
const float* const ip7 = &input[idx7 * block_size];
const float* const ip8 = &input[idx8 * block_size];
const float* const ip9 = &input[idx9 * block_size];
const float* const ip10 = &input[idx10 * block_size];
const float* const ip11 = &input[idx11 * block_size];
const float* const ip12 = &input[idx12 * block_size];
const float* const ip13 = &input[idx13 * block_size];
const float* const ip14 = &input[idx14 * block_size];
const float* const ip15 = &input[idx15 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8);
output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9);
output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10);
output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11);
output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12);
output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13);
output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14);
output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8);
output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9);
output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10);
output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11);
output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12);
output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13);
output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14);
output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15);
svst1(pg, &op[k], output);
k += vLen;
}
j += 16;
pos += 16;
}
// unrolling 8 times
while (j + 7 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
}
const float* const ip0 = &input[idx0 * block_size];
const float* const ip1 = &input[idx1 * block_size];
const float* const ip2 = &input[idx2 * block_size];
const float* const ip3 = &input[idx3 * block_size];
const float* const ip4 = &input[idx4 * block_size];
const float* const ip5 = &input[idx5 * block_size];
const float* const ip6 = &input[idx6 * block_size];
const float* const ip7 = &input[idx7 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4);
output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5);
output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6);
output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4);
output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5);
output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6);
output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7);
svst1(pg, &op[k], output);
k += vLen;
}
j += 8;
pos += 8;
}
// unrolling 4 times
while (j + 3 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
}
const float* const ip0 = &input[idx0 * block_size];
const float* const ip1 = &input[idx1 * block_size];
const float* const ip2 = &input[idx2 * block_size];
const float* const ip3 = &input[idx3 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2);
output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3);
svst1(pg, &op[k], output);
k += vLen;
}
j += 4;
pos += 4;
}
// unrolling 2 times
while (j + 1 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
}
const float* const ip0 = &input[idx0 * block_size];
const float* const ip1 = &input[idx1 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1);
svst1(pg, &op[k], output);
k += vLen;
}
j += 2;
pos += 2;
}
// tail loop
if (j < end_offset) {
const auto idx0 = indices[pos + 0];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
float wgt0 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
}
const float* const ip0 = &input[idx0 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0);
svst1(pg, &op[k], output);
k += vLen;
}
pos ++;
}
const int64_t length = end_offset - start_offset;
if (normalize_by_lengths && length != 0) {
const float len_inv = 1.0f / length;
svbool_t pg;
int64_t j = 0;
while (j + vLen - 1 < block_size) {
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
j += vLen;
}
if (j < block_size) {
pg = svwhilelt_b32_s64(j, block_size);
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
}
}
}
return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_float_float_false__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const float* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int64_t_float_float__sve<false>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
bool EmbeddingLookupIdx_int64_t_float_float_true__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const float* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int64_t_float_float__sve<true>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_half_float__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::Half* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
const svbool_t svAll = svptrue_b32();
const auto vLen = static_cast<int64_t>(svcntw());
int64_t pos = 0;
for (int64_t i = 0; i < output_size; ++i) {
float* const op = &out[i * block_size];
memset(op, 0, sizeof(float) * block_size);
if (pos != offsets[i] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[i];
int64_t end_offset = offsets[i + 1];
int64_t j = start_offset;
// unrolling 16 times
while (j + 15 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
const auto idx8 = indices[pos + 8];
const auto idx9 = indices[pos + 9];
const auto idx10 = indices[pos + 10];
const auto idx11 = indices[pos + 11];
const auto idx12 = indices[pos + 12];
const auto idx13 = indices[pos + 13];
const auto idx14 = indices[pos + 14];
const auto idx15 = indices[pos + 15];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
if (idx8 < 0 || idx8 >= data_size) {
return false;
}
if (idx9 < 0 || idx9 >= data_size) {
return false;
}
if (idx10 < 0 || idx10 >= data_size) {
return false;
}
if (idx11 < 0 || idx11 >= data_size) {
return false;
}
if (idx12 < 0 || idx12 >= data_size) {
return false;
}
if (idx13 < 0 || idx13 >= data_size) {
return false;
}
if (idx14 < 0 || idx14 >= data_size) {
return false;
}
if (idx15 < 0 || idx15 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float wgt8 = 1.f;
float wgt9 = 1.f;
float wgt10 = 1.f;
float wgt11 = 1.f;
float wgt12 = 1.f;
float wgt13 = 1.f;
float wgt14 = 1.f;
float wgt15 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
}
const at::Half* const ip0 = &input[idx0 * block_size];
const at::Half* const ip1 = &input[idx1 * block_size];
const at::Half* const ip2 = &input[idx2 * block_size];
const at::Half* const ip3 = &input[idx3 * block_size];
const at::Half* const ip4 = &input[idx4 * block_size];
const at::Half* const ip5 = &input[idx5 * block_size];
const at::Half* const ip6 = &input[idx6 * block_size];
const at::Half* const ip7 = &input[idx7 * block_size];
const at::Half* const ip8 = &input[idx8 * block_size];
const at::Half* const ip9 = &input[idx9 * block_size];
const at::Half* const ip10 = &input[idx10 * block_size];
const at::Half* const ip11 = &input[idx11 * block_size];
const at::Half* const ip12 = &input[idx12 * block_size];
const at::Half* const ip13 = &input[idx13 * block_size];
const at::Half* const ip14 = &input[idx14 * block_size];
const at::Half* const ip15 = &input[idx15 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
auto input8 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k]))));
auto input9 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k]))));
auto input10 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k]))));
auto input11 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k]))));
auto input12 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k]))));
auto input13 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k]))));
auto input14 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k]))));
auto input15 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k]))));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
output = svmla_x(svAll, output, input8, wgt8);
output = svmla_x(svAll, output, input9, wgt9);
output = svmla_x(svAll, output, input10, wgt10);
output = svmla_x(svAll, output, input11, wgt11);
output = svmla_x(svAll, output, input12, wgt12);
output = svmla_x(svAll, output, input13, wgt13);
output = svmla_x(svAll, output, input14, wgt14);
output = svmla_x(svAll, output, input15, wgt15);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
auto input8 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k]))));
auto input9 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k]))));
auto input10 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k]))));
auto input11 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k]))));
auto input12 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k]))));
auto input13 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k]))));
auto input14 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k]))));
auto input15 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k]))));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
output = svmla_x(pg, output, input8, wgt8);
output = svmla_x(pg, output, input9, wgt9);
output = svmla_x(pg, output, input10, wgt10);
output = svmla_x(pg, output, input11, wgt11);
output = svmla_x(pg, output, input12, wgt12);
output = svmla_x(pg, output, input13, wgt13);
output = svmla_x(pg, output, input14, wgt14);
output = svmla_x(pg, output, input15, wgt15);
svst1(pg, &op[k], output);
k += vLen;
}
j += 16;
pos += 16;
}
// unrolling 8 times
while (j + 7 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
}
const at::Half* const ip0 = &input[idx0 * block_size];
const at::Half* const ip1 = &input[idx1 * block_size];
const at::Half* const ip2 = &input[idx2 * block_size];
const at::Half* const ip3 = &input[idx3 * block_size];
const at::Half* const ip4 = &input[idx4 * block_size];
const at::Half* const ip5 = &input[idx5 * block_size];
const at::Half* const ip6 = &input[idx6 * block_size];
const at::Half* const ip7 = &input[idx7 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
svst1(pg, &op[k], output);
k += vLen;
}
j += 8;
pos += 8;
}
// unrolling 4 times
while (j + 3 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
}
const at::Half* const ip0 = &input[idx0 * block_size];
const at::Half* const ip1 = &input[idx1 * block_size];
const at::Half* const ip2 = &input[idx2 * block_size];
const at::Half* const ip3 = &input[idx3 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
svst1(pg, &op[k], output);
k += vLen;
}
j += 4;
pos += 4;
}
// unrolling 2 times
while (j + 1 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
}
const at::Half* const ip0 = &input[idx0 * block_size];
const at::Half* const ip1 = &input[idx1 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
svst1(pg, &op[k], output);
k += vLen;
}
j += 2;
pos += 2;
}
// tail loop
if (j < end_offset) {
const auto idx0 = indices[pos + 0];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
float wgt0 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
}
const at::Half* const ip0 = &input[idx0 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
output = svmla_x(svAll, output, input0, wgt0);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
output = svmla_x(pg, output, input0, wgt0);
svst1(pg, &op[k], output);
k += vLen;
}
pos ++;
}
const int64_t length = end_offset - start_offset;
if (normalize_by_lengths && length != 0) {
const float len_inv = 1.0f / length;
svbool_t pg;
int64_t j = 0;
while (j + vLen - 1 < block_size) {
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
j += vLen;
}
if (j < block_size) {
pg = svwhilelt_b32_s64(j, block_size);
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
}
}
}
return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_half_float_false__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::Half* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int32_t_half_float__sve<false>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
bool EmbeddingLookupIdx_int32_t_half_float_true__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::Half* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int32_t_half_float__sve<true>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_half_float__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::Half* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
const svbool_t svAll = svptrue_b32();
const auto vLen = static_cast<int64_t>(svcntw());
int64_t pos = 0;
for (int64_t i = 0; i < output_size; ++i) {
float* const op = &out[i * block_size];
memset(op, 0, sizeof(float) * block_size);
if (pos != offsets[i] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[i];
int64_t end_offset = offsets[i + 1];
int64_t j = start_offset;
// unrolling 16 times
while (j + 15 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
const auto idx8 = indices[pos + 8];
const auto idx9 = indices[pos + 9];
const auto idx10 = indices[pos + 10];
const auto idx11 = indices[pos + 11];
const auto idx12 = indices[pos + 12];
const auto idx13 = indices[pos + 13];
const auto idx14 = indices[pos + 14];
const auto idx15 = indices[pos + 15];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
if (idx8 < 0 || idx8 >= data_size) {
return false;
}
if (idx9 < 0 || idx9 >= data_size) {
return false;
}
if (idx10 < 0 || idx10 >= data_size) {
return false;
}
if (idx11 < 0 || idx11 >= data_size) {
return false;
}
if (idx12 < 0 || idx12 >= data_size) {
return false;
}
if (idx13 < 0 || idx13 >= data_size) {
return false;
}
if (idx14 < 0 || idx14 >= data_size) {
return false;
}
if (idx15 < 0 || idx15 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float wgt8 = 1.f;
float wgt9 = 1.f;
float wgt10 = 1.f;
float wgt11 = 1.f;
float wgt12 = 1.f;
float wgt13 = 1.f;
float wgt14 = 1.f;
float wgt15 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
}
const at::Half* const ip0 = &input[idx0 * block_size];
const at::Half* const ip1 = &input[idx1 * block_size];
const at::Half* const ip2 = &input[idx2 * block_size];
const at::Half* const ip3 = &input[idx3 * block_size];
const at::Half* const ip4 = &input[idx4 * block_size];
const at::Half* const ip5 = &input[idx5 * block_size];
const at::Half* const ip6 = &input[idx6 * block_size];
const at::Half* const ip7 = &input[idx7 * block_size];
const at::Half* const ip8 = &input[idx8 * block_size];
const at::Half* const ip9 = &input[idx9 * block_size];
const at::Half* const ip10 = &input[idx10 * block_size];
const at::Half* const ip11 = &input[idx11 * block_size];
const at::Half* const ip12 = &input[idx12 * block_size];
const at::Half* const ip13 = &input[idx13 * block_size];
const at::Half* const ip14 = &input[idx14 * block_size];
const at::Half* const ip15 = &input[idx15 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
auto input8 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k]))));
auto input9 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k]))));
auto input10 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k]))));
auto input11 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k]))));
auto input12 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k]))));
auto input13 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k]))));
auto input14 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k]))));
auto input15 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k]))));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
output = svmla_x(svAll, output, input8, wgt8);
output = svmla_x(svAll, output, input9, wgt9);
output = svmla_x(svAll, output, input10, wgt10);
output = svmla_x(svAll, output, input11, wgt11);
output = svmla_x(svAll, output, input12, wgt12);
output = svmla_x(svAll, output, input13, wgt13);
output = svmla_x(svAll, output, input14, wgt14);
output = svmla_x(svAll, output, input15, wgt15);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
auto input8 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k]))));
auto input9 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k]))));
auto input10 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k]))));
auto input11 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k]))));
auto input12 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k]))));
auto input13 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k]))));
auto input14 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k]))));
auto input15 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k]))));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
output = svmla_x(pg, output, input8, wgt8);
output = svmla_x(pg, output, input9, wgt9);
output = svmla_x(pg, output, input10, wgt10);
output = svmla_x(pg, output, input11, wgt11);
output = svmla_x(pg, output, input12, wgt12);
output = svmla_x(pg, output, input13, wgt13);
output = svmla_x(pg, output, input14, wgt14);
output = svmla_x(pg, output, input15, wgt15);
svst1(pg, &op[k], output);
k += vLen;
}
j += 16;
pos += 16;
}
// unrolling 8 times
while (j + 7 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
}
const at::Half* const ip0 = &input[idx0 * block_size];
const at::Half* const ip1 = &input[idx1 * block_size];
const at::Half* const ip2 = &input[idx2 * block_size];
const at::Half* const ip3 = &input[idx3 * block_size];
const at::Half* const ip4 = &input[idx4 * block_size];
const at::Half* const ip5 = &input[idx5 * block_size];
const at::Half* const ip6 = &input[idx6 * block_size];
const at::Half* const ip7 = &input[idx7 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
auto input4 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k]))));
auto input5 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k]))));
auto input6 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k]))));
auto input7 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k]))));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
auto input4 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k]))));
auto input5 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k]))));
auto input6 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k]))));
auto input7 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k]))));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
svst1(pg, &op[k], output);
k += vLen;
}
j += 8;
pos += 8;
}
// unrolling 4 times
while (j + 3 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
}
const at::Half* const ip0 = &input[idx0 * block_size];
const at::Half* const ip1 = &input[idx1 * block_size];
const at::Half* const ip2 = &input[idx2 * block_size];
const at::Half* const ip3 = &input[idx3 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k]))));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
auto input2 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k]))));
auto input3 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k]))));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
svst1(pg, &op[k], output);
k += vLen;
}
j += 4;
pos += 4;
}
// unrolling 2 times
while (j + 1 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
}
const at::Half* const ip0 = &input[idx0 * block_size];
const at::Half* const ip1 = &input[idx1 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k]))));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
auto input1 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k]))));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
svst1(pg, &op[k], output);
k += vLen;
}
j += 2;
pos += 2;
}
// tail loop
if (j < end_offset) {
const auto idx0 = indices[pos + 0];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
float wgt0 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
}
const at::Half* const ip0 = &input[idx0 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svcvt_f32_x(svAll, svreinterpret_f16(
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k]))));
output = svmla_x(svAll, output, input0, wgt0);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svcvt_f32_x(pg, svreinterpret_f16(
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k]))));
output = svmla_x(pg, output, input0, wgt0);
svst1(pg, &op[k], output);
k += vLen;
}
pos ++;
}
const int64_t length = end_offset - start_offset;
if (normalize_by_lengths && length != 0) {
const float len_inv = 1.0f / length;
svbool_t pg;
int64_t j = 0;
while (j + vLen - 1 < block_size) {
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
j += vLen;
}
if (j < block_size) {
pg = svwhilelt_b32_s64(j, block_size);
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
}
}
}
return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_half_float_false__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::Half* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int64_t_half_float__sve<false>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
bool EmbeddingLookupIdx_int64_t_half_float_true__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::Half* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int64_t_half_float__sve<true>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_bfloat16_float__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::BFloat16* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
const svbool_t svAll = svptrue_b32();
const auto vLen = static_cast<int64_t>(svcntw());
int64_t pos = 0;
for (int64_t i = 0; i < output_size; ++i) {
float* const op = &out[i * block_size];
memset(op, 0, sizeof(float) * block_size);
if (pos != offsets[i] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[i];
int64_t end_offset = offsets[i + 1];
int64_t j = start_offset;
// unrolling 16 times
while (j + 15 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
const auto idx8 = indices[pos + 8];
const auto idx9 = indices[pos + 9];
const auto idx10 = indices[pos + 10];
const auto idx11 = indices[pos + 11];
const auto idx12 = indices[pos + 12];
const auto idx13 = indices[pos + 13];
const auto idx14 = indices[pos + 14];
const auto idx15 = indices[pos + 15];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
if (idx8 < 0 || idx8 >= data_size) {
return false;
}
if (idx9 < 0 || idx9 >= data_size) {
return false;
}
if (idx10 < 0 || idx10 >= data_size) {
return false;
}
if (idx11 < 0 || idx11 >= data_size) {
return false;
}
if (idx12 < 0 || idx12 >= data_size) {
return false;
}
if (idx13 < 0 || idx13 >= data_size) {
return false;
}
if (idx14 < 0 || idx14 >= data_size) {
return false;
}
if (idx15 < 0 || idx15 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float wgt8 = 1.f;
float wgt9 = 1.f;
float wgt10 = 1.f;
float wgt11 = 1.f;
float wgt12 = 1.f;
float wgt13 = 1.f;
float wgt14 = 1.f;
float wgt15 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
const at::BFloat16* const ip1 = &input[idx1 * block_size];
const at::BFloat16* const ip2 = &input[idx2 * block_size];
const at::BFloat16* const ip3 = &input[idx3 * block_size];
const at::BFloat16* const ip4 = &input[idx4 * block_size];
const at::BFloat16* const ip5 = &input[idx5 * block_size];
const at::BFloat16* const ip6 = &input[idx6 * block_size];
const at::BFloat16* const ip7 = &input[idx7 * block_size];
const at::BFloat16* const ip8 = &input[idx8 * block_size];
const at::BFloat16* const ip9 = &input[idx9 * block_size];
const at::BFloat16* const ip10 = &input[idx10 * block_size];
const at::BFloat16* const ip11 = &input[idx11 * block_size];
const at::BFloat16* const ip12 = &input[idx12 * block_size];
const at::BFloat16* const ip13 = &input[idx13 * block_size];
const at::BFloat16* const ip14 = &input[idx14 * block_size];
const at::BFloat16* const ip15 = &input[idx15 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
auto input4 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
auto input5 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
auto input6 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
auto input7 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
auto input8 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
auto input9 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
auto input10 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
auto input11 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
auto input12 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
auto input13 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
auto input14 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
auto input15 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
output = svmla_x(svAll, output, input8, wgt8);
output = svmla_x(svAll, output, input9, wgt9);
output = svmla_x(svAll, output, input10, wgt10);
output = svmla_x(svAll, output, input11, wgt11);
output = svmla_x(svAll, output, input12, wgt12);
output = svmla_x(svAll, output, input13, wgt13);
output = svmla_x(svAll, output, input14, wgt14);
output = svmla_x(svAll, output, input15, wgt15);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
auto input4 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
auto input5 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
auto input6 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
auto input7 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
auto input8 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
auto input9 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
auto input10 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
auto input11 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
auto input12 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
auto input13 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
auto input14 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
auto input15 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
output = svmla_x(pg, output, input8, wgt8);
output = svmla_x(pg, output, input9, wgt9);
output = svmla_x(pg, output, input10, wgt10);
output = svmla_x(pg, output, input11, wgt11);
output = svmla_x(pg, output, input12, wgt12);
output = svmla_x(pg, output, input13, wgt13);
output = svmla_x(pg, output, input14, wgt14);
output = svmla_x(pg, output, input15, wgt15);
svst1(pg, &op[k], output);
k += vLen;
}
j += 16;
pos += 16;
}
// unrolling 8 times
while (j + 7 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
const at::BFloat16* const ip1 = &input[idx1 * block_size];
const at::BFloat16* const ip2 = &input[idx2 * block_size];
const at::BFloat16* const ip3 = &input[idx3 * block_size];
const at::BFloat16* const ip4 = &input[idx4 * block_size];
const at::BFloat16* const ip5 = &input[idx5 * block_size];
const at::BFloat16* const ip6 = &input[idx6 * block_size];
const at::BFloat16* const ip7 = &input[idx7 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
auto input4 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
auto input5 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
auto input6 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
auto input7 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
auto input4 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
auto input5 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
auto input6 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
auto input7 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
svst1(pg, &op[k], output);
k += vLen;
}
j += 8;
pos += 8;
}
// unrolling 4 times
while (j + 3 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
const at::BFloat16* const ip1 = &input[idx1 * block_size];
const at::BFloat16* const ip2 = &input[idx2 * block_size];
const at::BFloat16* const ip3 = &input[idx3 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
svst1(pg, &op[k], output);
k += vLen;
}
j += 4;
pos += 4;
}
// unrolling 2 times
while (j + 1 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
const at::BFloat16* const ip1 = &input[idx1 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
svst1(pg, &op[k], output);
k += vLen;
}
j += 2;
pos += 2;
}
// tail loop
if (j < end_offset) {
const auto idx0 = indices[pos + 0];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
float wgt0 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
svst1(pg, &op[k], output);
k += vLen;
}
pos ++;
}
const int64_t length = end_offset - start_offset;
if (normalize_by_lengths && length != 0) {
const float len_inv = 1.0f / length;
svbool_t pg;
int64_t j = 0;
while (j + vLen - 1 < block_size) {
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
j += vLen;
}
if (j < block_size) {
pg = svwhilelt_b32_s64(j, block_size);
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
}
}
}
return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::BFloat16* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int32_t_bfloat16_float__sve<false>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::BFloat16* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int32_t_bfloat16_float__sve<true>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_bfloat16_float__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::BFloat16* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
const svbool_t svAll = svptrue_b32();
const auto vLen = static_cast<int64_t>(svcntw());
int64_t pos = 0;
for (int64_t i = 0; i < output_size; ++i) {
float* const op = &out[i * block_size];
memset(op, 0, sizeof(float) * block_size);
if (pos != offsets[i] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[i];
int64_t end_offset = offsets[i + 1];
int64_t j = start_offset;
// unrolling 16 times
while (j + 15 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
const auto idx8 = indices[pos + 8];
const auto idx9 = indices[pos + 9];
const auto idx10 = indices[pos + 10];
const auto idx11 = indices[pos + 11];
const auto idx12 = indices[pos + 12];
const auto idx13 = indices[pos + 13];
const auto idx14 = indices[pos + 14];
const auto idx15 = indices[pos + 15];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
if (idx8 < 0 || idx8 >= data_size) {
return false;
}
if (idx9 < 0 || idx9 >= data_size) {
return false;
}
if (idx10 < 0 || idx10 >= data_size) {
return false;
}
if (idx11 < 0 || idx11 >= data_size) {
return false;
}
if (idx12 < 0 || idx12 >= data_size) {
return false;
}
if (idx13 < 0 || idx13 >= data_size) {
return false;
}
if (idx14 < 0 || idx14 >= data_size) {
return false;
}
if (idx15 < 0 || idx15 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float wgt8 = 1.f;
float wgt9 = 1.f;
float wgt10 = 1.f;
float wgt11 = 1.f;
float wgt12 = 1.f;
float wgt13 = 1.f;
float wgt14 = 1.f;
float wgt15 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
const at::BFloat16* const ip1 = &input[idx1 * block_size];
const at::BFloat16* const ip2 = &input[idx2 * block_size];
const at::BFloat16* const ip3 = &input[idx3 * block_size];
const at::BFloat16* const ip4 = &input[idx4 * block_size];
const at::BFloat16* const ip5 = &input[idx5 * block_size];
const at::BFloat16* const ip6 = &input[idx6 * block_size];
const at::BFloat16* const ip7 = &input[idx7 * block_size];
const at::BFloat16* const ip8 = &input[idx8 * block_size];
const at::BFloat16* const ip9 = &input[idx9 * block_size];
const at::BFloat16* const ip10 = &input[idx10 * block_size];
const at::BFloat16* const ip11 = &input[idx11 * block_size];
const at::BFloat16* const ip12 = &input[idx12 * block_size];
const at::BFloat16* const ip13 = &input[idx13 * block_size];
const at::BFloat16* const ip14 = &input[idx14 * block_size];
const at::BFloat16* const ip15 = &input[idx15 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
auto input4 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
auto input5 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
auto input6 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
auto input7 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
auto input8 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
auto input9 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
auto input10 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
auto input11 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
auto input12 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
auto input13 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
auto input14 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
auto input15 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
output = svmla_x(svAll, output, input8, wgt8);
output = svmla_x(svAll, output, input9, wgt9);
output = svmla_x(svAll, output, input10, wgt10);
output = svmla_x(svAll, output, input11, wgt11);
output = svmla_x(svAll, output, input12, wgt12);
output = svmla_x(svAll, output, input13, wgt13);
output = svmla_x(svAll, output, input14, wgt14);
output = svmla_x(svAll, output, input15, wgt15);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
auto input4 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
auto input5 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
auto input6 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
auto input7 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
auto input8 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip8[k])), 16));
auto input9 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip9[k])), 16));
auto input10 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip10[k])), 16));
auto input11 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip11[k])), 16));
auto input12 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip12[k])), 16));
auto input13 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip13[k])), 16));
auto input14 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip14[k])), 16));
auto input15 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip15[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
output = svmla_x(pg, output, input8, wgt8);
output = svmla_x(pg, output, input9, wgt9);
output = svmla_x(pg, output, input10, wgt10);
output = svmla_x(pg, output, input11, wgt11);
output = svmla_x(pg, output, input12, wgt12);
output = svmla_x(pg, output, input13, wgt13);
output = svmla_x(pg, output, input14, wgt14);
output = svmla_x(pg, output, input15, wgt15);
svst1(pg, &op[k], output);
k += vLen;
}
j += 16;
pos += 16;
}
// unrolling 8 times
while (j + 7 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
const at::BFloat16* const ip1 = &input[idx1 * block_size];
const at::BFloat16* const ip2 = &input[idx2 * block_size];
const at::BFloat16* const ip3 = &input[idx3 * block_size];
const at::BFloat16* const ip4 = &input[idx4 * block_size];
const at::BFloat16* const ip5 = &input[idx5 * block_size];
const at::BFloat16* const ip6 = &input[idx6 * block_size];
const at::BFloat16* const ip7 = &input[idx7 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
auto input4 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
auto input5 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
auto input6 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
auto input7 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
auto input4 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip4[k])), 16));
auto input5 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip5[k])), 16));
auto input6 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip6[k])), 16));
auto input7 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip7[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
svst1(pg, &op[k], output);
k += vLen;
}
j += 8;
pos += 8;
}
// unrolling 4 times
while (j + 3 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
const at::BFloat16* const ip1 = &input[idx1 * block_size];
const at::BFloat16* const ip2 = &input[idx2 * block_size];
const at::BFloat16* const ip3 = &input[idx3 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
auto input2 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip2[k])), 16));
auto input3 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip3[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
svst1(pg, &op[k], output);
k += vLen;
}
j += 4;
pos += 4;
}
// unrolling 2 times
while (j + 1 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
const at::BFloat16* const ip1 = &input[idx1 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
auto input1 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip1[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
svst1(pg, &op[k], output);
k += vLen;
}
j += 2;
pos += 2;
}
// tail loop
if (j < end_offset) {
const auto idx0 = indices[pos + 0];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
float wgt0 = 1.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
}
const at::BFloat16* const ip0 = &input[idx0 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(svAll,
svld1uh_u32(svAll, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
output = svmla_x(svAll, output, input0, wgt0);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
auto input0 = svreinterpret_f32(svlsl_x(pg,
svld1uh_u32(pg, reinterpret_cast<const uint16_t*>(&ip0[k])), 16));
output = svmla_x(pg, output, input0, wgt0);
svst1(pg, &op[k], output);
k += vLen;
}
pos ++;
}
const int64_t length = end_offset - start_offset;
if (normalize_by_lengths && length != 0) {
const float len_inv = 1.0f / length;
svbool_t pg;
int64_t j = 0;
while (j + vLen - 1 < block_size) {
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
j += vLen;
}
if (j < block_size) {
pg = svwhilelt_b32_s64(j, block_size);
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
}
}
}
return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::BFloat16* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int64_t_bfloat16_float__sve<false>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const at::BFloat16* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int64_t_bfloat16_float__sve<true>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int32_t_uint8_t_float__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const uint8_t* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
const svbool_t svAll = svptrue_b32();
const auto vLen = static_cast<int64_t>(svcntw());
int64_t pos = 0;
for (int64_t i = 0; i < output_size; ++i) {
float* const op = &out[i * block_size];
memset(op, 0, sizeof(float) * block_size);
if (pos != offsets[i] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[i];
int64_t end_offset = offsets[i + 1];
int64_t j = start_offset;
// unrolling 16 times
while (j + 15 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
const auto idx8 = indices[pos + 8];
const auto idx9 = indices[pos + 9];
const auto idx10 = indices[pos + 10];
const auto idx11 = indices[pos + 11];
const auto idx12 = indices[pos + 12];
const auto idx13 = indices[pos + 13];
const auto idx14 = indices[pos + 14];
const auto idx15 = indices[pos + 15];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
if (idx8 < 0 || idx8 >= data_size) {
return false;
}
if (idx9 < 0 || idx9 >= data_size) {
return false;
}
if (idx10 < 0 || idx10 >= data_size) {
return false;
}
if (idx11 < 0 || idx11 >= data_size) {
return false;
}
if (idx12 < 0 || idx12 >= data_size) {
return false;
}
if (idx13 < 0 || idx13 >= data_size) {
return false;
}
if (idx14 < 0 || idx14 >= data_size) {
return false;
}
if (idx15 < 0 || idx15 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float wgt8 = 1.f;
float wgt9 = 1.f;
float wgt10 = 1.f;
float wgt11 = 1.f;
float wgt12 = 1.f;
float wgt13 = 1.f;
float wgt14 = 1.f;
float wgt15 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
bio += wgt1 * scale_bias[2 * idx1 + 1];
wgt1 = wgt1 * scale_bias[2 * idx1];
bio += wgt2 * scale_bias[2 * idx2 + 1];
wgt2 = wgt2 * scale_bias[2 * idx2];
bio += wgt3 * scale_bias[2 * idx3 + 1];
wgt3 = wgt3 * scale_bias[2 * idx3];
bio += wgt4 * scale_bias[2 * idx4 + 1];
wgt4 = wgt4 * scale_bias[2 * idx4];
bio += wgt5 * scale_bias[2 * idx5 + 1];
wgt5 = wgt5 * scale_bias[2 * idx5];
bio += wgt6 * scale_bias[2 * idx6 + 1];
wgt6 = wgt6 * scale_bias[2 * idx6];
bio += wgt7 * scale_bias[2 * idx7 + 1];
wgt7 = wgt7 * scale_bias[2 * idx7];
bio += wgt8 * scale_bias[2 * idx8 + 1];
wgt8 = wgt8 * scale_bias[2 * idx8];
bio += wgt9 * scale_bias[2 * idx9 + 1];
wgt9 = wgt9 * scale_bias[2 * idx9];
bio += wgt10 * scale_bias[2 * idx10 + 1];
wgt10 = wgt10 * scale_bias[2 * idx10];
bio += wgt11 * scale_bias[2 * idx11 + 1];
wgt11 = wgt11 * scale_bias[2 * idx11];
bio += wgt12 * scale_bias[2 * idx12 + 1];
wgt12 = wgt12 * scale_bias[2 * idx12];
bio += wgt13 * scale_bias[2 * idx13 + 1];
wgt13 = wgt13 * scale_bias[2 * idx13];
bio += wgt14 * scale_bias[2 * idx14 + 1];
wgt14 = wgt14 * scale_bias[2 * idx14];
bio += wgt15 * scale_bias[2 * idx15 + 1];
wgt15 = wgt15 * scale_bias[2 * idx15];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
const uint8_t* const ip1 = &input[idx1 * block_size];
const uint8_t* const ip2 = &input[idx2 * block_size];
const uint8_t* const ip3 = &input[idx3 * block_size];
const uint8_t* const ip4 = &input[idx4 * block_size];
const uint8_t* const ip5 = &input[idx5 * block_size];
const uint8_t* const ip6 = &input[idx6 * block_size];
const uint8_t* const ip7 = &input[idx7 * block_size];
const uint8_t* const ip8 = &input[idx8 * block_size];
const uint8_t* const ip9 = &input[idx9 * block_size];
const uint8_t* const ip10 = &input[idx10 * block_size];
const uint8_t* const ip11 = &input[idx11 * block_size];
const uint8_t* const ip12 = &input[idx12 * block_size];
const uint8_t* const ip13 = &input[idx13 * block_size];
const uint8_t* const ip14 = &input[idx14 * block_size];
const uint8_t* const ip15 = &input[idx15 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k]));
auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k]));
auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k]));
auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k]));
auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k]));
auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k]));
auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k]));
auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k]));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
output = svmla_x(svAll, output, input8, wgt8);
output = svmla_x(svAll, output, input9, wgt9);
output = svmla_x(svAll, output, input10, wgt10);
output = svmla_x(svAll, output, input11, wgt11);
output = svmla_x(svAll, output, input12, wgt12);
output = svmla_x(svAll, output, input13, wgt13);
output = svmla_x(svAll, output, input14, wgt14);
output = svmla_x(svAll, output, input15, wgt15);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k]));
auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k]));
auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k]));
auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k]));
auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k]));
auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k]));
auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k]));
auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k]));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
output = svmla_x(pg, output, input8, wgt8);
output = svmla_x(pg, output, input9, wgt9);
output = svmla_x(pg, output, input10, wgt10);
output = svmla_x(pg, output, input11, wgt11);
output = svmla_x(pg, output, input12, wgt12);
output = svmla_x(pg, output, input13, wgt13);
output = svmla_x(pg, output, input14, wgt14);
output = svmla_x(pg, output, input15, wgt15);
svst1(pg, &op[k], output);
k += vLen;
}
j += 16;
pos += 16;
}
// unrolling 8 times
while (j + 7 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
bio += wgt1 * scale_bias[2 * idx1 + 1];
wgt1 = wgt1 * scale_bias[2 * idx1];
bio += wgt2 * scale_bias[2 * idx2 + 1];
wgt2 = wgt2 * scale_bias[2 * idx2];
bio += wgt3 * scale_bias[2 * idx3 + 1];
wgt3 = wgt3 * scale_bias[2 * idx3];
bio += wgt4 * scale_bias[2 * idx4 + 1];
wgt4 = wgt4 * scale_bias[2 * idx4];
bio += wgt5 * scale_bias[2 * idx5 + 1];
wgt5 = wgt5 * scale_bias[2 * idx5];
bio += wgt6 * scale_bias[2 * idx6 + 1];
wgt6 = wgt6 * scale_bias[2 * idx6];
bio += wgt7 * scale_bias[2 * idx7 + 1];
wgt7 = wgt7 * scale_bias[2 * idx7];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
const uint8_t* const ip1 = &input[idx1 * block_size];
const uint8_t* const ip2 = &input[idx2 * block_size];
const uint8_t* const ip3 = &input[idx3 * block_size];
const uint8_t* const ip4 = &input[idx4 * block_size];
const uint8_t* const ip5 = &input[idx5 * block_size];
const uint8_t* const ip6 = &input[idx6 * block_size];
const uint8_t* const ip7 = &input[idx7 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
svst1(pg, &op[k], output);
k += vLen;
}
j += 8;
pos += 8;
}
// unrolling 4 times
while (j + 3 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
bio += wgt1 * scale_bias[2 * idx1 + 1];
wgt1 = wgt1 * scale_bias[2 * idx1];
bio += wgt2 * scale_bias[2 * idx2 + 1];
wgt2 = wgt2 * scale_bias[2 * idx2];
bio += wgt3 * scale_bias[2 * idx3 + 1];
wgt3 = wgt3 * scale_bias[2 * idx3];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
const uint8_t* const ip1 = &input[idx1 * block_size];
const uint8_t* const ip2 = &input[idx2 * block_size];
const uint8_t* const ip3 = &input[idx3 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
svst1(pg, &op[k], output);
k += vLen;
}
j += 4;
pos += 4;
}
// unrolling 2 times
while (j + 1 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
bio += wgt1 * scale_bias[2 * idx1 + 1];
wgt1 = wgt1 * scale_bias[2 * idx1];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
const uint8_t* const ip1 = &input[idx1 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
svst1(pg, &op[k], output);
k += vLen;
}
j += 2;
pos += 2;
}
// tail loop
if (j < end_offset) {
const auto idx0 = indices[pos + 0];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
float wgt0 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
output = svmla_x(svAll, output, input0, wgt0);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
output = svmla_x(pg, output, input0, wgt0);
svst1(pg, &op[k], output);
k += vLen;
}
pos ++;
}
const int64_t length = end_offset - start_offset;
if (normalize_by_lengths && length != 0) {
const float len_inv = 1.0f / length;
svbool_t pg;
int64_t j = 0;
while (j + vLen - 1 < block_size) {
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
j += vLen;
}
if (j < block_size) {
pg = svwhilelt_b32_s64(j, block_size);
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
}
}
}
return pos == index_size;
}
bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const uint8_t* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int32_t_uint8_t_float__sve<false>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const uint8_t* input,
const int32_t* indices,
const int32_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int32_t_uint8_t_float__sve<true>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
template <bool IS_WEIGHT_POSITIONAL>
static bool EmbeddingLookupIdx_int64_t_uint8_t_float__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const uint8_t* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
const svbool_t svAll = svptrue_b32();
const auto vLen = static_cast<int64_t>(svcntw());
int64_t pos = 0;
for (int64_t i = 0; i < output_size; ++i) {
float* const op = &out[i * block_size];
memset(op, 0, sizeof(float) * block_size);
if (pos != offsets[i] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[i];
int64_t end_offset = offsets[i + 1];
int64_t j = start_offset;
// unrolling 16 times
while (j + 15 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
const auto idx8 = indices[pos + 8];
const auto idx9 = indices[pos + 9];
const auto idx10 = indices[pos + 10];
const auto idx11 = indices[pos + 11];
const auto idx12 = indices[pos + 12];
const auto idx13 = indices[pos + 13];
const auto idx14 = indices[pos + 14];
const auto idx15 = indices[pos + 15];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
if (idx8 < 0 || idx8 >= data_size) {
return false;
}
if (idx9 < 0 || idx9 >= data_size) {
return false;
}
if (idx10 < 0 || idx10 >= data_size) {
return false;
}
if (idx11 < 0 || idx11 >= data_size) {
return false;
}
if (idx12 < 0 || idx12 >= data_size) {
return false;
}
if (idx13 < 0 || idx13 >= data_size) {
return false;
}
if (idx14 < 0 || idx14 >= data_size) {
return false;
}
if (idx15 < 0 || idx15 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float wgt8 = 1.f;
float wgt9 = 1.f;
float wgt10 = 1.f;
float wgt11 = 1.f;
float wgt12 = 1.f;
float wgt13 = 1.f;
float wgt14 = 1.f;
float wgt15 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8];
wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9];
wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10];
wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11];
wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12];
wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13];
wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14];
wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
bio += wgt1 * scale_bias[2 * idx1 + 1];
wgt1 = wgt1 * scale_bias[2 * idx1];
bio += wgt2 * scale_bias[2 * idx2 + 1];
wgt2 = wgt2 * scale_bias[2 * idx2];
bio += wgt3 * scale_bias[2 * idx3 + 1];
wgt3 = wgt3 * scale_bias[2 * idx3];
bio += wgt4 * scale_bias[2 * idx4 + 1];
wgt4 = wgt4 * scale_bias[2 * idx4];
bio += wgt5 * scale_bias[2 * idx5 + 1];
wgt5 = wgt5 * scale_bias[2 * idx5];
bio += wgt6 * scale_bias[2 * idx6 + 1];
wgt6 = wgt6 * scale_bias[2 * idx6];
bio += wgt7 * scale_bias[2 * idx7 + 1];
wgt7 = wgt7 * scale_bias[2 * idx7];
bio += wgt8 * scale_bias[2 * idx8 + 1];
wgt8 = wgt8 * scale_bias[2 * idx8];
bio += wgt9 * scale_bias[2 * idx9 + 1];
wgt9 = wgt9 * scale_bias[2 * idx9];
bio += wgt10 * scale_bias[2 * idx10 + 1];
wgt10 = wgt10 * scale_bias[2 * idx10];
bio += wgt11 * scale_bias[2 * idx11 + 1];
wgt11 = wgt11 * scale_bias[2 * idx11];
bio += wgt12 * scale_bias[2 * idx12 + 1];
wgt12 = wgt12 * scale_bias[2 * idx12];
bio += wgt13 * scale_bias[2 * idx13 + 1];
wgt13 = wgt13 * scale_bias[2 * idx13];
bio += wgt14 * scale_bias[2 * idx14 + 1];
wgt14 = wgt14 * scale_bias[2 * idx14];
bio += wgt15 * scale_bias[2 * idx15 + 1];
wgt15 = wgt15 * scale_bias[2 * idx15];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
const uint8_t* const ip1 = &input[idx1 * block_size];
const uint8_t* const ip2 = &input[idx2 * block_size];
const uint8_t* const ip3 = &input[idx3 * block_size];
const uint8_t* const ip4 = &input[idx4 * block_size];
const uint8_t* const ip5 = &input[idx5 * block_size];
const uint8_t* const ip6 = &input[idx6 * block_size];
const uint8_t* const ip7 = &input[idx7 * block_size];
const uint8_t* const ip8 = &input[idx8 * block_size];
const uint8_t* const ip9 = &input[idx9 * block_size];
const uint8_t* const ip10 = &input[idx10 * block_size];
const uint8_t* const ip11 = &input[idx11 * block_size];
const uint8_t* const ip12 = &input[idx12 * block_size];
const uint8_t* const ip13 = &input[idx13 * block_size];
const uint8_t* const ip14 = &input[idx14 * block_size];
const uint8_t* const ip15 = &input[idx15 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k]));
auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k]));
auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k]));
auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k]));
auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k]));
auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k]));
auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k]));
auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k]));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
output = svmla_x(svAll, output, input8, wgt8);
output = svmla_x(svAll, output, input9, wgt9);
output = svmla_x(svAll, output, input10, wgt10);
output = svmla_x(svAll, output, input11, wgt11);
output = svmla_x(svAll, output, input12, wgt12);
output = svmla_x(svAll, output, input13, wgt13);
output = svmla_x(svAll, output, input14, wgt14);
output = svmla_x(svAll, output, input15, wgt15);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k]));
auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k]));
auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k]));
auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k]));
auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k]));
auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k]));
auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k]));
auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k]));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
output = svmla_x(pg, output, input8, wgt8);
output = svmla_x(pg, output, input9, wgt9);
output = svmla_x(pg, output, input10, wgt10);
output = svmla_x(pg, output, input11, wgt11);
output = svmla_x(pg, output, input12, wgt12);
output = svmla_x(pg, output, input13, wgt13);
output = svmla_x(pg, output, input14, wgt14);
output = svmla_x(pg, output, input15, wgt15);
svst1(pg, &op[k], output);
k += vLen;
}
j += 16;
pos += 16;
}
// unrolling 8 times
while (j + 7 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
const auto idx4 = indices[pos + 4];
const auto idx5 = indices[pos + 5];
const auto idx6 = indices[pos + 6];
const auto idx7 = indices[pos + 7];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
if (idx4 < 0 || idx4 >= data_size) {
return false;
}
if (idx5 < 0 || idx5 >= data_size) {
return false;
}
if (idx6 < 0 || idx6 >= data_size) {
return false;
}
if (idx7 < 0 || idx7 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float wgt4 = 1.f;
float wgt5 = 1.f;
float wgt6 = 1.f;
float wgt7 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4];
wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5];
wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6];
wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
bio += wgt1 * scale_bias[2 * idx1 + 1];
wgt1 = wgt1 * scale_bias[2 * idx1];
bio += wgt2 * scale_bias[2 * idx2 + 1];
wgt2 = wgt2 * scale_bias[2 * idx2];
bio += wgt3 * scale_bias[2 * idx3 + 1];
wgt3 = wgt3 * scale_bias[2 * idx3];
bio += wgt4 * scale_bias[2 * idx4 + 1];
wgt4 = wgt4 * scale_bias[2 * idx4];
bio += wgt5 * scale_bias[2 * idx5 + 1];
wgt5 = wgt5 * scale_bias[2 * idx5];
bio += wgt6 * scale_bias[2 * idx6 + 1];
wgt6 = wgt6 * scale_bias[2 * idx6];
bio += wgt7 * scale_bias[2 * idx7 + 1];
wgt7 = wgt7 * scale_bias[2 * idx7];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
const uint8_t* const ip1 = &input[idx1 * block_size];
const uint8_t* const ip2 = &input[idx2 * block_size];
const uint8_t* const ip3 = &input[idx3 * block_size];
const uint8_t* const ip4 = &input[idx4 * block_size];
const uint8_t* const ip5 = &input[idx5 * block_size];
const uint8_t* const ip6 = &input[idx6 * block_size];
const uint8_t* const ip7 = &input[idx7 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k]));
auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k]));
auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k]));
auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k]));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
output = svmla_x(svAll, output, input4, wgt4);
output = svmla_x(svAll, output, input5, wgt5);
output = svmla_x(svAll, output, input6, wgt6);
output = svmla_x(svAll, output, input7, wgt7);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k]));
auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k]));
auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k]));
auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k]));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
output = svmla_x(pg, output, input4, wgt4);
output = svmla_x(pg, output, input5, wgt5);
output = svmla_x(pg, output, input6, wgt6);
output = svmla_x(pg, output, input7, wgt7);
svst1(pg, &op[k], output);
k += vLen;
}
j += 8;
pos += 8;
}
// unrolling 4 times
while (j + 3 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
const auto idx2 = indices[pos + 2];
const auto idx3 = indices[pos + 3];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
if (idx2 < 0 || idx2 >= data_size) {
return false;
}
if (idx3 < 0 || idx3 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float wgt2 = 1.f;
float wgt3 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2];
wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
bio += wgt1 * scale_bias[2 * idx1 + 1];
wgt1 = wgt1 * scale_bias[2 * idx1];
bio += wgt2 * scale_bias[2 * idx2 + 1];
wgt2 = wgt2 * scale_bias[2 * idx2];
bio += wgt3 * scale_bias[2 * idx3 + 1];
wgt3 = wgt3 * scale_bias[2 * idx3];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
const uint8_t* const ip1 = &input[idx1 * block_size];
const uint8_t* const ip2 = &input[idx2 * block_size];
const uint8_t* const ip3 = &input[idx3 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k]));
auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k]));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
output = svmla_x(svAll, output, input2, wgt2);
output = svmla_x(svAll, output, input3, wgt3);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k]));
auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k]));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
output = svmla_x(pg, output, input2, wgt2);
output = svmla_x(pg, output, input3, wgt3);
svst1(pg, &op[k], output);
k += vLen;
}
j += 4;
pos += 4;
}
// unrolling 2 times
while (j + 1 < end_offset) {
const auto idx0 = indices[pos + 0];
const auto idx1 = indices[pos + 1];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
if (idx1 < 0 || idx1 >= data_size) {
return false;
}
float wgt0 = 1.f;
float wgt1 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
bio += wgt1 * scale_bias[2 * idx1 + 1];
wgt1 = wgt1 * scale_bias[2 * idx1];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
const uint8_t* const ip1 = &input[idx1 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k]));
output = svmla_x(svAll, output, input0, wgt0);
output = svmla_x(svAll, output, input1, wgt1);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k]));
output = svmla_x(pg, output, input0, wgt0);
output = svmla_x(pg, output, input1, wgt1);
svst1(pg, &op[k], output);
k += vLen;
}
j += 2;
pos += 2;
}
// tail loop
if (j < end_offset) {
const auto idx0 = indices[pos + 0];
if (idx0 < 0 || idx0 >= data_size) {
return false;
}
float wgt0 = 1.f;
float bio = 0.f;
if (weights) {
wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0];
}
if (scale_bias) {
bio += wgt0 * scale_bias[2 * idx0 + 1];
wgt0 = wgt0 * scale_bias[2 * idx0];
}
const uint8_t* const ip0 = &input[idx0 * block_size];
svbool_t pg;
int64_t k = 0;
while (k + vLen - 1 < block_size) {
auto output = svld1(svAll, &op[k]);
output = svadd_x(svAll, output, bio);
auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k]));
output = svmla_x(svAll, output, input0, wgt0);
svst1(svAll, &op[k], output);
k += vLen;
}
if (k < block_size) {
pg = svwhilelt_b32_s64(k, block_size);
auto output = svld1(pg, &op[k]);
output = svadd_x(pg, output, bio);
auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k]));
output = svmla_x(pg, output, input0, wgt0);
svst1(pg, &op[k], output);
k += vLen;
}
pos ++;
}
const int64_t length = end_offset - start_offset;
if (normalize_by_lengths && length != 0) {
const float len_inv = 1.0f / length;
svbool_t pg;
int64_t j = 0;
while (j + vLen - 1 < block_size) {
svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));
j += vLen;
}
if (j < block_size) {
pg = svwhilelt_b32_s64(j, block_size);
svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));
}
}
}
return pos == index_size;
}
bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const uint8_t* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int64_t_uint8_t_float__sve<false>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const uint8_t* input,
const int64_t* indices,
const int64_t* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
float* out) {
return EmbeddingLookupIdx_int64_t_uint8_t_float__sve<true>(
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets,
weights,
scale_bias,
normalize_by_lengths,
out);
}
} // namespace caffe2