# mypy: allow-untyped-defs import argparse import sys # Unroll loops when block_size is a multiple of vector length. def unroll(num_unrolls, IndexType, InType, OutType): def compute_output(num_unrolls, InType, is_main): code = [] pred = "svAll" if is_main else "pg" if InType == "float": for i in range(num_unrolls): code.append(f" output = svmla_x({pred}, output, svld1(svAll, &ip{i}[k]), wgt{i});") elif InType == "at::Half": for i in range(num_unrolls): code.append(f" auto input{i} = svcvt_f32_x({pred}, svreinterpret_f16(\n" f" svld1uh_u32({pred}, reinterpret_cast(&ip{i}[k]))));") for i in range(num_unrolls): code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});") elif InType == "at::BFloat16": for i in range(num_unrolls): code.append(f" auto input{i} = svreinterpret_f32(svlsl_x({pred},\n" f" svld1uh_u32({pred}, reinterpret_cast(&ip{i}[k])), 16));") for i in range(num_unrolls): code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});") elif InType == "uint8_t": code.append(f" output = svadd_x({pred}, output, bio);") for i in range(num_unrolls): code.append(f" auto input{i} = svcvt_f32_x({pred}, svld1ub_u32({pred}, &ip{i}[k]));") for i in range(num_unrolls): code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});") else: raise ValueError(f'Unknown datatype "{InType}"') return code code = [] if num_unrolls == 1: code.append(" // tail loop") code.append(" if (j < end_offset) {") else: code.append(f" // unrolling {num_unrolls} times") code.append(f" while (j + {num_unrolls - 1} < end_offset) {{") for i in range(num_unrolls): code.append(f" const auto idx{i} = indices[pos + {i}];") # check indices for i in range(num_unrolls): code.append( f" if (idx{i} < 0 || idx{i} >= data_size) {{\n" + " return false;\n" + " }" ) if InType == "uint8_t": for i in range(num_unrolls): code.append(f" {OutType} wgt{i} = 1.f;") code.append(f" {OutType} bio = 0.f;") else: for i in range(num_unrolls): code.append(f" {OutType} wgt{i} = 1.f;") code.append(" if (weights) {") for i in range(num_unrolls): code.append(f" wgt{i} = weights[IS_WEIGHT_POSITIONAL ? (j + {i} - start_offset) : pos + {i}];") code.append(" }") if InType == "uint8_t": code.append(" if (scale_bias) {") for i in range(num_unrolls): code.append(f" bio += wgt{i} * scale_bias[2 * idx{i} + 1];") code.append(f" wgt{i} = wgt{i} * scale_bias[2 * idx{i}];") code.append(" }") for i in range(num_unrolls): code.append(f" const {InType}* const ip{i} = &input[idx{i} * block_size];") # compute and store code.append(" svbool_t pg;") code.append(" int64_t k = 0;") # main loop code.append(" while (k + vLen - 1 < block_size) {") code.append(" auto output = svld1(svAll, &op[k]);") code.extend(compute_output(num_unrolls, InType, True)) code.append(" svst1(svAll, &op[k], output);") code.append(" k += vLen;") code.append(" }") # tail loop code.append(" if (k < block_size) {") code.append(" pg = svwhilelt_b32_s64(k, block_size);") code.append(" auto output = svld1(pg, &op[k]);") code.extend(compute_output(num_unrolls, InType, False)) code.append(" svst1(pg, &op[k], output);") code.append(" k += vLen;") code.append(" }") if num_unrolls == 1: code.append(" pos ++;") else: code.append(f" j += {num_unrolls};") code.append(f" pos += {num_unrolls};") code.append(" }") return code def main(): parser = argparse.ArgumentParser() parser.add_argument("-f", "--filename", help="file name") opts = parser.parse_args() if opts.filename: filename = opts.filename else: filename = "embedding_lookup_idx_sve.cc" options = [ ["int32_t", "int32_t", "float", "float", "float", "float"], ["int64_t", "int64_t", "float", "float", "float", "float"], ["int32_t", "int32_t", "half", "at::Half", "float", "float"], ["int64_t", "int64_t", "half", "at::Half", "float", "float"], ["int32_t", "int32_t", "bfloat16", "at::BFloat16", "float", "float"], ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"], ["int32_t", "int32_t", "uint8_t", "uint8_t", "float", "float"], ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"], ] code = [] # includes code.append("//// --------------------------") code.append("//// ATTENTION:") code.append("//// THIS CODE IS AUTOGENERATED") code.append(f"//// BY {' '.join(sys.argv)}") code.append("//// DO NOT MODIFY!!!") code.append("//// --------------------------\n") code.append("#include ") code.append("#include ") code.append("#include ") code.append("#include ") code.append("#include ") code.append("namespace caffe2 {\n") for o in options: [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o code.append("template ") fn_base = f"EmbeddingLookupIdx_{IndexTypeName}_{InTypeName}_{OutTypeName}" suffix = "__sve" fn = "static bool " + fn_base + suffix code.append(fn + "(") args = [] args.append(" const int64_t block_size,") args.append(" const int64_t output_size,") args.append(" const int64_t index_size,") args.append(" const int64_t data_size,") args.append(" const " + InType + "* input,") args.append(" const " + IndexType + "* indices,") args.append(" const " + IndexType + "* offsets,") args.append(" const float* weights,") args.append(" const float* scale_bias,") args.append(" bool normalize_by_lengths,") args.append(" " + OutType + "* out) {") code += args code.append(" const svbool_t svAll = svptrue_b32();") code.append(" const auto vLen = static_cast(svcntw());") code.append(" int64_t pos = 0;") code.append(" for (int64_t i = 0; i < output_size; ++i) {") code.append(" " + OutType + "* const op = &out[i * block_size];") # initialize to 0 code.append(" memset(op, 0, sizeof(float) * block_size);") # inner loop code.append( " if (pos != offsets[i] - offsets[0]) {\n" + " return false;\n" + " }" ) code.append( " int64_t start_offset = offsets[i];\n" + " int64_t end_offset = offsets[i + 1];" ) code.append(" int64_t j = start_offset;") code += unroll(16, IndexType, InType, OutType) code += unroll(8, IndexType, InType, OutType) code += unroll(4, IndexType, InType, OutType) code += unroll(2, IndexType, InType, OutType) code += unroll(1, IndexType, InType, OutType) code.append(" const int64_t length = end_offset - start_offset;\n") code.append(" if (normalize_by_lengths && length != 0) {") code.append(" const float len_inv = 1.0f / length;") code.append(" svbool_t pg;") code.append(" int64_t j = 0;") code.append(" while (j + vLen - 1 < block_size) {") code.append(" svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));") code.append(" j += vLen;") code.append(" }") code.append(" if (j < block_size) {") code.append(" pg = svwhilelt_b32_s64(j, block_size);") code.append(" svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));") code.append(" }") code.append(" }") code.append(" }") code.append(" return pos == index_size;") code.append("}") for is_weight_positional in ["false", "true"]: code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(") code += args # Resolve the Lint warnings: Limit of 80 characters in one line. extra_space = "\n " ret_string = ( " return " + fn_base + suffix + "<" + is_weight_positional + ">(" ) if len(ret_string) <= 80: code.append(ret_string) else: code.append( " return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(" ) code.append(" block_size,") code.append(" output_size,") code.append(" index_size,") code.append(" data_size,") code.append(" input,") code.append(" indices,") code.append(" offsets,") code.append(" weights,") code.append(" scale_bias,") code.append(" normalize_by_lengths,") code.append(" out);") code.append("}") code.append("") code.append("} // namespace caffe2") with open(filename, "w") as fout: fout.write("\n".join(code) + "\n") print("Created " + filename) if __name__ == "__main__": main()