Files
pytorch/tools/linter/adapters/_linter/file_linter.py
Maggie Moss f02e3947f6 Expand type checking to mypy strict files (#165697)
Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697
Approved by: https://github.com/ezyang
2025-10-18 04:34:45 +00:00

192 lines
6.3 KiB
Python

from __future__ import annotations
import json
import sys
from abc import abstractmethod
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
from typing_extensions import Never
from . import ParseError
from .argument_parser import ArgumentParser
from .messages import LintResult
from .python_file import PythonFile
if TYPE_CHECKING:
from argparse import Namespace
from collections.abc import Iterator, Sequence
class ErrorLines:
"""How many lines to display before and after an error"""
WINDOW = 5
BEFORE = 2
AFTER = WINDOW - BEFORE - 1
class FileLinter:
"""The base class that all token-based linters inherit from"""
description: str
linter_name: str
epilog: str | None = None
is_fixer: bool = True
report_column_numbers: bool = False
@abstractmethod
def _lint(self, python_file: PythonFile) -> Iterator[LintResult]:
raise NotImplementedError
def __init__(self, argv: Sequence[str] | None = None) -> None:
self.argv = argv
self.parser = ArgumentParser(
is_fixer=self.is_fixer,
description=self.description,
epilog=self.epilog,
)
self.result_shown = False
@classmethod
def run(cls) -> Never:
sys.exit(not cls().lint_all())
def lint_all(self) -> bool:
if self.args.fix and self.args.lintrunner:
raise ValueError("--fix and --lintrunner are incompatible")
success = True
for p in self.paths:
success = self._lint_file(p) and success
return self.args.lintrunner or success
@classmethod
def make_file(cls, pc: Path | str | None = None) -> PythonFile:
return PythonFile.make(cls.linter_name, pc)
@cached_property
def args(self) -> Namespace:
args = self.parser.parse_args(self.argv)
return args
@cached_property
def code(self) -> str:
return self.linter_name.upper()
@cached_property
def paths(self) -> list[Path]:
files = []
file_parts = (f for fp in self.args.files for f in fp.split(":"))
for f in file_parts:
if f.startswith("@"):
files.extend(Path(f[1:]).read_text().splitlines())
elif f != "--":
files.append(f)
return sorted(Path(f) for f in files)
def _lint_file(self, p: Path) -> bool:
if self.args.verbose:
print(p, "Reading", file=sys.stderr)
pf = self.make_file(p)
replacement, results = self._replace(pf)
if display := list(self._display(pf, results)):
print(*display, sep="\n")
if results and self.args.fix and pf.path and pf.contents != replacement:
pf.path.write_text(replacement)
return not results or self.args.fix and all(r.is_edit for r in results)
def _error(self, pf: PythonFile, result: LintResult) -> None:
"""Called on files that are unparsable"""
def _replace(self, pf: PythonFile) -> tuple[str, list[LintResult]]:
# Because of recursive replacements, we need to repeat replacing and reparsing
# from the inside out until all possible replacements are complete
previous_result_count = float("inf")
first_results = None
original = replacement = pf.contents
# pyrefly: ignore # bad-assignment
while True:
try:
results = sorted(self._lint(pf), key=LintResult.sort_key)
except IndentationError as e:
error, (_name, lineno, column, _line) = e.args
results = [LintResult(error, lineno, column)]
self._error(pf, *results)
except ParseError as e:
results = [LintResult(str(e), *e.token.start)]
self._error(pf, *results)
for i, ri in enumerate(results):
if not ri.is_recursive:
for rj in results[i + 1 :]:
if ri.contains(rj):
rj.is_recursive = True
else:
break
first_results = first_results or results
if not results or len(results) >= previous_result_count:
break
previous_result_count = len(results)
lines = pf.lines[:]
for r in reversed(results):
if r.is_edit and not r.is_recursive:
r.apply(lines)
replacement = "".join(lines)
if not any(r.is_recursive for r in results):
break
pf = pf.with_contents(replacement)
if first_results and self.args.lintrunner:
name = f"Suggested fixes for {self.linter_name}"
msg = LintResult(name=name, original=original, replacement=replacement)
first_results.append(msg)
return replacement, first_results
def _display(self, pf: PythonFile, results: list[LintResult]) -> Iterator[str]:
"""Emit a series of human-readable strings representing the results"""
for r in results:
if self.args.lintrunner:
msg = r.as_message(code=self.code, path=str(pf.path))
yield json.dumps(msg.asdict(), sort_keys=True)
else:
if self.result_shown:
yield ""
else:
self.result_shown = True
if r.line is None:
yield f"{pf.path}: {r.name}"
else:
yield from (i.rstrip() for i in self._display_window(pf, r))
def _display_window(self, pf: PythonFile, r: LintResult) -> Iterator[str]:
"""Display a window onto the code with an error"""
if r.char is None or not self.report_column_numbers:
yield f"{pf.path}:{r.line}: {r.name}"
else:
yield f"{pf.path}:{r.line}:{r.char + 1}: {r.name}"
begin = max((r.line or 0) - ErrorLines.BEFORE, 1)
end = min(begin + ErrorLines.WINDOW, 1 + len(pf.lines))
for lineno in range(begin, end):
source_line = pf.lines[lineno - 1].rstrip()
yield f"{lineno:5} | {source_line}"
if lineno == r.line:
spaces = 8 + (r.char or 0)
carets = len(source_line) if r.char is None else (r.length or 1)
yield spaces * " " + carets * "^"