mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch/header_only_apis.txt and enforce they're tested (#153635)
This PR adds enforcement of testing header only APIs. The benefit of torch/header_only_apis.txt is twofold: 1) this gives us a clear view of what we expect to be header only 2) this allows us to enforce testing The enforcement added in this PR is very basic--we literally string match that a symbol in `torch/header_only_apis.txt` is in a cpp test. This is meant to be a first step in verifying our APIs are properly tested and can get fancier over time. For now, I've added myself as a codeowner to learn what to look out for in terms of proper tests. Over time, I anticipate we can automate more steps, but right now let's just get something out the door. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153635 Approved by: https://github.com/albanD ghstack dependencies: #153965
This commit is contained in:
committed by
PyTorch MergeBot
parent
41a9aa6564
commit
fc33da410f
@ -1736,3 +1736,15 @@ command = [
|
||||
include_patterns = [
|
||||
'test/**/test_*.py',
|
||||
]
|
||||
|
||||
# 'header_only_linter' reports on properly testing header-only APIs.
|
||||
[[linter]]
|
||||
code = 'HEADER_ONLY_LINTER'
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/header_only_linter.py',
|
||||
]
|
||||
include_patterns = [
|
||||
'torch/header_only_apis.txt',
|
||||
]
|
||||
is_formatter = false
|
||||
|
@ -14,6 +14,7 @@
|
||||
/torch/csrc/autograd/ @albanD @soulitzer
|
||||
/torch/autograd/ @albanD @soulitzer
|
||||
/tools/autograd/ @albanD @soulitzer
|
||||
/torch/header_only_apis.txt @janeyx99
|
||||
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
|
||||
/torch/optim/ @albanD @janeyx99
|
||||
/test/test_public_bindings.py @albanD
|
||||
|
140
tools/linter/adapters/header_only_linter.py
Normal file
140
tools/linter/adapters/header_only_linter.py
Normal file
@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Checks that all symbols in torch/header_only_apis.txt are tested in a .cpp
|
||||
test file to ensure header-only-ness. The .cpp test file must be built
|
||||
without linking libtorch.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
|
||||
LINTER_CODE = "HEADER_ONLY_LINTER"
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
ERROR = "error"
|
||||
WARNING = "warning"
|
||||
ADVICE = "advice"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Union[str, None]
|
||||
line: Union[int, None]
|
||||
char: Union[int, None]
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Union[str, None]
|
||||
replacement: Union[str, None]
|
||||
description: Union[str, None]
|
||||
|
||||
|
||||
CPP_TEST_GLOBS = [
|
||||
"test/cpp/aoti_abi_check/*.cpp",
|
||||
]
|
||||
|
||||
REPO_ROOT = Path(__file__).parents[3]
|
||||
|
||||
|
||||
def find_matched_symbols(
|
||||
symbols_regex: re.Pattern[str], test_globs: list[str] = CPP_TEST_GLOBS
|
||||
) -> set[str]:
|
||||
"""
|
||||
Goes through all lines not starting with // in the cpp files and
|
||||
accumulates a list of matches with the symbols_regex. Note that
|
||||
we expect symbols_regex to be sorted in reverse alphabetical
|
||||
order to allow superset regexes to get matched.
|
||||
"""
|
||||
matched_symbols = set()
|
||||
# check noncommented out lines of the test files
|
||||
for cpp_test_glob in test_globs:
|
||||
for test_file in REPO_ROOT.glob(cpp_test_glob):
|
||||
with open(test_file) as tf:
|
||||
for test_file_line in tf:
|
||||
test_file_line = test_file_line.strip()
|
||||
if test_file_line.startswith(("//", "#")) or test_file_line == "":
|
||||
continue
|
||||
matches = re.findall(symbols_regex, test_file_line)
|
||||
for m in matches:
|
||||
if m != "":
|
||||
matched_symbols.add(m)
|
||||
return matched_symbols
|
||||
|
||||
|
||||
def check_file(
|
||||
filename: str, test_globs: list[str] = CPP_TEST_GLOBS
|
||||
) -> list[LintMessage]:
|
||||
"""
|
||||
Goes through the header_only_apis.txt file and verifies that all symbols
|
||||
within the file can be found tested in an appropriately independent .cpp
|
||||
file.
|
||||
|
||||
Note that we expect CPP_TEST_GLOBS to be passed in as test_globs--the
|
||||
only reason this is an argument at all is for ease of testing.
|
||||
"""
|
||||
lint_messages: list[LintMessage] = []
|
||||
|
||||
symbols: dict[str, int] = {} # symbol -> lineno
|
||||
with open(filename) as f:
|
||||
for idx, line in enumerate(f):
|
||||
# commented out lines should be skipped
|
||||
symbol = line.strip()
|
||||
if not symbol or symbol[0] == "#":
|
||||
continue
|
||||
|
||||
# symbols can in fact be duplicated and come from different headers.
|
||||
# we are aware this is a flaw in using simple string matching.
|
||||
symbols[symbol] = idx + 1
|
||||
|
||||
# Why reverse the keys? To allow superset regexes to get matched first in
|
||||
# find_matched_symbols. For example, we want Float8_e5m2fnuz to match
|
||||
# before Float8_e5m2. Otherwise, both Float8_e5m2fnuz and Float8_e5m2 will
|
||||
# match Float8_e5m2
|
||||
symbols_regex = re.compile("|".join(sorted(symbols.keys(), reverse=True)))
|
||||
matched_symbols = find_matched_symbols(symbols_regex, test_globs)
|
||||
|
||||
for s, lineno in symbols.items():
|
||||
if s not in matched_symbols:
|
||||
lint_messages.append(
|
||||
LintMessage(
|
||||
path=filename,
|
||||
line=lineno,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="[untested-symbol]",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"{s} has been included as a header-only API "
|
||||
"but is not tested in any of CPP_TEST_GLOBS, which "
|
||||
f"contains {CPP_TEST_GLOBS}.\n"
|
||||
"Please add a .cpp test using the symbol without "
|
||||
"linking anything to verify that the symbol is in "
|
||||
"fact header-only. If you already have a test but it's"
|
||||
" not found, please add the .cpp file to CPP_TEST_GLOBS"
|
||||
" in tools/linters/adapters/header_only_linter.py."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return lint_messages
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="header only APIs linter",
|
||||
fromfile_prefix_chars="@",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
for lint_message in check_file(
|
||||
str(REPO_ROOT) + "/torch/header_only_apis.txt", CPP_TEST_GLOBS
|
||||
):
|
||||
print(json.dumps(lint_message._asdict()), flush=True)
|
7
tools/test/header_only_linter_testdata/a.cpp
Normal file
7
tools/test/header_only_linter_testdata/a.cpp
Normal file
@ -0,0 +1,7 @@
|
||||
// bbb
|
||||
#include <a.h>
|
||||
|
||||
int main() {
|
||||
auto var = symC(2, 3);
|
||||
return symDef() + var;
|
||||
}
|
9
tools/test/header_only_linter_testdata/b.cpp
Normal file
9
tools/test/header_only_linter_testdata/b.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
#include <a.h>
|
||||
#include <bbb.h>
|
||||
|
||||
int main() {
|
||||
auto var = a();
|
||||
|
||||
// bbb
|
||||
return symDef() + var;
|
||||
}
|
9
tools/test/header_only_linter_testdata/bad.txt
Normal file
9
tools/test/header_only_linter_testdata/bad.txt
Normal file
@ -0,0 +1,9 @@
|
||||
# a.h
|
||||
a
|
||||
symC
|
||||
symD
|
||||
|
||||
# bbb.h
|
||||
bbb
|
||||
symD
|
||||
symDef
|
0
tools/test/header_only_linter_testdata/empty.txt
Normal file
0
tools/test/header_only_linter_testdata/empty.txt
Normal file
10
tools/test/header_only_linter_testdata/good.txt
Normal file
10
tools/test/header_only_linter_testdata/good.txt
Normal file
@ -0,0 +1,10 @@
|
||||
# a.h
|
||||
a
|
||||
|
||||
# indented comment should do nothing
|
||||
symC
|
||||
|
||||
|
||||
|
||||
# bbb.h
|
||||
symDef
|
98
tools/test/test_header_only_linter.py
Normal file
98
tools/test/test_header_only_linter.py
Normal file
@ -0,0 +1,98 @@
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from tools.linter.adapters.header_only_linter import (
|
||||
check_file,
|
||||
CPP_TEST_GLOBS,
|
||||
find_matched_symbols,
|
||||
LINTER_CODE,
|
||||
LintMessage,
|
||||
LintSeverity,
|
||||
REPO_ROOT,
|
||||
)
|
||||
|
||||
|
||||
class TestHeaderOnlyLinter(unittest.TestCase):
|
||||
"""
|
||||
Test the header only linter functionality
|
||||
"""
|
||||
|
||||
def test_find_matched_symbols(self) -> None:
|
||||
sample_regex = re.compile("symDef|symD|symC|bbb|a")
|
||||
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||
|
||||
expected_matches = {"symDef", "symC", "a"}
|
||||
self.assertEqual(
|
||||
find_matched_symbols(sample_regex, test_globs), expected_matches
|
||||
)
|
||||
|
||||
def test_find_matched_symbols_empty_regex(self) -> None:
|
||||
sample_regex = re.compile("")
|
||||
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||
|
||||
expected_matches: set[str] = set()
|
||||
self.assertEqual(
|
||||
find_matched_symbols(sample_regex, test_globs), expected_matches
|
||||
)
|
||||
|
||||
def test_check_file_no_issues(self) -> None:
|
||||
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/good.txt")
|
||||
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||
self.assertEqual(len(check_file(sample_txt, test_globs)), 0)
|
||||
|
||||
def test_check_empty_file(self) -> None:
|
||||
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/empty.txt")
|
||||
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||
self.assertEqual(len(check_file(sample_txt, test_globs)), 0)
|
||||
|
||||
def test_check_file_with_untested_symbols(self) -> None:
|
||||
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/bad.txt")
|
||||
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
||||
|
||||
expected_msgs = [
|
||||
LintMessage(
|
||||
path=sample_txt,
|
||||
line=7,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="[untested-symbol]",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"bbb has been included as a header-only API "
|
||||
"but is not tested in any of CPP_TEST_GLOBS, which "
|
||||
f"contains {CPP_TEST_GLOBS}.\n"
|
||||
"Please add a .cpp test using the symbol without "
|
||||
"linking anything to verify that the symbol is in "
|
||||
"fact header-only. If you already have a test but it's"
|
||||
" not found, please add the .cpp file to CPP_TEST_GLOBS"
|
||||
" in tools/linters/adapters/header_only_linter.py."
|
||||
),
|
||||
),
|
||||
LintMessage(
|
||||
path=sample_txt,
|
||||
line=8,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="[untested-symbol]",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"symD has been included as a header-only API "
|
||||
"but is not tested in any of CPP_TEST_GLOBS, which "
|
||||
f"contains {CPP_TEST_GLOBS}.\n"
|
||||
"Please add a .cpp test using the symbol without "
|
||||
"linking anything to verify that the symbol is in "
|
||||
"fact header-only. If you already have a test but it's"
|
||||
" not found, please add the .cpp file to CPP_TEST_GLOBS"
|
||||
" in tools/linters/adapters/header_only_linter.py."
|
||||
),
|
||||
),
|
||||
]
|
||||
self.assertEqual(set(check_file(sample_txt, test_globs)), set(expected_msgs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
49
torch/header_only_apis.txt
Normal file
49
torch/header_only_apis.txt
Normal file
@ -0,0 +1,49 @@
|
||||
# This file contains all the header-only C++ APIs/symbols in torch.
|
||||
# If a symbol is added in this file, it should be tested in a .cpp file
|
||||
# to guarantee that compiling these symbols do not require linking libtorch
|
||||
# to ensure header-only-ness.
|
||||
|
||||
# c10/util/TypeCast.h
|
||||
convert
|
||||
|
||||
# c10/util/bit_cast.h
|
||||
bit_cast
|
||||
|
||||
# c10/util/BFloat16-math.h, c10/util/BFloat16.h
|
||||
BFloat16
|
||||
|
||||
# c10/util/Float8_e4m3fn.h
|
||||
Float8_e4m3fn
|
||||
|
||||
# c10/util/Float8_e4m3fnuz.h
|
||||
Float8_e4m3fnuz
|
||||
|
||||
# c10/util/Float8_e5m2.h
|
||||
Float8_e5m2
|
||||
|
||||
# c10/util/Float8_e5m2fnuz.h
|
||||
Float8_e5m2fnuz
|
||||
|
||||
# c10/util/Half.h
|
||||
Half
|
||||
|
||||
# c10/util/complex.h
|
||||
complex
|
||||
|
||||
# ATen/NumericUtils.h, c10/util/generic_math.h
|
||||
div_floor_floating
|
||||
div_floor_integer
|
||||
_isnan
|
||||
|
||||
# ATen/core/PhiloxRNGEngine.h
|
||||
Philox4_32
|
||||
randn
|
||||
|
||||
# ATen/cpu/vec/vec.h
|
||||
Vectorized
|
||||
clamp_min
|
||||
convert
|
||||
loadu
|
||||
maximum
|
||||
minimum
|
||||
size
|
Reference in New Issue
Block a user