mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add SVE implementation of embedding_lookup_idx (#133995)
Adds an accelerated version of the embedding_lookup_idx perfkernels. This is done via a python codegen file similarly to `caffe2/perfkernels/hp_emblookup_codegen.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133995 Approved by: https://github.com/malfet, https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
b09d6f3a7d
commit
e1e6417d4c
@ -10,9 +10,13 @@ endif()
|
||||
file(GLOB common_srcs *.cc)
|
||||
file(GLOB avx_srcs *_avx.cc)
|
||||
file(GLOB avx2_srcs *_avx2.cc)
|
||||
# exclude avx and avx2 srcs from common_srcs
|
||||
file(GLOB avx512_srcs *_avx512.cc)
|
||||
file(GLOB sve_srcs *_sve.cc)
|
||||
# exclude avx, avx2, avx512, and sve srcs from common_srcs
|
||||
exclude(common_srcs "${common_srcs}" ${avx_srcs})
|
||||
exclude(common_srcs "${common_srcs}" ${avx2_srcs})
|
||||
exclude(common_srcs "${common_srcs}" ${avx512_srcs})
|
||||
exclude(common_srcs "${common_srcs}" ${sve_srcs})
|
||||
|
||||
# We will always build common srcs.
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs})
|
||||
@ -42,6 +46,22 @@ if(CXX_AVX2_FOUND)
|
||||
"Caffe2_perfkernels_avx2_interface")
|
||||
endif()
|
||||
|
||||
# We will only build the SVE perfkernel files if the compiler supports SVE
|
||||
# extensions.
|
||||
if(CXX_SVE_FOUND)
|
||||
add_library(Caffe2_perfkernels_sve STATIC ${sve_srcs})
|
||||
target_link_libraries(Caffe2_perfkernels_sve PRIVATE c10)
|
||||
install(TARGETS Caffe2_perfkernels_sve
|
||||
ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}")
|
||||
|
||||
target_compile_options(Caffe2_perfkernels_sve PRIVATE "-march=armv8-a+sve")
|
||||
|
||||
caffe2_interface_library(
|
||||
Caffe2_perfkernels_sve Caffe2_perfkernels_sve_interface)
|
||||
list(APPEND
|
||||
Caffe2_DEPENDENCY_WHOLE_LINK_LIBS "Caffe2_perfkernels_sve_interface")
|
||||
endif()
|
||||
|
||||
# TODO(jiayq): currently, we only implement the very base files for the
|
||||
# perfkernels. This is because to implement avx and avx2 files, we actually
|
||||
# need to set up different compilation units and this is a bit more involving
|
||||
|
@ -61,9 +61,8 @@ In foo.cc, do:
|
||||
// we use cpuinfo to identify cpu support and run the proper functions.
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(CAFFE2_PERF_WITH_AVX512) || defined(CAFFE2_PERF_WITH_AVX2) \
|
||||
|| defined(CAFFE2_PERF_WITH_AVX)
|
||||
#if defined(CAFFE2_PERF_WITH_SVE) || defined(CAFFE2_PERF_WITH_AVX512) || \
|
||||
defined(CAFFE2_PERF_WITH_AVX2) || defined(CAFFE2_PERF_WITH_AVX)
|
||||
#include <cpuinfo.h>
|
||||
#endif
|
||||
|
||||
@ -72,6 +71,18 @@ In foo.cc, do:
|
||||
|
||||
#define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__);
|
||||
|
||||
#ifdef CAFFE2_PERF_WITH_SVE
|
||||
#define SVE_DO(funcname, ...) \
|
||||
{ \
|
||||
static const bool isDo = cpuinfo_initialize() && cpuinfo_has_arm_sve(); \
|
||||
if (isDo) { \
|
||||
return funcname##__sve(__VA_ARGS__); \
|
||||
} \
|
||||
}
|
||||
#else // CAFFE2_PERF_WITH_SVE
|
||||
#define SVE_DO(funcname, ...)
|
||||
#endif // CAFFE2_PERF_WITH_SVE
|
||||
|
||||
#ifdef CAFFE2_PERF_WITH_AVX512
|
||||
#define AVX512_DO(funcname, ...) \
|
||||
{ \
|
||||
|
22
caffe2/perfkernels/common_sve.cc
Normal file
22
caffe2/perfkernels/common_sve.cc
Normal file
@ -0,0 +1,22 @@
|
||||
// This file is here merely to check that the flags are not mixed up: for
|
||||
// example, if your compiler did not specify -march=armv8-a+sve, you should not
|
||||
// provide the CAFFE2_PERF_WITH_SVE macro.
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
#ifdef CAFFE2_PERF_WITH_SVE
|
||||
#ifndef __ARM_FEATURE_SVE
|
||||
#error( \
|
||||
"You found a build system error: CAFFE2_PERF_WITH_SVE is defined" \
|
||||
"but __ARM_FEATURE_SVE is not defined (via e.g. -march=armv8-a+sve).");
|
||||
#endif // __ARM_FEATURE_SVE
|
||||
#endif // CAFFE2_PERF_WITH_SVE
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
#ifndef CAFFE2_PERF_WITH_SVE
|
||||
#error( \
|
||||
"You found a build system error: __SVE__ is defined \
|
||||
(via e.g. -march=armv8-a+sve) " \
|
||||
"but CAFFE2_PERF_WITH_SVE is not defined.");
|
||||
#endif // CAFFE2_PERF_WITH_SVE
|
||||
#endif
|
@ -88,7 +88,7 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||
const int64_t data_size, \
|
||||
const InType* input, \
|
||||
const IndexType* indices, \
|
||||
const IndexType* offsets, \
|
||||
const IndexType* offsets, \
|
||||
const float* weights, \
|
||||
const float* scale_bias, \
|
||||
bool normalize_by_lengths, \
|
||||
@ -113,6 +113,9 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||
decltype( \
|
||||
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \
|
||||
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \
|
||||
decltype( \
|
||||
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \
|
||||
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__sve; \
|
||||
bool \
|
||||
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
|
||||
const int64_t block_size, \
|
||||
@ -121,7 +124,7 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||
const int64_t data_size, \
|
||||
const InType* input, \
|
||||
const IndexType* indices, \
|
||||
const IndexType* offsets, \
|
||||
const IndexType* offsets, \
|
||||
const float* weights, \
|
||||
const float* scale_bias, \
|
||||
bool normalize_by_lengths, \
|
||||
@ -131,6 +134,19 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||
} else { \
|
||||
CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \
|
||||
} \
|
||||
SVE_DO( \
|
||||
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
|
||||
block_size, \
|
||||
output_size, \
|
||||
index_size, \
|
||||
data_size, \
|
||||
input, \
|
||||
indices, \
|
||||
offsets, \
|
||||
weights, \
|
||||
scale_bias, \
|
||||
normalize_by_lengths, \
|
||||
out); \
|
||||
AVX2_FMA_DO( \
|
||||
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
|
||||
block_size, \
|
||||
@ -166,7 +182,7 @@ static bool EmbeddingLookupGenericSlowIdx(
|
||||
const int64_t data_size, \
|
||||
const InType* input, \
|
||||
const IndexType* indices, \
|
||||
const IndexType* offsets, \
|
||||
const IndexType* offsets, \
|
||||
const float* weights, \
|
||||
const float* scale_bias, \
|
||||
bool normalize_by_lengths, \
|
||||
|
6769
caffe2/perfkernels/embedding_lookup_idx_sve.cc
Normal file
6769
caffe2/perfkernels/embedding_lookup_idx_sve.cc
Normal file
File diff suppressed because it is too large
Load Diff
408
caffe2/perfkernels/sve_emblookup_codegen.py
Normal file
408
caffe2/perfkernels/sve_emblookup_codegen.py
Normal file
@ -0,0 +1,408 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
# Unroll loops when block_size is a multiple of vector length.
|
||||
def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
def compute(regid, InType, use_weights):
|
||||
code = []
|
||||
|
||||
if InType == "float":
|
||||
code.append(
|
||||
f" vsum{regid} =\n"
|
||||
" svmad_f32_x("
|
||||
f"svAll, vwgt, svld1_f32(svAll, &ip[{regid} * vLen]),"
|
||||
f" vsum{regid});"
|
||||
)
|
||||
elif InType == "at::Half":
|
||||
code.append(
|
||||
f" vsum{regid} = svmad_f32_x(\n"
|
||||
" svAll,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_f16_x(\n"
|
||||
" svAll,\n"
|
||||
" svreinterpret_f16_u32(svld1uh_u32(\n"
|
||||
" svAll, reinterpret_cast<const uint16_t*>("
|
||||
f"&ip[{regid} * vLen])))),\n" # noqa
|
||||
f" vsum{regid});"
|
||||
)
|
||||
elif InType == "at::BFloat16":
|
||||
code.append(
|
||||
f" vsum{regid} = svmad_f32_x(\n"
|
||||
" svAll,\n"
|
||||
" vwgt,\n"
|
||||
" svreinterpret_f32_u32(svlsl_n_u32_x(\n"
|
||||
" svAll,\n"
|
||||
" svld1uh_u32(\n"
|
||||
" svAll, reinterpret_cast<const uint16_t*>("
|
||||
f"&ip[{regid} * vLen])),\n"
|
||||
" 16)),\n" # noqa
|
||||
f" vsum{regid});"
|
||||
)
|
||||
elif InType == "uint8_t":
|
||||
code.append(
|
||||
f" vsum{regid} = svmad_f32_x(\n"
|
||||
" svAll,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_u32_x(svAll,"
|
||||
f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" # noqa
|
||||
f" svadd_f32_x(svAll, vsum{regid}, vbio));"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown datatype \"{InType}\"")
|
||||
|
||||
return code
|
||||
|
||||
code = []
|
||||
code.append(f" // unrolling {num_unrolls} times")
|
||||
|
||||
code.append(" for (int64_t i = 0; i < output_size; ++i) {")
|
||||
|
||||
code.append(" " + OutType + "* const op = &out[i * block_size];")
|
||||
code.append(
|
||||
" if (pos != offsets[i] - offsets[0]) {\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
|
||||
# Initialise vector sum registers
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" svfloat32_t vsum{i} = svdup_n_f32(0);")
|
||||
|
||||
# inner loop
|
||||
code.append("""\
|
||||
int64_t start_offset = offsets[i];
|
||||
int64_t end_offset = offsets[i + 1];""")
|
||||
code.append(
|
||||
" for ("
|
||||
+ "int64_t"
|
||||
+ " j = start_offset; j < end_offset; ++j) {" # noqa
|
||||
)
|
||||
|
||||
code.append(" const auto idx = indices[pos];")
|
||||
code.append(
|
||||
" if (idx < 0 || idx >= data_size) {\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
|
||||
if InType == "uint8_t":
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" " + OutType + " bio{};")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" if (scale_bias) {")
|
||||
code.append(" bio = wgt * scale_bias[2 * idx + 1];")
|
||||
code.append(" wgt = wgt * scale_bias[2 * idx];")
|
||||
code.append(" }")
|
||||
code.append(" svfloat32_t vbio = svdup_n_f32(bio);")
|
||||
else:
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
||||
)
|
||||
code.append(" }")
|
||||
|
||||
code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);")
|
||||
code.append(f" const {InType}* const ip = &input[idx * block_size];")
|
||||
code.append(" // weight * input + out")
|
||||
|
||||
for i in range(num_unrolls):
|
||||
code.extend(compute(i, InType, use_weights))
|
||||
|
||||
code.append(" ++pos;")
|
||||
code.append(" }")
|
||||
|
||||
code.append(" // Normalisation")
|
||||
code.append(" const int64_t length = end_offset - start_offset;")
|
||||
code.append(" if (normalize_by_lengths && length != 0) {")
|
||||
code.append(" const float len_inv = 1.0f / length;")
|
||||
code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
||||
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" svst1_f32(svAll, &op[{i} * vLen],"
|
||||
+ f" svmul_f32_x(svAll, vsum{i}, vlen_inv));")
|
||||
|
||||
code.append(" } else {")
|
||||
# inv of length
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" svst1_f32(svAll, &op[{i} * vLen], vsum{i});")
|
||||
|
||||
code.append(" }")
|
||||
code.append(" }")
|
||||
return code
|
||||
|
||||
|
||||
# Handle the case where block_size is not a multiple of vector length.
|
||||
def generic(IndexType, InType, OutType, use_weights):
|
||||
def compute(InType, use_weights):
|
||||
code = []
|
||||
if InType == "float":
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg,\n"
|
||||
" &op[k],\n"
|
||||
" svmad_f32_x(\n"
|
||||
" pg, vwgt, svld1_f32(pg, &ip[k]),"
|
||||
" svld1_f32(pg, &op[k])));"
|
||||
)
|
||||
elif InType == "at::Half":
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg,\n"
|
||||
" &op[k],\n"
|
||||
" svmad_f32_x(\n"
|
||||
" pg,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_f16_x(\n"
|
||||
" pg,\n"
|
||||
" svreinterpret_f16_u32(svld1uh_u32(\n"
|
||||
" pg,"
|
||||
" reinterpret_cast<const uint16_t*>(&ip[k])))),\n"
|
||||
" svld1_f32(pg, &op[k])));"
|
||||
)
|
||||
elif InType == "at::BFloat16":
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg,\n"
|
||||
" &op[k],\n"
|
||||
" svmad_f32_x(\n"
|
||||
" pg,\n"
|
||||
" vwgt,\n"
|
||||
" svreinterpret_f32_u32(svlsl_n_u32_x(\n"
|
||||
" pg,\n"
|
||||
" svld1uh_u32(\n"
|
||||
" pg,"
|
||||
" reinterpret_cast<const uint16_t*>(&ip[k])),\n"
|
||||
" 16)),\n"
|
||||
" svld1_f32(pg, &op[k])));"
|
||||
)
|
||||
elif InType == "uint8_t":
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg,\n"
|
||||
" &op[k],\n"
|
||||
" svmad_f32_x(\n"
|
||||
" pg,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_u32_x(pg,"
|
||||
" svld1ub_u32(pg, &ip[k])),\n" # noqa
|
||||
" svadd_f32_x(pg,"
|
||||
" svld1_f32(pg, &op[k]), vbio)));"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown datatype \"{InType}\"")
|
||||
|
||||
return code
|
||||
|
||||
code = []
|
||||
|
||||
code.append(
|
||||
" for (int64_t i = 0; i < output_size; ++i) {"
|
||||
)
|
||||
|
||||
code.append(" " + OutType + "* const op = &out[i * block_size];")
|
||||
|
||||
# initialize to 0
|
||||
code.append(" memset(op, 0, sizeof(float) * block_size);")
|
||||
|
||||
# inner loop
|
||||
code.append(
|
||||
" if (pos != offsets[i] - offsets[0]) {\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
code.append(
|
||||
" int64_t start_offset = offsets[i];\n"
|
||||
+ " int64_t end_offset = offsets[i + 1];"
|
||||
)
|
||||
code.append(
|
||||
" for ("
|
||||
+ "int64_t"
|
||||
+ " j = start_offset; j < end_offset; ++j) {" # noqa
|
||||
)
|
||||
|
||||
code.append(" const auto idx = indices[pos];")
|
||||
code.append(
|
||||
" if (idx < 0 || idx >= data_size) {\n"
|
||||
+ " return false;\n"
|
||||
+ " }"
|
||||
)
|
||||
|
||||
if InType == "uint8_t":
|
||||
code.append(" // unimplemented")
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" " + OutType + " bio{};")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" if (scale_bias) {")
|
||||
code.append(" bio = wgt * scale_bias[2 * idx + 1];")
|
||||
code.append(" wgt = wgt * scale_bias[2 * idx];")
|
||||
code.append(" }")
|
||||
code.append(" svfloat32_t vbio = svdup_n_f32(bio);")
|
||||
else:
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
||||
)
|
||||
code.append(" }")
|
||||
|
||||
code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);")
|
||||
code.append(f" const {InType}* ip = &input[idx * block_size];")
|
||||
|
||||
# compute and store main loop
|
||||
code.append(" svbool_t pg;")
|
||||
code.append(" for (int64_t k = 0;")
|
||||
code.append(" svptest_first(svAll, pg = svwhilelt_b32_s64("
|
||||
+ "k, block_size));")
|
||||
code.append(" k += vLen) {")
|
||||
code.extend(compute(InType, use_weights))
|
||||
code.append(" }\n")
|
||||
code.append(" ++pos;")
|
||||
code.append(" }")
|
||||
|
||||
code.append(" const int64_t length = end_offset - start_offset;\n")
|
||||
code.append(" if (normalize_by_lengths && length != 0) {")
|
||||
code.append(" const float len_inv = 1.0f / length;")
|
||||
code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
||||
code.append(" svbool_t pg;")
|
||||
code.append(" for (int64_t j = 0;\n"
|
||||
" svptest_first(svAll, pg = svwhilelt_b32_s64("
|
||||
"j, block_size));")
|
||||
code.append(" j += vLen) {")
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
" pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));"
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" }")
|
||||
code.append(" }")
|
||||
return code
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f", "--filename", help="file name")
|
||||
opts = parser.parse_args()
|
||||
if opts.filename:
|
||||
filename = opts.filename
|
||||
else:
|
||||
filename = "embedding_lookup_idx_sve.cc"
|
||||
|
||||
options = [
|
||||
["int32_t", "int32_t", "float", "float", "float", "float"],
|
||||
["int64_t", "int64_t", "float", "float", "float", "float"],
|
||||
["int32_t", "int32_t", "half", "at::Half", "float", "float"],
|
||||
["int64_t", "int64_t", "half", "at::Half", "float", "float"],
|
||||
["int32_t", "int32_t", "bfloat16", "at::BFloat16", "float", "float"],
|
||||
["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"],
|
||||
["int32_t", "int32_t", "uint8_t", "uint8_t", "float", "float"],
|
||||
["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
|
||||
]
|
||||
|
||||
code = []
|
||||
# includes
|
||||
code.append("//// --------------------------")
|
||||
code.append("//// ATTENTION:")
|
||||
code.append("//// THIS CODE IS AUTOGENERATED")
|
||||
code.append(f"//// BY {' '.join(sys.argv)}")
|
||||
code.append("//// DO NOT MODIFY!!!")
|
||||
code.append("//// --------------------------\n")
|
||||
|
||||
code.append("#include <arm_sve.h>")
|
||||
code.append("#include <c10/util/BFloat16.h>")
|
||||
code.append("#include <c10/util/Half.h>")
|
||||
code.append("#include <cstdint>")
|
||||
code.append("#include <cstring>")
|
||||
|
||||
code.append("namespace caffe2 {\n")
|
||||
for o in options:
|
||||
[IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o
|
||||
|
||||
code.append("template <bool IS_WEIGHT_POSITIONAL>")
|
||||
fn_base = f"EmbeddingLookupIdx_{IndexTypeName}_{InTypeName}_{OutTypeName}"
|
||||
|
||||
suffix = "__sve"
|
||||
fn = "static bool " + fn_base + suffix
|
||||
code.append(fn + "(")
|
||||
|
||||
args = []
|
||||
args.append(" const int64_t block_size,")
|
||||
args.append(" const int64_t output_size,")
|
||||
args.append(" const int64_t index_size,")
|
||||
args.append(" const int64_t data_size,")
|
||||
args.append(" const " + InType + "* input,")
|
||||
args.append(" const " + IndexType + "* indices,")
|
||||
args.append(" const " + IndexType + "* offsets,")
|
||||
args.append(" const float* weights,")
|
||||
args.append(" const float* scale_bias,")
|
||||
args.append(" bool normalize_by_lengths,")
|
||||
args.append(" " + OutType + "* out) {")
|
||||
code += args
|
||||
|
||||
code.append(" const svbool_t svAll = svptrue_b32();")
|
||||
code.append(" const auto vLen = static_cast<int64_t>(svcntw());")
|
||||
code.append(" int64_t pos = 0;")
|
||||
|
||||
code.append(" if (block_size == 32 * vLen) {")
|
||||
code += unroll(32, IndexType, InType, OutType, True)
|
||||
code.append(" } else if (block_size == 16 * vLen) {")
|
||||
code += unroll(16, IndexType, InType, OutType, True)
|
||||
code.append(" } else if (block_size == 8 * vLen) {")
|
||||
code += unroll(8, IndexType, InType, OutType, True)
|
||||
code.append(" } else if (block_size == 4 * vLen) {")
|
||||
code += unroll(4, IndexType, InType, OutType, True)
|
||||
code.append(" } else if (block_size == 2 * vLen) {")
|
||||
code += unroll(2, IndexType, InType, OutType, True)
|
||||
code.append(" } else {")
|
||||
code.append(" // generic code:")
|
||||
code += generic(IndexType, InType, OutType, True)
|
||||
code.append(" }")
|
||||
code.append(" return pos == index_size;")
|
||||
|
||||
code.append("}")
|
||||
|
||||
for is_weight_positional in ["false", "true"]:
|
||||
code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(")
|
||||
code += args
|
||||
|
||||
# Resolve the Lint warnings: Limit of 80 characters in one line.
|
||||
extra_space = "\n "
|
||||
ret_string = " return " + fn_base + suffix \
|
||||
+ "<" + is_weight_positional + ">("
|
||||
if len(ret_string) <= 80:
|
||||
code.append(ret_string)
|
||||
else:
|
||||
code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(")
|
||||
|
||||
code.append(" block_size,")
|
||||
code.append(" output_size,")
|
||||
code.append(" index_size,")
|
||||
code.append(" data_size,")
|
||||
code.append(" input,")
|
||||
code.append(" indices,")
|
||||
code.append(" offsets,")
|
||||
code.append(" weights,")
|
||||
code.append(" scale_bias,")
|
||||
code.append(" normalize_by_lengths,")
|
||||
code.append(" out);")
|
||||
code.append("}")
|
||||
|
||||
code.append("")
|
||||
|
||||
code.append("} // namespace caffe2")
|
||||
|
||||
with open(filename, "w") as fout:
|
||||
fout.write("\n".join(code) + "\n")
|
||||
|
||||
print("Created " + filename)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -101,6 +101,16 @@ endif()
|
||||
# Also, we will turn off deprecated-declarations
|
||||
# due to protobuf.
|
||||
|
||||
# ---[ Check if the compiler has SVE support.
|
||||
find_package(ARM) # checks SVE
|
||||
if(CXX_SVE_FOUND)
|
||||
message(STATUS "Compiler supports SVE extension. Will build perfkernels.")
|
||||
# Also see CMakeLists.txt under caffe2/perfkernels.
|
||||
add_compile_definitions(CAFFE2_PERF_WITH_SVE=1)
|
||||
else()
|
||||
message(STATUS "Compiler does not support SVE extension. Will not build perfkernels.")
|
||||
endif()
|
||||
|
||||
if(IOS AND (${IOS_ARCH} MATCHES "armv7*"))
|
||||
add_definitions("-mfpu=neon-fp16")
|
||||
add_definitions("-arch" ${IOS_ARCH})
|
||||
|
@ -21,10 +21,10 @@ if("${ACL_VERSION_FILE}" STREQUAL "")
|
||||
message(WARNING "Build may fail: Could not determine ACL version (minimum required is ${ACL_MINIMUM_VERSION})")
|
||||
else()
|
||||
file(READ ${ACL_VERSION_FILE} ACL_VERSION_STRING)
|
||||
string(REGEX MATCH "v([0-9]+\\.[0-9]+)" ACL_VERSION ${ACL_VERSION_STRING})
|
||||
string(REGEX MATCH "v([0-9]+\\.[0-9]+)" ACL_VERSION "${ACL_VERSION_STRING}")
|
||||
set(ACL_VERSION "${CMAKE_MATCH_1}")
|
||||
|
||||
if(${ACL_VERSION} VERSION_EQUAL "0.0")
|
||||
if("${ACL_VERSION}" VERSION_EQUAL "0.0")
|
||||
# Unreleased ACL versions come with version string "v0.0-unreleased", and may not be compatible with oneDNN.
|
||||
# It is recommended to use the latest release of ACL.
|
||||
message(WARNING "Build may fail: Using unreleased ACL version (minimum required is ${ACL_MINIMUM_VERSION})")
|
||||
|
Reference in New Issue
Block a user