mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch/header_only_apis.txt and enforce they're tested (#153635)
This PR adds enforcement of testing header only APIs. The benefit of torch/header_only_apis.txt is twofold: 1) this gives us a clear view of what we expect to be header only 2) this allows us to enforce testing The enforcement added in this PR is very basic--we literally string match that a symbol in `torch/header_only_apis.txt` is in a cpp test. This is meant to be a first step in verifying our APIs are properly tested and can get fancier over time. For now, I've added myself as a codeowner to learn what to look out for in terms of proper tests. Over time, I anticipate we can automate more steps, but right now let's just get something out the door. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153635 Approved by: https://github.com/albanD ghstack dependencies: #153965
This commit is contained in:
committed by
PyTorch MergeBot
parent
41a9aa6564
commit
fc33da410f
@ -1736,3 +1736,15 @@ command = [
|
|||||||
include_patterns = [
|
include_patterns = [
|
||||||
'test/**/test_*.py',
|
'test/**/test_*.py',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 'header_only_linter' reports on properly testing header-only APIs.
|
||||||
|
[[linter]]
|
||||||
|
code = 'HEADER_ONLY_LINTER'
|
||||||
|
command = [
|
||||||
|
'python3',
|
||||||
|
'tools/linter/adapters/header_only_linter.py',
|
||||||
|
]
|
||||||
|
include_patterns = [
|
||||||
|
'torch/header_only_apis.txt',
|
||||||
|
]
|
||||||
|
is_formatter = false
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
/torch/csrc/autograd/ @albanD @soulitzer
|
/torch/csrc/autograd/ @albanD @soulitzer
|
||||||
/torch/autograd/ @albanD @soulitzer
|
/torch/autograd/ @albanD @soulitzer
|
||||||
/tools/autograd/ @albanD @soulitzer
|
/tools/autograd/ @albanD @soulitzer
|
||||||
|
/torch/header_only_apis.txt @janeyx99
|
||||||
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
|
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
|
||||||
/torch/optim/ @albanD @janeyx99
|
/torch/optim/ @albanD @janeyx99
|
||||||
/test/test_public_bindings.py @albanD
|
/test/test_public_bindings.py @albanD
|
||||||
|
140
tools/linter/adapters/header_only_linter.py
Normal file
140
tools/linter/adapters/header_only_linter.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Checks that all symbols in torch/header_only_apis.txt are tested in a .cpp
|
||||||
|
test file to ensure header-only-ness. The .cpp test file must be built
|
||||||
|
without linking libtorch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import NamedTuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
LINTER_CODE = "HEADER_ONLY_LINTER"
|
||||||
|
|
||||||
|
|
||||||
|
class LintSeverity(str, Enum):
|
||||||
|
ERROR = "error"
|
||||||
|
WARNING = "warning"
|
||||||
|
ADVICE = "advice"
|
||||||
|
DISABLED = "disabled"
|
||||||
|
|
||||||
|
|
||||||
|
class LintMessage(NamedTuple):
|
||||||
|
path: Union[str, None]
|
||||||
|
line: Union[int, None]
|
||||||
|
char: Union[int, None]
|
||||||
|
code: str
|
||||||
|
severity: LintSeverity
|
||||||
|
name: str
|
||||||
|
original: Union[str, None]
|
||||||
|
replacement: Union[str, None]
|
||||||
|
description: Union[str, None]
|
||||||
|
|
||||||
|
|
||||||
|
CPP_TEST_GLOBS = [
|
||||||
|
"test/cpp/aoti_abi_check/*.cpp",
|
||||||
|
]
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).parents[3]
|
||||||
|
|
||||||
|
|
||||||
|
def find_matched_symbols(
|
||||||
|
symbols_regex: re.Pattern[str], test_globs: list[str] = CPP_TEST_GLOBS
|
||||||
|
) -> set[str]:
|
||||||
|
"""
|
||||||
|
Goes through all lines not starting with // in the cpp files and
|
||||||
|
accumulates a list of matches with the symbols_regex. Note that
|
||||||
|
we expect symbols_regex to be sorted in reverse alphabetical
|
||||||
|
order to allow superset regexes to get matched.
|
||||||
|
"""
|
||||||
|
matched_symbols = set()
|
||||||
|
# check noncommented out lines of the test files
|
||||||
|
for cpp_test_glob in test_globs:
|
||||||
|
for test_file in REPO_ROOT.glob(cpp_test_glob):
|
||||||
|
with open(test_file) as tf:
|
||||||
|
for test_file_line in tf:
|
||||||
|
test_file_line = test_file_line.strip()
|
||||||
|
if test_file_line.startswith(("//", "#")) or test_file_line == "":
|
||||||
|
continue
|
||||||
|
matches = re.findall(symbols_regex, test_file_line)
|
||||||
|
for m in matches:
|
||||||
|
if m != "":
|
||||||
|
matched_symbols.add(m)
|
||||||
|
return matched_symbols
|
||||||
|
|
||||||
|
|
||||||
|
def check_file(
|
||||||
|
filename: str, test_globs: list[str] = CPP_TEST_GLOBS
|
||||||
|
) -> list[LintMessage]:
|
||||||
|
"""
|
||||||
|
Goes through the header_only_apis.txt file and verifies that all symbols
|
||||||
|
within the file can be found tested in an appropriately independent .cpp
|
||||||
|
file.
|
||||||
|
|
||||||
|
Note that we expect CPP_TEST_GLOBS to be passed in as test_globs--the
|
||||||
|
only reason this is an argument at all is for ease of testing.
|
||||||
|
"""
|
||||||
|
lint_messages: list[LintMessage] = []
|
||||||
|
|
||||||
|
symbols: dict[str, int] = {} # symbol -> lineno
|
||||||
|
with open(filename) as f:
|
||||||
|
for idx, line in enumerate(f):
|
||||||
|
# commented out lines should be skipped
|
||||||
|
symbol = line.strip()
|
||||||
|
if not symbol or symbol[0] == "#":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# symbols can in fact be duplicated and come from different headers.
|
||||||
|
# we are aware this is a flaw in using simple string matching.
|
||||||
|
symbols[symbol] = idx + 1
|
||||||
|
|
||||||
|
# Why reverse the keys? To allow superset regexes to get matched first in
|
||||||
|
# find_matched_symbols. For example, we want Float8_e5m2fnuz to match
|
||||||
|
# before Float8_e5m2. Otherwise, both Float8_e5m2fnuz and Float8_e5m2 will
|
||||||
|
# match Float8_e5m2
|
||||||
|
symbols_regex = re.compile("|".join(sorted(symbols.keys(), reverse=True)))
|
||||||
|
matched_symbols = find_matched_symbols(symbols_regex, test_globs)
|
||||||
|
|
||||||
|
for s, lineno in symbols.items():
|
||||||
|
if s not in matched_symbols:
|
||||||
|
lint_messages.append(
|
||||||
|
LintMessage(
|
||||||
|
path=filename,
|
||||||
|
line=lineno,
|
||||||
|
char=None,
|
||||||
|
code=LINTER_CODE,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="[untested-symbol]",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=(
|
||||||
|
f"{s} has been included as a header-only API "
|
||||||
|
"but is not tested in any of CPP_TEST_GLOBS, which "
|
||||||
|
f"contains {CPP_TEST_GLOBS}.\n"
|
||||||
|
"Please add a .cpp test using the symbol without "
|
||||||
|
"linking anything to verify that the symbol is in "
|
||||||
|
"fact header-only. If you already have a test but it's"
|
||||||
|
" not found, please add the .cpp file to CPP_TEST_GLOBS"
|
||||||
|
" in tools/linters/adapters/header_only_linter.py."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lint_messages
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="header only APIs linter",
|
||||||
|
fromfile_prefix_chars="@",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for lint_message in check_file(
|
||||||
|
str(REPO_ROOT) + "/torch/header_only_apis.txt", CPP_TEST_GLOBS
|
||||||
|
):
|
||||||
|
print(json.dumps(lint_message._asdict()), flush=True)
|
7
tools/test/header_only_linter_testdata/a.cpp
Normal file
7
tools/test/header_only_linter_testdata/a.cpp
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
// bbb
|
||||||
|
#include <a.h>
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
auto var = symC(2, 3);
|
||||||
|
return symDef() + var;
|
||||||
|
}
|
9
tools/test/header_only_linter_testdata/b.cpp
Normal file
9
tools/test/header_only_linter_testdata/b.cpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#include <a.h>
|
||||||
|
#include <bbb.h>
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
auto var = a();
|
||||||
|
|
||||||
|
// bbb
|
||||||
|
return symDef() + var;
|
||||||
|
}
|
9
tools/test/header_only_linter_testdata/bad.txt
Normal file
9
tools/test/header_only_linter_testdata/bad.txt
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# a.h
|
||||||
|
a
|
||||||
|
symC
|
||||||
|
symD
|
||||||
|
|
||||||
|
# bbb.h
|
||||||
|
bbb
|
||||||
|
symD
|
||||||
|
symDef
|
0
tools/test/header_only_linter_testdata/empty.txt
Normal file
0
tools/test/header_only_linter_testdata/empty.txt
Normal file
10
tools/test/header_only_linter_testdata/good.txt
Normal file
10
tools/test/header_only_linter_testdata/good.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# a.h
|
||||||
|
a
|
||||||
|
|
||||||
|
# indented comment should do nothing
|
||||||
|
symC
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# bbb.h
|
||||||
|
symDef
|
98
tools/test/test_header_only_linter.py
Normal file
98
tools/test/test_header_only_linter.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import re
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tools.linter.adapters.header_only_linter import (
|
||||||
|
check_file,
|
||||||
|
CPP_TEST_GLOBS,
|
||||||
|
find_matched_symbols,
|
||||||
|
LINTER_CODE,
|
||||||
|
LintMessage,
|
||||||
|
LintSeverity,
|
||||||
|
REPO_ROOT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeaderOnlyLinter(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test the header only linter functionality
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_find_matched_symbols(self) -> None:
|
||||||
|
sample_regex = re.compile("symDef|symD|symC|bbb|a")
|
||||||
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||||
|
|
||||||
|
expected_matches = {"symDef", "symC", "a"}
|
||||||
|
self.assertEqual(
|
||||||
|
find_matched_symbols(sample_regex, test_globs), expected_matches
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_find_matched_symbols_empty_regex(self) -> None:
|
||||||
|
sample_regex = re.compile("")
|
||||||
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||||
|
|
||||||
|
expected_matches: set[str] = set()
|
||||||
|
self.assertEqual(
|
||||||
|
find_matched_symbols(sample_regex, test_globs), expected_matches
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_check_file_no_issues(self) -> None:
|
||||||
|
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/good.txt")
|
||||||
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||||
|
self.assertEqual(len(check_file(sample_txt, test_globs)), 0)
|
||||||
|
|
||||||
|
def test_check_empty_file(self) -> None:
|
||||||
|
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/empty.txt")
|
||||||
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||||
|
self.assertEqual(len(check_file(sample_txt, test_globs)), 0)
|
||||||
|
|
||||||
|
def test_check_file_with_untested_symbols(self) -> None:
|
||||||
|
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/bad.txt")
|
||||||
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||||
|
|
||||||
|
expected_msgs = [
|
||||||
|
LintMessage(
|
||||||
|
path=sample_txt,
|
||||||
|
line=7,
|
||||||
|
char=None,
|
||||||
|
code=LINTER_CODE,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="[untested-symbol]",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=(
|
||||||
|
f"bbb has been included as a header-only API "
|
||||||
|
"but is not tested in any of CPP_TEST_GLOBS, which "
|
||||||
|
f"contains {CPP_TEST_GLOBS}.\n"
|
||||||
|
"Please add a .cpp test using the symbol without "
|
||||||
|
"linking anything to verify that the symbol is in "
|
||||||
|
"fact header-only. If you already have a test but it's"
|
||||||
|
" not found, please add the .cpp file to CPP_TEST_GLOBS"
|
||||||
|
" in tools/linters/adapters/header_only_linter.py."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
LintMessage(
|
||||||
|
path=sample_txt,
|
||||||
|
line=8,
|
||||||
|
char=None,
|
||||||
|
code=LINTER_CODE,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="[untested-symbol]",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=(
|
||||||
|
f"symD has been included as a header-only API "
|
||||||
|
"but is not tested in any of CPP_TEST_GLOBS, which "
|
||||||
|
f"contains {CPP_TEST_GLOBS}.\n"
|
||||||
|
"Please add a .cpp test using the symbol without "
|
||||||
|
"linking anything to verify that the symbol is in "
|
||||||
|
"fact header-only. If you already have a test but it's"
|
||||||
|
" not found, please add the .cpp file to CPP_TEST_GLOBS"
|
||||||
|
" in tools/linters/adapters/header_only_linter.py."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.assertEqual(set(check_file(sample_txt, test_globs)), set(expected_msgs))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
49
torch/header_only_apis.txt
Normal file
49
torch/header_only_apis.txt
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# This file contains all the header-only C++ APIs/symbols in torch.
|
||||||
|
# If a symbol is added in this file, it should be tested in a .cpp file
|
||||||
|
# to guarantee that compiling these symbols do not require linking libtorch
|
||||||
|
# to ensure header-only-ness.
|
||||||
|
|
||||||
|
# c10/util/TypeCast.h
|
||||||
|
convert
|
||||||
|
|
||||||
|
# c10/util/bit_cast.h
|
||||||
|
bit_cast
|
||||||
|
|
||||||
|
# c10/util/BFloat16-math.h, c10/util/BFloat16.h
|
||||||
|
BFloat16
|
||||||
|
|
||||||
|
# c10/util/Float8_e4m3fn.h
|
||||||
|
Float8_e4m3fn
|
||||||
|
|
||||||
|
# c10/util/Float8_e4m3fnuz.h
|
||||||
|
Float8_e4m3fnuz
|
||||||
|
|
||||||
|
# c10/util/Float8_e5m2.h
|
||||||
|
Float8_e5m2
|
||||||
|
|
||||||
|
# c10/util/Float8_e5m2fnuz.h
|
||||||
|
Float8_e5m2fnuz
|
||||||
|
|
||||||
|
# c10/util/Half.h
|
||||||
|
Half
|
||||||
|
|
||||||
|
# c10/util/complex.h
|
||||||
|
complex
|
||||||
|
|
||||||
|
# ATen/NumericUtils.h, c10/util/generic_math.h
|
||||||
|
div_floor_floating
|
||||||
|
div_floor_integer
|
||||||
|
_isnan
|
||||||
|
|
||||||
|
# ATen/core/PhiloxRNGEngine.h
|
||||||
|
Philox4_32
|
||||||
|
randn
|
||||||
|
|
||||||
|
# ATen/cpu/vec/vec.h
|
||||||
|
Vectorized
|
||||||
|
clamp_min
|
||||||
|
convert
|
||||||
|
loadu
|
||||||
|
maximum
|
||||||
|
minimum
|
||||||
|
size
|
Reference in New Issue
Block a user