mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
pytest to run test_ops, test_ops_gradients, test_ops_jit in non linux cuda environments (#79898)
This PR uses pytest to run test_ops, test_ops_gradients, and test_ops_jit in parallel in non linux cuda environments to decrease TTS. I am excluding linux cuda because running in parallel results in errors due to running out of memory Notes: * update hypothesis version for compatability with pytest * use rerun-failures to rerun tests (similar to flaky tests, although these test files generally don't have flaky tests) * reruns are denoted by a rerun tag in the xml. Failed reruns also have the failure tag. Successes (meaning that the test is flaky) do not have the failure tag. * see https://docs.google.com/spreadsheets/d/1aO0Rbg3y3ch7ghipt63PG2KNEUppl9a5b18Hmv2CZ4E/edit#gid=602543594 for info on speedup (or slowdown in the case of slow tests) * expecting windows tests to decrease by 60 minutes total * slow test infra is expected to stay the same - verified by running pytest and unittest on the same job and check the number of skipped/run tests * test reports to s3 changed - add entirely new table to keep track of invoking_file times Pull Request resolved: https://github.com/pytorch/pytorch/pull/79898 Approved by: https://github.com/malfet, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
9f9dd4f072
commit
06a0cfc0ea
@ -41,7 +41,7 @@ flatbuffers==2.0
|
|||||||
#Pinned versions:
|
#Pinned versions:
|
||||||
#test that import:
|
#test that import:
|
||||||
|
|
||||||
hypothesis==4.53.2
|
hypothesis==5.35.1
|
||||||
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
|
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
|
||||||
#Description: advanced library for generating parametrized tests
|
#Description: advanced library for generating parametrized tests
|
||||||
#Pinned versions: 3.44.6, 4.53.2
|
#Pinned versions: 3.44.6, 4.53.2
|
||||||
@ -143,6 +143,16 @@ pytest
|
|||||||
#Pinned versions:
|
#Pinned versions:
|
||||||
#test that import: test_typing.py, test_cpp_extensions_aot.py, run_test.py
|
#test that import: test_typing.py, test_cpp_extensions_aot.py, run_test.py
|
||||||
|
|
||||||
|
pytest-xdist
|
||||||
|
#Description: plugin for running pytest in parallel
|
||||||
|
#Pinned versions:
|
||||||
|
#test that import:
|
||||||
|
|
||||||
|
pytest-rerunfailures
|
||||||
|
#Description: plugin for rerunning tests in pytest
|
||||||
|
#Pinned versions:
|
||||||
|
#test that import:
|
||||||
|
|
||||||
#pytest-benchmark
|
#pytest-benchmark
|
||||||
#Description: fixture for benchmarking code
|
#Description: fixture for benchmarking code
|
||||||
#Pinned versions: 3.2.3
|
#Pinned versions: 3.2.3
|
||||||
|
@ -10,7 +10,9 @@ pip install -q hypothesis "expecttest==0.1.3" "librosa>=0.6.2" "numba<=0.49.1" p
|
|||||||
# TODO move this to docker
|
# TODO move this to docker
|
||||||
# Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
|
# Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
|
||||||
pip install "unittest-xml-reporting<=3.2.0,>=2.0.0" \
|
pip install "unittest-xml-reporting<=3.2.0,>=2.0.0" \
|
||||||
pytest
|
pytest \
|
||||||
|
pytest-xdist \
|
||||||
|
pytest-rerunfailures
|
||||||
|
|
||||||
if [ -z "${CI}" ]; then
|
if [ -z "${CI}" ]; then
|
||||||
rm -rf "${WORKSPACE_DIR}"/miniconda3/lib/python3.6/site-packages/torch*
|
rm -rf "${WORKSPACE_DIR}"/miniconda3/lib/python3.6/site-packages/torch*
|
||||||
|
@ -36,7 +36,7 @@ popd
|
|||||||
=======
|
=======
|
||||||
:: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
|
:: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
|
||||||
|
|
||||||
pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest
|
pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-rerunfailures
|
||||||
if errorlevel 1 exit /b
|
if errorlevel 1 exit /b
|
||||||
if not errorlevel 0 exit /b
|
if not errorlevel 0 exit /b
|
||||||
|
|
||||||
|
@ -10,3 +10,4 @@ addopts =
|
|||||||
-Wd
|
-Wd
|
||||||
testpaths =
|
testpaths =
|
||||||
test
|
test
|
||||||
|
junit_logging_reruns = all
|
||||||
|
146
test/conftest.py
Normal file
146
test/conftest.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
from _pytest.junitxml import LogXML, _NodeReporter, bin_xml_escape
|
||||||
|
from _pytest.terminal import _get_raw_skip_reason
|
||||||
|
from _pytest.stash import StashKey
|
||||||
|
from _pytest.reports import TestReport
|
||||||
|
from _pytest.config.argparsing import Parser
|
||||||
|
from _pytest.config import filename_arg
|
||||||
|
from _pytest.config import Config
|
||||||
|
from _pytest._code.code import ReprFileLocation
|
||||||
|
from typing import Union
|
||||||
|
from typing import Optional
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
import functools
|
||||||
|
|
||||||
|
# a lot of this file is copied from _pytest.junitxml and modified to get rerun info
|
||||||
|
|
||||||
|
xml_key = StashKey["LogXMLReruns"]()
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser: Parser) -> None:
|
||||||
|
group = parser.getgroup("terminal reporting")
|
||||||
|
group.addoption(
|
||||||
|
"--junit-xml-reruns",
|
||||||
|
action="store",
|
||||||
|
dest="xmlpath_reruns",
|
||||||
|
metavar="path",
|
||||||
|
type=functools.partial(filename_arg, optname="--junit-xml-reruns"),
|
||||||
|
default=None,
|
||||||
|
help="create junit-xml style report file at given path.",
|
||||||
|
)
|
||||||
|
group.addoption(
|
||||||
|
"--junit-prefix-reruns",
|
||||||
|
action="store",
|
||||||
|
metavar="str",
|
||||||
|
default=None,
|
||||||
|
help="prepend prefix to classnames in junit-xml output",
|
||||||
|
)
|
||||||
|
parser.addini(
|
||||||
|
"junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest"
|
||||||
|
)
|
||||||
|
parser.addini(
|
||||||
|
"junit_logging_reruns",
|
||||||
|
"Write captured log messages to JUnit report: "
|
||||||
|
"one of no|log|system-out|system-err|out-err|all",
|
||||||
|
default="no",
|
||||||
|
)
|
||||||
|
parser.addini(
|
||||||
|
"junit_log_passing_tests_reruns",
|
||||||
|
"Capture log information for passing tests to JUnit report: ",
|
||||||
|
type="bool",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
parser.addini(
|
||||||
|
"junit_duration_report_reruns",
|
||||||
|
"Duration time to report: one of total|call",
|
||||||
|
default="total",
|
||||||
|
)
|
||||||
|
parser.addini(
|
||||||
|
"junit_family_reruns",
|
||||||
|
"Emit XML for schema: one of legacy|xunit1|xunit2",
|
||||||
|
default="xunit2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config: Config) -> None:
|
||||||
|
xmlpath = config.option.xmlpath_reruns
|
||||||
|
# Prevent opening xmllog on worker nodes (xdist).
|
||||||
|
if xmlpath and not hasattr(config, "workerinput"):
|
||||||
|
junit_family = config.getini("junit_family_reruns")
|
||||||
|
config.stash[xml_key] = LogXMLReruns(
|
||||||
|
xmlpath,
|
||||||
|
config.option.junitprefix,
|
||||||
|
config.getini("junit_suite_name_reruns"),
|
||||||
|
config.getini("junit_logging_reruns"),
|
||||||
|
config.getini("junit_duration_report_reruns"),
|
||||||
|
junit_family,
|
||||||
|
config.getini("junit_log_passing_tests_reruns"),
|
||||||
|
)
|
||||||
|
config.pluginmanager.register(config.stash[xml_key])
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_unconfigure(config: Config) -> None:
|
||||||
|
xml = config.stash.get(xml_key, None)
|
||||||
|
if xml:
|
||||||
|
del config.stash[xml_key]
|
||||||
|
config.pluginmanager.unregister(xml)
|
||||||
|
|
||||||
|
|
||||||
|
class _NodeReporterReruns(_NodeReporter):
|
||||||
|
def _prepare_content(self, content: str, header: str) -> str:
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
|
||||||
|
if content == "":
|
||||||
|
return
|
||||||
|
tag = ET.Element(jheader)
|
||||||
|
tag.text = bin_xml_escape(content)
|
||||||
|
self.append(tag)
|
||||||
|
|
||||||
|
|
||||||
|
class LogXMLReruns(LogXML):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None:
|
||||||
|
if hasattr(report, "wasxfail"):
|
||||||
|
reporter._add_simple("skipped", "xfail-marked test passes unexpectedly")
|
||||||
|
else:
|
||||||
|
assert report.longrepr is not None
|
||||||
|
reprcrash: Optional[ReprFileLocation] = getattr(
|
||||||
|
report.longrepr, "reprcrash", None
|
||||||
|
)
|
||||||
|
if reprcrash is not None:
|
||||||
|
message = reprcrash.message
|
||||||
|
else:
|
||||||
|
message = str(report.longrepr)
|
||||||
|
message = bin_xml_escape(message)
|
||||||
|
reporter._add_simple("rerun", message, str(report.longrepr))
|
||||||
|
|
||||||
|
def pytest_runtest_logreport(self, report: TestReport) -> None:
|
||||||
|
super().pytest_runtest_logreport(report)
|
||||||
|
if report.outcome == "rerun":
|
||||||
|
reporter = self._opentestcase(report)
|
||||||
|
self.append_rerun(reporter, report)
|
||||||
|
if report.outcome == "skipped":
|
||||||
|
if isinstance(report.longrepr, tuple):
|
||||||
|
fspath, lineno, reason = report.longrepr
|
||||||
|
reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}"
|
||||||
|
report.longrepr = (fspath, lineno, reason)
|
||||||
|
|
||||||
|
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns:
|
||||||
|
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
|
||||||
|
# Local hack to handle xdist report order.
|
||||||
|
workernode = getattr(report, "node", None)
|
||||||
|
|
||||||
|
key = nodeid, workernode
|
||||||
|
|
||||||
|
if key in self.node_reporters:
|
||||||
|
# TODO: breaks for --dist=each
|
||||||
|
return self.node_reporters[key]
|
||||||
|
|
||||||
|
reporter = _NodeReporterReruns(nodeid, self)
|
||||||
|
|
||||||
|
self.node_reporters[key] = reporter
|
||||||
|
self.node_reporters_ordered.append(reporter)
|
||||||
|
|
||||||
|
return reporter
|
@ -1443,7 +1443,7 @@ class TestMathBits(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# input strides and size may have been altered due to the result of an inplace op
|
# input strides and size may have been altered due to the result of an inplace op
|
||||||
def test_inplace_view(func, input, rs, input_size, input_strides):
|
def check_inplace_view(func, input, rs, input_size, input_strides):
|
||||||
if func is None:
|
if func is None:
|
||||||
return
|
return
|
||||||
# TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out
|
# TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out
|
||||||
@ -1470,7 +1470,7 @@ class TestTagsMode(TorchDispatchMode):
|
|||||||
old_size = args[0].size()
|
old_size = args[0].size()
|
||||||
old_stride = args[0].stride()
|
old_stride = args[0].stride()
|
||||||
rs = func(*args, **kwargs)
|
rs = func(*args, **kwargs)
|
||||||
test_inplace_view(func, args[0], rs, old_size, old_stride)
|
check_inplace_view(func, args[0], rs, old_size, old_stride)
|
||||||
else:
|
else:
|
||||||
rs = func(*args, **kwargs)
|
rs = func(*args, **kwargs)
|
||||||
return rs
|
return rs
|
||||||
@ -1492,7 +1492,7 @@ class TestTags(TestCase):
|
|||||||
# TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
|
# TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
|
||||||
aten_name = op.aten_name if op.aten_name is not None else op.name
|
aten_name = op.aten_name if op.aten_name is not None else op.name
|
||||||
opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
|
opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
|
||||||
test_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
|
check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
|
||||||
|
|
||||||
|
|
||||||
class TestRefsOpsInfo(TestCase):
|
class TestRefsOpsInfo(TestCase):
|
||||||
|
@ -4,7 +4,7 @@ import sys
|
|||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
from tools.stats.upload_stats_lib import (
|
from tools.stats.upload_stats_lib import (
|
||||||
download_gha_artifacts,
|
download_gha_artifacts,
|
||||||
@ -14,6 +14,15 @@ from tools.stats.upload_stats_lib import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id(report: Path) -> int:
|
||||||
|
# [Job id in artifacts]
|
||||||
|
# Retrieve the job id from the report path. In our GHA workflows, we append
|
||||||
|
# the job id to the end of the report name, so `report` looks like:
|
||||||
|
# unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
|
||||||
|
# and we want to get `5596745227` out of it.
|
||||||
|
return int(report.parts[0].rpartition("_")[2])
|
||||||
|
|
||||||
|
|
||||||
def parse_xml_report(
|
def parse_xml_report(
|
||||||
tag: str,
|
tag: str,
|
||||||
report: Path,
|
report: Path,
|
||||||
@ -22,12 +31,8 @@ def parse_xml_report(
|
|||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Convert a test report xml file into a JSON-serializable list of test cases."""
|
"""Convert a test report xml file into a JSON-serializable list of test cases."""
|
||||||
print(f"Parsing {tag}s for test report: {report}")
|
print(f"Parsing {tag}s for test report: {report}")
|
||||||
# [Job id in artifacts]
|
|
||||||
# Retrieve the job id from the report path. In our GHA workflows, we append
|
job_id = get_job_id(report)
|
||||||
# the job id to the end of the report name, so `report` looks like:
|
|
||||||
# unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
|
|
||||||
# and we want to get `5596745227` out of it.
|
|
||||||
job_id = int(report.parts[0].rpartition("_")[2])
|
|
||||||
print(f"Found job id: {job_id}")
|
print(f"Found job id: {job_id}")
|
||||||
|
|
||||||
root = ET.parse(report)
|
root = ET.parse(report)
|
||||||
@ -112,7 +117,22 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, Any]]:
|
def get_pytest_parallel_times() -> Dict[Any, Any]:
|
||||||
|
pytest_parallel_times = {}
|
||||||
|
for report in Path(".").glob("**/python-pytest/**/*.xml"):
|
||||||
|
invoking_file = report.parent.name
|
||||||
|
root = ET.parse(report)
|
||||||
|
assert len(list(root.iter("testsuite"))) == 1
|
||||||
|
for test_suite in root.iter("testsuite"):
|
||||||
|
pytest_parallel_times[
|
||||||
|
(invoking_file, get_job_id(report))
|
||||||
|
] = test_suite.attrib["time"]
|
||||||
|
return pytest_parallel_times
|
||||||
|
|
||||||
|
|
||||||
|
def get_tests(
|
||||||
|
workflow_run_id: int, workflow_run_attempt: int
|
||||||
|
) -> Tuple[List[Dict[str, Any]], Dict[Any, Any]]:
|
||||||
with TemporaryDirectory() as temp_dir:
|
with TemporaryDirectory() as temp_dir:
|
||||||
print("Using temporary directory:", temp_dir)
|
print("Using temporary directory:", temp_dir)
|
||||||
os.chdir(temp_dir)
|
os.chdir(temp_dir)
|
||||||
@ -142,7 +162,44 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str,
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_cases
|
pytest_parallel_times = get_pytest_parallel_times()
|
||||||
|
|
||||||
|
return test_cases, pytest_parallel_times
|
||||||
|
|
||||||
|
|
||||||
|
def get_invoking_file_times(
|
||||||
|
test_case_summaries: List[Dict[str, Any]], pytest_parallel_times: Dict[Any, Any]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
def get_key(summary: Dict[str, Any]) -> Any:
|
||||||
|
return (
|
||||||
|
summary["invoking_file"],
|
||||||
|
summary["job_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_value(summary: Dict[str, Any]) -> Any:
|
||||||
|
return {
|
||||||
|
"job_id": summary["job_id"],
|
||||||
|
"workflow_id": summary["workflow_id"],
|
||||||
|
"workflow_run_attempt": summary["workflow_run_attempt"],
|
||||||
|
"invoking_file": summary["invoking_file"],
|
||||||
|
"time": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = {}
|
||||||
|
for summary in test_case_summaries:
|
||||||
|
key = get_key(summary)
|
||||||
|
if key not in ret:
|
||||||
|
ret[key] = init_value(summary)
|
||||||
|
ret[key]["time"] += summary["time"]
|
||||||
|
|
||||||
|
for key, val in ret.items():
|
||||||
|
# when running in parallel in pytest, adding the test times will not give the correct
|
||||||
|
# time used to run the file, which will make the sharding incorrect, so if the test is
|
||||||
|
# run in parallel, we take the time reported by the testsuite
|
||||||
|
if key in pytest_parallel_times:
|
||||||
|
val["time"] = pytest_parallel_times[key]
|
||||||
|
|
||||||
|
return list(ret.values())
|
||||||
|
|
||||||
|
|
||||||
def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
@ -220,18 +277,32 @@ if __name__ == "__main__":
|
|||||||
help="Head branch of the workflow",
|
help="Head branch of the workflow",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
|
test_cases, pytest_parallel_times = get_tests(
|
||||||
|
args.workflow_run_id, args.workflow_run_attempt
|
||||||
|
)
|
||||||
|
|
||||||
# Flush stdout so that any errors in rockset upload show up last in the logs.
|
# Flush stdout so that any errors in rockset upload show up last in the logs.
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
# For PRs, only upload a summary of test_runs. This helps lower the
|
# For PRs, only upload a summary of test_runs. This helps lower the
|
||||||
# volume of writes we do to Rockset.
|
# volume of writes we do to Rockset.
|
||||||
|
test_case_summary = summarize_test_cases(test_cases)
|
||||||
|
invoking_file_times = get_invoking_file_times(
|
||||||
|
test_case_summary, pytest_parallel_times
|
||||||
|
)
|
||||||
|
|
||||||
upload_to_s3(
|
upload_to_s3(
|
||||||
args.workflow_run_id,
|
args.workflow_run_id,
|
||||||
args.workflow_run_attempt,
|
args.workflow_run_attempt,
|
||||||
"test_run_summary",
|
"test_run_summary",
|
||||||
summarize_test_cases(test_cases),
|
test_case_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
upload_to_s3(
|
||||||
|
args.workflow_run_id,
|
||||||
|
args.workflow_run_attempt,
|
||||||
|
"invoking_file_times",
|
||||||
|
invoking_file_times,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.head_branch == "master":
|
if args.head_branch == "master":
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
IN_CI = os.environ.get("CI")
|
|
||||||
|
|
||||||
from tools.stats.upload_test_stats import get_tests, summarize_test_cases
|
from tools.stats.upload_test_stats import get_tests, summarize_test_cases
|
||||||
|
|
||||||
|
IN_CI = os.environ.get("CI")
|
||||||
|
|
||||||
|
|
||||||
class TestUploadTestStats(unittest.TestCase):
|
class TestUploadTestStats(unittest.TestCase):
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
@ -13,7 +13,7 @@ class TestUploadTestStats(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
def test_existing_job(self) -> None:
|
def test_existing_job(self) -> None:
|
||||||
"""Run on a known-good job and make sure we don't error and get basically okay reults."""
|
"""Run on a known-good job and make sure we don't error and get basically okay reults."""
|
||||||
test_cases = get_tests(2561394934, 1)
|
test_cases, _ = get_tests(2561394934, 1)
|
||||||
self.assertEqual(len(test_cases), 609873)
|
self.assertEqual(len(test_cases), 609873)
|
||||||
summary = summarize_test_cases(test_cases)
|
summary = summarize_test_cases(test_cases)
|
||||||
self.assertEqual(len(summary), 5068)
|
self.assertEqual(len(summary), 5068)
|
||||||
|
@ -75,6 +75,8 @@ from torch.onnx import (register_custom_op_symbolic,
|
|||||||
unregister_custom_op_symbolic)
|
unregister_custom_op_symbolic)
|
||||||
torch.backends.disable_global_flags()
|
torch.backends.disable_global_flags()
|
||||||
|
|
||||||
|
PYTEST_FILES = ["test_ops", "test_ops_gradients", "test_ops_jit"]
|
||||||
|
|
||||||
FILE_SCHEMA = "file://"
|
FILE_SCHEMA = "file://"
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
FILE_SCHEMA = "file:///"
|
FILE_SCHEMA = "file:///"
|
||||||
@ -91,9 +93,6 @@ MAX_NUM_RETRIES = 3
|
|||||||
DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
|
DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
|
||||||
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
|
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
|
||||||
|
|
||||||
slow_tests_dict: Optional[Dict[str, Any]] = None
|
|
||||||
disabled_tests_dict: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
NATIVE_DEVICES = ('cpu', 'cuda', 'meta')
|
NATIVE_DEVICES = ('cpu', 'cuda', 'meta')
|
||||||
|
|
||||||
|
|
||||||
@ -590,20 +589,33 @@ def lint_test_case_extension(suite):
|
|||||||
succeed = False
|
succeed = False
|
||||||
return succeed
|
return succeed
|
||||||
|
|
||||||
|
def sanitize_pytest_xml(xml_file: str):
|
||||||
|
# pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml
|
||||||
|
# consider somehow modifying the XML logger in conftest to do this instead
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
tree = ET.parse(xml_file)
|
||||||
|
for testcase in tree.iter('testcase'):
|
||||||
|
full_classname = testcase.attrib['classname']
|
||||||
|
regex_result = re.search(r"^test\.(.*)\.([^\.]*)$", full_classname)
|
||||||
|
classname = regex_result.group(2)
|
||||||
|
file = regex_result.group(1).replace('.', "/")
|
||||||
|
testcase.set('classname', classname)
|
||||||
|
testcase.set('file', f"{file}.py")
|
||||||
|
tree.write(xml_file)
|
||||||
|
|
||||||
def run_tests(argv=UNITTEST_ARGS):
|
def run_tests(argv=UNITTEST_ARGS):
|
||||||
# import test files.
|
# import test files.
|
||||||
if IMPORT_SLOW_TESTS:
|
if IMPORT_SLOW_TESTS:
|
||||||
if os.path.exists(IMPORT_SLOW_TESTS):
|
if os.path.exists(IMPORT_SLOW_TESTS):
|
||||||
global slow_tests_dict
|
|
||||||
with open(IMPORT_SLOW_TESTS, 'r') as fp:
|
with open(IMPORT_SLOW_TESTS, 'r') as fp:
|
||||||
slow_tests_dict = json.load(fp)
|
# use env vars so pytest-xdist subprocesses can still access them
|
||||||
|
os.environ['SLOW_TESTS_DICT'] = fp.read()
|
||||||
else:
|
else:
|
||||||
print(f'[WARNING] slow test file provided but not found: {IMPORT_SLOW_TESTS}')
|
print(f'[WARNING] slow test file provided but not found: {IMPORT_SLOW_TESTS}')
|
||||||
if IMPORT_DISABLED_TESTS:
|
if IMPORT_DISABLED_TESTS:
|
||||||
if os.path.exists(IMPORT_DISABLED_TESTS):
|
if os.path.exists(IMPORT_DISABLED_TESTS):
|
||||||
global disabled_tests_dict
|
|
||||||
with open(IMPORT_DISABLED_TESTS, 'r') as fp:
|
with open(IMPORT_DISABLED_TESTS, 'r') as fp:
|
||||||
disabled_tests_dict = json.load(fp)
|
os.environ['DISABLED_TESTS_DICT'] = fp.read()
|
||||||
else:
|
else:
|
||||||
print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}')
|
print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}')
|
||||||
# Determine the test launch mechanism
|
# Determine the test launch mechanism
|
||||||
@ -682,10 +694,32 @@ def run_tests(argv=UNITTEST_ARGS):
|
|||||||
test_filename = sanitize_test_filename(inspect.getfile(sys._getframe(1)))
|
test_filename = sanitize_test_filename(inspect.getfile(sys._getframe(1)))
|
||||||
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
||||||
test_report_path = os.path.join(test_report_path, test_filename)
|
test_report_path = os.path.join(test_report_path, test_filename)
|
||||||
|
if test_filename in PYTEST_FILES and not IS_SANDCASTLE and not (
|
||||||
|
"cuda" in os.environ["BUILD_ENVIRONMENT"] and "linux" in os.environ["BUILD_ENVIRONMENT"]
|
||||||
|
):
|
||||||
|
# exclude linux cuda tests because we run into memory issues when running in parallel
|
||||||
|
import pytest
|
||||||
|
os.environ["NO_COLOR"] = "1"
|
||||||
|
os.environ["USING_PYTEST"] = "1"
|
||||||
|
pytest_report_path = test_report_path.replace('python-unittest', 'python-pytest')
|
||||||
|
os.makedirs(pytest_report_path, exist_ok=True)
|
||||||
|
# part of our xml parsing looks for grandparent folder names
|
||||||
|
pytest_report_path = os.path.join(pytest_report_path, f"{test_filename}.xml")
|
||||||
|
print(f'Test results will be stored in {pytest_report_path}')
|
||||||
|
# mac slower on 4 proc than 3
|
||||||
|
num_procs = 3 if "macos" in os.environ["BUILD_ENVIRONMENT"] else 4
|
||||||
|
exit_code = pytest.main(args=[inspect.getfile(sys._getframe(1)), f'-n={num_procs}', '-vv', '-x',
|
||||||
|
'--reruns=2', '-rfEsX', f'--junit-xml-reruns={pytest_report_path}'])
|
||||||
|
del os.environ["USING_PYTEST"]
|
||||||
|
sanitize_pytest_xml(f'{pytest_report_path}')
|
||||||
|
# exitcode of 5 means no tests were found, which happens since some test configs don't
|
||||||
|
# run tests from certain files
|
||||||
|
exit(0 if exit_code == 5 else exit_code)
|
||||||
|
else:
|
||||||
os.makedirs(test_report_path, exist_ok=True)
|
os.makedirs(test_report_path, exist_ok=True)
|
||||||
verbose = '--verbose' in argv or '-v' in argv
|
verbose = '--verbose' in argv or '-v' in argv
|
||||||
if verbose:
|
if verbose:
|
||||||
print('Test results will be stored in {}'.format(test_report_path))
|
print(f'Test results will be stored in {test_report_path}')
|
||||||
unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
|
unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
|
||||||
output=test_report_path,
|
output=test_report_path,
|
||||||
verbosity=2 if verbose else 1,
|
verbosity=2 if verbose else 1,
|
||||||
@ -1504,13 +1538,16 @@ def remove_device_and_dtype_suffixes(test_name: str) -> str:
|
|||||||
|
|
||||||
def check_if_enable(test: unittest.TestCase):
|
def check_if_enable(test: unittest.TestCase):
|
||||||
test_suite = str(test.__class__).split('\'')[1]
|
test_suite = str(test.__class__).split('\'')[1]
|
||||||
|
if "USING_PYTEST" in os.environ:
|
||||||
|
test_suite = f"__main__.{test_suite.split('.')[1]}"
|
||||||
raw_test_name = f'{test._testMethodName} ({test_suite})'
|
raw_test_name = f'{test._testMethodName} ({test_suite})'
|
||||||
if slow_tests_dict is not None and raw_test_name in slow_tests_dict:
|
if raw_test_name in json.loads(os.environ.get("SLOW_TESTS_DICT", "{}")):
|
||||||
getattr(test, test._testMethodName).__dict__['slow_test'] = True
|
getattr(test, test._testMethodName).__dict__['slow_test'] = True
|
||||||
if not TEST_WITH_SLOW:
|
if not TEST_WITH_SLOW:
|
||||||
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
|
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
|
||||||
sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName)
|
sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName)
|
||||||
if not IS_SANDCASTLE and disabled_tests_dict is not None:
|
if not IS_SANDCASTLE and "DISABLED_TESTS_DICT" in os.environ:
|
||||||
|
disabled_tests_dict = json.loads(os.environ["DISABLED_TESTS_DICT"])
|
||||||
for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
|
for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
|
||||||
disable_test_parts = disabled_test.split()
|
disable_test_parts = disabled_test.split()
|
||||||
if len(disable_test_parts) > 1:
|
if len(disable_test_parts) > 1:
|
||||||
|
Reference in New Issue
Block a user