mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR adds enforcement of testing header only APIs. The benefit of torch/header_only_apis.txt is twofold: 1) this gives us a clear view of what we expect to be header only 2) this allows us to enforce testing The enforcement added in this PR is very basic--we literally string match that a symbol in `torch/header_only_apis.txt` is in a cpp test. This is meant to be a first step in verifying our APIs are properly tested and can get fancier over time. For now, I've added myself as a codeowner to learn what to look out for in terms of proper tests. Over time, I anticipate we can automate more steps, but right now let's just get something out the door. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153635 Approved by: https://github.com/albanD ghstack dependencies: #153965
141 lines
4.7 KiB
Python
141 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Checks that all symbols in torch/header_only_apis.txt are tested in a .cpp
|
|
test file to ensure header-only-ness. The .cpp test file must be built
|
|
without linking libtorch.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import re
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import NamedTuple, Union
|
|
|
|
|
|
LINTER_CODE = "HEADER_ONLY_LINTER"
|
|
|
|
|
|
class LintSeverity(str, Enum):
|
|
ERROR = "error"
|
|
WARNING = "warning"
|
|
ADVICE = "advice"
|
|
DISABLED = "disabled"
|
|
|
|
|
|
class LintMessage(NamedTuple):
|
|
path: Union[str, None]
|
|
line: Union[int, None]
|
|
char: Union[int, None]
|
|
code: str
|
|
severity: LintSeverity
|
|
name: str
|
|
original: Union[str, None]
|
|
replacement: Union[str, None]
|
|
description: Union[str, None]
|
|
|
|
|
|
CPP_TEST_GLOBS = [
|
|
"test/cpp/aoti_abi_check/*.cpp",
|
|
]
|
|
|
|
REPO_ROOT = Path(__file__).parents[3]
|
|
|
|
|
|
def find_matched_symbols(
|
|
symbols_regex: re.Pattern[str], test_globs: list[str] = CPP_TEST_GLOBS
|
|
) -> set[str]:
|
|
"""
|
|
Goes through all lines not starting with // in the cpp files and
|
|
accumulates a list of matches with the symbols_regex. Note that
|
|
we expect symbols_regex to be sorted in reverse alphabetical
|
|
order to allow superset regexes to get matched.
|
|
"""
|
|
matched_symbols = set()
|
|
# check noncommented out lines of the test files
|
|
for cpp_test_glob in test_globs:
|
|
for test_file in REPO_ROOT.glob(cpp_test_glob):
|
|
with open(test_file) as tf:
|
|
for test_file_line in tf:
|
|
test_file_line = test_file_line.strip()
|
|
if test_file_line.startswith(("//", "#")) or test_file_line == "":
|
|
continue
|
|
matches = re.findall(symbols_regex, test_file_line)
|
|
for m in matches:
|
|
if m != "":
|
|
matched_symbols.add(m)
|
|
return matched_symbols
|
|
|
|
|
|
def check_file(
|
|
filename: str, test_globs: list[str] = CPP_TEST_GLOBS
|
|
) -> list[LintMessage]:
|
|
"""
|
|
Goes through the header_only_apis.txt file and verifies that all symbols
|
|
within the file can be found tested in an appropriately independent .cpp
|
|
file.
|
|
|
|
Note that we expect CPP_TEST_GLOBS to be passed in as test_globs--the
|
|
only reason this is an argument at all is for ease of testing.
|
|
"""
|
|
lint_messages: list[LintMessage] = []
|
|
|
|
symbols: dict[str, int] = {} # symbol -> lineno
|
|
with open(filename) as f:
|
|
for idx, line in enumerate(f):
|
|
# commented out lines should be skipped
|
|
symbol = line.strip()
|
|
if not symbol or symbol[0] == "#":
|
|
continue
|
|
|
|
# symbols can in fact be duplicated and come from different headers.
|
|
# we are aware this is a flaw in using simple string matching.
|
|
symbols[symbol] = idx + 1
|
|
|
|
# Why reverse the keys? To allow superset regexes to get matched first in
|
|
# find_matched_symbols. For example, we want Float8_e5m2fnuz to match
|
|
# before Float8_e5m2. Otherwise, both Float8_e5m2fnuz and Float8_e5m2 will
|
|
# match Float8_e5m2
|
|
symbols_regex = re.compile("|".join(sorted(symbols.keys(), reverse=True)))
|
|
matched_symbols = find_matched_symbols(symbols_regex, test_globs)
|
|
|
|
for s, lineno in symbols.items():
|
|
if s not in matched_symbols:
|
|
lint_messages.append(
|
|
LintMessage(
|
|
path=filename,
|
|
line=lineno,
|
|
char=None,
|
|
code=LINTER_CODE,
|
|
severity=LintSeverity.ERROR,
|
|
name="[untested-symbol]",
|
|
original=None,
|
|
replacement=None,
|
|
description=(
|
|
f"{s} has been included as a header-only API "
|
|
"but is not tested in any of CPP_TEST_GLOBS, which "
|
|
f"contains {CPP_TEST_GLOBS}.\n"
|
|
"Please add a .cpp test using the symbol without "
|
|
"linking anything to verify that the symbol is in "
|
|
"fact header-only. If you already have a test but it's"
|
|
" not found, please add the .cpp file to CPP_TEST_GLOBS"
|
|
" in tools/linters/adapters/header_only_linter.py."
|
|
),
|
|
)
|
|
)
|
|
|
|
return lint_messages
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="header only APIs linter",
|
|
fromfile_prefix_chars="@",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
for lint_message in check_file(
|
|
str(REPO_ROOT) + "/torch/header_only_apis.txt", CPP_TEST_GLOBS
|
|
):
|
|
print(json.dumps(lint_message._asdict()), flush=True)
|