mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Update ruff to 0.13.1 so that we can remove `UP038` from `pyproject.toml` because it has been removed from supported rules of ruff. There are some fixes, the most notable one is [(PYI059)](https://docs.astral.sh/ruff/rules/generic-not-last-base-class/#generic-not-last-base-class-pyi059) ``` Checks for classes inheriting from typing.Generic[] where Generic[] is not the last base class in the bases tuple. ``` A BC-breaking change is introduced to change the typing of `OrderedSet .storage` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163744 Approved by: https://github.com/Skylion007, https://github.com/jingsh
261 lines
9.9 KiB
Python
261 lines
9.9 KiB
Python
# mypy: allow-untyped-defs
|
|
import argparse
|
|
import sys
|
|
|
|
|
|
# Unroll loops when block_size is a multiple of vector length.
|
|
def unroll(num_unrolls, IndexType, InType, OutType):
|
|
def compute_output(num_unrolls, InType, is_main):
|
|
code = []
|
|
|
|
pred = "svAll" if is_main else "pg"
|
|
if InType == "float":
|
|
for i in range(num_unrolls):
|
|
code.append(f" output = svmla_x({pred}, output, svld1(svAll, &ip{i}[k]), wgt{i});")
|
|
elif InType == "at::Half":
|
|
for i in range(num_unrolls):
|
|
code.append(f" auto input{i} = svcvt_f32_x({pred}, svreinterpret_f16(\n"
|
|
f" svld1uh_u32({pred}, reinterpret_cast<const uint16_t*>(&ip{i}[k]))));")
|
|
for i in range(num_unrolls):
|
|
code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});")
|
|
elif InType == "at::BFloat16":
|
|
for i in range(num_unrolls):
|
|
code.append(f" auto input{i} = svreinterpret_f32(svlsl_x({pred},\n"
|
|
f" svld1uh_u32({pred}, reinterpret_cast<const uint16_t*>(&ip{i}[k])), 16));")
|
|
for i in range(num_unrolls):
|
|
code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});")
|
|
elif InType == "uint8_t":
|
|
code.append(f" output = svadd_x({pred}, output, bio);")
|
|
for i in range(num_unrolls):
|
|
code.append(f" auto input{i} = svcvt_f32_x({pred}, svld1ub_u32({pred}, &ip{i}[k]));")
|
|
for i in range(num_unrolls):
|
|
code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});")
|
|
else:
|
|
raise ValueError(f'Unknown datatype "{InType}"')
|
|
|
|
return code
|
|
|
|
code = []
|
|
|
|
if num_unrolls == 1:
|
|
code.append(" // tail loop")
|
|
code.append(" if (j < end_offset) {")
|
|
else:
|
|
code.append(f" // unrolling {num_unrolls} times")
|
|
code.append(f" while (j + {num_unrolls - 1} < end_offset) {{")
|
|
for i in range(num_unrolls):
|
|
code.append(f" const auto idx{i} = indices[pos + {i}];")
|
|
|
|
# check indices
|
|
for i in range(num_unrolls):
|
|
code.append(
|
|
f" if (idx{i} < 0 || idx{i} >= data_size) {{\n"
|
|
+ " return false;\n"
|
|
+ " }"
|
|
)
|
|
|
|
if InType == "uint8_t":
|
|
for i in range(num_unrolls):
|
|
code.append(f" {OutType} wgt{i} = 1.f;")
|
|
code.append(f" {OutType} bio = 0.f;")
|
|
else:
|
|
for i in range(num_unrolls):
|
|
code.append(f" {OutType} wgt{i} = 1.f;")
|
|
|
|
code.append(" if (weights) {")
|
|
for i in range(num_unrolls):
|
|
code.append(f" wgt{i} = weights[IS_WEIGHT_POSITIONAL ? (j + {i} - start_offset) : pos + {i}];")
|
|
code.append(" }")
|
|
if InType == "uint8_t":
|
|
code.append(" if (scale_bias) {")
|
|
for i in range(num_unrolls):
|
|
code.append(f" bio += wgt{i} * scale_bias[2 * idx{i} + 1];")
|
|
code.append(f" wgt{i} = wgt{i} * scale_bias[2 * idx{i}];")
|
|
code.append(" }")
|
|
|
|
for i in range(num_unrolls):
|
|
code.append(f" const {InType}* const ip{i} = &input[idx{i} * block_size];")
|
|
|
|
# compute and store
|
|
code.append(" svbool_t pg;")
|
|
code.append(" int64_t k = 0;")
|
|
# main loop
|
|
code.append(" while (k + vLen - 1 < block_size) {")
|
|
code.append(" auto output = svld1(svAll, &op[k]);")
|
|
code.extend(compute_output(num_unrolls, InType, True))
|
|
code.append(" svst1(svAll, &op[k], output);")
|
|
code.append(" k += vLen;")
|
|
code.append(" }")
|
|
# tail loop
|
|
code.append(" if (k < block_size) {")
|
|
code.append(" pg = svwhilelt_b32_s64(k, block_size);")
|
|
code.append(" auto output = svld1(pg, &op[k]);")
|
|
code.extend(compute_output(num_unrolls, InType, False))
|
|
code.append(" svst1(pg, &op[k], output);")
|
|
code.append(" k += vLen;")
|
|
code.append(" }")
|
|
if num_unrolls == 1:
|
|
code.append(" pos ++;")
|
|
else:
|
|
code.append(f" j += {num_unrolls};")
|
|
code.append(f" pos += {num_unrolls};")
|
|
|
|
code.append(" }")
|
|
|
|
return code
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-f", "--filename", help="file name")
|
|
opts = parser.parse_args()
|
|
if opts.filename:
|
|
filename = opts.filename
|
|
else:
|
|
filename = "embedding_lookup_idx_sve.cc"
|
|
|
|
options = [
|
|
["int32_t", "int32_t", "float", "float", "float", "float"],
|
|
["int64_t", "int64_t", "float", "float", "float", "float"],
|
|
["int32_t", "int32_t", "half", "at::Half", "float", "float"],
|
|
["int64_t", "int64_t", "half", "at::Half", "float", "float"],
|
|
["int32_t", "int32_t", "bfloat16", "at::BFloat16", "float", "float"],
|
|
["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"],
|
|
["int32_t", "int32_t", "uint8_t", "uint8_t", "float", "float"],
|
|
["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
|
|
]
|
|
|
|
code = []
|
|
# includes
|
|
code.append("//// --------------------------")
|
|
code.append("//// ATTENTION:")
|
|
code.append("//// THIS CODE IS AUTOGENERATED")
|
|
code.append(f"//// BY {' '.join(sys.argv)}")
|
|
code.append("//// DO NOT MODIFY!!!")
|
|
code.append("//// --------------------------\n")
|
|
|
|
code.append("#include <arm_sve.h>")
|
|
code.append("#include <c10/util/BFloat16.h>")
|
|
code.append("#include <c10/util/Half.h>")
|
|
code.append("#include <cstdint>")
|
|
code.append("#include <cstring>")
|
|
|
|
code.append("namespace caffe2 {\n")
|
|
for o in options:
|
|
[IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o
|
|
|
|
code.append("template <bool IS_WEIGHT_POSITIONAL>")
|
|
fn_base = f"EmbeddingLookupIdx_{IndexTypeName}_{InTypeName}_{OutTypeName}"
|
|
|
|
suffix = "__sve"
|
|
fn = "static bool " + fn_base + suffix
|
|
code.append(fn + "(")
|
|
|
|
args = []
|
|
args.append(" const int64_t block_size,")
|
|
args.append(" const int64_t output_size,")
|
|
args.append(" const int64_t index_size,")
|
|
args.append(" const int64_t data_size,")
|
|
args.append(" const " + InType + "* input,")
|
|
args.append(" const " + IndexType + "* indices,")
|
|
args.append(" const " + IndexType + "* offsets,")
|
|
args.append(" const float* weights,")
|
|
args.append(" const float* scale_bias,")
|
|
args.append(" bool normalize_by_lengths,")
|
|
args.append(" " + OutType + "* out) {")
|
|
code += args
|
|
|
|
code.append(" const svbool_t svAll = svptrue_b32();")
|
|
code.append(" const auto vLen = static_cast<int64_t>(svcntw());")
|
|
code.append(" int64_t pos = 0;")
|
|
|
|
code.append(" for (int64_t i = 0; i < output_size; ++i) {")
|
|
code.append(" " + OutType + "* const op = &out[i * block_size];")
|
|
|
|
# initialize to 0
|
|
code.append(" memset(op, 0, sizeof(float) * block_size);")
|
|
|
|
# inner loop
|
|
code.append(
|
|
" if (pos != offsets[i] - offsets[0]) {\n"
|
|
+ " return false;\n"
|
|
+ " }"
|
|
)
|
|
code.append(
|
|
" int64_t start_offset = offsets[i];\n"
|
|
+ " int64_t end_offset = offsets[i + 1];"
|
|
)
|
|
code.append(" int64_t j = start_offset;")
|
|
|
|
code += unroll(16, IndexType, InType, OutType)
|
|
code += unroll(8, IndexType, InType, OutType)
|
|
code += unroll(4, IndexType, InType, OutType)
|
|
code += unroll(2, IndexType, InType, OutType)
|
|
code += unroll(1, IndexType, InType, OutType)
|
|
|
|
code.append(" const int64_t length = end_offset - start_offset;\n")
|
|
code.append(" if (normalize_by_lengths && length != 0) {")
|
|
code.append(" const float len_inv = 1.0f / length;")
|
|
code.append(" svbool_t pg;")
|
|
code.append(" int64_t j = 0;")
|
|
code.append(" while (j + vLen - 1 < block_size) {")
|
|
code.append(" svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));")
|
|
code.append(" j += vLen;")
|
|
code.append(" }")
|
|
code.append(" if (j < block_size) {")
|
|
code.append(" pg = svwhilelt_b32_s64(j, block_size);")
|
|
code.append(" svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));")
|
|
code.append(" }")
|
|
code.append(" }")
|
|
|
|
code.append(" }")
|
|
code.append(" return pos == index_size;")
|
|
code.append("}")
|
|
|
|
for is_weight_positional in ["false", "true"]:
|
|
code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(")
|
|
code += args
|
|
|
|
# Resolve the Lint warnings: Limit of 80 characters in one line.
|
|
extra_space = "\n "
|
|
ret_string = (
|
|
" return " + fn_base + suffix + "<" + is_weight_positional + ">("
|
|
)
|
|
if len(ret_string) <= 80:
|
|
code.append(ret_string)
|
|
else:
|
|
code.append(
|
|
" return "
|
|
+ fn_base
|
|
+ suffix
|
|
+ "<"
|
|
+ extra_space
|
|
+ is_weight_positional
|
|
+ ">("
|
|
)
|
|
|
|
code.append(" block_size,")
|
|
code.append(" output_size,")
|
|
code.append(" index_size,")
|
|
code.append(" data_size,")
|
|
code.append(" input,")
|
|
code.append(" indices,")
|
|
code.append(" offsets,")
|
|
code.append(" weights,")
|
|
code.append(" scale_bias,")
|
|
code.append(" normalize_by_lengths,")
|
|
code.append(" out);")
|
|
code.append("}")
|
|
|
|
code.append("")
|
|
|
|
code.append("} // namespace caffe2")
|
|
|
|
with open(filename, "w") as fout:
|
|
fout.write("\n".join(code) + "\n")
|
|
|
|
print("Created " + filename)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|