Add torch/header_only_apis.txt and enforce they're tested (#153635)

This PR adds enforcement of testing header only APIs.

The benefit of torch/header_only_apis.txt is twofold:
1) this gives us a clear view of what we expect to be header only
2) this allows us to enforce testing

The enforcement added in this PR is very basic--we literally string match that a symbol in `torch/header_only_apis.txt` is in a cpp test. This is meant to be a first step in verifying our APIs are properly tested and can get fancier over time. For now, I've added myself as a codeowner to learn what to look out for in terms of proper tests. Over time, I anticipate we can automate more steps, but right now let's just get something out the door.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153635
Approved by: https://github.com/albanD
ghstack dependencies: #153965
This commit is contained in:
Jane Xu
2025-05-20 13:39:32 -07:00
committed by PyTorch MergeBot
parent 41a9aa6564
commit fc33da410f
10 changed files with 335 additions and 0 deletions

View File

@ -1736,3 +1736,15 @@ command = [
include_patterns = [ include_patterns = [
'test/**/test_*.py', 'test/**/test_*.py',
] ]
# 'header_only_linter' reports on properly testing header-only APIs.
[[linter]]
code = 'HEADER_ONLY_LINTER'
command = [
'python3',
'tools/linter/adapters/header_only_linter.py',
]
include_patterns = [
'torch/header_only_apis.txt',
]
is_formatter = false

View File

@ -14,6 +14,7 @@
/torch/csrc/autograd/ @albanD @soulitzer /torch/csrc/autograd/ @albanD @soulitzer
/torch/autograd/ @albanD @soulitzer /torch/autograd/ @albanD @soulitzer
/tools/autograd/ @albanD @soulitzer /tools/autograd/ @albanD @soulitzer
/torch/header_only_apis.txt @janeyx99
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki /torch/nn/ @albanD @jbschlosser @mikaylagawarecki
/torch/optim/ @albanD @janeyx99 /torch/optim/ @albanD @janeyx99
/test/test_public_bindings.py @albanD /test/test_public_bindings.py @albanD

View File

@ -0,0 +1,140 @@
#!/usr/bin/env python3
"""
Checks that all symbols in torch/header_only_apis.txt are tested in a .cpp
test file to ensure header-only-ness. The .cpp test file must be built
without linking libtorch.
"""
import argparse
import json
import re
from enum import Enum
from pathlib import Path
from typing import NamedTuple, Union
LINTER_CODE = "HEADER_ONLY_LINTER"
class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"
class LintMessage(NamedTuple):
path: Union[str, None]
line: Union[int, None]
char: Union[int, None]
code: str
severity: LintSeverity
name: str
original: Union[str, None]
replacement: Union[str, None]
description: Union[str, None]
CPP_TEST_GLOBS = [
"test/cpp/aoti_abi_check/*.cpp",
]
REPO_ROOT = Path(__file__).parents[3]
def find_matched_symbols(
symbols_regex: re.Pattern[str], test_globs: list[str] = CPP_TEST_GLOBS
) -> set[str]:
"""
Goes through all lines not starting with // in the cpp files and
accumulates a list of matches with the symbols_regex. Note that
we expect symbols_regex to be sorted in reverse alphabetical
order to allow superset regexes to get matched.
"""
matched_symbols = set()
# check noncommented out lines of the test files
for cpp_test_glob in test_globs:
for test_file in REPO_ROOT.glob(cpp_test_glob):
with open(test_file) as tf:
for test_file_line in tf:
test_file_line = test_file_line.strip()
if test_file_line.startswith(("//", "#")) or test_file_line == "":
continue
matches = re.findall(symbols_regex, test_file_line)
for m in matches:
if m != "":
matched_symbols.add(m)
return matched_symbols
def check_file(
filename: str, test_globs: list[str] = CPP_TEST_GLOBS
) -> list[LintMessage]:
"""
Goes through the header_only_apis.txt file and verifies that all symbols
within the file can be found tested in an appropriately independent .cpp
file.
Note that we expect CPP_TEST_GLOBS to be passed in as test_globs--the
only reason this is an argument at all is for ease of testing.
"""
lint_messages: list[LintMessage] = []
symbols: dict[str, int] = {} # symbol -> lineno
with open(filename) as f:
for idx, line in enumerate(f):
# commented out lines should be skipped
symbol = line.strip()
if not symbol or symbol[0] == "#":
continue
# symbols can in fact be duplicated and come from different headers.
# we are aware this is a flaw in using simple string matching.
symbols[symbol] = idx + 1
# Why reverse the keys? To allow superset regexes to get matched first in
# find_matched_symbols. For example, we want Float8_e5m2fnuz to match
# before Float8_e5m2. Otherwise, both Float8_e5m2fnuz and Float8_e5m2 will
# match Float8_e5m2
symbols_regex = re.compile("|".join(sorted(symbols.keys(), reverse=True)))
matched_symbols = find_matched_symbols(symbols_regex, test_globs)
for s, lineno in symbols.items():
if s not in matched_symbols:
lint_messages.append(
LintMessage(
path=filename,
line=lineno,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="[untested-symbol]",
original=None,
replacement=None,
description=(
f"{s} has been included as a header-only API "
"but is not tested in any of CPP_TEST_GLOBS, which "
f"contains {CPP_TEST_GLOBS}.\n"
"Please add a .cpp test using the symbol without "
"linking anything to verify that the symbol is in "
"fact header-only. If you already have a test but it's"
" not found, please add the .cpp file to CPP_TEST_GLOBS"
" in tools/linters/adapters/header_only_linter.py."
),
)
)
return lint_messages
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="header only APIs linter",
fromfile_prefix_chars="@",
)
args = parser.parse_args()
for lint_message in check_file(
str(REPO_ROOT) + "/torch/header_only_apis.txt", CPP_TEST_GLOBS
):
print(json.dumps(lint_message._asdict()), flush=True)

View File

@ -0,0 +1,7 @@
// bbb
#include <a.h>
int main() {
auto var = symC(2, 3);
return symDef() + var;
}

View File

@ -0,0 +1,9 @@
#include <a.h>
#include <bbb.h>
int main() {
auto var = a();
// bbb
return symDef() + var;
}

View File

@ -0,0 +1,9 @@
# a.h
a
symC
symD
# bbb.h
bbb
symD
symDef

View File

@ -0,0 +1,10 @@
# a.h
a
# indented comment should do nothing
symC
# bbb.h
symDef

View File

@ -0,0 +1,98 @@
import re
import unittest
from tools.linter.adapters.header_only_linter import (
check_file,
CPP_TEST_GLOBS,
find_matched_symbols,
LINTER_CODE,
LintMessage,
LintSeverity,
REPO_ROOT,
)
class TestHeaderOnlyLinter(unittest.TestCase):
"""
Test the header only linter functionality
"""
def test_find_matched_symbols(self) -> None:
sample_regex = re.compile("symDef|symD|symC|bbb|a")
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
expected_matches = {"symDef", "symC", "a"}
self.assertEqual(
find_matched_symbols(sample_regex, test_globs), expected_matches
)
def test_find_matched_symbols_empty_regex(self) -> None:
sample_regex = re.compile("")
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
expected_matches: set[str] = set()
self.assertEqual(
find_matched_symbols(sample_regex, test_globs), expected_matches
)
def test_check_file_no_issues(self) -> None:
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/good.txt")
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
self.assertEqual(len(check_file(sample_txt, test_globs)), 0)
def test_check_empty_file(self) -> None:
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/empty.txt")
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
self.assertEqual(len(check_file(sample_txt, test_globs)), 0)
def test_check_file_with_untested_symbols(self) -> None:
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/bad.txt")
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
expected_msgs = [
LintMessage(
path=sample_txt,
line=7,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="[untested-symbol]",
original=None,
replacement=None,
description=(
f"bbb has been included as a header-only API "
"but is not tested in any of CPP_TEST_GLOBS, which "
f"contains {CPP_TEST_GLOBS}.\n"
"Please add a .cpp test using the symbol without "
"linking anything to verify that the symbol is in "
"fact header-only. If you already have a test but it's"
" not found, please add the .cpp file to CPP_TEST_GLOBS"
" in tools/linters/adapters/header_only_linter.py."
),
),
LintMessage(
path=sample_txt,
line=8,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="[untested-symbol]",
original=None,
replacement=None,
description=(
f"symD has been included as a header-only API "
"but is not tested in any of CPP_TEST_GLOBS, which "
f"contains {CPP_TEST_GLOBS}.\n"
"Please add a .cpp test using the symbol without "
"linking anything to verify that the symbol is in "
"fact header-only. If you already have a test but it's"
" not found, please add the .cpp file to CPP_TEST_GLOBS"
" in tools/linters/adapters/header_only_linter.py."
),
),
]
self.assertEqual(set(check_file(sample_txt, test_globs)), set(expected_msgs))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,49 @@
# This file contains all the header-only C++ APIs/symbols in torch.
# If a symbol is added in this file, it should be tested in a .cpp file
# to guarantee that compiling these symbols do not require linking libtorch
# to ensure header-only-ness.
# c10/util/TypeCast.h
convert
# c10/util/bit_cast.h
bit_cast
# c10/util/BFloat16-math.h, c10/util/BFloat16.h
BFloat16
# c10/util/Float8_e4m3fn.h
Float8_e4m3fn
# c10/util/Float8_e4m3fnuz.h
Float8_e4m3fnuz
# c10/util/Float8_e5m2.h
Float8_e5m2
# c10/util/Float8_e5m2fnuz.h
Float8_e5m2fnuz
# c10/util/Half.h
Half
# c10/util/complex.h
complex
# ATen/NumericUtils.h, c10/util/generic_math.h
div_floor_floating
div_floor_integer
_isnan
# ATen/core/PhiloxRNGEngine.h
Philox4_32
randn
# ATen/cpu/vec/vec.h
Vectorized
clamp_min
convert
loadu
maximum
minimum
size