From f9bf9057efc1fa881be9973bfd81268839e762db Mon Sep 17 00:00:00 2001 From: cyy Date: Sat, 4 Jan 2025 04:15:01 +0000 Subject: [PATCH] Fix ruff warnings in caffe2 and functorch (#144182) In preparation for upgrading ruff config to py3.9. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144182 Approved by: https://github.com/malfet --- caffe2/perfkernels/hp_emblookup_codegen.py | 75 ++++++++++++--------- caffe2/perfkernels/sve_emblookup_codegen.py | 69 +++++++++++-------- functorch/dim/__init__.py | 3 +- functorch/einops/_parsing.py | 6 +- functorch/einops/rearrange.py | 6 +- 5 files changed, 95 insertions(+), 64 deletions(-) diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index a3825a599d11..ca6da8ef4ffc 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -14,15 +14,14 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets) if InType == "float": code.append( - f" vop{regid:d} = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + ({regid:d})), vop{regid:d});" # noqa - + f" vop{regid:d} = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + ({regid:d})), vop{regid:d});" ) elif InType == "at::Half": code.append( f" vop{regid:d} = _mm256_fmadd_ps(\n" " vwgt,\n" " _mm256_cvtph_ps(\n" - f" _mm_loadu_si128(reinterpret_cast(ip + ({regid:d})))),\n" # noqa + f" _mm_loadu_si128(reinterpret_cast(ip + ({regid:d})))),\n" f" vop{regid:d});" ) elif InType == "at::BFloat16": @@ -32,7 +31,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets) " _mm256_castsi256_ps(_mm256_slli_epi32(\n" " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n" f" reinterpret_cast(ip + ({regid:d})))),\n" - " 16)),\n" # noqa + " 16)),\n" f" vop{regid:d});" ) elif InType == "uint8_t": @@ -40,17 +39,16 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets) f" vop{regid:d} = _mm256_fmadd_ps(\n" " vwgt,\n" " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n" - f" _mm_loadl_epi64(reinterpret_cast(ip + ({regid:d}))))),\n" # noqa + f" _mm_loadl_epi64(reinterpret_cast(ip + ({regid:d}))))),\n" f" _mm256_add_ps(vop{regid:d}, vbio));" ) else: - assert False + raise AssertionError if prefetch: code.append( " _mm_prefetch(\n" f" reinterpret_cast(&ip_next_T0[{regid:d}]), _MM_HINT_T0);" - ) else: code.append( @@ -93,7 +91,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets) code.append( " for (" + "int64_t" - + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa + + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" ) else: code.append( @@ -104,7 +102,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets) code.append( " for (" + IndexType - + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa + + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" ) code.append(" const " + IndexType + " idx = indices[dataInd];") code.append( @@ -119,7 +117,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets) code.append(" " + OutType + " bio;") code.append(" if (weights) {") code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa + " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" ) code.append(" }") if fused: @@ -137,7 +135,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets) code.append(" " + OutType + " wgt = 1.f;") code.append(" if (weights) {") code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa + " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" ) code.append(" }") code.append(" __m256 vwgt = _mm256_set1_ps(wgt);") @@ -182,7 +180,9 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets) if use_offsets: code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);") else: - code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);") + code.append( + " __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);" + ) for i in range(0, uf): j = 8 * i code.append( @@ -207,7 +207,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): " _mm256_storeu_ps(\n" " &op[j],\n" " _mm256_fmadd_ps(\n" - " vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));" # noqa + " vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));" ) elif InType == "at::Half": code.append( @@ -237,12 +237,12 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): " &op[j],\n" " _mm256_fmadd_ps(\n" " vwgt,\n" - " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n" # noqa + " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n" " reinterpret_cast(&ip[j])))),\n" " _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));" ) else: - assert False + raise AssertionError code.append( " _mm_prefetch(\n" @@ -257,7 +257,6 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): if InType == "at::BFloat16": code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};") - if use_offsets: code.append( " for (" @@ -295,7 +294,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): code.append( " for (" + "int64_t" - + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa + + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" ) else: code.append( @@ -306,7 +305,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): code.append( " for (" + IndexType - + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa + + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" ) code.append(" const " + IndexType + " idx = indices[dataInd];") code.append( @@ -321,7 +320,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): code.append(" " + OutType + " bio;") code.append(" if (weights) {") code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa + " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" ) code.append(" }") if fused: @@ -339,7 +338,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): code.append(" " + OutType + " wgt = 1.f;") code.append(" if (weights) {") code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa + " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" ) code.append(" }") code.append(" __m256 vwgt = _mm256_set1_ps(wgt);") @@ -390,7 +389,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): elif InType == "uint8_t": code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);") else: - assert False + raise AssertionError code.append(" }") @@ -496,13 +495,13 @@ for o in options: code += args code.append(" const " + IndexType + " prefdist_T0 = 16;") - code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)") + code.append( + " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)" + ) # block_size is the number of elements and fused_block_size is the size of # an entire row, including scale and bias. offset = (8 // sizeof[InType]) if opts.fused else 0 - code.append( - f" const {IndexType} fused_block_size = block_size + {offset};" - ) + code.append(f" const {IndexType} fused_block_size = block_size + {offset};") if opts.use_offsets: code.append(" int64_t dataInd = 0;") else: @@ -511,17 +510,29 @@ for o in options: # code.append("printf(\"calling " + fn + "\\n\");"); code.append(" if (block_size == 128) {") - code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) + code += unroll( + 16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets + ) code.append(" } else if (block_size == 64) {") - code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) + code += unroll( + 8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets + ) code.append(" } else if (block_size == 32) {") - code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) + code += unroll( + 4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets + ) code.append(" } else if (block_size == 16) {") - code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) + code += unroll( + 2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets + ) code.append(" } else {") code.append(" // generic code") - code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)") - code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) + code.append( + " // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)" + ) + code += generic( + IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets + ) code.append(" }") code.append(" return dataInd == index_size;") @@ -558,7 +569,7 @@ for o in options: code.append("} // namespace caffe2") -with open(filename, "w") as fout: +with open(filename, "w", encoding="utf8") as fout: for c in code: # print(c, file = fout) fout.write(c + "\n") diff --git a/caffe2/perfkernels/sve_emblookup_codegen.py b/caffe2/perfkernels/sve_emblookup_codegen.py index 02f010ccc250..643b614c9081 100644 --- a/caffe2/perfkernels/sve_emblookup_codegen.py +++ b/caffe2/perfkernels/sve_emblookup_codegen.py @@ -2,6 +2,7 @@ import argparse import sys + # Unroll loops when block_size is a multiple of vector length. def unroll(num_unrolls, IndexType, InType, OutType, use_weights): def compute(regid, InType, use_weights): @@ -23,7 +24,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights): " svAll,\n" " svreinterpret_f16_u32(svld1uh_u32(\n" " svAll, reinterpret_cast(" - f"&ip[{regid} * vLen])))),\n" # noqa + f"&ip[{regid} * vLen])))),\n" f" vsum{regid});" ) elif InType == "at::BFloat16": @@ -36,7 +37,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights): " svld1uh_u32(\n" " svAll, reinterpret_cast(" f"&ip[{regid} * vLen])),\n" - " 16)),\n" # noqa + " 16)),\n" f" vsum{regid});" ) elif InType == "uint8_t": @@ -45,11 +46,11 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights): " svAll,\n" " vwgt,\n" " svcvt_f32_u32_x(svAll," - f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" # noqa + f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" f" svadd_f32_x(svAll, vsum{regid}, vbio));" ) else: - raise ValueError(f"Unknown datatype \"{InType}\"") + raise ValueError(f'Unknown datatype "{InType}"') return code @@ -74,9 +75,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights): int64_t start_offset = offsets[i]; int64_t end_offset = offsets[i + 1];""") code.append( - " for (" - + "int64_t" - + " j = start_offset; j < end_offset; ++j) {" # noqa + " for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {" ) code.append(" const auto idx = indices[pos];") @@ -91,7 +90,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights): code.append(" " + OutType + " bio{};") code.append(" if (weights) {") code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" ) code.append(" }") code.append(" if (scale_bias) {") @@ -103,7 +102,7 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights): code.append(" " + OutType + " wgt = 1.f;") code.append(" if (weights) {") code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" ) code.append(" }") @@ -124,8 +123,10 @@ def unroll(num_unrolls, IndexType, InType, OutType, use_weights): code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);") for i in range(num_unrolls): - code.append(f" svst1_f32(svAll, &op[{i} * vLen]," - + f" svmul_f32_x(svAll, vsum{i}, vlen_inv));") + code.append( + f" svst1_f32(svAll, &op[{i} * vLen]," + + f" svmul_f32_x(svAll, vsum{i}, vlen_inv));" + ) code.append(" } else {") # inv of length @@ -190,20 +191,18 @@ def generic(IndexType, InType, OutType, use_weights): " pg,\n" " vwgt,\n" " svcvt_f32_u32_x(pg," - " svld1ub_u32(pg, &ip[k])),\n" # noqa + " svld1ub_u32(pg, &ip[k])),\n" " svadd_f32_x(pg," " svld1_f32(pg, &op[k]), vbio)));" ) else: - raise ValueError(f"Unknown datatype \"{InType}\"") + raise ValueError(f'Unknown datatype "{InType}"') return code code = [] - code.append( - " for (int64_t i = 0; i < output_size; ++i) {" - ) + code.append(" for (int64_t i = 0; i < output_size; ++i) {") code.append(" " + OutType + "* const op = &out[i * block_size];") @@ -221,9 +220,7 @@ def generic(IndexType, InType, OutType, use_weights): + " int64_t end_offset = offsets[i + 1];" ) code.append( - " for (" - + "int64_t" - + " j = start_offset; j < end_offset; ++j) {" # noqa + " for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {" ) code.append(" const auto idx = indices[pos];") @@ -239,7 +236,7 @@ def generic(IndexType, InType, OutType, use_weights): code.append(" " + OutType + " bio{};") code.append(" if (weights) {") code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" ) code.append(" }") code.append(" if (scale_bias) {") @@ -251,7 +248,7 @@ def generic(IndexType, InType, OutType, use_weights): code.append(" " + OutType + " wgt = 1.f;") code.append(" if (weights) {") code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" ) code.append(" }") @@ -261,8 +258,9 @@ def generic(IndexType, InType, OutType, use_weights): # compute and store main loop code.append(" svbool_t pg;") code.append(" for (int64_t k = 0;") - code.append(" svptest_first(svAll, pg = svwhilelt_b32_s64(" - + "k, block_size));") + code.append( + " svptest_first(svAll, pg = svwhilelt_b32_s64(" + "k, block_size));" + ) code.append(" k += vLen) {") code.extend(compute(InType, use_weights)) code.append(" }\n") @@ -274,9 +272,11 @@ def generic(IndexType, InType, OutType, use_weights): code.append(" const float len_inv = 1.0f / length;") code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);") code.append(" svbool_t pg;") - code.append(" for (int64_t j = 0;\n" - " svptest_first(svAll, pg = svwhilelt_b32_s64(" - "j, block_size));") + code.append( + " for (int64_t j = 0;\n" + " svptest_first(svAll, pg = svwhilelt_b32_s64(" + "j, block_size));" + ) code.append(" j += vLen) {") code.append( " svst1_f32(\n" @@ -287,6 +287,7 @@ def generic(IndexType, InType, OutType, use_weights): code.append(" }") return code + def main(): parser = argparse.ArgumentParser() parser.add_argument("-f", "--filename", help="file name") @@ -375,12 +376,21 @@ def main(): # Resolve the Lint warnings: Limit of 80 characters in one line. extra_space = "\n " - ret_string = " return " + fn_base + suffix \ - + "<" + is_weight_positional + ">(" + ret_string = ( + " return " + fn_base + suffix + "<" + is_weight_positional + ">(" + ) if len(ret_string) <= 80: code.append(ret_string) else: - code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(") + code.append( + " return " + + fn_base + + suffix + + "<" + + extra_space + + is_weight_positional + + ">(" + ) code.append(" block_size,") code.append(" output_size,") @@ -404,5 +414,6 @@ def main(): print("Created " + filename) + if __name__ == "__main__": main() diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index a6de6ad59e95..cc620c94e699 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -1,6 +1,7 @@ import dis import inspect -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union import functorch._C import torch diff --git a/functorch/einops/_parsing.py b/functorch/einops/_parsing.py index ee69aa60d1a5..0ef9dff72a52 100644 --- a/functorch/einops/_parsing.py +++ b/functorch/einops/_parsing.py @@ -27,7 +27,11 @@ from __future__ import annotations import keyword import warnings -from typing import Collection, List, Mapping, Optional, Set, Tuple, Union +from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from collections.abc import Collection, Mapping _ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated diff --git a/functorch/einops/rearrange.py b/functorch/einops/rearrange.py index a0bceed73883..02c27f432cba 100644 --- a/functorch/einops/rearrange.py +++ b/functorch/einops/rearrange.py @@ -1,7 +1,7 @@ from __future__ import annotations import functools -from typing import Callable, Dict, List, Sequence, Tuple, Union +from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union import torch from functorch._C import dim as _C @@ -15,6 +15,10 @@ from ._parsing import ( ) +if TYPE_CHECKING: + from collections.abc import Sequence + + __all__ = ["rearrange"] dims = _C.dims