mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Optimize SVE embedding performance (#150176)
Change loop unrolling strategy. Previously, the script only unrolls the inner loop over block_size when block size is multiple of vector length. This version instead unrolls the outer loop which reduces the number of load/store for accumulation into the output array and improves performance for cases when block size is not multiple of vector length. Benchmarking script: ```python # SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com> # SPDX-License-Identifier: BSD-3-Clause import torch import torch.nn as nn import numpy as np import time import sys np.random.seed(0) torch.manual_seed(0) num_embeddings = 400000 embedding_dim = int(sys.argv[1]) multi_hot = 100 batch_size = 400 nrun = 1000 class SimpleEmbeddingBagModel(nn.Module): def __init__(self, num_embeddings, embedding_dim): super(SimpleEmbeddingBagModel, self).__init__() weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)).to(torch.float16) # Defining the EmbeddingBag layer self.embedding_bag = torch.nn.EmbeddingBag(num_embeddings, embedding_dim, _weight=weights, mode='sum', include_last_offset=True, dtype=torch.float32) def forward(self, input, offsets): # Forward pass through the EmbeddingBag layer result32 = self.embedding_bag(input, offsets, per_sample_weights=None) return result32 # Instantiate the model model = SimpleEmbeddingBagModel(num_embeddings=num_embeddings, embedding_dim=embedding_dim) model.eval() # Example input input_tensor = torch.randint(0, num_embeddings, (batch_size * multi_hot,), dtype=torch.long) offsets = torch.tensor(range(0, batch_size * multi_hot + 1, multi_hot)) with torch.no_grad(): # warm up output32 = model(input_tensor, offsets) ti = time.time_ns() for i in range(nrun): _ = model(input_tensor, offsets) tf = time.time_ns() print("{:3d} {:.3E}".format(embedding_dim, (tf-ti)/nrun/1.e6)) ``` Speedup on NEOVERSEV1 with 1 thread  Pull Request resolved: https://github.com/pytorch/pytorch/pull/150176 Approved by: https://github.com/digantdesai, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
7d2411d30e
commit
6fcffd8cd1
File diff suppressed because it is too large
Load Diff
@ -4,197 +4,32 @@ import sys
|
||||
|
||||
|
||||
# Unroll loops when block_size is a multiple of vector length.
|
||||
def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
def compute(regid, InType, use_weights):
|
||||
def unroll(num_unrolls, IndexType, InType, OutType):
|
||||
def compute_output(num_unrolls, InType, is_main):
|
||||
code = []
|
||||
|
||||
pred = "svAll" if is_main else "pg"
|
||||
if InType == "float":
|
||||
code.append(
|
||||
f" vsum{regid} =\n"
|
||||
" svmad_f32_x("
|
||||
f"svAll, vwgt, svld1_f32(svAll, &ip[{regid} * vLen]),"
|
||||
f" vsum{regid});"
|
||||
)
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" output = svmla_x({pred}, output, svld1(svAll, &ip{i}[k]), wgt{i});")
|
||||
elif InType == "at::Half":
|
||||
code.append(
|
||||
f" vsum{regid} = svmad_f32_x(\n"
|
||||
" svAll,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_f16_x(\n"
|
||||
" svAll,\n"
|
||||
" svreinterpret_f16_u32(svld1uh_u32(\n"
|
||||
" svAll, reinterpret_cast<const uint16_t*>("
|
||||
f"&ip[{regid} * vLen])))),\n"
|
||||
f" vsum{regid});"
|
||||
)
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" auto input{i} = svcvt_f32_x({pred}, svreinterpret_f16(\n"
|
||||
f" svld1uh_u32({pred}, reinterpret_cast<const uint16_t*>(&ip{i}[k]))));")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});")
|
||||
elif InType == "at::BFloat16":
|
||||
code.append(
|
||||
f" vsum{regid} = svmad_f32_x(\n"
|
||||
" svAll,\n"
|
||||
" vwgt,\n"
|
||||
" svreinterpret_f32_u32(svlsl_n_u32_x(\n"
|
||||
" svAll,\n"
|
||||
" svld1uh_u32(\n"
|
||||
" svAll, reinterpret_cast<const uint16_t*>("
|
||||
f"&ip[{regid} * vLen])),\n"
|
||||
" 16)),\n"
|
||||
f" vsum{regid});"
|
||||
)
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" auto input{i} = svreinterpret_f32(svlsl_x({pred},\n"
|
||||
f" svld1uh_u32({pred}, reinterpret_cast<const uint16_t*>(&ip{i}[k])), 16));")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});")
|
||||
elif InType == "uint8_t":
|
||||
code.append(
|
||||
f" vsum{regid} = svmad_f32_x(\n"
|
||||
" svAll,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_u32_x(svAll,"
|
||||
f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n"
|
||||
f" svadd_f32_x(svAll, vsum{regid}, vbio));"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown datatype "{InType}"')
|
||||
|
||||
return code
|
||||
|
||||
code = []
|
||||
code.append(f" // unrolling {num_unrolls} times")
|
||||
|
||||
code.append(" for (int64_t i = 0; i < output_size; ++i) {")
|
||||
|
||||
code.append(" " + OutType + "* const op = &out[i * block_size];")
|
||||
code.append(
|
||||
" if (pos != offsets[i] - offsets[0]) {\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
|
||||
# Initialise vector sum registers
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" svfloat32_t vsum{i} = svdup_n_f32(0);")
|
||||
|
||||
# inner loop
|
||||
code.append("""\
|
||||
int64_t start_offset = offsets[i];
|
||||
int64_t end_offset = offsets[i + 1];""")
|
||||
code.append(
|
||||
" for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {"
|
||||
)
|
||||
|
||||
code.append(" const auto idx = indices[pos];")
|
||||
code.append(
|
||||
" if (idx < 0 || idx >= data_size) {\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
|
||||
if InType == "uint8_t":
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" " + OutType + " bio{};")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" if (scale_bias) {")
|
||||
code.append(" bio = wgt * scale_bias[2 * idx + 1];")
|
||||
code.append(" wgt = wgt * scale_bias[2 * idx];")
|
||||
code.append(" }")
|
||||
code.append(" svfloat32_t vbio = svdup_n_f32(bio);")
|
||||
else:
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||
)
|
||||
code.append(" }")
|
||||
|
||||
code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);")
|
||||
code.append(f" const {InType}* const ip = &input[idx * block_size];")
|
||||
code.append(" // weight * input + out")
|
||||
|
||||
for i in range(num_unrolls):
|
||||
code.extend(compute(i, InType, use_weights))
|
||||
|
||||
code.append(" ++pos;")
|
||||
code.append(" }")
|
||||
|
||||
code.append(" // Normalisation")
|
||||
code.append(" const int64_t length = end_offset - start_offset;")
|
||||
code.append(" if (normalize_by_lengths && length != 0) {")
|
||||
code.append(" const float len_inv = 1.0f / length;")
|
||||
code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
||||
|
||||
for i in range(num_unrolls):
|
||||
code.append(
|
||||
f" svst1_f32(svAll, &op[{i} * vLen],"
|
||||
+ f" svmul_f32_x(svAll, vsum{i}, vlen_inv));"
|
||||
)
|
||||
|
||||
code.append(" } else {")
|
||||
# inv of length
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" svst1_f32(svAll, &op[{i} * vLen], vsum{i});")
|
||||
|
||||
code.append(" }")
|
||||
code.append(" }")
|
||||
return code
|
||||
|
||||
|
||||
# Handle the case where block_size is not a multiple of vector length.
|
||||
def generic(IndexType, InType, OutType, use_weights):
|
||||
def compute(InType, use_weights):
|
||||
code = []
|
||||
if InType == "float":
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg,\n"
|
||||
" &op[k],\n"
|
||||
" svmad_f32_x(\n"
|
||||
" pg, vwgt, svld1_f32(pg, &ip[k]),"
|
||||
" svld1_f32(pg, &op[k])));"
|
||||
)
|
||||
elif InType == "at::Half":
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg,\n"
|
||||
" &op[k],\n"
|
||||
" svmad_f32_x(\n"
|
||||
" pg,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_f16_x(\n"
|
||||
" pg,\n"
|
||||
" svreinterpret_f16_u32(svld1uh_u32(\n"
|
||||
" pg,"
|
||||
" reinterpret_cast<const uint16_t*>(&ip[k])))),\n"
|
||||
" svld1_f32(pg, &op[k])));"
|
||||
)
|
||||
elif InType == "at::BFloat16":
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg,\n"
|
||||
" &op[k],\n"
|
||||
" svmad_f32_x(\n"
|
||||
" pg,\n"
|
||||
" vwgt,\n"
|
||||
" svreinterpret_f32_u32(svlsl_n_u32_x(\n"
|
||||
" pg,\n"
|
||||
" svld1uh_u32(\n"
|
||||
" pg,"
|
||||
" reinterpret_cast<const uint16_t*>(&ip[k])),\n"
|
||||
" 16)),\n"
|
||||
" svld1_f32(pg, &op[k])));"
|
||||
)
|
||||
elif InType == "uint8_t":
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg,\n"
|
||||
" &op[k],\n"
|
||||
" svmad_f32_x(\n"
|
||||
" pg,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_u32_x(pg,"
|
||||
" svld1ub_u32(pg, &ip[k])),\n"
|
||||
" svadd_f32_x(pg,"
|
||||
" svld1_f32(pg, &op[k]), vbio)));"
|
||||
)
|
||||
code.append(f" output = svadd_x({pred}, output, bio);")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" auto input{i} = svcvt_f32_x({pred}, svld1ub_u32({pred}, &ip{i}[k]));")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});")
|
||||
else:
|
||||
raise ValueError(f'Unknown datatype "{InType}"')
|
||||
|
||||
@ -202,91 +37,72 @@ def generic(IndexType, InType, OutType, use_weights):
|
||||
|
||||
code = []
|
||||
|
||||
code.append(" for (int64_t i = 0; i < output_size; ++i) {")
|
||||
if num_unrolls == 1:
|
||||
code.append(f" // tail loop")
|
||||
code.append(" if (j < end_offset) {")
|
||||
else:
|
||||
code.append(f" // unrolling {num_unrolls} times")
|
||||
code.append(f" while (j + {num_unrolls - 1} < end_offset) {{")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" const auto idx{i} = indices[pos + {i}];")
|
||||
|
||||
code.append(" " + OutType + "* const op = &out[i * block_size];")
|
||||
|
||||
# initialize to 0
|
||||
code.append(" memset(op, 0, sizeof(float) * block_size);")
|
||||
|
||||
# inner loop
|
||||
code.append(
|
||||
" if (pos != offsets[i] - offsets[0]) {\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
code.append(
|
||||
" int64_t start_offset = offsets[i];\n"
|
||||
+ " int64_t end_offset = offsets[i + 1];"
|
||||
)
|
||||
code.append(
|
||||
" for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {"
|
||||
)
|
||||
|
||||
code.append(" const auto idx = indices[pos];")
|
||||
code.append(
|
||||
" if (idx < 0 || idx >= data_size) {\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
# check indices
|
||||
for i in range(num_unrolls):
|
||||
code.append(
|
||||
f" if (idx{i} < 0 || idx{i} >= data_size) {{\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
|
||||
if InType == "uint8_t":
|
||||
code.append(" // unimplemented")
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" " + OutType + " bio{};")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" if (scale_bias) {")
|
||||
code.append(" bio = wgt * scale_bias[2 * idx + 1];")
|
||||
code.append(" wgt = wgt * scale_bias[2 * idx];")
|
||||
code.append(" }")
|
||||
code.append(" svfloat32_t vbio = svdup_n_f32(bio);")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" {OutType} wgt{i} = 1.f;")
|
||||
code.append(f" {OutType} bio = 0.f;")
|
||||
else:
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||
)
|
||||
code.append(" }")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" {OutType} wgt{i} = 1.f;")
|
||||
|
||||
code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);")
|
||||
code.append(f" const {InType}* ip = &input[idx * block_size];")
|
||||
|
||||
# compute and store main loop
|
||||
code.append(" svbool_t pg;")
|
||||
code.append(" for (int64_t k = 0;")
|
||||
code.append(
|
||||
" svptest_first(svAll, pg = svwhilelt_b32_s64(" + "k, block_size));"
|
||||
)
|
||||
code.append(" k += vLen) {")
|
||||
code.extend(compute(InType, use_weights))
|
||||
code.append(" }\n")
|
||||
code.append(" ++pos;")
|
||||
code.append(" if (weights) {")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" wgt{i} = weights[IS_WEIGHT_POSITIONAL ? (j + {i} - start_offset) : pos + {i}];")
|
||||
code.append(" }")
|
||||
if InType == "uint8_t":
|
||||
code.append(" if (scale_bias) {")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" bio += wgt{i} * scale_bias[2 * idx{i} + 1];")
|
||||
code.append(f" wgt{i} = wgt{i} * scale_bias[2 * idx{i}];")
|
||||
code.append(" }")
|
||||
|
||||
code.append(" const int64_t length = end_offset - start_offset;\n")
|
||||
code.append(" if (normalize_by_lengths && length != 0) {")
|
||||
code.append(" const float len_inv = 1.0f / length;")
|
||||
code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
||||
code.append(" svbool_t pg;")
|
||||
code.append(
|
||||
" for (int64_t j = 0;\n"
|
||||
" svptest_first(svAll, pg = svwhilelt_b32_s64("
|
||||
"j, block_size));"
|
||||
)
|
||||
code.append(" j += vLen) {")
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));"
|
||||
)
|
||||
code.append(" }")
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" const {InType}* const ip{i} = &input[idx{i} * block_size];")
|
||||
|
||||
# compute and store
|
||||
code.append(" svbool_t pg;")
|
||||
code.append(" int64_t k = 0;")
|
||||
# main loop
|
||||
code.append(" while (k + vLen - 1 < block_size) {")
|
||||
code.append(" auto output = svld1(svAll, &op[k]);")
|
||||
code.extend(compute_output(num_unrolls, InType, True))
|
||||
code.append(" svst1(svAll, &op[k], output);")
|
||||
code.append(" k += vLen;")
|
||||
code.append(" }")
|
||||
# tail loop
|
||||
code.append(" if (k < block_size) {")
|
||||
code.append(" pg = svwhilelt_b32_s64(k, block_size);")
|
||||
code.append(" auto output = svld1(pg, &op[k]);")
|
||||
code.extend(compute_output(num_unrolls, InType, False))
|
||||
code.append(" svst1(pg, &op[k], output);")
|
||||
code.append(" k += vLen;")
|
||||
code.append(" }")
|
||||
if num_unrolls == 1:
|
||||
code.append(" pos ++;")
|
||||
else:
|
||||
code.append(f" j += {num_unrolls};")
|
||||
code.append(f" pos += {num_unrolls};")
|
||||
|
||||
code.append(" }")
|
||||
return code
|
||||
|
||||
return code
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -352,22 +168,47 @@ def main():
|
||||
code.append(" const auto vLen = static_cast<int64_t>(svcntw());")
|
||||
code.append(" int64_t pos = 0;")
|
||||
|
||||
code.append(" if (block_size == 32 * vLen) {")
|
||||
code += unroll(32, IndexType, InType, OutType, True)
|
||||
code.append(" } else if (block_size == 16 * vLen) {")
|
||||
code += unroll(16, IndexType, InType, OutType, True)
|
||||
code.append(" } else if (block_size == 8 * vLen) {")
|
||||
code += unroll(8, IndexType, InType, OutType, True)
|
||||
code.append(" } else if (block_size == 4 * vLen) {")
|
||||
code += unroll(4, IndexType, InType, OutType, True)
|
||||
code.append(" } else if (block_size == 2 * vLen) {")
|
||||
code += unroll(2, IndexType, InType, OutType, True)
|
||||
code.append(" } else {")
|
||||
code.append(" // generic code:")
|
||||
code += generic(IndexType, InType, OutType, True)
|
||||
code.append(" for (int64_t i = 0; i < output_size; ++i) {")
|
||||
code.append(" " + OutType + "* const op = &out[i * block_size];")
|
||||
|
||||
# initialize to 0
|
||||
code.append(" memset(op, 0, sizeof(float) * block_size);")
|
||||
|
||||
# inner loop
|
||||
code.append(
|
||||
" if (pos != offsets[i] - offsets[0]) {\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
code.append(
|
||||
" int64_t start_offset = offsets[i];\n"
|
||||
+ " int64_t end_offset = offsets[i + 1];"
|
||||
)
|
||||
code.append(" int64_t j = start_offset;")
|
||||
|
||||
code += unroll(16, IndexType, InType, OutType)
|
||||
code += unroll(8, IndexType, InType, OutType)
|
||||
code += unroll(4, IndexType, InType, OutType)
|
||||
code += unroll(2, IndexType, InType, OutType)
|
||||
code += unroll(1, IndexType, InType, OutType)
|
||||
|
||||
code.append(" const int64_t length = end_offset - start_offset;\n")
|
||||
code.append(" if (normalize_by_lengths && length != 0) {")
|
||||
code.append(" const float len_inv = 1.0f / length;")
|
||||
code.append(" svbool_t pg;")
|
||||
code.append(" int64_t j = 0;")
|
||||
code.append(" while (j + vLen - 1 < block_size) {")
|
||||
code.append(" svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));")
|
||||
code.append(" j += vLen;")
|
||||
code.append(" }")
|
||||
code.append(" if (j < block_size) {")
|
||||
code.append(" pg = svwhilelt_b32_s64(j, block_size);")
|
||||
code.append(" svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));")
|
||||
code.append(" }")
|
||||
code.append(" }")
|
||||
|
||||
code.append(" }")
|
||||
code.append(" return pos == index_size;")
|
||||
|
||||
code.append("}")
|
||||
|
||||
for is_weight_positional in ["false", "true"]:
|
||||
|
Reference in New Issue
Block a user