mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135363 Approved by: https://github.com/oulgen
73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
import sys
|
|
import textwrap
|
|
from pathlib import Path
|
|
|
|
|
|
def check(path):
|
|
"""Check a test file for common issues with pytest->pytorch conversion."""
|
|
print(path.name)
|
|
print("=" * len(path.name), "\n")
|
|
|
|
src = path.read_text().split("\n")
|
|
for num, line in enumerate(src):
|
|
if is_comment(line):
|
|
continue
|
|
|
|
# module level test functions
|
|
if line.startswith("def test"):
|
|
report_violation(line, num, header="Module-level test function")
|
|
|
|
# test classes must inherit from TestCase
|
|
if line.startswith("class Test") and "TestCase" not in line:
|
|
report_violation(
|
|
line, num, header="Test class does not inherit from TestCase"
|
|
)
|
|
|
|
# last vestiges of pytest-specific stuff
|
|
if "pytest.mark" in line:
|
|
report_violation(line, num, header="pytest.mark.something")
|
|
|
|
for part in ["pytest.xfail", "pytest.skip", "pytest.param"]:
|
|
if part in line:
|
|
report_violation(line, num, header=f"stray {part}")
|
|
|
|
if textwrap.dedent(line).startswith("@parametrize"):
|
|
# backtrack to check
|
|
nn = num
|
|
for nn in range(num, -1, -1):
|
|
ln = src[nn]
|
|
if "class Test" in ln:
|
|
# hack: large indent => likely an inner class
|
|
if len(ln) - len(ln.lstrip()) < 8:
|
|
break
|
|
else:
|
|
report_violation(line, num, "off-class parametrize")
|
|
if not src[nn - 1].startswith("@instantiate_parametrized_tests"):
|
|
report_violation(
|
|
line, num, f"missing instantiation of parametrized tests in {ln}?"
|
|
)
|
|
|
|
|
|
def is_comment(line):
|
|
return textwrap.dedent(line).startswith("#")
|
|
|
|
|
|
def report_violation(line, lineno, header):
|
|
print(f">>>> line {lineno} : {header}\n {line}\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
argv = sys.argv
|
|
if len(argv) != 2:
|
|
raise ValueError("Usage : python check_tests_conform path/to/file/or/dir")
|
|
|
|
path = Path(argv[1])
|
|
|
|
if path.is_dir():
|
|
# run for all files in the directory (no subdirs)
|
|
for this_path in path.glob("test*.py"):
|
|
# breakpoint()
|
|
check(this_path)
|
|
else:
|
|
check(path)
|