mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix ruff warnings in caffe2 and functorch (#144182)
In preparation for upgrading ruff config to py3.9. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144182 Approved by: https://github.com/malfet
This commit is contained in:
@ -14,15 +14,14 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||||||
|
|
||||||
if InType == "float":
|
if InType == "float":
|
||||||
code.append(
|
code.append(
|
||||||
f" vop{regid:d} = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + ({regid:d})), vop{regid:d});" # noqa
|
f" vop{regid:d} = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + ({regid:d})), vop{regid:d});"
|
||||||
|
|
||||||
)
|
)
|
||||||
elif InType == "at::Half":
|
elif InType == "at::Half":
|
||||||
code.append(
|
code.append(
|
||||||
f" vop{regid:d} = _mm256_fmadd_ps(\n"
|
f" vop{regid:d} = _mm256_fmadd_ps(\n"
|
||||||
" vwgt,\n"
|
" vwgt,\n"
|
||||||
" _mm256_cvtph_ps(\n"
|
" _mm256_cvtph_ps(\n"
|
||||||
f" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n" # noqa
|
f" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n"
|
||||||
f" vop{regid:d});"
|
f" vop{regid:d});"
|
||||||
)
|
)
|
||||||
elif InType == "at::BFloat16":
|
elif InType == "at::BFloat16":
|
||||||
@ -32,7 +31,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||||||
" _mm256_castsi256_ps(_mm256_slli_epi32(\n"
|
" _mm256_castsi256_ps(_mm256_slli_epi32(\n"
|
||||||
" _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
|
" _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
|
||||||
f" reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n"
|
f" reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n"
|
||||||
" 16)),\n" # noqa
|
" 16)),\n"
|
||||||
f" vop{regid:d});"
|
f" vop{regid:d});"
|
||||||
)
|
)
|
||||||
elif InType == "uint8_t":
|
elif InType == "uint8_t":
|
||||||
@ -40,17 +39,16 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||||||
f" vop{regid:d} = _mm256_fmadd_ps(\n"
|
f" vop{regid:d} = _mm256_fmadd_ps(\n"
|
||||||
" vwgt,\n"
|
" vwgt,\n"
|
||||||
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
|
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
|
||||||
f" _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + ({regid:d}))))),\n" # noqa
|
f" _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + ({regid:d}))))),\n"
|
||||||
f" _mm256_add_ps(vop{regid:d}, vbio));"
|
f" _mm256_add_ps(vop{regid:d}, vbio));"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert False
|
raise AssertionError
|
||||||
|
|
||||||
if prefetch:
|
if prefetch:
|
||||||
code.append(
|
code.append(
|
||||||
" _mm_prefetch(\n"
|
" _mm_prefetch(\n"
|
||||||
f" reinterpret_cast<const char*>(&ip_next_T0[{regid:d}]), _MM_HINT_T0);"
|
f" reinterpret_cast<const char*>(&ip_next_T0[{regid:d}]), _MM_HINT_T0);"
|
||||||
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
code.append(
|
code.append(
|
||||||
@ -93,7 +91,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||||||
code.append(
|
code.append(
|
||||||
" for ("
|
" for ("
|
||||||
+ "int64_t"
|
+ "int64_t"
|
||||||
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa
|
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
code.append(
|
code.append(
|
||||||
@ -104,7 +102,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||||||
code.append(
|
code.append(
|
||||||
" for ("
|
" for ("
|
||||||
+ IndexType
|
+ IndexType
|
||||||
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa
|
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {"
|
||||||
)
|
)
|
||||||
code.append(" const " + IndexType + " idx = indices[dataInd];")
|
code.append(" const " + IndexType + " idx = indices[dataInd];")
|
||||||
code.append(
|
code.append(
|
||||||
@ -119,7 +117,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||||||
code.append(" " + OutType + " bio;")
|
code.append(" " + OutType + " bio;")
|
||||||
code.append(" if (weights) {")
|
code.append(" if (weights) {")
|
||||||
code.append(
|
code.append(
|
||||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
|
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
|
||||||
)
|
)
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
if fused:
|
if fused:
|
||||||
@ -137,7 +135,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||||||
code.append(" " + OutType + " wgt = 1.f;")
|
code.append(" " + OutType + " wgt = 1.f;")
|
||||||
code.append(" if (weights) {")
|
code.append(" if (weights) {")
|
||||||
code.append(
|
code.append(
|
||||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
|
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
|
||||||
)
|
)
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
|
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
|
||||||
@ -182,7 +180,9 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||||||
if use_offsets:
|
if use_offsets:
|
||||||
code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);")
|
code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);")
|
||||||
else:
|
else:
|
||||||
code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
|
code.append(
|
||||||
|
" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);"
|
||||||
|
)
|
||||||
for i in range(0, uf):
|
for i in range(0, uf):
|
||||||
j = 8 * i
|
j = 8 * i
|
||||||
code.append(
|
code.append(
|
||||||
@ -207,7 +207,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||||||
" _mm256_storeu_ps(\n"
|
" _mm256_storeu_ps(\n"
|
||||||
" &op[j],\n"
|
" &op[j],\n"
|
||||||
" _mm256_fmadd_ps(\n"
|
" _mm256_fmadd_ps(\n"
|
||||||
" vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));" # noqa
|
" vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));"
|
||||||
)
|
)
|
||||||
elif InType == "at::Half":
|
elif InType == "at::Half":
|
||||||
code.append(
|
code.append(
|
||||||
@ -237,12 +237,12 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||||||
" &op[j],\n"
|
" &op[j],\n"
|
||||||
" _mm256_fmadd_ps(\n"
|
" _mm256_fmadd_ps(\n"
|
||||||
" vwgt,\n"
|
" vwgt,\n"
|
||||||
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n" # noqa
|
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n"
|
||||||
" reinterpret_cast<const __m128i*>(&ip[j])))),\n"
|
" reinterpret_cast<const __m128i*>(&ip[j])))),\n"
|
||||||
" _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));"
|
" _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert False
|
raise AssertionError
|
||||||
|
|
||||||
code.append(
|
code.append(
|
||||||
" _mm_prefetch(\n"
|
" _mm_prefetch(\n"
|
||||||
@ -257,7 +257,6 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||||||
if InType == "at::BFloat16":
|
if InType == "at::BFloat16":
|
||||||
code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};")
|
code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};")
|
||||||
|
|
||||||
|
|
||||||
if use_offsets:
|
if use_offsets:
|
||||||
code.append(
|
code.append(
|
||||||
" for ("
|
" for ("
|
||||||
@ -295,7 +294,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||||||
code.append(
|
code.append(
|
||||||
" for ("
|
" for ("
|
||||||
+ "int64_t"
|
+ "int64_t"
|
||||||
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa
|
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
code.append(
|
code.append(
|
||||||
@ -306,7 +305,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||||||
code.append(
|
code.append(
|
||||||
" for ("
|
" for ("
|
||||||
+ IndexType
|
+ IndexType
|
||||||
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa
|
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {"
|
||||||
)
|
)
|
||||||
code.append(" const " + IndexType + " idx = indices[dataInd];")
|
code.append(" const " + IndexType + " idx = indices[dataInd];")
|
||||||
code.append(
|
code.append(
|
||||||
@ -321,7 +320,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||||||
code.append(" " + OutType + " bio;")
|
code.append(" " + OutType + " bio;")
|
||||||
code.append(" if (weights) {")
|
code.append(" if (weights) {")
|
||||||
code.append(
|
code.append(
|
||||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
|
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
|
||||||
)
|
)
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
if fused:
|
if fused:
|
||||||
@ -339,7 +338,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||||||
code.append(" " + OutType + " wgt = 1.f;")
|
code.append(" " + OutType + " wgt = 1.f;")
|
||||||
code.append(" if (weights) {")
|
code.append(" if (weights) {")
|
||||||
code.append(
|
code.append(
|
||||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
|
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
|
||||||
)
|
)
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
|
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
|
||||||
@ -390,7 +389,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||||||
elif InType == "uint8_t":
|
elif InType == "uint8_t":
|
||||||
code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
|
code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
|
||||||
else:
|
else:
|
||||||
assert False
|
raise AssertionError
|
||||||
|
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
|
|
||||||
@ -496,13 +495,13 @@ for o in options:
|
|||||||
code += args
|
code += args
|
||||||
|
|
||||||
code.append(" const " + IndexType + " prefdist_T0 = 16;")
|
code.append(" const " + IndexType + " prefdist_T0 = 16;")
|
||||||
code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)")
|
code.append(
|
||||||
|
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)"
|
||||||
|
)
|
||||||
# block_size is the number of elements and fused_block_size is the size of
|
# block_size is the number of elements and fused_block_size is the size of
|
||||||
# an entire row, including scale and bias.
|
# an entire row, including scale and bias.
|
||||||
offset = (8 // sizeof[InType]) if opts.fused else 0
|
offset = (8 // sizeof[InType]) if opts.fused else 0
|
||||||
code.append(
|
code.append(f" const {IndexType} fused_block_size = block_size + {offset};")
|
||||||
f" const {IndexType} fused_block_size = block_size + {offset};"
|
|
||||||
)
|
|
||||||
if opts.use_offsets:
|
if opts.use_offsets:
|
||||||
code.append(" int64_t dataInd = 0;")
|
code.append(" int64_t dataInd = 0;")
|
||||||
else:
|
else:
|
||||||
@ -511,17 +510,29 @@ for o in options:
|
|||||||
# code.append("printf(\"calling " + fn + "\\n\");");
|
# code.append("printf(\"calling " + fn + "\\n\");");
|
||||||
|
|
||||||
code.append(" if (block_size == 128) {")
|
code.append(" if (block_size == 128) {")
|
||||||
code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
code += unroll(
|
||||||
|
16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||||
|
)
|
||||||
code.append(" } else if (block_size == 64) {")
|
code.append(" } else if (block_size == 64) {")
|
||||||
code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
code += unroll(
|
||||||
|
8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||||
|
)
|
||||||
code.append(" } else if (block_size == 32) {")
|
code.append(" } else if (block_size == 32) {")
|
||||||
code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
code += unroll(
|
||||||
|
4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||||
|
)
|
||||||
code.append(" } else if (block_size == 16) {")
|
code.append(" } else if (block_size == 16) {")
|
||||||
code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
code += unroll(
|
||||||
|
2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||||
|
)
|
||||||
code.append(" } else {")
|
code.append(" } else {")
|
||||||
code.append(" // generic code")
|
code.append(" // generic code")
|
||||||
code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)")
|
code.append(
|
||||||
code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)"
|
||||||
|
)
|
||||||
|
code += generic(
|
||||||
|
IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||||
|
)
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
code.append(" return dataInd == index_size;")
|
code.append(" return dataInd == index_size;")
|
||||||
|
|
||||||
@ -558,7 +569,7 @@ for o in options:
|
|||||||
|
|
||||||
code.append("} // namespace caffe2")
|
code.append("} // namespace caffe2")
|
||||||
|
|
||||||
with open(filename, "w") as fout:
|
with open(filename, "w", encoding="utf8") as fout:
|
||||||
for c in code:
|
for c in code:
|
||||||
# print(c, file = fout)
|
# print(c, file = fout)
|
||||||
fout.write(c + "\n")
|
fout.write(c + "\n")
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
# Unroll loops when block_size is a multiple of vector length.
|
# Unroll loops when block_size is a multiple of vector length.
|
||||||
def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||||
def compute(regid, InType, use_weights):
|
def compute(regid, InType, use_weights):
|
||||||
@ -23,7 +24,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
|||||||
" svAll,\n"
|
" svAll,\n"
|
||||||
" svreinterpret_f16_u32(svld1uh_u32(\n"
|
" svreinterpret_f16_u32(svld1uh_u32(\n"
|
||||||
" svAll, reinterpret_cast<const uint16_t*>("
|
" svAll, reinterpret_cast<const uint16_t*>("
|
||||||
f"&ip[{regid} * vLen])))),\n" # noqa
|
f"&ip[{regid} * vLen])))),\n"
|
||||||
f" vsum{regid});"
|
f" vsum{regid});"
|
||||||
)
|
)
|
||||||
elif InType == "at::BFloat16":
|
elif InType == "at::BFloat16":
|
||||||
@ -36,7 +37,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
|||||||
" svld1uh_u32(\n"
|
" svld1uh_u32(\n"
|
||||||
" svAll, reinterpret_cast<const uint16_t*>("
|
" svAll, reinterpret_cast<const uint16_t*>("
|
||||||
f"&ip[{regid} * vLen])),\n"
|
f"&ip[{regid} * vLen])),\n"
|
||||||
" 16)),\n" # noqa
|
" 16)),\n"
|
||||||
f" vsum{regid});"
|
f" vsum{regid});"
|
||||||
)
|
)
|
||||||
elif InType == "uint8_t":
|
elif InType == "uint8_t":
|
||||||
@ -45,11 +46,11 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
|||||||
" svAll,\n"
|
" svAll,\n"
|
||||||
" vwgt,\n"
|
" vwgt,\n"
|
||||||
" svcvt_f32_u32_x(svAll,"
|
" svcvt_f32_u32_x(svAll,"
|
||||||
f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" # noqa
|
f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n"
|
||||||
f" svadd_f32_x(svAll, vsum{regid}, vbio));"
|
f" svadd_f32_x(svAll, vsum{regid}, vbio));"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown datatype \"{InType}\"")
|
raise ValueError(f'Unknown datatype "{InType}"')
|
||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
@ -74,9 +75,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
|||||||
int64_t start_offset = offsets[i];
|
int64_t start_offset = offsets[i];
|
||||||
int64_t end_offset = offsets[i + 1];""")
|
int64_t end_offset = offsets[i + 1];""")
|
||||||
code.append(
|
code.append(
|
||||||
" for ("
|
" for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {"
|
||||||
+ "int64_t"
|
|
||||||
+ " j = start_offset; j < end_offset; ++j) {" # noqa
|
|
||||||
)
|
)
|
||||||
|
|
||||||
code.append(" const auto idx = indices[pos];")
|
code.append(" const auto idx = indices[pos];")
|
||||||
@ -91,7 +90,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
|||||||
code.append(" " + OutType + " bio{};")
|
code.append(" " + OutType + " bio{};")
|
||||||
code.append(" if (weights) {")
|
code.append(" if (weights) {")
|
||||||
code.append(
|
code.append(
|
||||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||||
)
|
)
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
code.append(" if (scale_bias) {")
|
code.append(" if (scale_bias) {")
|
||||||
@ -103,7 +102,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
|||||||
code.append(" " + OutType + " wgt = 1.f;")
|
code.append(" " + OutType + " wgt = 1.f;")
|
||||||
code.append(" if (weights) {")
|
code.append(" if (weights) {")
|
||||||
code.append(
|
code.append(
|
||||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||||
)
|
)
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
|
|
||||||
@ -124,8 +123,10 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
|||||||
code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
||||||
|
|
||||||
for i in range(num_unrolls):
|
for i in range(num_unrolls):
|
||||||
code.append(f" svst1_f32(svAll, &op[{i} * vLen],"
|
code.append(
|
||||||
+ f" svmul_f32_x(svAll, vsum{i}, vlen_inv));")
|
f" svst1_f32(svAll, &op[{i} * vLen],"
|
||||||
|
+ f" svmul_f32_x(svAll, vsum{i}, vlen_inv));"
|
||||||
|
)
|
||||||
|
|
||||||
code.append(" } else {")
|
code.append(" } else {")
|
||||||
# inv of length
|
# inv of length
|
||||||
@ -190,20 +191,18 @@ def generic(IndexType, InType, OutType, use_weights):
|
|||||||
" pg,\n"
|
" pg,\n"
|
||||||
" vwgt,\n"
|
" vwgt,\n"
|
||||||
" svcvt_f32_u32_x(pg,"
|
" svcvt_f32_u32_x(pg,"
|
||||||
" svld1ub_u32(pg, &ip[k])),\n" # noqa
|
" svld1ub_u32(pg, &ip[k])),\n"
|
||||||
" svadd_f32_x(pg,"
|
" svadd_f32_x(pg,"
|
||||||
" svld1_f32(pg, &op[k]), vbio)));"
|
" svld1_f32(pg, &op[k]), vbio)));"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown datatype \"{InType}\"")
|
raise ValueError(f'Unknown datatype "{InType}"')
|
||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
code = []
|
code = []
|
||||||
|
|
||||||
code.append(
|
code.append(" for (int64_t i = 0; i < output_size; ++i) {")
|
||||||
" for (int64_t i = 0; i < output_size; ++i) {"
|
|
||||||
)
|
|
||||||
|
|
||||||
code.append(" " + OutType + "* const op = &out[i * block_size];")
|
code.append(" " + OutType + "* const op = &out[i * block_size];")
|
||||||
|
|
||||||
@ -221,9 +220,7 @@ def generic(IndexType, InType, OutType, use_weights):
|
|||||||
+ " int64_t end_offset = offsets[i + 1];"
|
+ " int64_t end_offset = offsets[i + 1];"
|
||||||
)
|
)
|
||||||
code.append(
|
code.append(
|
||||||
" for ("
|
" for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {"
|
||||||
+ "int64_t"
|
|
||||||
+ " j = start_offset; j < end_offset; ++j) {" # noqa
|
|
||||||
)
|
)
|
||||||
|
|
||||||
code.append(" const auto idx = indices[pos];")
|
code.append(" const auto idx = indices[pos];")
|
||||||
@ -239,7 +236,7 @@ def generic(IndexType, InType, OutType, use_weights):
|
|||||||
code.append(" " + OutType + " bio{};")
|
code.append(" " + OutType + " bio{};")
|
||||||
code.append(" if (weights) {")
|
code.append(" if (weights) {")
|
||||||
code.append(
|
code.append(
|
||||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||||
)
|
)
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
code.append(" if (scale_bias) {")
|
code.append(" if (scale_bias) {")
|
||||||
@ -251,7 +248,7 @@ def generic(IndexType, InType, OutType, use_weights):
|
|||||||
code.append(" " + OutType + " wgt = 1.f;")
|
code.append(" " + OutType + " wgt = 1.f;")
|
||||||
code.append(" if (weights) {")
|
code.append(" if (weights) {")
|
||||||
code.append(
|
code.append(
|
||||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||||
)
|
)
|
||||||
code.append(" }")
|
code.append(" }")
|
||||||
|
|
||||||
@ -261,8 +258,9 @@ def generic(IndexType, InType, OutType, use_weights):
|
|||||||
# compute and store main loop
|
# compute and store main loop
|
||||||
code.append(" svbool_t pg;")
|
code.append(" svbool_t pg;")
|
||||||
code.append(" for (int64_t k = 0;")
|
code.append(" for (int64_t k = 0;")
|
||||||
code.append(" svptest_first(svAll, pg = svwhilelt_b32_s64("
|
code.append(
|
||||||
+ "k, block_size));")
|
" svptest_first(svAll, pg = svwhilelt_b32_s64(" + "k, block_size));"
|
||||||
|
)
|
||||||
code.append(" k += vLen) {")
|
code.append(" k += vLen) {")
|
||||||
code.extend(compute(InType, use_weights))
|
code.extend(compute(InType, use_weights))
|
||||||
code.append(" }\n")
|
code.append(" }\n")
|
||||||
@ -274,9 +272,11 @@ def generic(IndexType, InType, OutType, use_weights):
|
|||||||
code.append(" const float len_inv = 1.0f / length;")
|
code.append(" const float len_inv = 1.0f / length;")
|
||||||
code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
||||||
code.append(" svbool_t pg;")
|
code.append(" svbool_t pg;")
|
||||||
code.append(" for (int64_t j = 0;\n"
|
code.append(
|
||||||
" svptest_first(svAll, pg = svwhilelt_b32_s64("
|
" for (int64_t j = 0;\n"
|
||||||
"j, block_size));")
|
" svptest_first(svAll, pg = svwhilelt_b32_s64("
|
||||||
|
"j, block_size));"
|
||||||
|
)
|
||||||
code.append(" j += vLen) {")
|
code.append(" j += vLen) {")
|
||||||
code.append(
|
code.append(
|
||||||
" svst1_f32(\n"
|
" svst1_f32(\n"
|
||||||
@ -287,6 +287,7 @@ def generic(IndexType, InType, OutType, use_weights):
|
|||||||
code.append(" }")
|
code.append(" }")
|
||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-f", "--filename", help="file name")
|
parser.add_argument("-f", "--filename", help="file name")
|
||||||
@ -375,12 +376,21 @@ def main():
|
|||||||
|
|
||||||
# Resolve the Lint warnings: Limit of 80 characters in one line.
|
# Resolve the Lint warnings: Limit of 80 characters in one line.
|
||||||
extra_space = "\n "
|
extra_space = "\n "
|
||||||
ret_string = " return " + fn_base + suffix \
|
ret_string = (
|
||||||
+ "<" + is_weight_positional + ">("
|
" return " + fn_base + suffix + "<" + is_weight_positional + ">("
|
||||||
|
)
|
||||||
if len(ret_string) <= 80:
|
if len(ret_string) <= 80:
|
||||||
code.append(ret_string)
|
code.append(ret_string)
|
||||||
else:
|
else:
|
||||||
code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(")
|
code.append(
|
||||||
|
" return "
|
||||||
|
+ fn_base
|
||||||
|
+ suffix
|
||||||
|
+ "<"
|
||||||
|
+ extra_space
|
||||||
|
+ is_weight_positional
|
||||||
|
+ ">("
|
||||||
|
)
|
||||||
|
|
||||||
code.append(" block_size,")
|
code.append(" block_size,")
|
||||||
code.append(" output_size,")
|
code.append(" output_size,")
|
||||||
@ -404,5 +414,6 @@ def main():
|
|||||||
|
|
||||||
print("Created " + filename)
|
print("Created " + filename)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import dis
|
import dis
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import functorch._C
|
import functorch._C
|
||||||
import torch
|
import torch
|
||||||
|
@ -27,7 +27,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import keyword
|
import keyword
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
|
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Collection, Mapping
|
||||||
|
|
||||||
|
|
||||||
_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
|
_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Callable, Dict, List, Sequence, Tuple, Union
|
from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from functorch._C import dim as _C
|
from functorch._C import dim as _C
|
||||||
@ -15,6 +15,10 @@ from ._parsing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["rearrange"]
|
__all__ = ["rearrange"]
|
||||||
|
|
||||||
dims = _C.dims
|
dims = _C.dims
|
||||||
|
Reference in New Issue
Block a user