Expand type checking to mypy strict files (#165697)

Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697
Approved by: https://github.com/ezyang
This commit is contained in:
Maggie Moss
2025-10-18 04:34:41 +00:00
committed by PyTorch MergeBot
parent 9095a9dfae
commit f02e3947f6
42 changed files with 89 additions and 11 deletions

View File

@ -5,6 +5,7 @@ python-version = "3.12"
project-includes = [
"torch",
"caffe2",
"tools",
"test/test_bundled_images.py",
"test/test_bundled_inputs.py",
"test/test_complex.py",
@ -24,8 +25,11 @@ project-excludes = [
# ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/runtime",
"torch/_inductor/codegen/triton.py",
"tools/linter/adapters/test_device_bias_linter.py",
"tools/code_analyzer/gen_operators_yaml.py",
# formatting issues, will turn on after adjusting where suppressions can be
# in import statements
"tools/flight_recorder/components/types.py",
"torch/linalg/__init__.py",
"torch/package/importer.py",
"torch/package/_package_pickler.py",
@ -40,17 +44,6 @@ project-excludes = [
"torch/distributed/elastic/metrics/__init__.py",
"torch/_inductor/fx_passes/bucketing.py",
# ====
"benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py",
"benchmarks/instruction_counts/applications/ci.py",
"benchmarks/instruction_counts/core/api.py",
"benchmarks/instruction_counts/core/expand.py",
"benchmarks/instruction_counts/core/types.py",
"benchmarks/instruction_counts/core/utils.py",
"benchmarks/instruction_counts/definitions/standard.py",
"benchmarks/instruction_counts/definitions/setup.py",
"benchmarks/instruction_counts/execution/runner.py",
"benchmarks/instruction_counts/execution/work.py",
"torch/include/**",
"torch/csrc/**",
"torch/distributed/elastic/agent/server/api.py",
@ -137,3 +130,4 @@ errors.bad-param-name-override = false
errors.implicit-import = false
permissive-ignores = true
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
search-path = ["tools/experimental"]

View File

@ -863,6 +863,7 @@ static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
saved_variables.append(f"{type.cpp_type()} {name};")
if type in MISC_GETTER_DEFS:
# pyrefly: ignore # index-error
getter_def, body = MISC_GETTER_DEFS[type]
getter_definitions.append(
getter_def.substitute(op=info.op, name=name, body=body)
@ -1033,6 +1034,7 @@ static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
unpack_ivalues = []
for typ, name in zip(apply_functional_args_ref_types, apply_functional_args):
typ = typ.removesuffix("&")
# pyrefly: ignore # bad-argument-type
unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();")
schema_args = [f"std::array<bool, {len(input_name_to_idx)}>"]

View File

@ -182,6 +182,7 @@ def format_trace_inputs(f: NativeFunction) -> str:
ADD_TRACE_INPUT.substitute(
name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name
)
# pyrefly: ignore # unbound-name
for i in range(num_out_args)
]

View File

@ -1495,6 +1495,7 @@ def emit_body(
else:
expr = f"SavedVariable({var}, {str(is_output).lower()})"
if foreacharg is not None and "original_selfs" not in expr:
# pyrefly: ignore # unbound-name
expr = expr.replace(src_name, name_in_expr)
elif (
type == BaseCType(tensorListT)
@ -1844,12 +1845,14 @@ def emit_body(
)
)
cur_derivative_conditions.append(
# pyrefly: ignore # bad-argument-type
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
req_inp=inp_name + "[i]"
)
)
else:
cur_derivative_conditions.append(
# pyrefly: ignore # bad-argument-type
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name)
)
@ -1920,6 +1923,7 @@ def emit_body(
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
inp_name="original_self",
inp="original_self" + input_suffix,
# pyrefly: ignore # unbound-name
zeros_fn=zeros_fn,
)
unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(

View File

@ -95,8 +95,11 @@ def add_view_copy_derivatives(
else:
break
# prefer manually-defined derivatives if any
# pyrefly: ignore # unbound-name
if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
# pyrefly: ignore # unbound-name
assert fn_schema is not None
# pyrefly: ignore # unbound-name
view_infos[fn_schema] = view_copy_differentiability_infos
infos.update(view_infos)
@ -398,6 +401,7 @@ def postprocess_forward_derivatives(
for arg_name in all_arg_names:
if arg_name in diff_arg_names:
arg_name = arg_name + "_t"
# pyrefly: ignore # bad-argument-type
new_args.append(arg_name)
# TODO we are trolling
@ -938,6 +942,7 @@ def saved_variables(
+ f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
)
for nctype in nctypes:
# pyrefly: ignore # bad-assignment
name = (
nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
)
@ -947,6 +952,7 @@ def saved_variables(
def repl(m: re.Match[str]) -> str:
suffix: str = (
# pyrefly: ignore # bad-assignment
info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
)
expr: str = info["expr"](name) if "expr" in info else m.group(0)

View File

@ -67,6 +67,7 @@ def is_intrested_file(
# ignore files that are not belong to pytorch
if platform == TestPlatform.OSS:
# pyrefly: ignore # import-error
from package.oss.utils import get_pytorch_folder
if not file_path.startswith(get_pytorch_folder()):

View File

@ -24,6 +24,7 @@ def report_download_progress(
file_size: int,
) -> None:
if file_size != -1:
# pyrefly: ignore # no-matching-overload
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = "#" * int(64 * percent)
sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")

View File

@ -105,8 +105,10 @@ def extract_info_from_keyword(source: str, kw: ast.keyword) -> Any:
evaluated_context = []
for value in kw.value.values:
if isinstance(value, ast.FormattedValue):
# pyrefly: ignore # bad-argument-type
evaluated_context.append(f"{{{ast.unparse(value.value)}}}")
elif isinstance(value, ast.Constant):
# pyrefly: ignore # bad-argument-type
evaluated_context.append(value.value)
return "".join(evaluated_context)
else:
@ -152,6 +154,7 @@ def find_unimplemented_v2_calls(
for kw in node.keywords:
if kw.arg in info:
# pyrefly: ignore # unsupported-operation
info[kw.arg] = extract_info_from_keyword(source, kw)
if info["gb_type"] is None:

View File

@ -296,6 +296,7 @@ def run_multi_process_fuzzer(
)
def write_func(msg):
# pyrefly: ignore # missing-attribute
pbar.write(msg)
else:
persist_print("Progress: (install tqdm for better progress bar)")

View File

@ -111,6 +111,7 @@ class ConstantOperator(Operator):
]:
# Clamp integer values to [0, 3] to avoid index overflow in multiplication
# Even with multiplication, indices should stay in reasonable range
# pyrefly: ignore # bad-argument-type
fill_value = max(0, min(3, abs(fill_value)))
tensor_creation = (

View File

@ -78,15 +78,22 @@ class JobConfig:
def parse_args(
self: "JobConfig", args: Optional[Sequence[str]]
) -> argparse.Namespace:
# pyrefly: ignore # bad-assignment
args = self.parser.parse_args(args)
# pyrefly: ignore # missing-attribute
if args.selected_ranks is not None:
# pyrefly: ignore # missing-attribute
assert args.just_print_entries, (
"Not support selecting ranks without printing entries"
)
# pyrefly: ignore # missing-attribute
if args.pg_filters is not None:
# pyrefly: ignore # missing-attribute
assert args.just_print_entries, (
"Not support selecting pg filters without printing entries"
)
# pyrefly: ignore # missing-attribute
if args.verbose:
logger.set_log_level(logging.DEBUG)
# pyrefly: ignore # bad-return
return args

View File

@ -41,6 +41,7 @@ def format_frame(frame: dict[str, str]) -> str:
def format_frames(frames: list[dict[str, str]]) -> str:
formatted_frames = []
for frame in frames:
# pyrefly: ignore # bad-argument-type
formatted_frames.append(format_frame(frame))
return "\n".join(formatted_frames)
@ -695,6 +696,7 @@ def check_version(version_by_ranks: dict[str, str], version: str) -> None:
def get_version_detail(version: str) -> tuple[int, int]:
# pyrefly: ignore # bad-assignment
version = version.split(".")
assert len(version) == 2, f"Invalid version {version}"
major, minor = map(int, version)

View File

@ -40,11 +40,17 @@ from tools.flight_recorder.components.types import types
def main(args: Optional[Sequence[str]] = None) -> None:
config = JobConfig()
# pyrefly: ignore # bad-assignment
args = config.parse_args(args)
# pyrefly: ignore # missing-attribute
assert args.trace_dir, "Trace directory trace_dir is required"
# pyrefly: ignore # bad-argument-type
details, version = read_dir(args)
# pyrefly: ignore # bad-argument-type
db = build_db(details, args, version)
# pyrefly: ignore # missing-attribute
if args.output:
# pyrefly: ignore # no-matching-overload
with open(args.output, "wb") as f:
pickle.dump((types, db), f)

View File

@ -34,6 +34,7 @@ class TensorRepr(gdb.Command): # type: ignore[misc, no-any-unimported]
on it.
"""
# pyrefly: ignore # bad-argument-type
__doc__ = textwrap.dedent(__doc__).strip()
def __init__(self) -> None:

View File

@ -118,6 +118,7 @@ def extract_filename(path: str, keep_ext: bool = True) -> Any:
# https://gist.github.com/pypt/94d747fe5180851196eb
# pyrefly: ignore # invalid-inheritance
class UniqueKeyLoader(Loader):
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
if not isinstance(node, MappingNode):
@ -233,6 +234,7 @@ def preprocess(
last_indent = input_indent
while blank_lines != 0:
# pyrefly: ignore # unbound-name
python_lines.append(python_indent + "print(file=OUT_STREAM)")
blank_lines -= 1
@ -667,6 +669,7 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
" ",
)
# pyrefly: ignore # unbound-name
return shader_dispatch_str
@ -681,15 +684,18 @@ def genCppFiles(
name = getName(spvPath).replace("_spv", "")
sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
# pyrefly: ignore # bad-argument-type
spv_bin_strs.append(spv_bin_str)
shader_info = getShaderInfo(srcPath)
register_shader_info_strs.append(
# pyrefly: ignore # bad-argument-type
generateShaderInfoStr(shader_info, name, sizeBytes)
)
if shader_info.register_for is not None:
# pyrefly: ignore # bad-argument-type
shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
spv_bin_arrays = "\n".join(spv_bin_strs)

View File

@ -131,12 +131,14 @@ class ComputeCodegenUnboxedKernels:
else:
arg_cpp = f"c10::IValue({arg_default})"
args_code.append(
# pyrefly: ignore # bad-argument-type
f"""c10::Argument("{arg.name}", nullptr, ::std::nullopt, {arg_cpp})"""
)
returns = f.func.returns
returns_code = []
for ret in returns:
# pyrefly: ignore # bad-argument-type
returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""")
return f"""
// aten::{schema}

View File

@ -112,6 +112,7 @@ class FileLinter:
first_results = None
original = replacement = pf.contents
# pyrefly: ignore # bad-assignment
while True:
try:
results = sorted(self._lint(pf), key=LintResult.sort_key)

View File

@ -41,6 +41,7 @@ class LineWithSets:
t = self.tokens[i]
after = i < len(self.tokens) - 1 and self.tokens[i + 1]
if t.string == "Set" and t.type == token.NAME:
# pyrefly: ignore # bad-return
return after and after.string == "[" and after.type == token.OP
return (
(t.string == "set" and t.type == token.NAME)

View File

@ -19,11 +19,13 @@ from typing import NamedTuple
# PyTorch directory root
def scm_root() -> str:
path = os.path.abspath(os.getcwd())
# pyrefly: ignore # bad-assignment
while True:
if os.path.exists(os.path.join(path, ".git")):
return path
if os.path.isdir(os.path.join(path, ".hg")):
return path
# pyrefly: ignore # bad-argument-type
n = len(path)
path = os.path.dirname(path)
if len(path) == n:

View File

@ -101,6 +101,7 @@ def check_dictionary(filename: str) -> list[LintMessage]:
words_set = set(words)
if len(words) != len(words_set):
raise ValueError("The dictionary file contains duplicate entries.")
# pyrefly: ignore # no-matching-overload
uncased_words = list(map(str.lower, words))
if uncased_words != sorted(uncased_words):
raise ValueError(

View File

@ -12,6 +12,7 @@ from enum import Enum
from pathlib import Path
from typing import NamedTuple
# pyrefly: ignore # import-error
import isort
import usort

View File

@ -55,6 +55,7 @@ def report_download_progress(
Pretty printer for file download progress.
"""
if file_size != -1:
# pyrefly: ignore # no-matching-overload
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = "#" * int(64 * percent)
sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")

View File

@ -15,7 +15,10 @@ import multiprocessing as mp
from enum import Enum
from typing import NamedTuple
# pyrefly: ignore # import-error
import libcst as cst
# pyrefly: ignore # import-error
import libcst.matchers as m

View File

@ -69,6 +69,7 @@ def print_lint_message(path: Path, job: dict[str, Any], sync_tag: str) -> None:
lint_message = LintMessage(
path=str(path),
# pyrefly: ignore # unbound-name
line=line_number,
char=None,
code="WORKFLOWSYNC",

View File

@ -73,6 +73,7 @@ def get_selected_kernel_dtypes_code(
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
body_parts.append(
# pyrefly: ignore # bad-argument-type
if_condition_template.substitute(
kernel_tag_name=kernel_tag,
dtype_checks=" || ".join(conditions),

View File

@ -311,6 +311,7 @@ class Venv:
python=python,
capture_output=True,
).stdout
# pyrefly: ignore # no-matching-overload
candidates = list(map(Path, filter(None, map(str.strip, output.splitlines()))))
candidates = [p for p in candidates if p.is_dir() and p.name == "site-packages"]
if not candidates:
@ -480,6 +481,7 @@ class Venv:
cmd = [str(python), *args]
env = popen_kwargs.pop("env", None) or {}
check = popen_kwargs.pop("check", True)
# pyrefly: ignore # no-matching-overload
return subprocess.run(
cmd,
check=check,
@ -531,6 +533,7 @@ class Venv:
cmd = [str(self.bindir / "uv"), *args]
env = popen_kwargs.pop("env", None) or {}
check = popen_kwargs.pop("check", True)
# pyrefly: ignore # no-matching-overload
return subprocess.run(
cmd,
check=check,
@ -938,6 +941,7 @@ def _move_single(
def _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
for src in listing:
# pyrefly: ignore # bad-argument-type
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")

View File

@ -118,6 +118,7 @@ def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str:
urllib.request.urlopen(patch_url) as response,
open(patch_file, "wb") as out_file,
):
# pyrefly: ignore # bad-specialization
shutil.copyfileobj(response, out_file)
if not os.path.isfile(patch_file):
print(f"Failed to download patch for PR #{pr_number}")

View File

@ -994,6 +994,7 @@ def add_docstr_to_hint(docstr: str, hint: str) -> str:
hint = hint.removesuffix("...").rstrip() # remove "..."
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
# Remove trailing whitespace on each line
# pyrefly: ignore # no-matching-overload
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
# attribute or property

View File

@ -100,6 +100,7 @@ class CMake:
if ver is not None:
eprint(f"Found {cmd} ({command}) version: {ver}", end="")
cmake_versions.append(f"{cmd}=={ver}")
# pyrefly: ignore # unsupported-operation
if ver >= CMAKE_MINIMUM_VERSION:
eprint(f" (>={CMAKE_MINIMUM_VERSION})")
valid_cmake_versions[cmd] = ver

View File

@ -31,7 +31,9 @@ def gen_linker_script(
text_line_start = text_line_start[0]
# ensure that parent directory exists before writing
# pyrefly: ignore # bad-assignment
fout = Path(fout)
# pyrefly: ignore # missing-attribute
fout.parent.mkdir(parents=True, exist_ok=True)
with open(fout, "w") as f:

View File

@ -60,6 +60,7 @@ class SegmentGenerator:
df[time_col_name] = pd.to_datetime(df[time_col_name], unit="s", utc=True)
# get unique cmd names
# pyrefly: ignore # bad-argument-type
unique_cmds_df = pd.DataFrame(df[cmd_col_name].unique(), columns=[cmd_col_name])
# get all detected python cmds

View File

@ -7,6 +7,7 @@ import unittest
from collections import defaultdict
from unittest.mock import Mock, patch
# pyrefly: ignore # import-error
from gen_operators_yaml import (
fill_output,
get_parser_options,
@ -241,5 +242,6 @@ class GenOperatorsYAMLTest(unittest.TestCase):
fill_output(output, options)
# pyrefly: ignore # missing-attribute
for op_val in output["operators"].values():
self.assertFalse(op_val["include_all_overloads"])

View File

@ -88,6 +88,7 @@ operators:
self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
# pyrefly: ignore # bad-argument-type
["aten::add", "aten::add.int", "aten::mul.int"],
False,
False,
@ -103,6 +104,7 @@ operators:
)
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
# pyrefly: ignore # bad-argument-type
["aten::add", "aten::add.int", "aten::mul.int"],
True,
False,
@ -118,6 +120,7 @@ operators:
)
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
# pyrefly: ignore # bad-argument-type
["aten::add", "aten::add.int", "aten::mul.int"],
False,
True,

View File

@ -83,7 +83,9 @@ def _rank_correlated_tests(
) -> list[str]:
# Find the tests failures that are correlated with the edited files.
# Filter the list to only include tests we want to run.
# pyrefly: ignore # bad-assignment
tests_to_run = set(tests_to_run)
# pyrefly: ignore # bad-argument-type
ratings = _get_ratings_for_tests(tests_to_run)
prioritize = sorted(ratings, key=lambda x: -ratings[x])
return prioritize

View File

@ -36,11 +36,13 @@ def concated_logs() -> str:
for log_file in glob.glob(
f"{REPO_ROOT}/test/test-reports/**/*.log", recursive=True
):
# pyrefly: ignore # bad-argument-type
logs.append(f"=== {log_file} ===")
with open(log_file) as f:
# For every line, prefix with fake timestamp for log classifier
for line in f:
line = line.rstrip("\n") # Remove any trailing newline
# pyrefly: ignore # bad-argument-type
logs.append(f"2020-01-01T00:00:00.0000000Z {line}")
return "\n".join(logs)

View File

@ -1739,6 +1739,7 @@ class KernelArgs:
for outer, inner in chain(
# pyrefly: ignore # bad-argument-type
self.input_buffers.items(),
# pyrefly: ignore # bad-argument-type
self.output_buffers.items(),
):
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):

View File

@ -1480,6 +1480,7 @@ class CppGemmTemplate(CppTemplate):
gemm_output_buffer = ir.Buffer(
# pyrefly: ignore # missing-attribute
name=gemm_output_name,
# pyrefly: ignore # missing-attribute
layout=template_buffer.layout,
)
current_input_buffer = gemm_output_buffer
@ -1503,6 +1504,7 @@ class CppGemmTemplate(CppTemplate):
current_input_buffer = ir.Buffer(
# pyrefly: ignore # missing-attribute
name=buffer_name,
# pyrefly: ignore # missing-attribute
layout=template_buffer.layout,
)

View File

@ -824,6 +824,7 @@ class CppWrapperGpu(CppWrapperCpu):
call_args, arg_types = self.prepare_triton_wrapper_args(
# pyrefly: ignore # bad-argument-type
call_args,
# pyrefly: ignore # bad-argument-type
arg_types,
)
wrapper_name = f"call_{kernel_name}"

View File

@ -683,6 +683,7 @@ class MetalKernel(SIMDKernel):
# pyrefly: ignore # missing-argument
t
for t in self.range_tree_nodes.values()
# pyrefly: ignore # missing-argument
if t.is_reduction
)
cmp_op = ">" if reduction_type == "argmax" else "<"
@ -865,6 +866,7 @@ class MetalKernel(SIMDKernel):
# pyrefly: ignore # missing-argument
t.numel
for t in self.range_trees
# pyrefly: ignore # missing-argument
if t.is_reduction
)
# If using dynamic shapes, set the threadgroup size to be the

View File

@ -968,6 +968,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
# pyrefly: ignore # missing-argument
t
for t in self.range_trees
# pyrefly: ignore # missing-argument
if not t.is_reduction or self.inside_reduction
]

View File

@ -1004,6 +1004,7 @@ class FxConverter:
# pyrefly: ignore # missing-attribute
call_kwargs[key]
for key in signature
# pyrefly: ignore # missing-attribute
if key not in cfg.kwargs
]

View File

@ -421,6 +421,7 @@ def get_proxy_slot(
else:
# Attempt to build it from first principles.
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
# pyrefly: ignore # no-matching-overload
value = tracker.get(obj)
if value is None: