Fix ruff warnings in caffe2 and functorch (#144182)

In preparation for upgrading ruff config to py3.9.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144182
Approved by: https://github.com/malfet
This commit is contained in:
cyy
2025-01-04 04:15:01 +00:00
committed by PyTorch MergeBot
parent ec1f56fdcf
commit f9bf9057ef
5 changed files with 95 additions and 64 deletions

View File

@ -14,15 +14,14 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
if InType == "float": if InType == "float":
code.append( code.append(
f" vop{regid:d} = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + ({regid:d})), vop{regid:d});" # noqa f" vop{regid:d} = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + ({regid:d})), vop{regid:d});"
) )
elif InType == "at::Half": elif InType == "at::Half":
code.append( code.append(
f" vop{regid:d} = _mm256_fmadd_ps(\n" f" vop{regid:d} = _mm256_fmadd_ps(\n"
" vwgt,\n" " vwgt,\n"
" _mm256_cvtph_ps(\n" " _mm256_cvtph_ps(\n"
f" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n" # noqa f" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n"
f" vop{regid:d});" f" vop{regid:d});"
) )
elif InType == "at::BFloat16": elif InType == "at::BFloat16":
@ -32,7 +31,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
" _mm256_castsi256_ps(_mm256_slli_epi32(\n" " _mm256_castsi256_ps(_mm256_slli_epi32(\n"
" _mm256_cvtepu16_epi32(_mm_loadu_si128(\n" " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
f" reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n" f" reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n"
" 16)),\n" # noqa " 16)),\n"
f" vop{regid:d});" f" vop{regid:d});"
) )
elif InType == "uint8_t": elif InType == "uint8_t":
@ -40,17 +39,16 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
f" vop{regid:d} = _mm256_fmadd_ps(\n" f" vop{regid:d} = _mm256_fmadd_ps(\n"
" vwgt,\n" " vwgt,\n"
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n" " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
f" _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + ({regid:d}))))),\n" # noqa f" _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + ({regid:d}))))),\n"
f" _mm256_add_ps(vop{regid:d}, vbio));" f" _mm256_add_ps(vop{regid:d}, vbio));"
) )
else: else:
assert False raise AssertionError
if prefetch: if prefetch:
code.append( code.append(
" _mm_prefetch(\n" " _mm_prefetch(\n"
f" reinterpret_cast<const char*>(&ip_next_T0[{regid:d}]), _MM_HINT_T0);" f" reinterpret_cast<const char*>(&ip_next_T0[{regid:d}]), _MM_HINT_T0);"
) )
else: else:
code.append( code.append(
@ -93,7 +91,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
code.append( code.append(
" for (" " for ("
+ "int64_t" + "int64_t"
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {"
) )
else: else:
code.append( code.append(
@ -104,7 +102,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
code.append( code.append(
" for (" " for ("
+ IndexType + IndexType
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {"
) )
code.append(" const " + IndexType + " idx = indices[dataInd];") code.append(" const " + IndexType + " idx = indices[dataInd];")
code.append( code.append(
@ -119,7 +117,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
code.append(" " + OutType + " bio;") code.append(" " + OutType + " bio;")
code.append(" if (weights) {") code.append(" if (weights) {")
code.append( code.append(
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
) )
code.append(" }") code.append(" }")
if fused: if fused:
@ -137,7 +135,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
code.append(" " + OutType + " wgt = 1.f;") code.append(" " + OutType + " wgt = 1.f;")
code.append(" if (weights) {") code.append(" if (weights) {")
code.append( code.append(
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
) )
code.append(" }") code.append(" }")
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);") code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
@ -182,7 +180,9 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
if use_offsets: if use_offsets:
code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);") code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);")
else: else:
code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);") code.append(
" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);"
)
for i in range(0, uf): for i in range(0, uf):
j = 8 * i j = 8 * i
code.append( code.append(
@ -207,7 +207,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
" _mm256_storeu_ps(\n" " _mm256_storeu_ps(\n"
" &op[j],\n" " &op[j],\n"
" _mm256_fmadd_ps(\n" " _mm256_fmadd_ps(\n"
" vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));" # noqa " vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));"
) )
elif InType == "at::Half": elif InType == "at::Half":
code.append( code.append(
@ -237,12 +237,12 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
" &op[j],\n" " &op[j],\n"
" _mm256_fmadd_ps(\n" " _mm256_fmadd_ps(\n"
" vwgt,\n" " vwgt,\n"
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n" # noqa " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n"
" reinterpret_cast<const __m128i*>(&ip[j])))),\n" " reinterpret_cast<const __m128i*>(&ip[j])))),\n"
" _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));" " _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));"
) )
else: else:
assert False raise AssertionError
code.append( code.append(
" _mm_prefetch(\n" " _mm_prefetch(\n"
@ -257,7 +257,6 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
if InType == "at::BFloat16": if InType == "at::BFloat16":
code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};") code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};")
if use_offsets: if use_offsets:
code.append( code.append(
" for (" " for ("
@ -295,7 +294,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
code.append( code.append(
" for (" " for ("
+ "int64_t" + "int64_t"
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {"
) )
else: else:
code.append( code.append(
@ -306,7 +305,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
code.append( code.append(
" for (" " for ("
+ IndexType + IndexType
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {"
) )
code.append(" const " + IndexType + " idx = indices[dataInd];") code.append(" const " + IndexType + " idx = indices[dataInd];")
code.append( code.append(
@ -321,7 +320,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
code.append(" " + OutType + " bio;") code.append(" " + OutType + " bio;")
code.append(" if (weights) {") code.append(" if (weights) {")
code.append( code.append(
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
) )
code.append(" }") code.append(" }")
if fused: if fused:
@ -339,7 +338,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
code.append(" " + OutType + " wgt = 1.f;") code.append(" " + OutType + " wgt = 1.f;")
code.append(" if (weights) {") code.append(" if (weights) {")
code.append( code.append(
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
) )
code.append(" }") code.append(" }")
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);") code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
@ -390,7 +389,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
elif InType == "uint8_t": elif InType == "uint8_t":
code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);") code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
else: else:
assert False raise AssertionError
code.append(" }") code.append(" }")
@ -496,13 +495,13 @@ for o in options:
code += args code += args
code.append(" const " + IndexType + " prefdist_T0 = 16;") code.append(" const " + IndexType + " prefdist_T0 = 16;")
code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)") code.append(
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)"
)
# block_size is the number of elements and fused_block_size is the size of # block_size is the number of elements and fused_block_size is the size of
# an entire row, including scale and bias. # an entire row, including scale and bias.
offset = (8 // sizeof[InType]) if opts.fused else 0 offset = (8 // sizeof[InType]) if opts.fused else 0
code.append( code.append(f" const {IndexType} fused_block_size = block_size + {offset};")
f" const {IndexType} fused_block_size = block_size + {offset};"
)
if opts.use_offsets: if opts.use_offsets:
code.append(" int64_t dataInd = 0;") code.append(" int64_t dataInd = 0;")
else: else:
@ -511,17 +510,29 @@ for o in options:
# code.append("printf(\"calling " + fn + "\\n\");"); # code.append("printf(\"calling " + fn + "\\n\");");
code.append(" if (block_size == 128) {") code.append(" if (block_size == 128) {")
code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code += unroll(
16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
)
code.append(" } else if (block_size == 64) {") code.append(" } else if (block_size == 64) {")
code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code += unroll(
8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
)
code.append(" } else if (block_size == 32) {") code.append(" } else if (block_size == 32) {")
code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code += unroll(
4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
)
code.append(" } else if (block_size == 16) {") code.append(" } else if (block_size == 16) {")
code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code += unroll(
2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
)
code.append(" } else {") code.append(" } else {")
code.append(" // generic code") code.append(" // generic code")
code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)") code.append(
code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) " // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)"
)
code += generic(
IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
)
code.append(" }") code.append(" }")
code.append(" return dataInd == index_size;") code.append(" return dataInd == index_size;")
@ -558,7 +569,7 @@ for o in options:
code.append("} // namespace caffe2") code.append("} // namespace caffe2")
with open(filename, "w") as fout: with open(filename, "w", encoding="utf8") as fout:
for c in code: for c in code:
# print(c, file = fout) # print(c, file = fout)
fout.write(c + "\n") fout.write(c + "\n")

View File

@ -2,6 +2,7 @@
import argparse import argparse
import sys import sys
# Unroll loops when block_size is a multiple of vector length. # Unroll loops when block_size is a multiple of vector length.
def unroll(num_unrolls, IndexType, InType, OutType, use_weights): def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
def compute(regid, InType, use_weights): def compute(regid, InType, use_weights):
@ -23,7 +24,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
" svAll,\n" " svAll,\n"
" svreinterpret_f16_u32(svld1uh_u32(\n" " svreinterpret_f16_u32(svld1uh_u32(\n"
" svAll, reinterpret_cast<const uint16_t*>(" " svAll, reinterpret_cast<const uint16_t*>("
f"&ip[{regid} * vLen])))),\n" # noqa f"&ip[{regid} * vLen])))),\n"
f" vsum{regid});" f" vsum{regid});"
) )
elif InType == "at::BFloat16": elif InType == "at::BFloat16":
@ -36,7 +37,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
" svld1uh_u32(\n" " svld1uh_u32(\n"
" svAll, reinterpret_cast<const uint16_t*>(" " svAll, reinterpret_cast<const uint16_t*>("
f"&ip[{regid} * vLen])),\n" f"&ip[{regid} * vLen])),\n"
" 16)),\n" # noqa " 16)),\n"
f" vsum{regid});" f" vsum{regid});"
) )
elif InType == "uint8_t": elif InType == "uint8_t":
@ -45,11 +46,11 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
" svAll,\n" " svAll,\n"
" vwgt,\n" " vwgt,\n"
" svcvt_f32_u32_x(svAll," " svcvt_f32_u32_x(svAll,"
f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" # noqa f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n"
f" svadd_f32_x(svAll, vsum{regid}, vbio));" f" svadd_f32_x(svAll, vsum{regid}, vbio));"
) )
else: else:
raise ValueError(f"Unknown datatype \"{InType}\"") raise ValueError(f'Unknown datatype "{InType}"')
return code return code
@ -74,9 +75,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
int64_t start_offset = offsets[i]; int64_t start_offset = offsets[i];
int64_t end_offset = offsets[i + 1];""") int64_t end_offset = offsets[i + 1];""")
code.append( code.append(
" for (" " for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {"
+ "int64_t"
+ " j = start_offset; j < end_offset; ++j) {" # noqa
) )
code.append(" const auto idx = indices[pos];") code.append(" const auto idx = indices[pos];")
@ -91,7 +90,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
code.append(" " + OutType + " bio{};") code.append(" " + OutType + " bio{};")
code.append(" if (weights) {") code.append(" if (weights) {")
code.append( code.append(
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
) )
code.append(" }") code.append(" }")
code.append(" if (scale_bias) {") code.append(" if (scale_bias) {")
@ -103,7 +102,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
code.append(" " + OutType + " wgt = 1.f;") code.append(" " + OutType + " wgt = 1.f;")
code.append(" if (weights) {") code.append(" if (weights) {")
code.append( code.append(
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
) )
code.append(" }") code.append(" }")
@ -124,8 +123,10 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);") code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
for i in range(num_unrolls): for i in range(num_unrolls):
code.append(f" svst1_f32(svAll, &op[{i} * vLen]," code.append(
+ f" svmul_f32_x(svAll, vsum{i}, vlen_inv));") f" svst1_f32(svAll, &op[{i} * vLen],"
+ f" svmul_f32_x(svAll, vsum{i}, vlen_inv));"
)
code.append(" } else {") code.append(" } else {")
# inv of length # inv of length
@ -190,20 +191,18 @@ def generic(IndexType, InType, OutType, use_weights):
" pg,\n" " pg,\n"
" vwgt,\n" " vwgt,\n"
" svcvt_f32_u32_x(pg," " svcvt_f32_u32_x(pg,"
" svld1ub_u32(pg, &ip[k])),\n" # noqa " svld1ub_u32(pg, &ip[k])),\n"
" svadd_f32_x(pg," " svadd_f32_x(pg,"
" svld1_f32(pg, &op[k]), vbio)));" " svld1_f32(pg, &op[k]), vbio)));"
) )
else: else:
raise ValueError(f"Unknown datatype \"{InType}\"") raise ValueError(f'Unknown datatype "{InType}"')
return code return code
code = [] code = []
code.append( code.append(" for (int64_t i = 0; i < output_size; ++i) {")
" for (int64_t i = 0; i < output_size; ++i) {"
)
code.append(" " + OutType + "* const op = &out[i * block_size];") code.append(" " + OutType + "* const op = &out[i * block_size];")
@ -221,9 +220,7 @@ def generic(IndexType, InType, OutType, use_weights):
+ " int64_t end_offset = offsets[i + 1];" + " int64_t end_offset = offsets[i + 1];"
) )
code.append( code.append(
" for (" " for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {"
+ "int64_t"
+ " j = start_offset; j < end_offset; ++j) {" # noqa
) )
code.append(" const auto idx = indices[pos];") code.append(" const auto idx = indices[pos];")
@ -239,7 +236,7 @@ def generic(IndexType, InType, OutType, use_weights):
code.append(" " + OutType + " bio{};") code.append(" " + OutType + " bio{};")
code.append(" if (weights) {") code.append(" if (weights) {")
code.append( code.append(
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
) )
code.append(" }") code.append(" }")
code.append(" if (scale_bias) {") code.append(" if (scale_bias) {")
@ -251,7 +248,7 @@ def generic(IndexType, InType, OutType, use_weights):
code.append(" " + OutType + " wgt = 1.f;") code.append(" " + OutType + " wgt = 1.f;")
code.append(" if (weights) {") code.append(" if (weights) {")
code.append( code.append(
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
) )
code.append(" }") code.append(" }")
@ -261,8 +258,9 @@ def generic(IndexType, InType, OutType, use_weights):
# compute and store main loop # compute and store main loop
code.append(" svbool_t pg;") code.append(" svbool_t pg;")
code.append(" for (int64_t k = 0;") code.append(" for (int64_t k = 0;")
code.append(" svptest_first(svAll, pg = svwhilelt_b32_s64(" code.append(
+ "k, block_size));") " svptest_first(svAll, pg = svwhilelt_b32_s64(" + "k, block_size));"
)
code.append(" k += vLen) {") code.append(" k += vLen) {")
code.extend(compute(InType, use_weights)) code.extend(compute(InType, use_weights))
code.append(" }\n") code.append(" }\n")
@ -274,9 +272,11 @@ def generic(IndexType, InType, OutType, use_weights):
code.append(" const float len_inv = 1.0f / length;") code.append(" const float len_inv = 1.0f / length;")
code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);") code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
code.append(" svbool_t pg;") code.append(" svbool_t pg;")
code.append(" for (int64_t j = 0;\n" code.append(
" svptest_first(svAll, pg = svwhilelt_b32_s64(" " for (int64_t j = 0;\n"
"j, block_size));") " svptest_first(svAll, pg = svwhilelt_b32_s64("
"j, block_size));"
)
code.append(" j += vLen) {") code.append(" j += vLen) {")
code.append( code.append(
" svst1_f32(\n" " svst1_f32(\n"
@ -287,6 +287,7 @@ def generic(IndexType, InType, OutType, use_weights):
code.append(" }") code.append(" }")
return code return code
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-f", "--filename", help="file name") parser.add_argument("-f", "--filename", help="file name")
@ -375,12 +376,21 @@ def main():
# Resolve the Lint warnings: Limit of 80 characters in one line. # Resolve the Lint warnings: Limit of 80 characters in one line.
extra_space = "\n " extra_space = "\n "
ret_string = " return " + fn_base + suffix \ ret_string = (
+ "<" + is_weight_positional + ">(" " return " + fn_base + suffix + "<" + is_weight_positional + ">("
)
if len(ret_string) <= 80: if len(ret_string) <= 80:
code.append(ret_string) code.append(ret_string)
else: else:
code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(") code.append(
" return "
+ fn_base
+ suffix
+ "<"
+ extra_space
+ is_weight_positional
+ ">("
)
code.append(" block_size,") code.append(" block_size,")
code.append(" output_size,") code.append(" output_size,")
@ -404,5 +414,6 @@ def main():
print("Created " + filename) print("Created " + filename)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,6 +1,7 @@
import dis import dis
import inspect import inspect
from typing import Sequence, Union from collections.abc import Sequence
from typing import Union
import functorch._C import functorch._C
import torch import torch

View File

@ -27,7 +27,11 @@ from __future__ import annotations
import keyword import keyword
import warnings import warnings
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
if TYPE_CHECKING:
from collections.abc import Collection, Mapping
_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated _ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
from typing import Callable, Dict, List, Sequence, Tuple, Union from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union
import torch import torch
from functorch._C import dim as _C from functorch._C import dim as _C
@ -15,6 +15,10 @@ from ._parsing import (
) )
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["rearrange"] __all__ = ["rearrange"]
dims = _C.dims dims = _C.dims