mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697 Approved by: https://github.com/ezyang
192 lines
6.3 KiB
Python
192 lines
6.3 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import sys
|
|
from abc import abstractmethod
|
|
from functools import cached_property
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
from typing_extensions import Never
|
|
|
|
from . import ParseError
|
|
from .argument_parser import ArgumentParser
|
|
from .messages import LintResult
|
|
from .python_file import PythonFile
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from argparse import Namespace
|
|
from collections.abc import Iterator, Sequence
|
|
|
|
|
|
class ErrorLines:
|
|
"""How many lines to display before and after an error"""
|
|
|
|
WINDOW = 5
|
|
BEFORE = 2
|
|
AFTER = WINDOW - BEFORE - 1
|
|
|
|
|
|
class FileLinter:
|
|
"""The base class that all token-based linters inherit from"""
|
|
|
|
description: str
|
|
linter_name: str
|
|
|
|
epilog: str | None = None
|
|
is_fixer: bool = True
|
|
report_column_numbers: bool = False
|
|
|
|
@abstractmethod
|
|
def _lint(self, python_file: PythonFile) -> Iterator[LintResult]:
|
|
raise NotImplementedError
|
|
|
|
def __init__(self, argv: Sequence[str] | None = None) -> None:
|
|
self.argv = argv
|
|
self.parser = ArgumentParser(
|
|
is_fixer=self.is_fixer,
|
|
description=self.description,
|
|
epilog=self.epilog,
|
|
)
|
|
self.result_shown = False
|
|
|
|
@classmethod
|
|
def run(cls) -> Never:
|
|
sys.exit(not cls().lint_all())
|
|
|
|
def lint_all(self) -> bool:
|
|
if self.args.fix and self.args.lintrunner:
|
|
raise ValueError("--fix and --lintrunner are incompatible")
|
|
|
|
success = True
|
|
for p in self.paths:
|
|
success = self._lint_file(p) and success
|
|
return self.args.lintrunner or success
|
|
|
|
@classmethod
|
|
def make_file(cls, pc: Path | str | None = None) -> PythonFile:
|
|
return PythonFile.make(cls.linter_name, pc)
|
|
|
|
@cached_property
|
|
def args(self) -> Namespace:
|
|
args = self.parser.parse_args(self.argv)
|
|
|
|
return args
|
|
|
|
@cached_property
|
|
def code(self) -> str:
|
|
return self.linter_name.upper()
|
|
|
|
@cached_property
|
|
def paths(self) -> list[Path]:
|
|
files = []
|
|
file_parts = (f for fp in self.args.files for f in fp.split(":"))
|
|
for f in file_parts:
|
|
if f.startswith("@"):
|
|
files.extend(Path(f[1:]).read_text().splitlines())
|
|
elif f != "--":
|
|
files.append(f)
|
|
return sorted(Path(f) for f in files)
|
|
|
|
def _lint_file(self, p: Path) -> bool:
|
|
if self.args.verbose:
|
|
print(p, "Reading", file=sys.stderr)
|
|
|
|
pf = self.make_file(p)
|
|
replacement, results = self._replace(pf)
|
|
|
|
if display := list(self._display(pf, results)):
|
|
print(*display, sep="\n")
|
|
if results and self.args.fix and pf.path and pf.contents != replacement:
|
|
pf.path.write_text(replacement)
|
|
|
|
return not results or self.args.fix and all(r.is_edit for r in results)
|
|
|
|
def _error(self, pf: PythonFile, result: LintResult) -> None:
|
|
"""Called on files that are unparsable"""
|
|
|
|
def _replace(self, pf: PythonFile) -> tuple[str, list[LintResult]]:
|
|
# Because of recursive replacements, we need to repeat replacing and reparsing
|
|
# from the inside out until all possible replacements are complete
|
|
previous_result_count = float("inf")
|
|
first_results = None
|
|
original = replacement = pf.contents
|
|
|
|
# pyrefly: ignore # bad-assignment
|
|
while True:
|
|
try:
|
|
results = sorted(self._lint(pf), key=LintResult.sort_key)
|
|
except IndentationError as e:
|
|
error, (_name, lineno, column, _line) = e.args
|
|
|
|
results = [LintResult(error, lineno, column)]
|
|
self._error(pf, *results)
|
|
|
|
except ParseError as e:
|
|
results = [LintResult(str(e), *e.token.start)]
|
|
self._error(pf, *results)
|
|
|
|
for i, ri in enumerate(results):
|
|
if not ri.is_recursive:
|
|
for rj in results[i + 1 :]:
|
|
if ri.contains(rj):
|
|
rj.is_recursive = True
|
|
else:
|
|
break
|
|
|
|
first_results = first_results or results
|
|
if not results or len(results) >= previous_result_count:
|
|
break
|
|
previous_result_count = len(results)
|
|
|
|
lines = pf.lines[:]
|
|
for r in reversed(results):
|
|
if r.is_edit and not r.is_recursive:
|
|
r.apply(lines)
|
|
replacement = "".join(lines)
|
|
|
|
if not any(r.is_recursive for r in results):
|
|
break
|
|
pf = pf.with_contents(replacement)
|
|
|
|
if first_results and self.args.lintrunner:
|
|
name = f"Suggested fixes for {self.linter_name}"
|
|
msg = LintResult(name=name, original=original, replacement=replacement)
|
|
first_results.append(msg)
|
|
|
|
return replacement, first_results
|
|
|
|
def _display(self, pf: PythonFile, results: list[LintResult]) -> Iterator[str]:
|
|
"""Emit a series of human-readable strings representing the results"""
|
|
for r in results:
|
|
if self.args.lintrunner:
|
|
msg = r.as_message(code=self.code, path=str(pf.path))
|
|
yield json.dumps(msg.asdict(), sort_keys=True)
|
|
else:
|
|
if self.result_shown:
|
|
yield ""
|
|
else:
|
|
self.result_shown = True
|
|
if r.line is None:
|
|
yield f"{pf.path}: {r.name}"
|
|
else:
|
|
yield from (i.rstrip() for i in self._display_window(pf, r))
|
|
|
|
def _display_window(self, pf: PythonFile, r: LintResult) -> Iterator[str]:
|
|
"""Display a window onto the code with an error"""
|
|
if r.char is None or not self.report_column_numbers:
|
|
yield f"{pf.path}:{r.line}: {r.name}"
|
|
else:
|
|
yield f"{pf.path}:{r.line}:{r.char + 1}: {r.name}"
|
|
|
|
begin = max((r.line or 0) - ErrorLines.BEFORE, 1)
|
|
end = min(begin + ErrorLines.WINDOW, 1 + len(pf.lines))
|
|
|
|
for lineno in range(begin, end):
|
|
source_line = pf.lines[lineno - 1].rstrip()
|
|
yield f"{lineno:5} | {source_line}"
|
|
if lineno == r.line:
|
|
spaces = 8 + (r.char or 0)
|
|
carets = len(source_line) if r.char is None else (r.length or 1)
|
|
yield spaces * " " + carets * "^"
|