mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix ruff warnings in caffe2 and functorch (#144182)
In preparation for upgrading ruff config to py3.9. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144182 Approved by: https://github.com/malfet
This commit is contained in:
@ -14,15 +14,14 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
||||
|
||||
if InType == "float":
|
||||
code.append(
|
||||
f" vop{regid:d} = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + ({regid:d})), vop{regid:d});" # noqa
|
||||
|
||||
f" vop{regid:d} = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + ({regid:d})), vop{regid:d});"
|
||||
)
|
||||
elif InType == "at::Half":
|
||||
code.append(
|
||||
f" vop{regid:d} = _mm256_fmadd_ps(\n"
|
||||
" vwgt,\n"
|
||||
" _mm256_cvtph_ps(\n"
|
||||
f" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n" # noqa
|
||||
f" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n"
|
||||
f" vop{regid:d});"
|
||||
)
|
||||
elif InType == "at::BFloat16":
|
||||
@ -32,7 +31,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
||||
" _mm256_castsi256_ps(_mm256_slli_epi32(\n"
|
||||
" _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
|
||||
f" reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n"
|
||||
" 16)),\n" # noqa
|
||||
" 16)),\n"
|
||||
f" vop{regid:d});"
|
||||
)
|
||||
elif InType == "uint8_t":
|
||||
@ -40,17 +39,16 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
||||
f" vop{regid:d} = _mm256_fmadd_ps(\n"
|
||||
" vwgt,\n"
|
||||
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
|
||||
f" _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + ({regid:d}))))),\n" # noqa
|
||||
f" _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + ({regid:d}))))),\n"
|
||||
f" _mm256_add_ps(vop{regid:d}, vbio));"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
raise AssertionError
|
||||
|
||||
if prefetch:
|
||||
code.append(
|
||||
" _mm_prefetch(\n"
|
||||
f" reinterpret_cast<const char*>(&ip_next_T0[{regid:d}]), _MM_HINT_T0);"
|
||||
|
||||
)
|
||||
else:
|
||||
code.append(
|
||||
@ -93,7 +91,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
||||
code.append(
|
||||
" for ("
|
||||
+ "int64_t"
|
||||
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa
|
||||
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {"
|
||||
)
|
||||
else:
|
||||
code.append(
|
||||
@ -104,7 +102,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
||||
code.append(
|
||||
" for ("
|
||||
+ IndexType
|
||||
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa
|
||||
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {"
|
||||
)
|
||||
code.append(" const " + IndexType + " idx = indices[dataInd];")
|
||||
code.append(
|
||||
@ -119,7 +117,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
||||
code.append(" " + OutType + " bio;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
|
||||
)
|
||||
code.append(" }")
|
||||
if fused:
|
||||
@ -137,7 +135,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
|
||||
@ -182,7 +180,9 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
||||
if use_offsets:
|
||||
code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);")
|
||||
else:
|
||||
code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
|
||||
code.append(
|
||||
" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);"
|
||||
)
|
||||
for i in range(0, uf):
|
||||
j = 8 * i
|
||||
code.append(
|
||||
@ -207,7 +207,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
||||
" _mm256_storeu_ps(\n"
|
||||
" &op[j],\n"
|
||||
" _mm256_fmadd_ps(\n"
|
||||
" vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));" # noqa
|
||||
" vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));"
|
||||
)
|
||||
elif InType == "at::Half":
|
||||
code.append(
|
||||
@ -237,12 +237,12 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
||||
" &op[j],\n"
|
||||
" _mm256_fmadd_ps(\n"
|
||||
" vwgt,\n"
|
||||
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n" # noqa
|
||||
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n"
|
||||
" reinterpret_cast<const __m128i*>(&ip[j])))),\n"
|
||||
" _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
raise AssertionError
|
||||
|
||||
code.append(
|
||||
" _mm_prefetch(\n"
|
||||
@ -257,7 +257,6 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
||||
if InType == "at::BFloat16":
|
||||
code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};")
|
||||
|
||||
|
||||
if use_offsets:
|
||||
code.append(
|
||||
" for ("
|
||||
@ -295,7 +294,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
||||
code.append(
|
||||
" for ("
|
||||
+ "int64_t"
|
||||
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa
|
||||
+ " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {"
|
||||
)
|
||||
else:
|
||||
code.append(
|
||||
@ -306,7 +305,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
||||
code.append(
|
||||
" for ("
|
||||
+ IndexType
|
||||
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa
|
||||
+ " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {"
|
||||
)
|
||||
code.append(" const " + IndexType + " idx = indices[dataInd];")
|
||||
code.append(
|
||||
@ -321,7 +320,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
||||
code.append(" " + OutType + " bio;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
|
||||
)
|
||||
code.append(" }")
|
||||
if fused:
|
||||
@ -339,7 +338,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
|
||||
@ -390,7 +389,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
||||
elif InType == "uint8_t":
|
||||
code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
|
||||
else:
|
||||
assert False
|
||||
raise AssertionError
|
||||
|
||||
code.append(" }")
|
||||
|
||||
@ -496,13 +495,13 @@ for o in options:
|
||||
code += args
|
||||
|
||||
code.append(" const " + IndexType + " prefdist_T0 = 16;")
|
||||
code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)")
|
||||
code.append(
|
||||
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)"
|
||||
)
|
||||
# block_size is the number of elements and fused_block_size is the size of
|
||||
# an entire row, including scale and bias.
|
||||
offset = (8 // sizeof[InType]) if opts.fused else 0
|
||||
code.append(
|
||||
f" const {IndexType} fused_block_size = block_size + {offset};"
|
||||
)
|
||||
code.append(f" const {IndexType} fused_block_size = block_size + {offset};")
|
||||
if opts.use_offsets:
|
||||
code.append(" int64_t dataInd = 0;")
|
||||
else:
|
||||
@ -511,17 +510,29 @@ for o in options:
|
||||
# code.append("printf(\"calling " + fn + "\\n\");");
|
||||
|
||||
code.append(" if (block_size == 128) {")
|
||||
code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
||||
code += unroll(
|
||||
16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||
)
|
||||
code.append(" } else if (block_size == 64) {")
|
||||
code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
||||
code += unroll(
|
||||
8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||
)
|
||||
code.append(" } else if (block_size == 32) {")
|
||||
code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
||||
code += unroll(
|
||||
4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||
)
|
||||
code.append(" } else if (block_size == 16) {")
|
||||
code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
||||
code += unroll(
|
||||
2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||
)
|
||||
code.append(" } else {")
|
||||
code.append(" // generic code")
|
||||
code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)")
|
||||
code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
||||
code.append(
|
||||
" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)"
|
||||
)
|
||||
code += generic(
|
||||
IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" return dataInd == index_size;")
|
||||
|
||||
@ -558,7 +569,7 @@ for o in options:
|
||||
|
||||
code.append("} // namespace caffe2")
|
||||
|
||||
with open(filename, "w") as fout:
|
||||
with open(filename, "w", encoding="utf8") as fout:
|
||||
for c in code:
|
||||
# print(c, file = fout)
|
||||
fout.write(c + "\n")
|
||||
|
@ -2,6 +2,7 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
# Unroll loops when block_size is a multiple of vector length.
|
||||
def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
def compute(regid, InType, use_weights):
|
||||
@ -23,7 +24,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
" svAll,\n"
|
||||
" svreinterpret_f16_u32(svld1uh_u32(\n"
|
||||
" svAll, reinterpret_cast<const uint16_t*>("
|
||||
f"&ip[{regid} * vLen])))),\n" # noqa
|
||||
f"&ip[{regid} * vLen])))),\n"
|
||||
f" vsum{regid});"
|
||||
)
|
||||
elif InType == "at::BFloat16":
|
||||
@ -36,7 +37,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
" svld1uh_u32(\n"
|
||||
" svAll, reinterpret_cast<const uint16_t*>("
|
||||
f"&ip[{regid} * vLen])),\n"
|
||||
" 16)),\n" # noqa
|
||||
" 16)),\n"
|
||||
f" vsum{regid});"
|
||||
)
|
||||
elif InType == "uint8_t":
|
||||
@ -45,11 +46,11 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
" svAll,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_u32_x(svAll,"
|
||||
f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" # noqa
|
||||
f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n"
|
||||
f" svadd_f32_x(svAll, vsum{regid}, vbio));"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown datatype \"{InType}\"")
|
||||
raise ValueError(f'Unknown datatype "{InType}"')
|
||||
|
||||
return code
|
||||
|
||||
@ -74,9 +75,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
int64_t start_offset = offsets[i];
|
||||
int64_t end_offset = offsets[i + 1];""")
|
||||
code.append(
|
||||
" for ("
|
||||
+ "int64_t"
|
||||
+ " j = start_offset; j < end_offset; ++j) {" # noqa
|
||||
" for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {"
|
||||
)
|
||||
|
||||
code.append(" const auto idx = indices[pos];")
|
||||
@ -91,7 +90,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
code.append(" " + OutType + " bio{};")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" if (scale_bias) {")
|
||||
@ -103,7 +102,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||
)
|
||||
code.append(" }")
|
||||
|
||||
@ -124,8 +123,10 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights):
|
||||
code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
||||
|
||||
for i in range(num_unrolls):
|
||||
code.append(f" svst1_f32(svAll, &op[{i} * vLen],"
|
||||
+ f" svmul_f32_x(svAll, vsum{i}, vlen_inv));")
|
||||
code.append(
|
||||
f" svst1_f32(svAll, &op[{i} * vLen],"
|
||||
+ f" svmul_f32_x(svAll, vsum{i}, vlen_inv));"
|
||||
)
|
||||
|
||||
code.append(" } else {")
|
||||
# inv of length
|
||||
@ -190,20 +191,18 @@ def generic(IndexType, InType, OutType, use_weights):
|
||||
" pg,\n"
|
||||
" vwgt,\n"
|
||||
" svcvt_f32_u32_x(pg,"
|
||||
" svld1ub_u32(pg, &ip[k])),\n" # noqa
|
||||
" svld1ub_u32(pg, &ip[k])),\n"
|
||||
" svadd_f32_x(pg,"
|
||||
" svld1_f32(pg, &op[k]), vbio)));"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown datatype \"{InType}\"")
|
||||
raise ValueError(f'Unknown datatype "{InType}"')
|
||||
|
||||
return code
|
||||
|
||||
code = []
|
||||
|
||||
code.append(
|
||||
" for (int64_t i = 0; i < output_size; ++i) {"
|
||||
)
|
||||
code.append(" for (int64_t i = 0; i < output_size; ++i) {")
|
||||
|
||||
code.append(" " + OutType + "* const op = &out[i * block_size];")
|
||||
|
||||
@ -221,9 +220,7 @@ def generic(IndexType, InType, OutType, use_weights):
|
||||
+ " int64_t end_offset = offsets[i + 1];"
|
||||
)
|
||||
code.append(
|
||||
" for ("
|
||||
+ "int64_t"
|
||||
+ " j = start_offset; j < end_offset; ++j) {" # noqa
|
||||
" for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {"
|
||||
)
|
||||
|
||||
code.append(" const auto idx = indices[pos];")
|
||||
@ -239,7 +236,7 @@ def generic(IndexType, InType, OutType, use_weights):
|
||||
code.append(" " + OutType + " bio{};")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||
)
|
||||
code.append(" }")
|
||||
code.append(" if (scale_bias) {")
|
||||
@ -251,7 +248,7 @@ def generic(IndexType, InType, OutType, use_weights):
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa
|
||||
" wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];"
|
||||
)
|
||||
code.append(" }")
|
||||
|
||||
@ -261,8 +258,9 @@ def generic(IndexType, InType, OutType, use_weights):
|
||||
# compute and store main loop
|
||||
code.append(" svbool_t pg;")
|
||||
code.append(" for (int64_t k = 0;")
|
||||
code.append(" svptest_first(svAll, pg = svwhilelt_b32_s64("
|
||||
+ "k, block_size));")
|
||||
code.append(
|
||||
" svptest_first(svAll, pg = svwhilelt_b32_s64(" + "k, block_size));"
|
||||
)
|
||||
code.append(" k += vLen) {")
|
||||
code.extend(compute(InType, use_weights))
|
||||
code.append(" }\n")
|
||||
@ -274,9 +272,11 @@ def generic(IndexType, InType, OutType, use_weights):
|
||||
code.append(" const float len_inv = 1.0f / length;")
|
||||
code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);")
|
||||
code.append(" svbool_t pg;")
|
||||
code.append(" for (int64_t j = 0;\n"
|
||||
" svptest_first(svAll, pg = svwhilelt_b32_s64("
|
||||
"j, block_size));")
|
||||
code.append(
|
||||
" for (int64_t j = 0;\n"
|
||||
" svptest_first(svAll, pg = svwhilelt_b32_s64("
|
||||
"j, block_size));"
|
||||
)
|
||||
code.append(" j += vLen) {")
|
||||
code.append(
|
||||
" svst1_f32(\n"
|
||||
@ -287,6 +287,7 @@ def generic(IndexType, InType, OutType, use_weights):
|
||||
code.append(" }")
|
||||
return code
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f", "--filename", help="file name")
|
||||
@ -375,12 +376,21 @@ def main():
|
||||
|
||||
# Resolve the Lint warnings: Limit of 80 characters in one line.
|
||||
extra_space = "\n "
|
||||
ret_string = " return " + fn_base + suffix \
|
||||
+ "<" + is_weight_positional + ">("
|
||||
ret_string = (
|
||||
" return " + fn_base + suffix + "<" + is_weight_positional + ">("
|
||||
)
|
||||
if len(ret_string) <= 80:
|
||||
code.append(ret_string)
|
||||
else:
|
||||
code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(")
|
||||
code.append(
|
||||
" return "
|
||||
+ fn_base
|
||||
+ suffix
|
||||
+ "<"
|
||||
+ extra_space
|
||||
+ is_weight_positional
|
||||
+ ">("
|
||||
)
|
||||
|
||||
code.append(" block_size,")
|
||||
code.append(" output_size,")
|
||||
@ -404,5 +414,6 @@ def main():
|
||||
|
||||
print("Created " + filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,6 +1,7 @@
|
||||
import dis
|
||||
import inspect
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import functorch._C
|
||||
import torch
|
||||
|
@ -27,7 +27,11 @@ from __future__ import annotations
|
||||
|
||||
import keyword
|
||||
import warnings
|
||||
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Mapping
|
||||
|
||||
|
||||
_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Callable, Dict, List, Sequence, Tuple, Union
|
||||
from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from functorch._C import dim as _C
|
||||
@ -15,6 +15,10 @@ from ._parsing import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
__all__ = ["rearrange"]
|
||||
|
||||
dims = _C.dims
|
||||
|
Reference in New Issue
Block a user