Files
pytorch/tools/linter/adapters/header_only_linter.py
Jane Xu fc33da410f Add torch/header_only_apis.txt and enforce they're tested (#153635)
This PR adds enforcement of testing header only APIs.

The benefit of torch/header_only_apis.txt is twofold:
1) this gives us a clear view of what we expect to be header only
2) this allows us to enforce testing

The enforcement added in this PR is very basic--we literally string match that a symbol in `torch/header_only_apis.txt` is in a cpp test. This is meant to be a first step in verifying our APIs are properly tested and can get fancier over time. For now, I've added myself as a codeowner to learn what to look out for in terms of proper tests. Over time, I anticipate we can automate more steps, but right now let's just get something out the door.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153635
Approved by: https://github.com/albanD
ghstack dependencies: #153965
2025-05-20 23:42:24 +00:00

141 lines
4.7 KiB
Python

#!/usr/bin/env python3
"""
Checks that all symbols in torch/header_only_apis.txt are tested in a .cpp
test file to ensure header-only-ness. The .cpp test file must be built
without linking libtorch.
"""
import argparse
import json
import re
from enum import Enum
from pathlib import Path
from typing import NamedTuple, Union
LINTER_CODE = "HEADER_ONLY_LINTER"
class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"
class LintMessage(NamedTuple):
path: Union[str, None]
line: Union[int, None]
char: Union[int, None]
code: str
severity: LintSeverity
name: str
original: Union[str, None]
replacement: Union[str, None]
description: Union[str, None]
CPP_TEST_GLOBS = [
"test/cpp/aoti_abi_check/*.cpp",
]
REPO_ROOT = Path(__file__).parents[3]
def find_matched_symbols(
symbols_regex: re.Pattern[str], test_globs: list[str] = CPP_TEST_GLOBS
) -> set[str]:
"""
Goes through all lines not starting with // in the cpp files and
accumulates a list of matches with the symbols_regex. Note that
we expect symbols_regex to be sorted in reverse alphabetical
order to allow superset regexes to get matched.
"""
matched_symbols = set()
# check noncommented out lines of the test files
for cpp_test_glob in test_globs:
for test_file in REPO_ROOT.glob(cpp_test_glob):
with open(test_file) as tf:
for test_file_line in tf:
test_file_line = test_file_line.strip()
if test_file_line.startswith(("//", "#")) or test_file_line == "":
continue
matches = re.findall(symbols_regex, test_file_line)
for m in matches:
if m != "":
matched_symbols.add(m)
return matched_symbols
def check_file(
filename: str, test_globs: list[str] = CPP_TEST_GLOBS
) -> list[LintMessage]:
"""
Goes through the header_only_apis.txt file and verifies that all symbols
within the file can be found tested in an appropriately independent .cpp
file.
Note that we expect CPP_TEST_GLOBS to be passed in as test_globs--the
only reason this is an argument at all is for ease of testing.
"""
lint_messages: list[LintMessage] = []
symbols: dict[str, int] = {} # symbol -> lineno
with open(filename) as f:
for idx, line in enumerate(f):
# commented out lines should be skipped
symbol = line.strip()
if not symbol or symbol[0] == "#":
continue
# symbols can in fact be duplicated and come from different headers.
# we are aware this is a flaw in using simple string matching.
symbols[symbol] = idx + 1
# Why reverse the keys? To allow superset regexes to get matched first in
# find_matched_symbols. For example, we want Float8_e5m2fnuz to match
# before Float8_e5m2. Otherwise, both Float8_e5m2fnuz and Float8_e5m2 will
# match Float8_e5m2
symbols_regex = re.compile("|".join(sorted(symbols.keys(), reverse=True)))
matched_symbols = find_matched_symbols(symbols_regex, test_globs)
for s, lineno in symbols.items():
if s not in matched_symbols:
lint_messages.append(
LintMessage(
path=filename,
line=lineno,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="[untested-symbol]",
original=None,
replacement=None,
description=(
f"{s} has been included as a header-only API "
"but is not tested in any of CPP_TEST_GLOBS, which "
f"contains {CPP_TEST_GLOBS}.\n"
"Please add a .cpp test using the symbol without "
"linking anything to verify that the symbol is in "
"fact header-only. If you already have a test but it's"
" not found, please add the .cpp file to CPP_TEST_GLOBS"
" in tools/linters/adapters/header_only_linter.py."
),
)
)
return lint_messages
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="header only APIs linter",
fromfile_prefix_chars="@",
)
args = parser.parse_args()
for lint_message in check_file(
str(REPO_ROOT) + "/torch/header_only_apis.txt", CPP_TEST_GLOBS
):
print(json.dumps(lint_message._asdict()), flush=True)