[ci] delete old linter stuff

lintrunner has been running for a while, so delete redundant linter
things

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76984

Approved by: https://github.com/janeyx99
This commit is contained in:
Michael Suo
2022-05-11 00:18:46 -07:00
committed by PyTorch MergeBot
parent 533b44a280
commit 6cbe9d1f58
28 changed files with 2 additions and 3629 deletions

View File

@ -22,38 +22,12 @@ linecount:
echo "Cloc is not available on the machine. You can install cloc with " && \
echo " sudo apt-get install cloc"
shellcheck:
@$(PYTHON) tools/actions_local_runner.py \
--file .github/workflows/lint.yml \
--job 'workflow-checks' \
--step "Regenerate workflows"
@$(PYTHON) tools/actions_local_runner.py \
--file .github/workflows/lint.yml \
--job 'workflow-checks' \
--step "Assert that regenerating the workflows didn't change them"
setup_lint:
$(PIP) install lintrunner
lintrunner init
$(PYTHON) -mpip install jinja2 --user
quick_checks:
# TODO: This is broken when 'git config submodule.recurse' is 'true' since the
# lints will descend into third_party submodules
@$(PYTHON) tools/actions_local_runner.py \
--file .github/workflows/lint.yml \
--job 'quick-checks' \
--step 'Ensure no versionless Python shebangs'
toc:
@$(PYTHON) tools/actions_local_runner.py \
--file .github/workflows/lint.yml \
--job 'toc' \
--step "Regenerate ToCs and check that they didn't change"
lint: quick_checks shellcheck
lint:
lintrunner
quicklint: CHANGED_ONLY=--changed-only
quicklint: quick_checks shellcheck
quicklint:
lintrunner

View File

@ -33,41 +33,14 @@ Build system pieces:
Developer tools which you might find useful:
* [linter/clang_tidy](linter/clang_tidy/__main__.py) - Script for running clang-tidy
on lines of your script which you changed.
* [extract_scripts.py](extract_scripts.py) - Extract scripts from
`.github/workflows/*.yml` into a specified dir, on which linters such as
[linter/run_shellcheck.sh](linter/run_shellcheck.sh) can be run. Assumes that every `run`
script has `shell: bash` unless a different shell is explicitly listed on that
specific step (so `defaults` doesn't currently work), but also has some rules
for other situations such as [actions/github-script][]. Exits with nonzero
status if any of the extracted scripts contain [GitHub Actions expressions][]:
`${{<expression> }}`
* [git_add_generated_dirs.sh](git_add_generated_dirs.sh) and
[git_reset_generated_dirs.sh](git_reset_generated_dirs.sh) -
Use this to force add generated files to your Git index, so that you
can conveniently run diffs on them when working on code-generation.
(See also [generated_dirs.txt](generated_dirs.txt) which
specifies the list of directories with generated files.)
* [linter/mypy_wrapper.py](linter/mypy_wrapper.py) - Run `mypy` on a single file using the
appropriate subset of our `mypy*.ini` configs.
* [linter/run_shellcheck.sh](linter/run_shellcheck.sh) - Find `*.sh` files (recursively) in
the directories specified as arguments, and run [ShellCheck][] on all of them.
* [stats/test_history.py](stats/test_history.py) - Query S3 to display history of a single
test across multiple jobs over time.
* [linter/trailing_newlines.py](linter/trailing_newlines.py) - Take names of UTF-8 files from
stdin, print names of nonempty files whose contents don't end in exactly one
trailing newline, exit with status 1 if no output printed or 0 if some
filenames were printed.
* [linter/translate_annotations.py](linter/translate_annotations.py) - Read [Flake8][] or
[clang-tidy][] warnings (according to a `--regex`) from a `--file`, convert to
the JSON format accepted by [pytorch/add-annotations-github-action], and
translate line numbers from `HEAD` back in time to the given `--commit` by
running `git diff-index --unified=0` appropriately.
* [vscode_settings.py](vscode_settings.py) - Merge
`.vscode/settings_recommended.json` into your workspace-local
`.vscode/settings.json`, preferring the former in case of conflicts but
otherwise preserving the latter as much as possible.
Important if you want to run on AMD GPU:

View File

@ -1,445 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import subprocess
import sys
import os
import argparse
import yaml
import asyncio
import shutil
import re
import fnmatch
import shlex
import configparser
from typing import List, Dict, Any, Optional, Union, NamedTuple, Set
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class col:
HEADER = "\033[95m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
RESET = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def should_color() -> bool:
return hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
def color(the_color: str, text: str) -> str:
if should_color():
return col.BOLD + the_color + str(text) + col.RESET
else:
return text
def cprint(the_color: str, text: str) -> None:
if should_color():
print(color(the_color, text))
else:
print(text)
def git(args: List[str]) -> List[str]:
p = subprocess.run(
["git"] + args,
cwd=REPO_ROOT,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
)
lines = p.stdout.decode().strip().split("\n")
return [line.strip() for line in lines]
def find_changed_files(ref_branch: str = "origin/master") -> List[str]:
untracked = []
for line in git(["status", "--porcelain"]):
# Untracked files start with ??, so grab all of those
if line.startswith("?? "):
untracked.append(line.replace("?? ", ""))
# Modified, unstaged
modified = git(["diff", "--name-only"])
# Modified, staged
cached = git(["diff", "--cached", "--name-only"])
# Committed
merge_base = git(["merge-base", ref_branch, "HEAD"])[0]
diff_with_origin = git(["diff", "--name-only", merge_base, "HEAD"])
# De-duplicate
all_files = set()
for x in untracked + cached + modified + diff_with_origin:
stripped = x.strip()
if stripped != "" and os.path.exists(stripped):
all_files.add(stripped)
return list(all_files)
def print_results(job_name: str, passed: bool, streams: List[str]) -> None:
icon = color(col.GREEN, "") if passed else color(col.RED, "x")
print(f"{icon} {color(col.BLUE, job_name)}")
for stream in streams:
stream = stream.strip()
if stream != "":
print(stream)
class CommandResult(NamedTuple):
passed: bool
stdout: str
stderr: str
async def shell_cmd(
cmd: Union[str, List[str]],
env: Optional[Dict[str, Any]] = None,
redirect: bool = True,
) -> CommandResult:
if isinstance(cmd, list):
cmd_str = " ".join(shlex.quote(arg) for arg in cmd)
else:
cmd_str = cmd
proc = await asyncio.create_subprocess_shell(
cmd_str,
shell=True,
cwd=REPO_ROOT,
env=env,
stdout=subprocess.PIPE if redirect else None,
stderr=subprocess.PIPE if redirect else None,
executable=shutil.which("bash"),
)
stdout, stderr = await proc.communicate()
passed = proc.returncode == 0
if not redirect:
return CommandResult(passed, "", "")
return CommandResult(passed, stdout.decode().strip(), stderr.decode().strip())
class Check:
name: str
def __init__(self, files: Optional[List[str]], quiet: bool):
self.quiet = quiet
self.files = files
async def run(self) -> bool:
result = await self.run_helper()
if result is None:
return True
streams = []
if not result.passed:
streams = [
result.stderr,
result.stdout,
]
print_results(self.name, result.passed, streams)
return result.passed
async def run_helper(self) -> Optional[CommandResult]:
if self.files is not None:
relevant_files = self.filter_files(self.files)
if len(relevant_files) == 0:
# No files, do nothing
return CommandResult(passed=True, stdout="", stderr="")
return await self.quick(relevant_files)
return await self.full()
def filter_ext(self, files: List[str], extensions: Set[str]) -> List[str]:
def passes(filename: str) -> bool:
return os.path.splitext(filename)[1] in extensions
return [f for f in files if passes(f)]
def filter_files(self, files: List[str]) -> List[str]:
return files
async def quick(self, files: List[str]) -> CommandResult:
raise NotImplementedError
async def full(self) -> Optional[CommandResult]:
raise NotImplementedError
class Flake8(Check):
name = "flake8"
def filter_files(self, files: List[str]) -> List[str]:
config = configparser.ConfigParser()
config.read(os.path.join(REPO_ROOT, ".flake8"))
excludes = re.split(r",\s*", config["flake8"]["exclude"].strip())
excludes = [e.strip() for e in excludes if e.strip() != ""]
def should_include(name: str) -> bool:
for exclude in excludes:
if fnmatch.fnmatch(name, pat=exclude):
return False
if name.startswith(exclude) or f"./{name}".startswith(exclude):
return False
return True
files = self.filter_ext(files, {".py"})
return [f for f in files if should_include(f)]
async def quick(self, files: List[str]) -> CommandResult:
return await shell_cmd(["flake8"] + files)
async def full(self) -> CommandResult:
return await shell_cmd(["flake8"])
class Mypy(Check):
name = "mypy (skipped typestub generation)"
def filter_files(self, files: List[str]) -> List[str]:
return self.filter_ext(files, {".py", ".pyi"})
def env(self) -> Dict[str, Any]:
env = os.environ.copy()
if should_color():
# Secret env variable: https://github.com/python/mypy/issues/7771
env["MYPY_FORCE_COLOR"] = "1"
return env
async def quick(self, files: List[str]) -> CommandResult:
return await shell_cmd(
[sys.executable, "tools/linter/mypy_wrapper.py"]
+ [os.path.join(REPO_ROOT, f) for f in files],
env=self.env(),
)
async def full(self) -> None:
env = self.env()
# hackily change the name
self.name = "mypy"
await shell_cmd(
[
sys.executable,
"tools/actions_local_runner.py",
"--job",
"mypy",
"--file",
".github/workflows/lint.yml",
"--step",
"Run autogen",
],
redirect=False,
env=env,
)
await shell_cmd(
[
sys.executable,
"tools/actions_local_runner.py",
"--job",
"mypy",
"--file",
".github/workflows/lint.yml",
"--step",
"Run mypy",
],
redirect=False,
env=env,
)
class ShellCheck(Check):
name = "shellcheck: Run ShellCheck"
def filter_files(self, files: List[str]) -> List[str]:
return self.filter_ext(files, {".sh"})
async def quick(self, files: List[str]) -> CommandResult:
return await shell_cmd(
["tools/linter/run_shellcheck.sh"]
+ [os.path.join(REPO_ROOT, f) for f in files],
)
async def full(self) -> None:
await shell_cmd(
[
sys.executable,
"tools/actions_local_runner.py",
"--job",
"shellcheck",
"--file",
".github/workflows/lint.yml",
"--step",
"Run ShellCheck",
],
redirect=False,
)
class ClangTidy(Check):
name = "clang-tidy: Run clang-tidy"
common_options = [
"--clang-tidy-exe",
".clang-tidy-bin/clang-tidy",
]
def filter_files(self, files: List[str]) -> List[str]:
return self.filter_ext(files, {".c", ".cc", ".cpp"})
async def quick(self, files: List[str]) -> CommandResult:
return await shell_cmd(
[sys.executable, "-m", "tools.linter.clang_tidy", "--paths"]
+ [os.path.join(REPO_ROOT, f) for f in files]
+ self.common_options,
)
async def full(self) -> None:
await shell_cmd(
[sys.executable, "-m", "tools.linter.clang_tidy"] + self.common_options,
redirect=False,
)
class YamlStep(Check):
def __init__(self, step: Dict[str, Any], job_name: str, quiet: bool):
super().__init__(files=None, quiet=quiet)
self.step = step
self.name = f'{job_name}: {self.step["name"]}'
async def full(self) -> CommandResult:
env = os.environ.copy()
env["GITHUB_WORKSPACE"] = "/tmp"
script = self.step["run"]
if self.quiet:
# TODO: Either lint that GHA scripts only use 'set -eux' or make this more
# resilient
script = script.replace("set -eux", "set -eu")
script = re.sub(r"^time ", "", script, flags=re.MULTILINE)
return await shell_cmd(script, env=env)
def changed_files(ref_branch: str = "origin/master") -> Optional[List[str]]:
changed_files: Optional[List[str]] = None
try:
changed_files = sorted(find_changed_files(ref_branch))
except Exception:
# If the git commands failed for some reason, bail out and use the whole list
print(
"Could not query git for changed files, falling back to testing all files instead",
file=sys.stderr,
)
return None
return changed_files
def grab_specific_steps(
steps_to_grab: List[str], job: Dict[str, Any]
) -> List[Dict[str, Any]]:
relevant_steps = []
for step in steps_to_grab:
for actual_step in job["steps"]:
if actual_step["name"].lower().strip() == step.lower().strip():
relevant_steps.append(actual_step)
break
if len(relevant_steps) != len(steps_to_grab):
raise RuntimeError(f"Missing steps:\n{relevant_steps}\n{steps_to_grab}")
return relevant_steps
def main() -> None:
parser = argparse.ArgumentParser(
description="Pull shell scripts out of GitHub actions and run them"
)
parser.add_argument("--file", help="YAML file with actions")
parser.add_argument(
"--changed-only",
help="only run on changed files",
action="store_true",
default=False,
)
parser.add_argument("--job", help="job name", required=True)
parser.add_argument(
"--no-quiet", help="output commands", action="store_true", default=False
)
parser.add_argument("--step", action="append", help="steps to run (in order)")
parser.add_argument(
"--ref_branch",
default="origin/master",
help="remote/branch used during comparison for --changed-only (default=origin/master",
)
args = parser.parse_args()
quiet = not args.no_quiet
if args.file is None:
# If there is no .yml file provided, fall back to the list of known
# jobs. We use this for flake8 and mypy since they run different
# locally than in CI due to 'make quicklint'
if args.job not in ad_hoc_steps:
raise RuntimeError(
f"Job {args.job} not found and no .yml file was provided"
)
files = None
if args.changed_only:
files = changed_files(args.ref_branch)
checks = [ad_hoc_steps[args.job](files, quiet)]
else:
if args.step is None:
raise RuntimeError("1+ --steps must be provided")
action = yaml.safe_load(open(args.file, "r"))
if "jobs" not in action:
raise RuntimeError(f"top level key 'jobs' not found in {args.file}")
jobs = action["jobs"]
if args.job not in jobs:
raise RuntimeError(f"job '{args.job}' not found in {args.file}")
job = jobs[args.job]
# Pull the relevant sections out of the provided .yml file and run them
relevant_steps = grab_specific_steps(args.step, job)
checks = [
YamlStep(step=step, job_name=args.job, quiet=quiet)
for step in relevant_steps
]
loop = asyncio.get_event_loop()
loop.run_until_complete(asyncio.gather(*[check.run() for check in checks]))
# These are run differently locally in order to enable quicklint, so dispatch
# out to special handlers instead of using lint.yml
ad_hoc_steps = {
"mypy": Mypy,
"flake8-py3": Flake8,
"shellcheck": ShellCheck,
"clang-tidy": ClangTidy,
}
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass

View File

@ -1 +0,0 @@
e1c8b97b919541a99e0a355df5c3f9e8abebc64259dbee6f8c68e1ef90582856

View File

@ -1 +0,0 @@
1485a242a96c737ba7cdd9f259114f2201accdb46d87ac7a8650b1a814cd4d4d

View File

@ -1,10 +0,0 @@
#!/bin/bash
set -e
echo "Running pre-commit clang-tidy"
git diff HEAD > pr.diff
python3 -m tools.linter.clang_tidy --diff-file "pr.diff"
rm pr.diff
echo "Running pre-commit clang-format"
tools/linter/git-clang-format HEAD~ --force

View File

@ -1,195 +0,0 @@
#!/usr/bin/env python3
"""
A script that runs clang-format on all C/C++ files in CLANG_FORMAT_ALLOWLIST. There is
also a diff mode which simply checks if clang-format would make any changes, which is useful for
CI purposes.
If clang-format is not available, the script also downloads a platform-appropriate binary from
and S3 bucket and verifies it against a precommited set of blessed binary hashes.
"""
import argparse
import asyncio
import re
import os
import sys
from typing import List, Set
from .clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH
# Allowlist of directories to check. All files that in that directory
# (recursively) will be checked.
# If you edit this, please edit the allowlist in clang_format_ci.sh as well.
CLANG_FORMAT_ALLOWLIST = [
"c10/",
"ios/",
"torch/csrc/jit/",
"torch/csrc/deploy/",
"test/cpp/jit/",
"test/cpp/tensorexpr/",
]
CLANG_FORMAT_BLOCK_LIST = {
"torch/csrc/jit/serialization/mobile_bytecode_generated.h",
}
# Only files with names matching this regex will be formatted.
CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp|m|mm)$")
def get_allowlisted_files() -> Set[str]:
"""
Parse CLANG_FORMAT_ALLOWLIST and resolve all directories.
Returns the set of allowlist cpp source files.
"""
matches = []
for dir in CLANG_FORMAT_ALLOWLIST:
for root, dirnames, filenames in os.walk(dir):
for filename in filenames:
fullpath = os.path.join(root, filename)
if fullpath in CLANG_FORMAT_BLOCK_LIST:
continue
if CPP_FILE_REGEX.match(filename):
matches.append(os.path.join(root, filename))
return set(matches)
async def run_clang_format_on_file(
filename: str,
semaphore: asyncio.Semaphore,
verbose: bool = False,
) -> None:
"""
Run clang-format on the provided file.
"""
# -style=file picks up the closest .clang-format, -i formats the files inplace.
cmd = "{} -style=file -i {}".format(CLANG_FORMAT_PATH, filename)
async with semaphore:
proc = await asyncio.create_subprocess_shell(cmd)
_ = await proc.wait()
if verbose:
print("Formatted {}".format(filename))
async def file_clang_formatted_correctly(
filename: str,
semaphore: asyncio.Semaphore,
verbose: bool = False,
) -> bool:
"""
Checks if a file is formatted correctly and returns True if so.
"""
ok = True
# -style=file picks up the closest .clang-format
cmd = "{} -style=file {}".format(CLANG_FORMAT_PATH, filename)
async with semaphore:
proc = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE
)
# Read back the formatted file.
stdout, _ = await proc.communicate()
formatted_contents = stdout.decode()
# Compare the formatted file to the original file.
with open(filename) as orig:
orig_contents = orig.read()
if formatted_contents != orig_contents:
ok = False
if verbose:
print("{} is not formatted correctly".format(filename))
return ok
async def run_clang_format(
max_processes: int,
diff: bool = False,
verbose: bool = False,
) -> bool:
"""
Run clang-format to all files in CLANG_FORMAT_ALLOWLIST that match CPP_FILE_REGEX.
"""
# Check to make sure the clang-format binary exists.
if not os.path.exists(CLANG_FORMAT_PATH):
print("clang-format binary not found")
return False
# Gather command-line options for clang-format.
args = [CLANG_FORMAT_PATH, "-style=file"]
if not diff:
args.append("-i")
ok = True
# Semaphore to bound the number of subprocesses that can be created at once to format files.
semaphore = asyncio.Semaphore(max_processes)
# Format files in parallel.
if diff:
for f in asyncio.as_completed(
[
file_clang_formatted_correctly(f, semaphore, verbose)
for f in get_allowlisted_files()
]
):
ok &= await f
if ok:
print("All files formatted correctly")
else:
print("Some files not formatted correctly")
else:
await asyncio.gather(
*[
run_clang_format_on_file(f, semaphore, verbose)
for f in get_allowlisted_files()
]
)
return ok
def parse_args(args: List[str]) -> argparse.Namespace:
"""
Parse and return command-line arguments.
"""
parser = argparse.ArgumentParser(
description="Execute clang-format on your working copy changes."
)
parser.add_argument(
"-d",
"--diff",
action="store_true",
default=False,
help="Determine whether running clang-format would produce changes",
)
parser.add_argument("--verbose", "-v", action="store_true", default=False)
parser.add_argument(
"--max-processes",
type=int,
default=50,
help="Maximum number of subprocesses to create to format files in parallel",
)
return parser.parse_args(args)
def main(args: List[str]) -> bool:
# Parse arguments.
options = parse_args(args)
# Get clang-format and make sure it is the right binary and it is in the right place.
ok = get_and_check_clang_format(options.verbose)
# Invoke clang-format on all files in the directories in the allowlist.
if ok:
loop = asyncio.get_event_loop()
ok = loop.run_until_complete(
run_clang_format(options.max_processes, options.diff, options.verbose)
)
# We have to invert because False -> 0, which is the code to be returned if everything is okay.
return not ok
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

View File

@ -1,15 +0,0 @@
#!/bin/sh
set -eux
# Runs clang-format on allowlisted files.
# Requires a single argument, which is the <commit> argument to git-clang-format
# If you edit this allowlist, please edit the one in clang_format_all.py as well
find . -type f \
-path './c10/*' -or \
-path './ios/*' -or \
-path './torch/csrc/jit/!(serialization/mobile_bytecode_generated.h)' -or \
-path './torch/csrc/deploy/*' -or \
-path './test/cpp/jit/*' -or \
-path './test/cpp/tensorexpr/*' \
| xargs tools/linter/git-clang-format --verbose "$1" --

View File

@ -1,25 +0,0 @@
import os
from install.download_bin import download, PYTORCH_ROOT # type: ignore[import]
# This dictionary maps each platform to the S3 object URL for its clang-format binary.
PLATFORM_TO_CF_URL = {
"Darwin": "https://oss-clang-format.s3.us-east-2.amazonaws.com/mac/clang-format-mojave",
"Linux": "https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64",
}
# This dictionary maps each platform to a relative path to a file containing its reference hash.
PLATFORM_TO_HASH = {
"Darwin": os.path.join("tools", "clang_format_hash", "mac", "clang-format-mojave"),
"Linux": os.path.join(
"tools", "clang_format_hash", "linux64", "clang-format-linux64"
),
}
CLANG_FORMAT_DIR = os.path.join(PYTORCH_ROOT, ".clang-format-bin")
CLANG_FORMAT_PATH = os.path.join(CLANG_FORMAT_DIR, "clang-format")
def get_and_check_clang_format(verbose: bool = False) -> bool:
return bool(
download("clang-format", CLANG_FORMAT_DIR, PLATFORM_TO_CF_URL, PLATFORM_TO_HASH)
)

View File

@ -1,223 +0,0 @@
import argparse
import pathlib
import os
import shutil
import subprocess
import re
import sys
from sysconfig import get_paths as gp
from typing import List
from tools.linter.clang_tidy.run import run
from tools.linter.clang_tidy.generate_build_files import generate_build_files
from tools.linter.install.clang_tidy import INSTALLATION_PATH
from tools.linter.install.download_bin import PYTORCH_ROOT
# Returns '/usr/local/include/python<version number>'
def get_python_include_dir() -> str:
return gp()["include"]
def clang_search_dirs() -> List[str]:
# Compilers are ordered based on fallback preference
# We pick the first one that is available on the system
compilers = ["clang", "gcc", "cpp", "cc"]
compilers = [c for c in compilers if shutil.which(c) is not None]
if len(compilers) == 0:
raise RuntimeError(f"None of {compilers} were found")
compiler = compilers[0]
result = subprocess.run(
[compiler, "-E", "-x", "c++", "-", "-v"],
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
)
stderr = result.stderr.decode().strip().split("\n")
search_start = r"#include.*search starts here:"
search_end = r"End of search list."
append_path = False
search_paths = []
for line in stderr:
if re.match(search_start, line):
if append_path:
continue
else:
append_path = True
elif re.match(search_end, line):
break
elif append_path:
search_paths.append(line.strip())
# There are source files include <torch/cuda.h>, <torch/torch.h> etc.
# under torch/csrc/api/include folder. Since torch/csrc/api/include is not
# a search path for clang-tidy, there will be clang-disagnostic errors
# complaing those header files not found. Change the source code to include
# full path like torch/csrc/api/include/torch/torch.h does not work well
# since torch/torch.h includes torch/all.h which inturn includes more.
# We would need recursively change mutliple files.
# Adding the include path to the lint script should be a better solution.
search_paths.append(
os.path.join(PYTORCH_ROOT, "torch/csrc/api/include"),
)
return search_paths
DEFAULTS = {
"glob": [
# The negative filters below are to exclude files that include onnx_pb.h or
# caffe2_pb.h, otherwise we'd have to build protos as part of this CI job.
# FunctionsManual.cpp is excluded to keep this diff clean. It will be fixed
# in a follow up PR.
# /torch/csrc/generic/*.cpp is excluded because those files aren't actually built.
# deploy/interpreter files are excluded due to using macros and other techniquies
# that are not easily converted to accepted c++
"-torch/csrc/jit/passes/onnx/helper.cpp",
"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp",
"-torch/csrc/jit/serialization/onnx.cpp",
"-torch/csrc/jit/serialization/export.cpp",
"-torch/csrc/jit/serialization/import.cpp",
"-torch/csrc/jit/serialization/import_legacy.cpp",
"-torch/csrc/jit/serialization/mobile_bytecode_generated.cpp",
"-torch/csrc/init_flatbuffer_module.cpp",
"-torch/csrc/stub_with_flatbuffer.c",
"-torch/csrc/onnx/init.cpp",
"-torch/csrc/cuda/nccl.*",
"-torch/csrc/cuda/python_nccl.cpp",
"-torch/csrc/autograd/FunctionsManual.cpp",
"-torch/csrc/generic/*.cpp",
"-torch/csrc/jit/codegen/cuda/runtime/*",
"-torch/csrc/deploy/interactive_embedded_interpreter.cpp",
"-torch/csrc/deploy/interpreter/interpreter.cpp",
"-torch/csrc/deploy/interpreter/interpreter.h",
"-torch/csrc/deploy/interpreter/interpreter_impl.h",
"-torch/csrc/deploy/interpreter/test_main.cpp",
"-torch/csrc/deploy/test_deploy_python_ext.cpp",
],
"paths": ["torch/csrc/"],
"include-dir": [
"/usr/lib/llvm-11/include/openmp",
get_python_include_dir(),
os.path.join(PYTORCH_ROOT, "third_party/pybind11/include"),
]
+ clang_search_dirs(),
"clang-tidy-exe": INSTALLATION_PATH,
"compile-commands-dir": "build",
"config-file": ".clang-tidy",
"disable-progress-bar": False,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="clang-tidy wrapper script")
parser.add_argument(
"-e",
"--clang-tidy-exe",
default=DEFAULTS["clang-tidy-exe"],
help="Path to clang-tidy executable",
)
parser.add_argument(
"-g",
"--glob",
action="append",
default=DEFAULTS["glob"],
help="Only lint files that match these glob patterns "
"(see documentation for `fnmatch` for supported syntax)."
"If a pattern starts with a - the search is negated for that pattern.",
)
parser.add_argument(
"-x",
"--regex",
action="append",
default=[],
help="Only lint files that match these regular expressions (from the start of the filename). "
"If a pattern starts with a - the search is negated for that pattern.",
)
parser.add_argument(
"-c",
"--compile-commands-dir",
default=DEFAULTS["compile-commands-dir"],
help="Path to the folder containing compile_commands.json",
)
parser.add_argument(
"--diff-file",
help="File containing diff to use for determining files to lint and line filters",
)
parser.add_argument(
"-p",
"--paths",
nargs="+",
default=DEFAULTS["paths"],
help="Lint only the given paths (recursively)",
)
parser.add_argument(
"-n",
"--dry-run",
action="store_true",
help="Only show the command to be executed, without running it",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument("-q", "--quiet", action="store_true", help="Don't print output")
parser.add_argument(
"--config-file",
default=DEFAULTS["config-file"],
help="Path to a clang-tidy config file. Defaults to '.clang-tidy'.",
)
parser.add_argument(
"--print-include-paths",
action="store_true",
help="Print the search paths used for include directives",
)
parser.add_argument(
"-I",
"--include-dir",
action="append",
default=DEFAULTS["include-dir"],
help="Add the specified directory to the search path for include files",
)
parser.add_argument(
"-s",
"--suppress-diagnostics",
action="store_true",
help="Add NOLINT to suppress clang-tidy violations",
)
parser.add_argument(
"--disable-progress-bar",
action="store_true",
default=DEFAULTS["disable-progress-bar"],
help="Disable the progress bar",
)
parser.add_argument(
"extra_args", nargs="*", help="Extra arguments to forward to clang-tidy"
)
return parser.parse_args()
def main() -> None:
options = parse_args()
if not pathlib.Path("build").exists():
generate_build_files()
# Check if clang-tidy executable exists
exists = os.access(options.clang_tidy_exe, os.X_OK)
if not exists:
msg = (
f"Could not find '{options.clang_tidy_exe}'\n"
+ "We provide a custom build of clang-tidy that has additional checks.\n"
+ "You can install it by running:\n"
+ "$ python3 -m tools.linter.install.clang_tidy \n"
+ "from the pytorch folder"
)
raise RuntimeError(msg)
result, _ = run(options)
sys.exit(result.returncode)
if __name__ == "__main__":
main()

View File

@ -1,111 +0,0 @@
import argparse
import re
from typing import List
# > Why is DEFAULT_MAX_TOKEN_COUNT set to 1?
#
# clang-tidy doesn't have a direct way to query for token counts in the
# codebase. The workaround is to set the max token count to 1. This will cause
# clang-tidy to output a warning with the actual token count of the file.
#
# A non-destructive way to set the max token count to 1 would be to pass it
# through the -fmax-tokens option. However, this flag will be overridden if here
# exists a #pragma max_tokens_total statement in the file. This necessitates a
# destructive way to set the max token count to 1.
DEFAULT_MAX_TOKEN_COUNT = 1
MAX_TOKENS_CHECK_DIAG_NAME = "misc-max-tokens"
MAX_TOKENS_PRAGMA_PATTERN = r"^#pragma\s+clang\s+max_tokens_total\s+(\d+)$"
def add_max_tokens_pragma(code: str, num_max_tokens: int) -> str:
lines = code.splitlines()
found_pragma = False
pragma = f"#pragma clang max_tokens_total {num_max_tokens}"
for idx, line in enumerate(lines):
match = re.match(MAX_TOKENS_PRAGMA_PATTERN, line.strip())
if match:
found_pragma = True
token_count = match.group(1)
if int(token_count) != num_max_tokens:
lines[idx] = pragma
if not found_pragma:
lines = [pragma] + lines
return "\n".join(lines)
def strip_max_tokens_pragmas(code: str) -> str:
lines = code.splitlines()
lines = [
line
for line in lines
if re.match(MAX_TOKENS_PRAGMA_PATTERN, line.strip()) is None
]
return "\n".join(lines)
def add_max_tokens_pragma_to_files(files: List[str], num_max_tokens: int) -> None:
for filename in files:
with open(filename, "r+") as f:
data = f.read()
data = add_max_tokens_pragma(data, num_max_tokens)
f.seek(0)
f.write(data)
f.truncate()
def strip_max_tokens_pragma_from_files(files: List[str]) -> None:
for filename in files:
with open(filename, "r+") as f:
data = f.read()
data = strip_max_tokens_pragmas(data)
f.seek(0)
f.write(data)
f.truncate()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Add max_tokens_total pragmas to C/C++ source files"
)
parser.add_argument(
"-n",
"--num-max-tokens",
default=DEFAULT_MAX_TOKEN_COUNT,
help="Set the token count to this value",
type=int,
)
parser.add_argument(
"files", nargs="+", help="Add max_tokens_total pragmas to the specified files"
)
parser.add_argument(
"-i", "--ignore", nargs="+", default=[], help="Ignore the specified files"
)
parser.add_argument(
"-s",
"--strip",
action="store_true",
help="Remove max_tokens_total pragmas from the input files",
)
return parser.parse_args()
def main() -> None:
options = parse_args()
ignored = set(options.ignore)
files = [filename for filename in options.files if filename not in ignored]
if options.strip:
strip_max_tokens_pragma_from_files(files)
else:
add_max_tokens_pragma_to_files(files, options.num_max_tokens)
if __name__ == "__main__":
main()

View File

@ -1 +0,0 @@
unidiff==0.6.0

View File

@ -1,516 +0,0 @@
#!/usr/bin/env python3
"""
A driver script to run clang-tidy on changes detected via git.
By default, clang-tidy runs on all files you point it at. This means that even
if you changed only parts of that file, you will get warnings for the whole
file. This script has the ability to ask git for the exact lines that have
changed since a particular git revision, and makes clang-tidy only lint those.
This makes it much less overhead to integrate in CI and much more relevant to
developers. This git-enabled mode is optional, and full scans of a directory
tree are also possible. In both cases, the script allows filtering files via
glob or regular expressions.
"""
import collections
import fnmatch
import json
import os
import os.path
import re
import shutil
import sys
import asyncio
import shlex
import multiprocessing
from typing import Any, Dict, Iterable, List, Set, Tuple
Patterns = collections.namedtuple("Patterns", "positive, negative")
# NOTE: Clang-tidy cannot lint headers directly, because headers are not
# compiled -- translation units are, of which there is one per implementation
# (c/cc/cpp) file.
DEFAULT_FILE_PATTERN = re.compile(r"^.*\.c(c|pp)?$")
CLANG_WARNING_PATTERN = re.compile(
r"([^:]+):(\d+):\d+:\s+(warning|error):.*\[([^\]]+)\]"
)
# Set from command line arguments in main().
VERBOSE = False
QUIET = False
def log(*args: Any, **kwargs: Any) -> None:
if not QUIET:
print(*args, **kwargs)
class CommandResult:
def __init__(self, returncode: int, stdout: str, stderr: str):
self.returncode = returncode
self.stdout = stdout.strip()
self.stderr = stderr.strip()
def failed(self) -> bool:
return self.returncode != 0
def __add__(self, other: "CommandResult") -> "CommandResult":
return CommandResult(
self.returncode + other.returncode,
f"{self.stdout}\n{other.stdout}",
f"{self.stderr}\n{other.stderr}",
)
def __str__(self) -> str:
return f"{self.stdout}"
def __repr__(self) -> str:
return (
f"returncode: {self.returncode}\n"
+ f"stdout: {self.stdout}\n"
+ f"stderr: {self.stderr}"
)
class ProgressMeter:
def __init__(
self, num_items: int, start_msg: str = "", disable_progress_bar: bool = False
) -> None:
self.num_items = num_items
self.num_processed = 0
self.width = 80
self.disable_progress_bar = disable_progress_bar
# helper escape sequences
self._clear_to_end = "\x1b[2K"
self._move_to_previous_line = "\x1b[F"
self._move_to_start_of_line = "\r"
self._move_to_next_line = "\n"
if self.disable_progress_bar:
log(start_msg)
else:
self._write(
start_msg
+ self._move_to_next_line
+ "[>"
+ (self.width * " ")
+ "]"
+ self._move_to_start_of_line
)
self._flush()
def _write(self, s: str) -> None:
sys.stderr.write(s)
def _flush(self) -> None:
sys.stderr.flush()
def update(self, msg: str) -> None:
if self.disable_progress_bar:
return
# Once we've processed all items, clear the progress bar
if self.num_processed == self.num_items - 1:
self._write(self._clear_to_end)
return
# NOP if we've already processed all items
if self.num_processed > self.num_items:
return
self.num_processed += 1
self._write(
self._move_to_previous_line
+ self._clear_to_end
+ msg
+ self._move_to_next_line
)
progress = int((self.num_processed / self.num_items) * self.width)
padding = self.width - progress
self._write(
self._move_to_start_of_line
+ self._clear_to_end
+ f"({self.num_processed} of {self.num_items}) "
+ f"[{progress*'='}>{padding*' '}]"
+ self._move_to_start_of_line
)
self._flush()
def print(self, msg: str) -> None:
if QUIET:
return
elif self.disable_progress_bar:
print(msg)
else:
self._write(
self._clear_to_end
+ self._move_to_previous_line
+ self._clear_to_end
+ msg
+ self._move_to_next_line
+ self._move_to_next_line
)
self._flush()
class ClangTidyWarning:
def __init__(self, name: str, occurrences: List[Tuple[str, int]]):
self.name = name
self.occurrences = occurrences
def __str__(self) -> str:
base = f"[{self.name}] occurred {len(self.occurrences)} times\n"
for occ in self.occurrences:
base += f" {occ[0]}:{occ[1]}\n"
return base
async def run_shell_command(
cmd: List[str], on_completed: Any = None, *args: Any
) -> CommandResult:
"""Executes a shell command and runs an optional callback when complete"""
if VERBOSE:
log("Running: ", " ".join(cmd))
proc = await asyncio.create_subprocess_shell(
" ".join(shlex.quote(x) for x in cmd), # type: ignore[attr-defined]
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
output = await proc.communicate()
result = CommandResult(
returncode=proc.returncode if proc.returncode is not None else -1,
stdout=output[0].decode("utf-8").strip(),
stderr=output[1].decode("utf-8").strip(),
)
if on_completed:
on_completed(result, *args)
return result
async def _run_clang_tidy_in_parallel(
commands: List[Tuple[List[str], str]], disable_progress_bar: bool
) -> CommandResult:
progress_meter = ProgressMeter(
len(commands),
f"Processing {len(commands)} clang-tidy jobs",
disable_progress_bar=disable_progress_bar,
)
async def gather_with_concurrency(n: int, tasks: List[Any]) -> Any:
semaphore = asyncio.Semaphore(n)
async def sem_task(task: Any) -> Any:
async with semaphore:
return await task
return await asyncio.gather(
*(sem_task(task) for task in tasks), return_exceptions=True
)
async def helper() -> Any:
def on_completed(result: CommandResult, filename: str) -> None:
if result.failed():
msg = str(result) if not VERBOSE else repr(result)
progress_meter.print(msg)
progress_meter.update(f"Processed {filename}")
coros = [
run_shell_command(cmd, on_completed, filename)
for (cmd, filename) in commands
]
return await gather_with_concurrency(multiprocessing.cpu_count(), coros)
results = await helper()
return sum(results, CommandResult(0, "", ""))
async def _run_clang_tidy(
options: Any, line_filters: List[Dict[str, Any]], files: Iterable[str]
) -> CommandResult:
"""Executes the actual clang-tidy command in the shell."""
base = [options.clang_tidy_exe]
# Apply common options
base += ["-p", options.compile_commands_dir]
if not options.config_file and os.path.exists(".clang-tidy"):
options.config_file = ".clang-tidy"
if options.config_file:
import yaml
with open(options.config_file) as config:
# Here we convert the YAML config file to a JSON blob.
base += [
"-config",
json.dumps(yaml.load(config, Loader=yaml.SafeLoader)),
]
if options.print_include_paths:
base += ["--extra-arg", "-v"]
if options.include_dir:
for dir in options.include_dir:
base += ["--extra-arg", f"-I{dir}"]
base += options.extra_args
if line_filters:
base += ["-line-filter", json.dumps(line_filters)]
# Apply per-file options
commands = []
for f in files:
command = list(base) + [map_filename(options.compile_commands_dir, f)]
commands.append((command, f))
if options.dry_run:
return CommandResult(0, str([c for c, _ in commands]), "")
return await _run_clang_tidy_in_parallel(commands, options.disable_progress_bar)
def extract_warnings(
output: str, base_dir: str = "."
) -> Tuple[Dict[str, Dict[int, Set[str]]], List[ClangTidyWarning]]:
warn2occ: Dict[str, List[Tuple[str, int]]] = {}
fixes: Dict[str, Dict[int, Set[str]]] = {}
for line in output.splitlines():
p = CLANG_WARNING_PATTERN.match(line)
if p is None:
continue
if os.path.isabs(p.group(1)):
path = os.path.abspath(p.group(1))
else:
path = os.path.abspath(os.path.join(base_dir, p.group(1)))
line_no = int(p.group(2))
# Filter out any options (which start with '-')
warning_names = set([w for w in p.group(4).split(",") if not w.startswith("-")])
for name in warning_names:
if name not in warn2occ:
warn2occ[name] = []
warn2occ[name].append((path, line_no))
if path not in fixes:
fixes[path] = {}
if line_no not in fixes[path]:
fixes[path][line_no] = set()
fixes[path][line_no].update(warning_names)
warnings = [ClangTidyWarning(name, sorted(occ)) for name, occ in warn2occ.items()]
return fixes, warnings
def apply_nolint(fname: str, warnings: Dict[int, Set[str]]) -> None:
with open(fname, encoding="utf-8") as f:
lines = f.readlines()
line_offset = -1 # As in .cpp files lines are numbered starting from 1
for line_no in sorted(warnings.keys()):
nolint_diagnostics = ",".join(warnings[line_no])
line_no += line_offset
indent = " " * (len(lines[line_no]) - len(lines[line_no].lstrip(" ")))
lines.insert(line_no, f"{indent}// NOLINTNEXTLINE({nolint_diagnostics})\n")
line_offset += 1
with open(fname, mode="w") as f:
f.write("".join(lines))
# Functions for correct handling of "ATen/native/cpu" mapping
# Sources in that folder are not built in place but first copied into build folder with `.[CPUARCH].cpp` suffixes
def map_filename(build_folder: str, fname: str) -> str:
fname = os.path.relpath(fname)
native_cpu_prefix = "aten/src/ATen/native/cpu/"
build_cpu_prefix = os.path.join(build_folder, native_cpu_prefix, "")
default_arch_suffix = ".DEFAULT.cpp"
if fname.startswith(native_cpu_prefix) and fname.endswith(".cpp"):
return (
f"{build_cpu_prefix}{fname[len(native_cpu_prefix):]}{default_arch_suffix}"
)
if fname.startswith(build_cpu_prefix) and fname.endswith(default_arch_suffix):
return f"{native_cpu_prefix}{fname[len(build_cpu_prefix):-len(default_arch_suffix)]}"
return fname
def map_filenames(build_folder: str, fnames: Iterable[str]) -> List[str]:
return [map_filename(build_folder, fname) for fname in fnames]
def split_negative_from_positive_patterns(patterns: Iterable[str]) -> Patterns:
"""Separates negative patterns (that start with a dash) from positive patterns"""
positive, negative = [], []
for pattern in patterns:
if pattern.startswith("-"):
negative.append(pattern[1:])
else:
positive.append(pattern)
return Patterns(positive, negative)
def get_file_patterns(globs: Iterable[str], regexes: Iterable[str]) -> Patterns:
"""Returns a list of compiled regex objects from globs and regex pattern strings."""
# fnmatch.translate converts a glob into a regular expression.
# https://docs.python.org/2/library/fnmatch.html#fnmatch.translate
glob = split_negative_from_positive_patterns(globs)
regexes_ = split_negative_from_positive_patterns(regexes)
positive_regexes = regexes_.positive + [fnmatch.translate(g) for g in glob.positive]
negative_regexes = regexes_.negative + [fnmatch.translate(g) for g in glob.negative]
positive_patterns = [re.compile(regex) for regex in positive_regexes] or [
DEFAULT_FILE_PATTERN
]
negative_patterns = [re.compile(regex) for regex in negative_regexes]
return Patterns(positive_patterns, negative_patterns)
def filter_files(files: Iterable[str], file_patterns: Patterns) -> Iterable[str]:
"""Returns all files that match any of the patterns."""
if VERBOSE:
log("Filtering with these file patterns: {}".format(file_patterns))
for file in files:
if not any(n.match(file) for n in file_patterns.negative):
if any(p.match(file) for p in file_patterns.positive):
yield file
continue
if VERBOSE:
log(f"{file} omitted due to file filters")
async def get_all_files(paths: List[str]) -> List[str]:
"""Returns all files that are tracked by git in the given paths."""
output = await run_shell_command(["git", "ls-files"] + paths)
return str(output).strip().splitlines()
def find_changed_lines(diff: str) -> Dict[str, List[Tuple[int, int]]]:
# Delay import since this isn't required unless using the --diff-file
# argument, which for local runs people don't care about
try:
import unidiff # type: ignore[import]
except ImportError as e:
e.msg += ", run 'pip install unidiff'" # type: ignore[attr-defined]
raise e
files: Any = collections.defaultdict(list)
for file in unidiff.PatchSet(diff):
for hunk in file:
added_line_nos = [line.target_line_no for line in hunk if line.is_added]
if len(added_line_nos) == 0:
continue
# Convert list of line numbers to ranges
# Eg: [1, 2, 3, 12, 13, 14, 15] becomes [[1,3], [12, 15]]
i = 1
ranges = [[added_line_nos[0], added_line_nos[0]]]
while i < len(added_line_nos):
if added_line_nos[i] != added_line_nos[i - 1] + 1:
ranges[-1][1] = added_line_nos[i - 1]
ranges.append([added_line_nos[i], added_line_nos[i]])
i += 1
ranges[-1][1] = added_line_nos[-1]
files[file.path] += ranges
return dict(files)
def filter_from_diff(
paths: List[str], diffs: List[str]
) -> Tuple[List[str], List[Dict[Any, Any]]]:
files = []
line_filters = []
for diff in diffs:
changed_files = find_changed_lines(diff)
changed_files = {
filename: v
for filename, v in changed_files.items()
if any(filename.startswith(path) for path in paths)
}
line_filters += [
{"name": name, "lines": lines} for name, lines, in changed_files.items()
]
files += list(changed_files.keys())
return files, line_filters
def filter_from_diff_file(
paths: List[str], filename: str
) -> Tuple[List[str], List[Dict[Any, Any]]]:
with open(filename, "r") as f:
diff = f.read()
return filter_from_diff(paths, [diff])
async def filter_default(paths: List[str]) -> Tuple[List[str], List[Dict[Any, Any]]]:
return await get_all_files(paths), []
async def _run(options: Any) -> Tuple[CommandResult, List[ClangTidyWarning]]:
# These flags are pervasive enough to set it globally. It makes the code
# cleaner compared to threading it through every single function.
global VERBOSE
global QUIET
VERBOSE = options.verbose
QUIET = options.quiet
# Normalize the paths first
paths = [path.rstrip("/") for path in options.paths]
# Filter files
if options.diff_file:
files, line_filters = filter_from_diff_file(options.paths, options.diff_file)
else:
files, line_filters = await filter_default(options.paths)
file_patterns = get_file_patterns(options.glob, options.regex)
files = list(filter_files(files, file_patterns))
# clang-tidy errors when it does not get input files.
if not files:
log("No files detected")
return CommandResult(0, "", ""), []
result = await _run_clang_tidy(options, line_filters, files)
fixes, warnings = extract_warnings(
result.stdout, base_dir=options.compile_commands_dir
)
if options.suppress_diagnostics:
for fname in fixes.keys():
mapped_fname = map_filename(options.compile_commands_dir, fname)
log(f"Applying fixes to {mapped_fname}")
apply_nolint(fname, fixes[fname])
if os.path.relpath(fname) != mapped_fname:
shutil.copyfile(fname, mapped_fname)
if options.dry_run:
log(result)
elif result.failed():
# If you change this message, update the error checking logic in
# .github/workflows/lint.yml
msg = "Warnings detected!"
log(msg)
log("Summary:")
for w in warnings:
log(str(w))
return result, warnings
def run(options: Any) -> Tuple[CommandResult, List[ClangTidyWarning]]:
loop = asyncio.get_event_loop()
return loop.run_until_complete(_run(options))

View File

@ -1,655 +0,0 @@
#!/usr/bin/env python3
#
# ===- git-clang-format - ClangFormat Git Integration ---------*- python -*--===#
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# ===------------------------------------------------------------------------===#
r"""
clang-format git integration
============================
This file provides a clang-format integration for git. Put it somewhere in your
path and ensure that it is executable. Then, "git clang-format" will invoke
clang-format on the changes in current files or a specific commit.
For further details, run:
git clang-format -h
Requires Python 2.7 or Python 3
"""
from __future__ import absolute_import, division, print_function
import argparse
import collections
import contextlib
import errno
import os
import re
import subprocess
import sys
from clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH
usage = "git clang-format [OPTIONS] [<commit>] [<commit>] [--] [<file>...]"
desc = """
If zero or one commits are given, run clang-format on all lines that differ
between the working directory and <commit>, which defaults to HEAD. Changes are
only applied to the working directory.
If two commits are given (requires --diff), run clang-format on all lines in the
second <commit> that differ from the first <commit>.
If --binary is unspecified, we will try to fetch the correct clang-format
binary for PyTorch
The following git-config settings set the default of the corresponding option:
clangFormat.binary
clangFormat.commit
clangFormat.extension
clangFormat.style
"""
# Name of the temporary index file in which save the output of clang-format.
# This file is created within the .git directory.
temp_index_basename = "clang-format-index"
Range = collections.namedtuple("Range", "start, count")
def main():
config = load_git_config()
# In order to keep '--' yet allow options after positionals, we need to
# check for '--' ourselves. (Setting nargs='*' throws away the '--', while
# nargs=argparse.REMAINDER disallows options after positionals.)
argv = sys.argv[1:]
try:
idx = argv.index("--")
except ValueError:
dash_dash = []
else:
dash_dash = argv[idx:]
argv = argv[:idx]
default_extensions = ",".join(
[
# From clang/lib/Frontend/FrontendOptions.cpp, all lower case
"c",
"h", # C
"m", # ObjC
"mm", # ObjC++
"cc",
"cp",
"cpp",
"c++",
"cxx",
"hh",
"hpp",
"hxx", # C++
"cu", # CUDA
# Other languages that clang-format supports
"proto",
"protodevel", # Protocol Buffers
"java", # Java
"js", # JavaScript
"ts", # TypeScript
"cs", # C Sharp
]
)
p = argparse.ArgumentParser(
usage=usage,
formatter_class=argparse.RawDescriptionHelpFormatter,
description=desc,
)
p.add_argument("--binary", default=None, help="path to clang-format"),
p.add_argument(
"--commit",
default=config.get("clangformat.commit", "HEAD"),
help="default commit to use if none is specified",
),
p.add_argument(
"--diff",
action="store_true",
help="print a diff instead of applying the changes",
)
p.add_argument(
"--extensions",
default=config.get("clangformat.extensions", default_extensions),
help=(
"comma-separated list of file extensions to format, "
"excluding the period and case-insensitive"
),
),
p.add_argument(
"-f", "--force", action="store_true", help="allow changes to unstaged files"
)
p.add_argument(
"-p", "--patch", action="store_true", help="select hunks interactively"
)
p.add_argument(
"-q", "--quiet", action="count", default=0, help="print less information"
)
p.add_argument(
"--style",
default=config.get("clangformat.style", None),
help="passed to clang-format",
),
p.add_argument(
"-v", "--verbose", action="count", default=0, help="print extra information"
)
# We gather all the remaining positional arguments into 'args' since we need
# to use some heuristics to determine whether or not <commit> was present.
# However, to print pretty messages, we make use of metavar and help.
p.add_argument(
"args",
nargs="*",
metavar="<commit>",
help="revision from which to compute the diff",
)
p.add_argument(
"ignored",
nargs="*",
metavar="<file>...",
help="if specified, only consider differences in these files",
)
opts = p.parse_args(argv)
opts.verbose -= opts.quiet
del opts.quiet
ok = get_and_check_clang_format(opts.verbose)
if not ok:
# We have to invert because False -> 0, which is the code to be returned if everything is okay.
return not ok
if opts.binary is None:
opts.binary = CLANG_FORMAT_PATH
commits, files = interpret_args(opts.args, dash_dash, opts.commit)
if len(commits) > 1:
if not opts.diff:
die("--diff is required when two commits are given")
else:
if len(commits) > 2:
die("at most two commits allowed; %d given" % len(commits))
changed_lines = compute_diff_and_extract_lines(commits, files)
if opts.verbose >= 1:
ignored_files = set(changed_lines)
filter_by_extension(changed_lines, opts.extensions.lower().split(","))
if opts.verbose >= 1:
ignored_files.difference_update(changed_lines)
if ignored_files:
print("Ignoring changes in the following files (wrong extension):")
for filename in ignored_files:
print(" %s" % filename)
if changed_lines:
print("Running clang-format on the following files:")
for filename in changed_lines:
print(" %s" % filename)
if not changed_lines:
print("no modified files to format")
return
# The computed diff outputs absolute paths, so we must cd before accessing
# those files.
cd_to_toplevel()
if len(commits) > 1:
old_tree = commits[1]
new_tree = run_clang_format_and_save_to_tree(
changed_lines, revision=commits[1], binary=opts.binary, style=opts.style
)
else:
old_tree = create_tree_from_workdir(changed_lines)
new_tree = run_clang_format_and_save_to_tree(
changed_lines, binary=opts.binary, style=opts.style
)
if opts.verbose >= 1:
print("old tree: %s" % old_tree)
print("new tree: %s" % new_tree)
if old_tree == new_tree:
if opts.verbose >= 0:
print("clang-format did not modify any files")
elif opts.diff:
print_diff(old_tree, new_tree)
else:
changed_files = apply_changes(
old_tree, new_tree, force=opts.force, patch_mode=opts.patch
)
if (opts.verbose >= 0 and not opts.patch) or opts.verbose >= 1:
print("changed files:")
for filename in changed_files:
print(" %s" % filename)
def load_git_config(non_string_options=None):
"""Return the git configuration as a dictionary.
All options are assumed to be strings unless in `non_string_options`, in which
is a dictionary mapping option name (in lower case) to either "--bool" or
"--int"."""
if non_string_options is None:
non_string_options = {}
out = {}
for entry in run("git", "config", "--list", "--null").split("\0"):
if entry:
name, value = entry.split("\n", 1)
if name in non_string_options:
value = run("git", "config", non_string_options[name], name)
out[name] = value
return out
def interpret_args(args, dash_dash, default_commit):
"""Interpret `args` as "[commits] [--] [files]" and return (commits, files).
It is assumed that "--" and everything that follows has been removed from
args and placed in `dash_dash`.
If "--" is present (i.e., `dash_dash` is non-empty), the arguments to its
left (if present) are taken as commits. Otherwise, the arguments are checked
from left to right if they are commits or files. If commits are not given,
a list with `default_commit` is used."""
if dash_dash:
if len(args) == 0:
commits = [default_commit]
else:
commits = args
for commit in commits:
object_type = get_object_type(commit)
if object_type not in ("commit", "tag"):
if object_type is None:
die("'%s' is not a commit" % commit)
else:
die(
"'%s' is a %s, but a commit was expected"
% (commit, object_type)
)
files = dash_dash[1:]
elif args:
commits = []
while args:
if not disambiguate_revision(args[0]):
break
commits.append(args.pop(0))
if not commits:
commits = [default_commit]
files = args
else:
commits = [default_commit]
files = []
return commits, files
def disambiguate_revision(value):
"""Returns True if `value` is a revision, False if it is a file, or dies."""
# If `value` is ambiguous (neither a commit nor a file), the following
# command will die with an appropriate error message.
run("git", "rev-parse", value, verbose=False)
object_type = get_object_type(value)
if object_type is None:
return False
if object_type in ("commit", "tag"):
return True
die("`%s` is a %s, but a commit or filename was expected" % (value, object_type))
def get_object_type(value):
"""Returns a string description of an object's type, or None if it is not
a valid git object."""
cmd = ["git", "cat-file", "-t", value]
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = p.communicate()
if p.returncode != 0:
return None
return convert_string(stdout.strip())
def compute_diff_and_extract_lines(commits, files):
"""Calls compute_diff() followed by extract_lines()."""
diff_process = compute_diff(commits, files)
changed_lines = extract_lines(diff_process.stdout)
diff_process.stdout.close()
diff_process.wait()
if diff_process.returncode != 0:
# Assume error was already printed to stderr.
sys.exit(2)
return changed_lines
def compute_diff(commits, files):
"""Return a subprocess object producing the diff from `commits`.
The return value's `stdin` file object will produce a patch with the
differences between the working directory and the first commit if a single
one was specified, or the difference between both specified commits, filtered
on `files` (if non-empty). Zero context lines are used in the patch."""
git_tool = "diff-index"
if len(commits) > 1:
git_tool = "diff-tree"
cmd = ["git", git_tool, "-p", "-U0"] + commits + ["--"]
cmd.extend(files)
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
p.stdin.close()
return p
def extract_lines(patch_file):
"""Extract the changed lines in `patch_file`.
The return value is a dictionary mapping filename to a list of (start_line,
line_count) pairs.
The input must have been produced with ``-U0``, meaning unidiff format with
zero lines of context. The return value is a dict mapping filename to a
list of line `Range`s."""
matches = {}
for line in patch_file:
line = convert_string(line)
match = re.search(r"^\+\+\+\ [^/]+/(.*)", line)
if match:
filename = match.group(1).rstrip("\r\n")
match = re.search(r"^@@ -[0-9,]+ \+(\d+)(,(\d+))?", line)
if match:
start_line = int(match.group(1))
line_count = 1
if match.group(3):
line_count = int(match.group(3))
if line_count > 0:
matches.setdefault(filename, []).append(Range(start_line, line_count))
return matches
def filter_by_extension(dictionary, allowed_extensions):
"""Delete every key in `dictionary` that doesn't have an allowed extension.
`allowed_extensions` must be a collection of lowercase file extensions,
excluding the period."""
allowed_extensions = frozenset(allowed_extensions)
for filename in list(dictionary.keys()):
base_ext = filename.rsplit(".", 1)
if len(base_ext) == 1 and "" in allowed_extensions:
continue
if len(base_ext) == 1 or base_ext[1].lower() not in allowed_extensions:
del dictionary[filename]
def cd_to_toplevel():
"""Change to the top level of the git repository."""
toplevel = run("git", "rev-parse", "--show-toplevel")
os.chdir(toplevel)
def create_tree_from_workdir(filenames):
"""Create a new git tree with the given files from the working directory.
Returns the object ID (SHA-1) of the created tree."""
return create_tree(filenames, "--stdin")
def run_clang_format_and_save_to_tree(
changed_lines, revision=None, binary="clang-format", style=None
):
"""Run clang-format on each file and save the result to a git tree.
Returns the object ID (SHA-1) of the created tree."""
def iteritems(container):
try:
return container.iteritems() # Python 2
except AttributeError:
return container.items() # Python 3
def index_info_generator():
for filename, line_ranges in iteritems(changed_lines):
if revision:
git_metadata_cmd = [
"git",
"ls-tree",
"%s:%s" % (revision, os.path.dirname(filename)),
os.path.basename(filename),
]
git_metadata = subprocess.Popen(
git_metadata_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE
)
stdout = git_metadata.communicate()[0]
mode = oct(int(stdout.split()[0], 8))
else:
mode = oct(os.stat(filename).st_mode)
# Adjust python3 octal format so that it matches what git expects
if mode.startswith("0o"):
mode = "0" + mode[2:]
blob_id = clang_format_to_blob(
filename, line_ranges, revision=revision, binary=binary, style=style
)
yield "%s %s\t%s" % (mode, blob_id, filename)
return create_tree(index_info_generator(), "--index-info")
def create_tree(input_lines, mode):
"""Create a tree object from the given input.
If mode is '--stdin', it must be a list of filenames. If mode is
'--index-info' is must be a list of values suitable for "git update-index
--index-info", such as "<mode> <SP> <sha1> <TAB> <filename>". Any other mode
is invalid."""
assert mode in ("--stdin", "--index-info")
cmd = ["git", "update-index", "--add", "-z", mode]
with temporary_index_file():
p = subprocess.Popen(cmd, stdin=subprocess.PIPE)
for line in input_lines:
p.stdin.write(to_bytes("%s\0" % line))
p.stdin.close()
if p.wait() != 0:
die("`%s` failed" % " ".join(cmd))
tree_id = run("git", "write-tree")
return tree_id
def clang_format_to_blob(
filename, line_ranges, revision=None, binary="clang-format", style=None
):
"""Run clang-format on the given file and save the result to a git blob.
Runs on the file in `revision` if not None, or on the file in the working
directory if `revision` is None.
Returns the object ID (SHA-1) of the created blob."""
clang_format_cmd = [binary]
if style:
clang_format_cmd.extend(["-style=" + style])
clang_format_cmd.extend(
[
"-lines=%s:%s" % (start_line, start_line + line_count - 1)
for start_line, line_count in line_ranges
]
)
if revision:
clang_format_cmd.extend(["-assume-filename=" + filename])
git_show_cmd = ["git", "cat-file", "blob", "%s:%s" % (revision, filename)]
git_show = subprocess.Popen(
git_show_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE
)
git_show.stdin.close()
clang_format_stdin = git_show.stdout
else:
clang_format_cmd.extend([filename])
git_show = None
clang_format_stdin = subprocess.PIPE
try:
clang_format = subprocess.Popen(
clang_format_cmd, stdin=clang_format_stdin, stdout=subprocess.PIPE
)
if clang_format_stdin == subprocess.PIPE:
clang_format_stdin = clang_format.stdin
except OSError as e:
if e.errno == errno.ENOENT:
die('cannot find executable "%s"' % binary)
else:
raise
clang_format_stdin.close()
hash_object_cmd = ["git", "hash-object", "-w", "--path=" + filename, "--stdin"]
hash_object = subprocess.Popen(
hash_object_cmd, stdin=clang_format.stdout, stdout=subprocess.PIPE
)
clang_format.stdout.close()
stdout = hash_object.communicate()[0]
if hash_object.returncode != 0:
die("`%s` failed" % " ".join(hash_object_cmd))
if clang_format.wait() != 0:
die("`%s` failed" % " ".join(clang_format_cmd))
if git_show and git_show.wait() != 0:
die("`%s` failed" % " ".join(git_show_cmd))
return convert_string(stdout).rstrip("\r\n")
@contextlib.contextmanager
def temporary_index_file(tree=None):
"""Context manager for setting GIT_INDEX_FILE to a temporary file and deleting
the file afterward."""
index_path = create_temporary_index(tree)
old_index_path = os.environ.get("GIT_INDEX_FILE")
os.environ["GIT_INDEX_FILE"] = index_path
try:
yield
finally:
if old_index_path is None:
del os.environ["GIT_INDEX_FILE"]
else:
os.environ["GIT_INDEX_FILE"] = old_index_path
os.remove(index_path)
def create_temporary_index(tree=None):
"""Create a temporary index file and return the created file's path.
If `tree` is not None, use that as the tree to read in. Otherwise, an
empty index is created."""
gitdir = run("git", "rev-parse", "--git-dir")
path = os.path.join(gitdir, temp_index_basename)
if tree is None:
tree = "--empty"
run("git", "read-tree", "--index-output=" + path, tree)
return path
def print_diff(old_tree, new_tree):
"""Print the diff between the two trees to stdout."""
# We use the porcelain 'diff' and not plumbing 'diff-tree' because the output
# is expected to be viewed by the user, and only the former does nice things
# like color and pagination.
#
# We also only print modified files since `new_tree` only contains the files
# that were modified, so unmodified files would show as deleted without the
# filter.
subprocess.check_call(["git", "diff", "--diff-filter=M", old_tree, new_tree, "--"])
def apply_changes(old_tree, new_tree, force=False, patch_mode=False):
"""Apply the changes in `new_tree` to the working directory.
Bails if there are local changes in those files and not `force`. If
`patch_mode`, runs `git checkout --patch` to select hunks interactively."""
changed_files = (
run(
"git",
"diff-tree",
"--diff-filter=M",
"-r",
"-z",
"--name-only",
old_tree,
new_tree,
)
.rstrip("\0")
.split("\0")
)
if not force:
unstaged_files = run("git", "diff-files", "--name-status", *changed_files)
if unstaged_files:
print(
"The following files would be modified but " "have unstaged changes:",
file=sys.stderr,
)
print(unstaged_files, file=sys.stderr)
print("Please commit, stage, or stash them first.", file=sys.stderr)
sys.exit(2)
if patch_mode:
# In patch mode, we could just as well create an index from the new tree
# and checkout from that, but then the user will be presented with a
# message saying "Discard ... from worktree". Instead, we use the old
# tree as the index and checkout from new_tree, which gives the slightly
# better message, "Apply ... to index and worktree". This is not quite
# right, since it won't be applied to the user's index, but oh well.
with temporary_index_file(old_tree):
subprocess.check_call(["git", "checkout", "--patch", new_tree])
index_tree = old_tree
else:
with temporary_index_file(new_tree):
run("git", "checkout-index", "-a", "-f")
return changed_files
def run(*args, **kwargs):
stdin = kwargs.pop("stdin", "")
verbose = kwargs.pop("verbose", True)
strip = kwargs.pop("strip", True)
for name in kwargs:
raise TypeError("run() got an unexpected keyword argument '%s'" % name)
p = subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE
)
stdout, stderr = p.communicate(input=stdin)
stdout = convert_string(stdout)
stderr = convert_string(stderr)
if p.returncode == 0:
if stderr:
if verbose:
print("`%s` printed to stderr:" % " ".join(args), file=sys.stderr)
print(stderr.rstrip(), file=sys.stderr)
if strip:
stdout = stdout.rstrip("\r\n")
return stdout
if verbose:
print("`%s` returned %s" % (" ".join(args), p.returncode), file=sys.stderr)
if stderr:
print(stderr.rstrip(), file=sys.stderr)
sys.exit(2)
def die(message):
print("error:", message, file=sys.stderr)
sys.exit(2)
def to_bytes(str_input):
# Encode to UTF-8 to get binary data.
if isinstance(str_input, bytes):
return str_input
return str_input.encode("utf-8")
def to_string(bytes_input):
if isinstance(bytes_input, str):
return bytes_input
return bytes_input.encode("utf-8")
def convert_string(bytes_input):
try:
return to_string(bytes_input.decode("utf-8"))
except AttributeError: # 'str' object has no attribute 'decode'.
return str(bytes_input)
except UnicodeError:
return str(bytes_input)
if __name__ == "__main__":
main()

View File

@ -1,21 +0,0 @@
import os
from tools.linter.install.download_bin import download, PYTORCH_ROOT, HASH_PATH
PLATFORM_TO_URL = {
"Linux": "https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-tidy",
"Darwin": "https://oss-clang-format.s3.us-east-2.amazonaws.com/macos/clang-tidy",
}
PLATFORM_TO_HASH = {
"Linux": os.path.join(HASH_PATH, "clang-tidy-linux64"),
"Darwin": os.path.join(HASH_PATH, "clang-tidy-macos"),
}
OUTPUT_DIR = os.path.join(PYTORCH_ROOT, ".clang-tidy-bin")
INSTALLATION_PATH = os.path.join(OUTPUT_DIR, "clang-tidy")
if __name__ == "__main__":
ok = download("clang-tidy", OUTPUT_DIR, PLATFORM_TO_URL, PLATFORM_TO_HASH)
if not ok:
print("Installation failed!")
exit(1)

View File

@ -1,180 +0,0 @@
import platform
import sys
import stat
import hashlib
import subprocess
import os
import urllib.request
import urllib.error
from typing import Dict
# String representing the host platform (e.g. Linux, Darwin).
HOST_PLATFORM = platform.system()
# PyTorch directory root
result = subprocess.run(
["git", "rev-parse", "--show-toplevel"],
stdout=subprocess.PIPE,
check=True,
)
PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
HASH_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hashes")
def compute_file_sha256(path: str) -> str:
"""Compute the SHA256 hash of a file and return it as a hex string."""
# If the file doesn't exist, return an empty string.
if not os.path.exists(path):
return ""
hash = hashlib.sha256()
# Open the file in binary mode and hash it.
with open(path, "rb") as f:
for b in f:
hash.update(b)
# Return the hash as a hexadecimal string.
return hash.hexdigest()
def report_download_progress(
chunk_number: int, chunk_size: int, file_size: int
) -> None:
"""
Pretty printer for file download progress.
"""
if file_size != -1:
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = "#" * int(64 * percent)
sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))
def download_bin(name: str, output_dir: str, platform_to_url: Dict[str, str]) -> bool:
"""
Downloads the binary appropriate for the host platform and stores it in the given output directory.
"""
if HOST_PLATFORM not in platform_to_url:
print(f"Unsupported platform: {HOST_PLATFORM}", file=sys.stderr)
return False
url = platform_to_url[HOST_PLATFORM]
filename = os.path.join(output_dir, name)
# Try to download binary.
print(f"Downloading {name} to {output_dir}", file=sys.stderr)
try:
urllib.request.urlretrieve(
url,
filename,
reporthook=report_download_progress if sys.stdout.isatty() else None,
)
except urllib.error.URLError as e:
print(f"Error downloading {filename}: {e}", file=sys.stderr)
return False
finally:
print(file=sys.stderr)
return True
def download(
name: str,
output_dir: str,
platform_to_url: Dict[str, str],
platform_to_hash: Dict[str, str],
verbose: bool = False,
) -> bool:
"""
Download a platform-appropriate binary if one doesn't already exist at the expected location and verifies
that it is the right binary by checking its SHA256 hash against the expected hash.
"""
output_path = os.path.join(output_dir, name)
if not os.path.exists(output_dir):
# If the directory doesn't exist, try to create it.
try:
os.mkdir(output_dir)
except OSError as e:
print(
f"Unable to create directory for {name} binary: {output_dir}",
file=sys.stderr,
)
return False
finally:
if verbose:
print(
f"Created directory {output_dir} for {name} binary", file=sys.stderr
)
# If the directory didn't exist, neither did the binary, so download it.
ok = download_bin(name, output_dir, platform_to_url)
if not ok:
return False
else:
# If the directory exists but the binary doesn't, download it.
if not os.path.exists(output_path):
ok = download_bin(name, output_dir, platform_to_url)
if not ok:
return False
else:
if verbose:
print(
f"Found pre-existing {name} binary, skipping download",
file=sys.stderr,
)
# Now that the binary is where it should be, hash it.
actual_bin_hash = compute_file_sha256(output_path)
# If the host platform is not in platform_to_hash, it is unsupported.
if HOST_PLATFORM not in platform_to_hash:
print(f"Unsupported platform: {HOST_PLATFORM}", file=sys.stderr)
return False
# This is the path to the file containing the reference hash.
hashpath = os.path.join(PYTORCH_ROOT, platform_to_hash[HOST_PLATFORM])
if not os.path.exists(hashpath):
print("Unable to find reference binary hash", file=sys.stderr)
return False
# Load the reference hash and compare the actual hash to it.
with open(hashpath, "r") as f:
reference_bin_hash = f.readline().strip()
if verbose:
print(f"Reference Hash: {reference_bin_hash}", file=sys.stderr)
print(f"Actual Hash: {repr(actual_bin_hash)}", file=sys.stderr)
if reference_bin_hash != actual_bin_hash:
print("The downloaded binary is not what was expected!", file=sys.stderr)
print(
f"Downloaded hash: {repr(actual_bin_hash)} vs expected {reference_bin_hash}",
file=sys.stderr,
)
# Err on the side of caution and try to delete the downloaded binary.
try:
os.unlink(output_path)
print("The binary has been deleted just to be safe", file=sys.stderr)
except OSError as e:
print(f"Failed to delete binary: {e}", file=sys.stderr)
print(
"Delete this binary as soon as possible and do not execute it!",
file=sys.stderr,
)
return False
else:
# Make sure the binary is executable.
mode = os.stat(output_path).st_mode
mode |= stat.S_IXUSR
os.chmod(output_path, mode)
print(f"Using {name} located at {output_path}", file=sys.stderr)
return True

View File

@ -1 +0,0 @@
49343a448fcb75cd1e0fb9d6b1f6c2ef4b008b6f91d6ff899d4ac6060f5e52a5

View File

@ -1 +0,0 @@
541797a7b8fa795e2f3c1adcd8236cc336a40aa927028dc5bc79172e1d9eca36

View File

@ -1,224 +0,0 @@
#!/usr/bin/env python3
"""
This module is meant to be run as a script (see the docstring of main
below) and passed the filename of any Python file in this repo, to
typecheck that file using only the subset of our mypy configs that apply
to it.
Since editors (e.g. VS Code) can be configured to use this wrapper
script in lieu of mypy itself, the idea is that this can be used to get
inline mypy results while developing, and have at least some degree of
assurance that those inline results match up with what you would get
from running the mypy lint from the .github/workflows/lint.yml file.
See also these wiki pages:
- https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch
- https://github.com/pytorch/pytorch/wiki/Lint-as-you-type
"""
import sys
from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path, PurePath, PurePosixPath
from typing import Any, Dict, List, Optional, Set, Tuple
import mypy.api
# not part of the public API, but this is the easiest way to ensure that
# we agree with what mypy actually does
import mypy.config_parser
def read_config(config_path: Path) -> Set[str]:
"""
Return the set of `files` in the `mypy` ini file at config_path.
"""
config = ConfigParser()
config.read(config_path)
# hopefully on Windows this gives posix paths
return set(
mypy.config_parser.split_and_match_files(
config["mypy"]["files"],
)
)
# see tools/test/test_mypy_wrapper.py for examples of many of the
# following functions
def config_files() -> Dict[str, Set[str]]:
"""
Return a dict from all our `mypy` ini filenames to their `files`.
"""
return {str(ini): read_config(ini) for ini in Path().glob("mypy*.ini")}
def split_path(path: str) -> List[str]:
"""
Split a relative (not absolute) POSIX path into its segments.
"""
pure = PurePosixPath(path)
return [str(p.name) for p in list(reversed(pure.parents))[1:] + [pure]]
# mypy doesn't support recursive types yet
# https://github.com/python/mypy/issues/731
# but if it did, the `Any` here would be `Union[Set[str], 'Trie']`,
# although that is not completely accurate: specifically, every `None`
# key must map to a `Set[str]`, and every `str` key must map to a `Trie`
Trie = Dict[Optional[str], Any]
def make_trie(configs: Dict[str, Set[str]]) -> Trie:
"""
Return a trie from path prefixes to their `mypy` configs.
Specifically, each layer of the trie represents a segment of a POSIX
path relative to the root of this repo. If you follow a path down
the trie and reach a `None` key, that `None` maps to the (nonempty)
set of keys in `configs` which explicitly include that path.
"""
trie: Trie = {}
for ini, files in configs.items():
for f in files:
inner = trie
for segment in split_path(f):
inner = inner.setdefault(segment, {})
inner.setdefault(None, set()).add(ini)
return trie
def lookup(trie: Trie, filename: str) -> Set[str]:
"""
Return the configs in `trie` that include a prefix of `filename`.
A path is included by a config if any of its ancestors are included
by the wildcard-expanded version of that config's `files`. Thus,
this function follows `filename`'s path down the `trie` and
accumulates all the configs it finds along the way.
"""
configs = set()
inner = trie
for segment in split_path(filename):
inner = inner.get(segment, {})
configs |= inner.get(None, set())
return configs
def make_plan(
*, configs: Dict[str, Set[str]], files: List[str]
) -> Dict[str, List[str]]:
"""
Return a dict from config names to the files to run them with.
The keys of the returned dict are a subset of the keys of `configs`.
The list of files in each value of returned dict should contain a
nonempty subset of the given `files`, in the same order as `files`.
"""
trie = make_trie(configs)
plan = defaultdict(list)
for filename in files:
for config in lookup(trie, filename):
plan[config].append(filename)
return plan
def run(
*,
args: List[str],
files: List[str],
) -> Tuple[int, List[str], List[str]]:
"""
Return the exit code and list of output lines from running `mypy`.
The given `args` are passed verbatim to `mypy`. The `files` (each of
which must be an absolute path) are converted to relative paths
(that is, relative to the root of this repo) and then classified
according to which ones need to be run with each `mypy` config.
Thus, `mypy` may be run zero, one, or multiple times, but it will be
run at most once for each `mypy` config used by this repo.
"""
repo_root = Path.cwd()
plan = make_plan(
configs=config_files(),
files=[PurePath(f).relative_to(repo_root).as_posix() for f in files],
)
mypy_results = [
mypy.api.run(
# insert custom flags after args to avoid being overridden
# by existing flags in args
args
+ [
# don't special-case the last line
"--no-error-summary",
f"--config-file={config}",
]
+ filtered
)
# by construction, filtered must be nonempty
for config, filtered in plan.items()
]
return (
# assume all mypy exit codes are nonnegative
# https://github.com/python/mypy/issues/6003
max(
[exit_code for _, _, exit_code in mypy_results],
default=0,
),
list(
dict.fromkeys( # remove duplicates, retain order
item for stdout, _, _ in mypy_results for item in stdout.splitlines()
)
),
[stderr for _, stderr, _ in mypy_results],
)
def main(args: List[str]) -> None:
"""
Run mypy on one Python file using the correct config file(s).
This function assumes the following preconditions hold:
- the cwd is set to the root of this cloned repo
- args is a valid list of CLI arguments that could be passed to mypy
- some of args are absolute paths to files to typecheck
- all the other args are config flags for mypy, rather than files
These assumptions hold, for instance, when mypy is run automatically
by VS Code's Python extension, so in your clone of this repository,
you could modify your .vscode/settings.json to look something like
this (assuming you use a conda environment named "pytorch"):
{
"python.linting.enabled": true,
"python.linting.mypyEnabled": true,
"python.linting.mypyPath":
"${env:HOME}/miniconda3/envs/pytorch/bin/python",
"python.linting.mypyArgs": [
"${workspaceFolder}/tools/linter/mypy_wrapper.py"
]
}
More generally, this should work for any editor sets the cwd to the
repo root, runs mypy on individual files via their absolute paths,
and allows you to set the path to the mypy executable.
"""
repo_root = str(Path.cwd())
exit_code, mypy_issues, stderrs = run(
args=[arg for arg in args if not arg.startswith(repo_root)],
files=[arg for arg in args if arg.startswith(repo_root)],
)
for issue in mypy_issues:
print(issue)
for stderr in stderrs:
print(stderr, end="", file=sys.stderr)
sys.exit(exit_code)
if __name__ == "__main__":
main(sys.argv[1:])

View File

@ -1,2 +0,0 @@
#!/usr/bin/env bash
find "$@" -name '*.sh' -print0 | xargs -0 -n1 shellcheck --external-sources

View File

@ -1,37 +0,0 @@
#!/usr/bin/env python3
import fileinput
import os
import sys
(NEWLINE,) = b"\n"
def correct_trailing_newlines(filename: str) -> bool:
with open(filename, "rb") as f:
a = len(f.read(2))
if a == 0:
return True
elif a == 1:
# file is wrong whether or not the only byte is a newline
return False
else:
f.seek(-2, os.SEEK_END)
b, c = f.read(2)
# no ASCII byte is part of any non-ASCII character in UTF-8
return b != NEWLINE and c == NEWLINE
def main() -> int:
# mimic git grep exit code behavior
exit_code = 1
for line in fileinput.input():
stripped = line.rstrip()
if not correct_trailing_newlines(stripped):
exit_code = 0
print(stripped)
return exit_code
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,183 +0,0 @@
#!/usr/bin/env python3
import argparse
import json
import re
import subprocess
from bisect import bisect_right
from collections import defaultdict
from typing import (
Callable,
DefaultDict,
Generic,
List,
Optional,
Pattern,
Sequence,
TypeVar,
cast,
)
from typing_extensions import TypedDict
class Hunk(TypedDict):
old_start: int
old_count: int
new_start: int
new_count: int
class Diff(TypedDict):
old_filename: Optional[str]
hunks: List[Hunk]
# @@ -start,count +start,count @@
hunk_pattern = r"^@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@"
def parse_diff(diff: str) -> Diff:
name = None
name_found = False
hunks: List[Hunk] = []
for line in diff.splitlines():
hunk_match = re.match(hunk_pattern, line)
if name_found:
if hunk_match:
old_start, old_count, new_start, new_count = hunk_match.groups()
hunks.append(
{
"old_start": int(old_start),
"old_count": int(old_count or "1"),
"new_start": int(new_start),
"new_count": int(new_count or "1"),
}
)
else:
assert not hunk_match
name_match = re.match(r"^--- (?:(?:/dev/null)|(?:a/(.*)))$", line)
if name_match:
name_found = True
(name,) = name_match.groups()
return {
"old_filename": name,
"hunks": hunks,
}
T = TypeVar("T")
U = TypeVar("U")
# we want to use bisect.bisect_right to find the closest hunk to a given
# line number, but the bisect module won't have a key function until
# Python 3.10 https://github.com/python/cpython/pull/20556 so we make an
# O(1) wrapper around the list of hunks that makes it pretend to just be
# a list of line numbers
# https://gist.github.com/ericremoreynolds/2d80300dabc70eebc790
class KeyifyList(Generic[T, U]):
def __init__(self, inner: List[T], key: Callable[[T], U]) -> None:
self.inner = inner
self.key = key
def __len__(self) -> int:
return len(self.inner)
def __getitem__(self, k: int) -> U:
return self.key(self.inner[k])
def translate(diff: Diff, line_number: int) -> Optional[int]:
if line_number < 1:
return None
hunks = diff["hunks"]
if not hunks:
return line_number
keyified = KeyifyList(
hunks, lambda hunk: hunk["new_start"] + (0 if hunk["new_count"] > 0 else 1)
)
i = bisect_right(cast(Sequence[int], keyified), line_number)
if i < 1:
return line_number
hunk = hunks[i - 1]
d = line_number - (hunk["new_start"] + (hunk["new_count"] or 1))
return None if d < 0 else hunk["old_start"] + (hunk["old_count"] or 1) + d
# we use camelCase here because this will be output as JSON and so the
# field names need to match the group names from here:
# https://github.com/pytorch/add-annotations-github-action/blob/3ab7d7345209f5299d53303f7aaca7d3bc09e250/action.yml#L23
class Annotation(TypedDict):
filename: str
lineNumber: int
columnNumber: int
errorCode: str
errorDesc: str
def parse_annotation(regex: Pattern[str], line: str) -> Optional[Annotation]:
m = re.match(regex, line)
if m:
try:
line_number = int(m.group("lineNumber"))
column_number = int(m.group("columnNumber"))
except ValueError:
return None
return {
"filename": m.group("filename"),
"lineNumber": line_number,
"columnNumber": column_number,
"errorCode": m.group("errorCode"),
"errorDesc": m.group("errorDesc"),
}
else:
return None
def translate_all(
*, lines: List[str], regex: Pattern[str], commit: str
) -> List[Annotation]:
ann_dict: DefaultDict[str, List[Annotation]] = defaultdict(list)
for line in lines:
annotation = parse_annotation(regex, line)
if annotation is not None:
ann_dict[annotation["filename"]].append(annotation)
ann_list = []
for filename, annotations in ann_dict.items():
raw_diff = subprocess.check_output(
["git", "diff-index", "--unified=0", commit, filename],
encoding="utf-8",
)
diff = parse_diff(raw_diff) if raw_diff.strip() else None
# if there is a diff but it doesn't list an old filename, that
# means the file is absent in the commit we're targeting, so we
# skip it
if not (diff and not diff["old_filename"]):
for annotation in annotations:
line_number: Optional[int] = annotation["lineNumber"]
if diff:
annotation["filename"] = cast(str, diff["old_filename"])
line_number = translate(diff, cast(int, line_number))
if line_number:
annotation["lineNumber"] = line_number
ann_list.append(annotation)
return ann_list
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--file")
parser.add_argument("--regex")
parser.add_argument("--commit")
args = parser.parse_args()
with open(args.file, "r") as f:
lines = f.readlines()
print(json.dumps(translate_all(lines=lines, regex=args.regex, commit=args.commit)))
if __name__ == "__main__":
main()

View File

@ -1,95 +0,0 @@
import unittest
from tools import extract_scripts
requirements_sh = """
#!/usr/bin/env bash
set -eo pipefail
pip install -r requirements.txt
""".strip()
hello_sh = """
#!/usr/bin/env sh
set -e
echo hello world
""".strip()
class TestExtractScripts(unittest.TestCase):
def test_extract_none(self) -> None:
self.assertEqual(
extract_scripts.extract(
{
"name": "Checkout PyTorch",
"uses": "zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9",
}
),
None,
)
def test_extract_run_default_bash(self) -> None:
self.assertEqual(
extract_scripts.extract(
{
"name": "Install requirements",
"run": "pip install -r requirements.txt",
}
),
{
"extension": ".sh",
"script": requirements_sh,
},
)
def test_extract_run_sh(self) -> None:
self.assertEqual(
extract_scripts.extract(
{
"name": "Hello world",
"run": "echo hello world",
"shell": "sh",
}
),
{
"extension": ".sh",
"script": hello_sh,
},
)
def test_extract_run_py(self) -> None:
self.assertEqual(
extract_scripts.extract(
{
"name": "Hello world",
"run": 'print("Hello!")',
"shell": "python",
}
),
{
"extension": ".py",
"script": 'print("Hello!")',
},
)
def test_extract_github_script(self) -> None:
self.assertEqual(
# https://github.com/actions/github-script/tree/v3.1.1#reading-step-results
extract_scripts.extract(
{
"uses": "actions/github-script@v3",
"id": "set-result",
"with": {
"script": 'return "Hello!"',
"result-encoding": "string",
},
}
),
{
"extension": ".js",
"script": 'return "Hello!"',
},
)
if __name__ == "__main__":
unittest.main()

View File

@ -1,132 +0,0 @@
import unittest
from tools.linter.clang_tidy.max_tokens_pragma import (
add_max_tokens_pragma,
strip_max_tokens_pragmas,
)
def compare_code(a: str, b: str) -> bool:
a_lines = [line.strip() for line in a.splitlines()]
b_lines = [line.strip() for line in b.splitlines()]
return a_lines == b_lines
class TestMaxTokensPragma(unittest.TestCase):
def test_no_prior_pragmas(self) -> None:
input = """\
// File without any prior pragmas
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
expected = """\
#pragma clang max_tokens_total 42
// File without any prior pragmas
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
output = add_max_tokens_pragma(input, 42)
self.assertTrue(compare_code(output, expected))
output = strip_max_tokens_pragmas(output)
self.assertTrue(compare_code(output, input))
def test_single_prior_pragma(self) -> None:
input = """\
// File with prior pragmas
#pragma clang max_tokens_total 1
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
expected = """\
// File with prior pragmas
#pragma clang max_tokens_total 42
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
stripped = """\
// File with prior pragmas
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
output = add_max_tokens_pragma(input, 42)
self.assertTrue(compare_code(output, expected))
output = strip_max_tokens_pragmas(output)
self.assertTrue(compare_code(output, stripped))
def test_multiple_prior_pragmas(self) -> None:
input = """\
// File with multiple prior pragmas
#pragma clang max_tokens_total 1
// Different pragma; script should ignore this
#pragma clang max_tokens_here 20
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
#pragma clang max_tokens_total 1
"""
expected = """\
// File with multiple prior pragmas
#pragma clang max_tokens_total 42
// Different pragma; script should ignore this
#pragma clang max_tokens_here 20
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
#pragma clang max_tokens_total 42
"""
stripped = """\
// File with multiple prior pragmas
// Different pragma; script should ignore this
#pragma clang max_tokens_here 20
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
output = add_max_tokens_pragma(input, 42)
self.assertTrue(compare_code(output, expected))
output = strip_max_tokens_pragmas(output)
self.assertTrue(compare_code(output, stripped))
if __name__ == "__main__":
unittest.main()

View File

@ -1,173 +0,0 @@
import unittest
from tools.linter import mypy_wrapper
class TestMypyWrapper(unittest.TestCase):
configs = {
"foo.ini": {
"file1.abc",
"dir2",
"dir3/file4.xyz",
},
"bar/baz.ini": {
"file1.abc",
"dir2/dir5/file6.def",
"dir3/file7.abc",
},
}
trie: mypy_wrapper.Trie = {
"file1.abc": {None: {"foo.ini", "bar/baz.ini"}},
"dir2": {
None: {"foo.ini"},
"dir5": {"file6.def": {None: {"bar/baz.ini"}}},
},
"dir3": {
"file4.xyz": {None: {"foo.ini"}},
"file7.abc": {None: {"bar/baz.ini"}},
},
}
def test_config_files(self) -> None:
self.assertEqual(
mypy_wrapper.config_files().keys(),
{
"mypy.ini",
"mypy-strict.ini",
},
)
def test_split_path(self) -> None:
self.assertEqual(mypy_wrapper.split_path("file1.abc"), ["file1.abc"])
self.assertEqual(
mypy_wrapper.split_path("dir3/file4.xyz"),
["dir3", "file4.xyz"],
)
self.assertEqual(
mypy_wrapper.split_path("dir2/dir5/file6.def"),
["dir2", "dir5", "file6.def"],
)
def test_make_trie(self) -> None:
self.assertEqual(mypy_wrapper.make_trie(self.configs), self.trie)
def test_lookup(self) -> None:
self.assertEqual(
mypy_wrapper.lookup(self.trie, "file1.abc"),
{"foo.ini", "bar/baz.ini"},
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir2/dir5/file6.def"),
{"foo.ini", "bar/baz.ini"},
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir3/file4.xyz"),
{"foo.ini"},
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir3/file7.abc"),
{"bar/baz.ini"},
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "file8.xyz"),
set(),
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir2/dir9/file10.abc"),
{"foo.ini"},
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir3/file11.abc"),
set(),
)
# non-leaves shouldn't ever be passed to lookup in practice, but
# still, good to consider/test these cases
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir2"),
{"foo.ini"},
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir2/dir5"),
{"foo.ini"},
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir3"),
set(),
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir2/dir9"),
{"foo.ini"},
)
self.assertEqual(
mypy_wrapper.lookup(self.trie, "dir4"),
set(),
)
def test_make_plan(self) -> None:
self.assertEqual(
mypy_wrapper.make_plan(
configs=self.configs,
files=[
"file8.xyz",
"dir3/file11.abc",
],
),
{},
)
self.assertEqual(
mypy_wrapper.make_plan(
configs=self.configs,
files=[
"file8.xyz",
"dir2/dir9/file10.abc",
"dir3/file4.xyz",
"dir3/file11.abc",
],
),
{
"foo.ini": ["dir2/dir9/file10.abc", "dir3/file4.xyz"],
},
)
self.assertEqual(
mypy_wrapper.make_plan(
configs=self.configs,
files=[
"file8.xyz",
"dir3/file11.abc",
"dir3/file7.abc",
],
),
{
"bar/baz.ini": ["dir3/file7.abc"],
},
)
self.assertEqual(
mypy_wrapper.make_plan(
configs=self.configs,
files=[
"dir2/dir9/file10.abc",
"dir2/dir5/file6.def",
"dir3/file7.abc",
"file1.abc",
"dir3/file11.abc",
],
),
{
"foo.ini": [
"dir2/dir9/file10.abc",
"dir2/dir5/file6.def",
"file1.abc",
],
"bar/baz.ini": [
"dir2/dir5/file6.def",
"dir3/file7.abc",
"file1.abc",
],
},
)
if __name__ == "__main__":
unittest.main()

View File

@ -1,49 +0,0 @@
from tools.linter import trailing_newlines
import unittest
import tempfile
def correct_trailing_newlines(file_contents: str) -> bool:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
filename = tmp.name
tmp.write(file_contents)
return trailing_newlines.correct_trailing_newlines(filename)
class TestTrailingNewlines(unittest.TestCase):
def test_empty(self) -> None:
self.assertTrue(correct_trailing_newlines(""))
def test_single_byte(self) -> None:
self.assertFalse(correct_trailing_newlines("a"))
def test_single_newline(self) -> None:
self.assertFalse(correct_trailing_newlines("\n"))
def test_two_newlines(self) -> None:
self.assertFalse(correct_trailing_newlines("\n\n"))
def test_three_newlines(self) -> None:
self.assertFalse(correct_trailing_newlines("\n\n\n"))
def test_hello_world(self) -> None:
self.assertFalse(correct_trailing_newlines("hello world"))
def test_hello_world_newline(self) -> None:
self.assertTrue(correct_trailing_newlines("hello world\n"))
def test_hello_world_two_newlines(self) -> None:
self.assertFalse(correct_trailing_newlines("hello world\n\n"))
def test_hello_world_three_newlines(self) -> None:
self.assertFalse(correct_trailing_newlines("hello world\n\n\n"))
def test_hello_world_multiline(self) -> None:
self.assertFalse(correct_trailing_newlines("hello\nworld"))
def test_hello_world_multiline_gap(self) -> None:
self.assertTrue(correct_trailing_newlines("hello\n\nworld\n"))
if __name__ == "__main__":
unittest.main()

View File

@ -1,278 +0,0 @@
import re
import unittest
from tools.linter.translate_annotations import parse_annotation, parse_diff, translate
flake8_regex = r"^(?P<filename>.*?):(?P<lineNumber>\d+):(?P<columnNumber>\d+): (?P<errorCode>\w+\d+) (?P<errorDesc>.*)"
clang_tidy_regex = r"^(?P<filename>.*?):(?P<lineNumber>\d+):(?P<columnNumber>\d+): (?P<errorDesc>.*?) \[(?P<errorCode>.*)\]"
# in the below example patch, note that the filenames differ, so the
# translation should reflect that as well as the line numbers
# $ git clone -b 1.0.2 https://github.com/cscorley/whatthepatch.git
# $ cd whatthepatch/tests/casefiles
# $ git diff --no-index --unified=0 lao tzu
lao_tzu_diff = """
diff --git a/lao b/tzu
index 635ef2c..5af88a8 100644
--- a/lao
+++ b/tzu
@@ -1,2 +0,0 @@
-The Way that can be told of is not the eternal Way;
-The name that can be named is not the eternal name.
@@ -4 +2,2 @@ The Nameless is the origin of Heaven and Earth;
-The Named is the mother of all things.
+The named is the mother of all things.
+
@@ -11,0 +11,3 @@ But after they are produced,
+They both may be called deep and profound.
+Deeper and more profound,
+The door of all subtleties!
""".lstrip()
sparser_diff = """
diff --git a/foo.txt b/bar.txt
index 27a6dad..6fae323 100644
--- a/foo.txt
+++ b/bar.txt
@@ -4,3 +4,2 @@ lines
-lines
-lines
-lines
+A change!!
+Wow
@@ -10,2 +8,0 @@ more lines
-even more
-even more
""".lstrip()
new_file_diff = """
diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.h b/torch/csrc/jit/tensorexpr/operators/conv2d.h
new file mode 100644
index 0000000000..a81eeae346
--- /dev/null
+++ b/torch/csrc/jit/tensorexpr/operators/conv2d.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <torch/csrc/jit/tensorexpr/tensor.h>
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+TORCH_API Tensor* conv2d_depthwise(
+ BufHandle input,
+ BufHandle weight,
+ BufHandle bias,
+ int stride,
+ int pad,
+ int groups);
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
""".lstrip()
# fun fact, this example fools VS Code's diff syntax highlighter
haskell_diff = """
diff --git a/hello.hs b/hello.hs
index ffb8d4ad14..0872ac9db6 100644
--- a/hello.hs
+++ b/hello.hs
@@ -1 +1 @@
--- a/hello/world/example
+main = putStrLn "Hello, world!"
""".lstrip()
class TestTranslateAnnotations(unittest.TestCase):
maxDiff = None
def test_parse_diff_lao_tzu(self) -> None:
self.assertEqual(
parse_diff(lao_tzu_diff),
{
"old_filename": "lao",
"hunks": [
{
"old_start": 1,
"old_count": 2,
"new_start": 0,
"new_count": 0,
},
{
"old_start": 4,
"old_count": 1,
"new_start": 2,
"new_count": 2,
},
{
"old_start": 11,
"old_count": 0,
"new_start": 11,
"new_count": 3,
},
],
},
)
def test_parse_diff_new_file(self) -> None:
self.assertEqual(
parse_diff(new_file_diff),
{
"old_filename": None,
"hunks": [
{
"old_start": 0,
"old_count": 0,
"new_start": 1,
"new_count": 19,
},
],
},
)
def test_parse_diff_haskell(self) -> None:
self.assertEqual(
parse_diff(haskell_diff),
{
"old_filename": "hello.hs",
"hunks": [
{
"old_start": 1,
"old_count": 1,
"new_start": 1,
"new_count": 1,
},
],
},
)
def test_translate_lao_tzu(self) -> None:
# we'll pretend that this diff represents the file lao being
# renamed to tzu and also modified
diff = parse_diff(lao_tzu_diff)
# line numbers less than 1 are invalid so they map to None
self.assertEqual(translate(diff, -1), None)
self.assertEqual(translate(diff, 0), None)
# the first two lines of the file were removed, so the first
# line of the new version corresponds to the third line of the
# original
self.assertEqual(translate(diff, 1), 3)
# the second and third lines of the new file were not present in
# the original version, so they map to None
self.assertEqual(translate(diff, 2), None)
self.assertEqual(translate(diff, 3), None)
# at this point, we have a stretch of lines that are identical
# in both versions of the file, but the original version of the
# file had 4 lines before this section whereas the new version
# has only 3 lines before this section
self.assertEqual(translate(diff, 4), 5)
self.assertEqual(translate(diff, 5), 6)
self.assertEqual(translate(diff, 6), 7)
self.assertEqual(translate(diff, 7), 8)
self.assertEqual(translate(diff, 8), 9)
self.assertEqual(translate(diff, 9), 10)
self.assertEqual(translate(diff, 10), 11)
# these three lines were added in the new version of the file,
# so they map to None
self.assertEqual(translate(diff, 11), None)
self.assertEqual(translate(diff, 12), None)
self.assertEqual(translate(diff, 13), None)
# the diff doesn't say how long the file is, so we keep mapping
# line numbers back; since we can look back at the original
# files, though, we can see that the original is two lines
# shorter than the new version, which explains why we are
# subtracting 2 here
self.assertEqual(translate(diff, 14), 12)
self.assertEqual(translate(diff, 15), 13)
def test_translate_empty(self) -> None:
diff = parse_diff("--- a/foo")
# again, we start numbering at 1
self.assertEqual(translate(diff, -1), None)
self.assertEqual(translate(diff, 0), None)
# this diff says there are no changes, so all line numbers
# greater than zero map to themselves
self.assertEqual(translate(diff, 1), 1)
self.assertEqual(translate(diff, 2), 2)
self.assertEqual(translate(diff, 3), 3)
self.assertEqual(translate(diff, 4), 4)
self.assertEqual(translate(diff, 5), 5)
def test_translate_sparser(self) -> None:
diff = parse_diff(sparser_diff)
# again, we start numbering at 1
self.assertEqual(translate(diff, -1), None)
self.assertEqual(translate(diff, 0), None)
# the first three lines are unchanged
self.assertEqual(translate(diff, 1), 1)
self.assertEqual(translate(diff, 2), 2)
self.assertEqual(translate(diff, 3), 3)
# we removed three lines here and added two, so the two lines we
# added don't map back to anything in the original file
self.assertEqual(translate(diff, 4), None)
self.assertEqual(translate(diff, 5), None)
# we have some unchanged lines here, but in the preceding hunk
# we removed 3 and added only 2, so we have an offset of 1
self.assertEqual(translate(diff, 6), 7)
self.assertEqual(translate(diff, 7), 8)
# since the unified diff format essentially subtracts 1 from the
# starting line number when the count is 0, and since we use
# bisect.bisect_right to decide which hunk to look at, an
# earlier version of translate had a bug that caused it to get
# confused because it would look at the second hunk (which lists
# 8 as its start line number) rather than the first hunk
self.assertEqual(translate(diff, 8), 9)
# after the two lines that we removed in the second hunk, we've
# reduced the total length of the file by 3 lines, so once we
# reach the end of the diff, we just add 3 to every line number
self.assertEqual(translate(diff, 9), 12)
self.assertEqual(translate(diff, 10), 13)
self.assertEqual(translate(diff, 11), 14)
self.assertEqual(translate(diff, 12), 15)
def test_parse_annotation_flake8(self) -> None:
regex = re.compile(flake8_regex)
self.assertEqual(
parse_annotation(regex, "README.md:1:3: R100 make a better title"),
{
"filename": "README.md",
"lineNumber": 1,
"columnNumber": 3,
"errorCode": "R100",
"errorDesc": "make a better title",
},
)
def test_parse_annotation_clang_tidy(self) -> None:
regex = re.compile(clang_tidy_regex)
self.assertEqual(
parse_annotation(regex, "README.md:2:1: improve description [R200]"),
{
"filename": "README.md",
"lineNumber": 2,
"columnNumber": 1,
"errorCode": "R200",
"errorDesc": "improve description",
},
)
if __name__ == "__main__":
unittest.main()