mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enable ruff in lintrunner (#99785)
### This change - Implements the ruff linter in pytorch lintrunner. It is adapted from https://github.com/justinchuby/lintrunner-adapters/blob/main/lintrunner_adapters/adapters/ruff_linter.py. It does **both linting and fixing**. 🔧 - Migrated all flake8 configs to the ruff config and enabled it for the repo. ✅ - **`ruff` lints the whole repo in under 2s** 🤯 Fixes https://github.com/pytorch/pytorch/issues/94737 Replaces #99280 @huydhn @Skylion007 <!-- copilot:all --> ### <samp>🤖 Generated by Copilot at 6b982dd</samp> ### Summary 🧹🛠️🎨 <!-- 1. 🧹 This emoji represents cleaning or tidying up, which is what `ruff` does by formatting and linting the code. It also suggests improving the code quality and removing unnecessary or redundant code. 2. 🛠️ This emoji represents tools or fixing, which is what `ruff` is as a code formatter and linter. It also suggests enhancing the code functionality and performance, and resolving potential issues or bugs. 3. 🎨 This emoji represents art or creativity, which is what `ruff` allows by providing a consistent and configurable style for the code. It also suggests adding some flair or personality to the code, and making it more readable and enjoyable. --> Add `[tool.ruff]` section to `pyproject.toml` to configure `ruff` code formatter and linter. This change aims to improve code quality and consistency with a single tool. > _`ruff` cleans the code_ > _like a spring breeze in the fields_ > _`pyproject.toml`_ ### Walkthrough * Configure `ruff` code formatter and linter for the whole project ([link](https://github.com/pytorch/pytorch/pull/99785/files?diff=unified&w=0#diff-50c86b7ed8ac2cf95bd48334961bf0530cdc77b5a56f852c5c61b89d735fd711R22-R79)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99785 Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
dcd686f478
commit
7d2a18da0b
2
.flake8
2
.flake8
@ -1,4 +1,6 @@
|
||||
[flake8]
|
||||
# NOTE: **Mirror any changes** to this file the [tool.ruff] config in pyproject.toml
|
||||
# before we can fully move to use ruff
|
||||
enable-extensions = G
|
||||
select = B,C,E,F,G,P,SIM1,T4,W,B9
|
||||
max-line-length = 120
|
||||
|
@ -972,3 +972,29 @@ command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/lintrunner_version_linter.py'
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'RUFF'
|
||||
include_patterns = ['**/*.py']
|
||||
exclude_patterns = [
|
||||
'caffe2/**',
|
||||
'functorch/docs/**',
|
||||
'functorch/notebooks/**',
|
||||
'scripts/**',
|
||||
'third_party/**',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/ruff_linter.py',
|
||||
'--config=pyproject.toml',
|
||||
'--show-disable',
|
||||
'--',
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.0.262',
|
||||
]
|
||||
is_formatter = true
|
||||
|
@ -19,3 +19,63 @@ build-backend = "setuptools.build_meta:__legacy__"
|
||||
# Uncomment if pyproject.toml worked fine to ensure consistency with flake8
|
||||
# line-length = 120
|
||||
target-version = ["py38", "py39", "py310", "py311"]
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py38"
|
||||
|
||||
# NOTE: Synchoronize the ignores with .flake8
|
||||
ignore = [
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
"B007", "B008", "B017",
|
||||
"B018", # Useless expression
|
||||
"B019", "B020",
|
||||
"B022", # Allow empty context manager
|
||||
"B023", "B024", "B026",
|
||||
"B028", # No explicit `stacklevel` keyword argument found
|
||||
"B027", "B904", "B905",
|
||||
"E402",
|
||||
"C408", # C408 ignored because we like the dict keyword argument syntax
|
||||
"C419", # generators may not be supported by jit
|
||||
"E501", # E501 is not flexible enough, we're using B950 instead
|
||||
"E721",
|
||||
"E731", # Assign lambda expression
|
||||
"E741",
|
||||
"EXE001",
|
||||
"F405",
|
||||
"F821",
|
||||
"F841",
|
||||
# these ignores are from flake8-logging-format; please fix!
|
||||
"G101", "G201", "G202",
|
||||
"SIM102", "SIM103", "SIM112", # flake8-simplify code styles
|
||||
"SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason
|
||||
"SIM108",
|
||||
"SIM109",
|
||||
"SIM110",
|
||||
"SIM114", # Combine `if` branches using logical `or` operator
|
||||
"SIM115",
|
||||
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
|
||||
"SIM117",
|
||||
"SIM118",
|
||||
]
|
||||
line-length = 120
|
||||
select = [
|
||||
"B",
|
||||
"C4",
|
||||
"G",
|
||||
"E",
|
||||
"F",
|
||||
"SIM1",
|
||||
"W",
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["F401"]
|
||||
"torchgen/api/types/__init__.py" = [
|
||||
"F401",
|
||||
"F403",
|
||||
]
|
||||
"torchgen/executorch/api/types/__init__.py" = [
|
||||
"F401",
|
||||
"F403",
|
||||
]
|
||||
|
462
tools/linter/adapters/ruff_linter.py
Normal file
462
tools/linter/adapters/ruff_linter.py
Normal file
@ -0,0 +1,462 @@
|
||||
"""Adapter for https://github.com/charliermarsh/ruff."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, BinaryIO
|
||||
|
||||
LINTER_CODE = "RUFF"
|
||||
IS_WINDOWS: bool = os.name == "nt"
|
||||
|
||||
|
||||
def eprint(*args: Any, **kwargs: Any) -> None:
|
||||
"""Print to stderr."""
|
||||
print(*args, file=sys.stderr, flush=True, **kwargs)
|
||||
|
||||
|
||||
class LintSeverity(str, enum.Enum):
|
||||
"""Severity of a lint message."""
|
||||
|
||||
ERROR = "error"
|
||||
WARNING = "warning"
|
||||
ADVICE = "advice"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LintMessage:
|
||||
"""A lint message defined by https://docs.rs/lintrunner/latest/lintrunner/lint_message/struct.LintMessage.html."""
|
||||
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
def asdict(self) -> dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def display(self) -> None:
|
||||
"""Print to stdout for lintrunner to consume."""
|
||||
print(json.dumps(self.asdict()), flush=True)
|
||||
|
||||
|
||||
def as_posix(name: str) -> str:
|
||||
return name.replace("\\", "/") if IS_WINDOWS else name
|
||||
|
||||
|
||||
def _run_command(
|
||||
args: list[str],
|
||||
*,
|
||||
timeout: int | None,
|
||||
stdin: BinaryIO | None,
|
||||
input: bytes | None,
|
||||
check: bool,
|
||||
cwd: os.PathLike[Any] | None,
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
if input is not None:
|
||||
return subprocess.run(
|
||||
args,
|
||||
capture_output=True,
|
||||
shell=False,
|
||||
input=input,
|
||||
timeout=timeout,
|
||||
check=check,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
return subprocess.run(
|
||||
args,
|
||||
stdin=stdin,
|
||||
capture_output=True,
|
||||
shell=False,
|
||||
timeout=timeout,
|
||||
check=check,
|
||||
cwd=cwd,
|
||||
)
|
||||
finally:
|
||||
end_time = time.monotonic()
|
||||
logging.debug("took %dms", (end_time - start_time) * 1000)
|
||||
|
||||
|
||||
def run_command(
|
||||
args: list[str],
|
||||
*,
|
||||
retries: int = 0,
|
||||
timeout: int | None = None,
|
||||
stdin: BinaryIO | None = None,
|
||||
input: bytes | None = None,
|
||||
check: bool = False,
|
||||
cwd: os.PathLike[Any] | None = None,
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
remaining_retries = retries
|
||||
while True:
|
||||
try:
|
||||
return _run_command(
|
||||
args, timeout=timeout, stdin=stdin, input=input, check=check, cwd=cwd
|
||||
)
|
||||
except subprocess.TimeoutExpired as err:
|
||||
if remaining_retries == 0:
|
||||
raise err
|
||||
remaining_retries -= 1
|
||||
logging.warning(
|
||||
"(%s/%s) Retrying because command failed with: %r",
|
||||
retries - remaining_retries,
|
||||
retries,
|
||||
err,
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def add_default_options(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add default options to a parser.
|
||||
|
||||
This should be called the last in the chain of add_argument calls.
|
||||
"""
|
||||
parser.add_argument(
|
||||
"--retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="number of times to retry if the linter times out.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="verbose logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"filenames",
|
||||
nargs="+",
|
||||
help="paths to lint",
|
||||
)
|
||||
|
||||
|
||||
def explain_rule(code: str) -> str:
|
||||
proc = run_command(
|
||||
["ruff", "rule", "--format=json", code],
|
||||
check=True,
|
||||
)
|
||||
rule = json.loads(str(proc.stdout, "utf-8").strip())
|
||||
return f"\n{rule['linter']}: {rule['summary']}"
|
||||
|
||||
|
||||
def get_issue_severity(code: str) -> LintSeverity:
|
||||
# "B901": `return x` inside a generator
|
||||
# "B902": Invalid first argument to a method
|
||||
# "B903": __slots__ efficiency
|
||||
# "B950": Line too long
|
||||
# "C4": Flake8 Comprehensions
|
||||
# "C9": Cyclomatic complexity
|
||||
# "E2": PEP8 horizontal whitespace "errors"
|
||||
# "E3": PEP8 blank line "errors"
|
||||
# "E5": PEP8 line length "errors"
|
||||
# "T400": type checking Notes
|
||||
# "T49": internal type checker errors or unmatched messages
|
||||
if any(
|
||||
code.startswith(x)
|
||||
for x in (
|
||||
"B9",
|
||||
"C4",
|
||||
"C9",
|
||||
"E2",
|
||||
"E3",
|
||||
"E5",
|
||||
"T400",
|
||||
"T49",
|
||||
"PLC",
|
||||
"PLR",
|
||||
)
|
||||
):
|
||||
return LintSeverity.ADVICE
|
||||
|
||||
# "F821": Undefined name
|
||||
# "E999": syntax error
|
||||
if any(code.startswith(x) for x in ("F821", "E999", "PLE")):
|
||||
return LintSeverity.ERROR
|
||||
|
||||
# "F": PyFlakes Error
|
||||
# "B": flake8-bugbear Error
|
||||
# "E": PEP8 "Error"
|
||||
# "W": PEP8 Warning
|
||||
# possibly other plugins...
|
||||
return LintSeverity.WARNING
|
||||
|
||||
|
||||
def format_lint_message(
|
||||
message: str, code: str, rules: dict[str, str], show_disable: bool
|
||||
) -> str:
|
||||
if rules:
|
||||
message += f".\n{rules.get(code) or ''}"
|
||||
message += ".\nSee https://beta.ruff.rs/docs/rules/"
|
||||
if show_disable:
|
||||
message += f".\n\nTo disable, use ` # noqa: {code}`"
|
||||
return message
|
||||
|
||||
|
||||
def check_files(
|
||||
filenames: list[str],
|
||||
severities: dict[str, LintSeverity],
|
||||
*,
|
||||
config: str | None,
|
||||
retries: int,
|
||||
timeout: int,
|
||||
explain: bool,
|
||||
show_disable: bool,
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
proc = run_command(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"ruff",
|
||||
"--exit-zero",
|
||||
"--quiet",
|
||||
"--format=json",
|
||||
*([f"--config={config}"] if config else []),
|
||||
*filenames,
|
||||
],
|
||||
retries=retries,
|
||||
timeout=timeout,
|
||||
check=True,
|
||||
)
|
||||
except (OSError, subprocess.CalledProcessError) as err:
|
||||
return [
|
||||
LintMessage(
|
||||
path=None,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="command-failed",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"Failed due to {err.__class__.__name__}:\n{err}"
|
||||
if not isinstance(err, subprocess.CalledProcessError)
|
||||
else (
|
||||
f"COMMAND (exit code {err.returncode})\n"
|
||||
f"{' '.join(as_posix(x) for x in err.cmd)}\n\n"
|
||||
f"STDERR\n{err.stderr.decode('utf-8').strip() or '(empty)'}\n\n"
|
||||
f"STDOUT\n{err.stdout.decode('utf-8').strip() or '(empty)'}"
|
||||
)
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
stdout = str(proc.stdout, "utf-8").strip()
|
||||
vulnerabilities = json.loads(stdout)
|
||||
|
||||
if explain:
|
||||
all_codes = {v["code"] for v in vulnerabilities}
|
||||
rules = {code: explain_rule(code) for code in all_codes}
|
||||
else:
|
||||
rules = {}
|
||||
|
||||
return [
|
||||
LintMessage(
|
||||
path=vuln["filename"],
|
||||
name=vuln["code"],
|
||||
description=(
|
||||
format_lint_message(
|
||||
vuln["message"],
|
||||
vuln["code"],
|
||||
rules,
|
||||
show_disable,
|
||||
)
|
||||
),
|
||||
line=int(vuln["location"]["row"]),
|
||||
char=int(vuln["location"]["column"]),
|
||||
code=LINTER_CODE,
|
||||
severity=severities.get(vuln["code"], get_issue_severity(vuln["code"])),
|
||||
original=None,
|
||||
replacement=None,
|
||||
)
|
||||
for vuln in vulnerabilities
|
||||
]
|
||||
|
||||
|
||||
def check_file_for_fixes(
|
||||
filename: str,
|
||||
*,
|
||||
config: str | None,
|
||||
retries: int,
|
||||
timeout: int,
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
original = f.read()
|
||||
with open(filename, "rb") as f:
|
||||
proc_fix = run_command(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"ruff",
|
||||
"--fix-only",
|
||||
"--exit-zero",
|
||||
*([f"--config={config}"] if config else []),
|
||||
"--stdin-filename",
|
||||
filename,
|
||||
"-",
|
||||
],
|
||||
stdin=f,
|
||||
retries=retries,
|
||||
timeout=timeout,
|
||||
check=True,
|
||||
)
|
||||
except (OSError, subprocess.CalledProcessError) as err:
|
||||
return [
|
||||
LintMessage(
|
||||
path=None,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="command-failed",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"Failed due to {err.__class__.__name__}:\n{err}"
|
||||
if not isinstance(err, subprocess.CalledProcessError)
|
||||
else (
|
||||
f"COMMAND (exit code {err.returncode})\n"
|
||||
f"{' '.join(as_posix(x) for x in err.cmd)}\n\n"
|
||||
f"STDERR\n{err.stderr.decode('utf-8').strip() or '(empty)'}\n\n"
|
||||
f"STDOUT\n{err.stdout.decode('utf-8').strip() or '(empty)'}"
|
||||
)
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
replacement = proc_fix.stdout
|
||||
if original == replacement:
|
||||
return []
|
||||
|
||||
return [
|
||||
LintMessage(
|
||||
path=filename,
|
||||
name="format",
|
||||
description="Run `lintrunner -a` to apply this patch.",
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.WARNING,
|
||||
original=original.decode("utf-8"),
|
||||
replacement=replacement.decode("utf-8"),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=f"Ruff linter. Linter code: {LINTER_CODE}. Use with RUFF-FIX to auto-fix issues.",
|
||||
fromfile_prefix_chars="@",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=None,
|
||||
help="Path to the `pyproject.toml` or `ruff.toml` file to use for configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--explain",
|
||||
action="store_true",
|
||||
help="Explain a rule",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-disable",
|
||||
action="store_true",
|
||||
help="Show how to disable a lint message",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
default=90,
|
||||
type=int,
|
||||
help="Seconds to wait for ruff",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--severity",
|
||||
action="append",
|
||||
help="map code to severity (e.g. `F401:advice`). This option can be used multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-fix",
|
||||
action="store_true",
|
||||
help="Do not suggest fixes",
|
||||
)
|
||||
add_default_options(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
format="<%(threadName)s:%(levelname)s> %(message)s",
|
||||
level=logging.NOTSET
|
||||
if args.verbose
|
||||
else logging.DEBUG
|
||||
if len(args.filenames) < 1000
|
||||
else logging.INFO,
|
||||
stream=sys.stderr,
|
||||
)
|
||||
|
||||
severities: dict[str, LintSeverity] = {}
|
||||
if args.severity:
|
||||
for severity in args.severity:
|
||||
parts = severity.split(":", 1)
|
||||
assert len(parts) == 2, f"invalid severity `{severity}`"
|
||||
severities[parts[0]] = LintSeverity(parts[1])
|
||||
|
||||
lint_messages = check_files(
|
||||
args.filenames,
|
||||
severities=severities,
|
||||
config=args.config,
|
||||
retries=args.retries,
|
||||
timeout=args.timeout,
|
||||
explain=args.explain,
|
||||
show_disable=args.show_disable,
|
||||
)
|
||||
for lint_message in lint_messages:
|
||||
lint_message.display()
|
||||
|
||||
if args.no_fix or not lint_messages:
|
||||
# If we're not fixing, we can exit early
|
||||
return
|
||||
|
||||
files_with_lints = {lint.path for lint in lint_messages if lint.path is not None}
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=os.cpu_count(),
|
||||
thread_name_prefix="Thread",
|
||||
) as executor:
|
||||
futures = {
|
||||
executor.submit(
|
||||
check_file_for_fixes,
|
||||
path,
|
||||
config=args.config,
|
||||
retries=args.retries,
|
||||
timeout=args.timeout,
|
||||
): path
|
||||
for path in files_with_lints
|
||||
}
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
for lint_message in future.result():
|
||||
lint_message.display()
|
||||
except Exception: # Catch all exceptions for lintrunner
|
||||
logging.critical('Failed at "%s".', futures[future])
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
|
||||
from torchgen.selective_build.operator import *
|
||||
from torchgen.selective_build.operator import * # noqa: F403
|
||||
from torchgen.model import Location, NativeFunction
|
||||
from torchgen.selective_build.selector import (
|
||||
combine_selective_builders,
|
||||
|
@ -143,7 +143,7 @@ class Batch:
|
||||
self._values = value
|
||||
|
||||
def _setitem_by_slice(self, index: slice, value) -> None:
|
||||
if not (index.start is index.stop is index.step is None):
|
||||
if not (index.start is index.stop is index.step is None): # noqa: E714
|
||||
raise NotImplementedError("only slice [:] supported")
|
||||
|
||||
if not self.atomic:
|
||||
|
Reference in New Issue
Block a user