mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This PR adds enforcement of testing header only APIs. The benefit of torch/header_only_apis.txt is twofold: 1) this gives us a clear view of what we expect to be header only 2) this allows us to enforce testing The enforcement added in this PR is very basic--we literally string match that a symbol in `torch/header_only_apis.txt` is in a cpp test. This is meant to be a first step in verifying our APIs are properly tested and can get fancier over time. For now, I've added myself as a codeowner to learn what to look out for in terms of proper tests. Over time, I anticipate we can automate more steps, but right now let's just get something out the door. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153635 Approved by: https://github.com/albanD ghstack dependencies: #153965
99 lines
3.7 KiB
Python
99 lines
3.7 KiB
Python
import re
|
|
import unittest
|
|
|
|
from tools.linter.adapters.header_only_linter import (
|
|
check_file,
|
|
CPP_TEST_GLOBS,
|
|
find_matched_symbols,
|
|
LINTER_CODE,
|
|
LintMessage,
|
|
LintSeverity,
|
|
REPO_ROOT,
|
|
)
|
|
|
|
|
|
class TestHeaderOnlyLinter(unittest.TestCase):
|
|
"""
|
|
Test the header only linter functionality
|
|
"""
|
|
|
|
def test_find_matched_symbols(self) -> None:
|
|
sample_regex = re.compile("symDef|symD|symC|bbb|a")
|
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
|
|
|
expected_matches = {"symDef", "symC", "a"}
|
|
self.assertEqual(
|
|
find_matched_symbols(sample_regex, test_globs), expected_matches
|
|
)
|
|
|
|
def test_find_matched_symbols_empty_regex(self) -> None:
|
|
sample_regex = re.compile("")
|
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
|
|
|
expected_matches: set[str] = set()
|
|
self.assertEqual(
|
|
find_matched_symbols(sample_regex, test_globs), expected_matches
|
|
)
|
|
|
|
def test_check_file_no_issues(self) -> None:
|
|
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/good.txt")
|
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
|
self.assertEqual(len(check_file(sample_txt, test_globs)), 0)
|
|
|
|
def test_check_empty_file(self) -> None:
|
|
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/empty.txt")
|
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
|
self.assertEqual(len(check_file(sample_txt, test_globs)), 0)
|
|
|
|
def test_check_file_with_untested_symbols(self) -> None:
|
|
sample_txt = str(REPO_ROOT / "tools/test/header_only_linter_testdata/bad.txt")
|
|
test_globs = ["tools/test/header_only_linter_testdata/*.cpp"]
|
|
|
|
expected_msgs = [
|
|
LintMessage(
|
|
path=sample_txt,
|
|
line=7,
|
|
char=None,
|
|
code=LINTER_CODE,
|
|
severity=LintSeverity.ERROR,
|
|
name="[untested-symbol]",
|
|
original=None,
|
|
replacement=None,
|
|
description=(
|
|
f"bbb has been included as a header-only API "
|
|
"but is not tested in any of CPP_TEST_GLOBS, which "
|
|
f"contains {CPP_TEST_GLOBS}.\n"
|
|
"Please add a .cpp test using the symbol without "
|
|
"linking anything to verify that the symbol is in "
|
|
"fact header-only. If you already have a test but it's"
|
|
" not found, please add the .cpp file to CPP_TEST_GLOBS"
|
|
" in tools/linters/adapters/header_only_linter.py."
|
|
),
|
|
),
|
|
LintMessage(
|
|
path=sample_txt,
|
|
line=8,
|
|
char=None,
|
|
code=LINTER_CODE,
|
|
severity=LintSeverity.ERROR,
|
|
name="[untested-symbol]",
|
|
original=None,
|
|
replacement=None,
|
|
description=(
|
|
f"symD has been included as a header-only API "
|
|
"but is not tested in any of CPP_TEST_GLOBS, which "
|
|
f"contains {CPP_TEST_GLOBS}.\n"
|
|
"Please add a .cpp test using the symbol without "
|
|
"linking anything to verify that the symbol is in "
|
|
"fact header-only. If you already have a test but it's"
|
|
" not found, please add the .cpp file to CPP_TEST_GLOBS"
|
|
" in tools/linters/adapters/header_only_linter.py."
|
|
),
|
|
),
|
|
]
|
|
self.assertEqual(set(check_file(sample_txt, test_globs)), set(expected_msgs))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|