pytest to run test_ops, test_ops_gradients, test_ops_jit in non linux cuda environments (#79898)

This PR uses pytest to run test_ops, test_ops_gradients, and test_ops_jit in parallel in non linux cuda environments to decrease TTS.  I am excluding linux cuda because running in parallel results in errors due to running out of memory

Notes:
* update hypothesis version for compatability with pytest
* use rerun-failures to rerun tests (similar to flaky tests, although these test files generally don't have flaky tests)
  * reruns are denoted by a rerun tag in the xml.  Failed reruns also have the failure tag.  Successes (meaning that the test is flaky) do not have the failure tag.
* see https://docs.google.com/spreadsheets/d/1aO0Rbg3y3ch7ghipt63PG2KNEUppl9a5b18Hmv2CZ4E/edit#gid=602543594 for info on speedup (or slowdown in the case of slow tests)
  * expecting windows tests to decrease by 60 minutes total
* slow test infra is expected to stay the same - verified by running pytest and unittest on the same job and check the number of skipped/run tests
* test reports to s3 changed - add entirely new table to keep track of invoking_file times
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79898
Approved by: https://github.com/malfet, https://github.com/janeyx99
This commit is contained in:
Catherine Lee
2022-07-19 19:50:57 +00:00
committed by PyTorch MergeBot
parent 9f9dd4f072
commit 06a0cfc0ea
9 changed files with 304 additions and 37 deletions

View File

@ -41,7 +41,7 @@ flatbuffers==2.0
#Pinned versions:
#test that import:
hypothesis==4.53.2
hypothesis==5.35.1
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
#Description: advanced library for generating parametrized tests
#Pinned versions: 3.44.6, 4.53.2
@ -143,6 +143,16 @@ pytest
#Pinned versions:
#test that import: test_typing.py, test_cpp_extensions_aot.py, run_test.py
pytest-xdist
#Description: plugin for running pytest in parallel
#Pinned versions:
#test that import:
pytest-rerunfailures
#Description: plugin for rerunning tests in pytest
#Pinned versions:
#test that import:
#pytest-benchmark
#Description: fixture for benchmarking code
#Pinned versions: 3.2.3

View File

@ -10,7 +10,9 @@ pip install -q hypothesis "expecttest==0.1.3" "librosa>=0.6.2" "numba<=0.49.1" p
# TODO move this to docker
# Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
pip install "unittest-xml-reporting<=3.2.0,>=2.0.0" \
pytest
pytest \
pytest-xdist \
pytest-rerunfailures
if [ -z "${CI}" ]; then
rm -rf "${WORKSPACE_DIR}"/miniconda3/lib/python3.6/site-packages/torch*

View File

@ -36,7 +36,7 @@ popd
=======
:: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest
pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-rerunfailures
if errorlevel 1 exit /b
if not errorlevel 0 exit /b

View File

@ -10,3 +10,4 @@ addopts =
-Wd
testpaths =
test
junit_logging_reruns = all

146
test/conftest.py Normal file
View File

@ -0,0 +1,146 @@
from _pytest.junitxml import LogXML, _NodeReporter, bin_xml_escape
from _pytest.terminal import _get_raw_skip_reason
from _pytest.stash import StashKey
from _pytest.reports import TestReport
from _pytest.config.argparsing import Parser
from _pytest.config import filename_arg
from _pytest.config import Config
from _pytest._code.code import ReprFileLocation
from typing import Union
from typing import Optional
import xml.etree.ElementTree as ET
import functools
# a lot of this file is copied from _pytest.junitxml and modified to get rerun info
xml_key = StashKey["LogXMLReruns"]()
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group.addoption(
"--junit-xml-reruns",
action="store",
dest="xmlpath_reruns",
metavar="path",
type=functools.partial(filename_arg, optname="--junit-xml-reruns"),
default=None,
help="create junit-xml style report file at given path.",
)
group.addoption(
"--junit-prefix-reruns",
action="store",
metavar="str",
default=None,
help="prepend prefix to classnames in junit-xml output",
)
parser.addini(
"junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest"
)
parser.addini(
"junit_logging_reruns",
"Write captured log messages to JUnit report: "
"one of no|log|system-out|system-err|out-err|all",
default="no",
)
parser.addini(
"junit_log_passing_tests_reruns",
"Capture log information for passing tests to JUnit report: ",
type="bool",
default=True,
)
parser.addini(
"junit_duration_report_reruns",
"Duration time to report: one of total|call",
default="total",
)
parser.addini(
"junit_family_reruns",
"Emit XML for schema: one of legacy|xunit1|xunit2",
default="xunit2",
)
def pytest_configure(config: Config) -> None:
xmlpath = config.option.xmlpath_reruns
# Prevent opening xmllog on worker nodes (xdist).
if xmlpath and not hasattr(config, "workerinput"):
junit_family = config.getini("junit_family_reruns")
config.stash[xml_key] = LogXMLReruns(
xmlpath,
config.option.junitprefix,
config.getini("junit_suite_name_reruns"),
config.getini("junit_logging_reruns"),
config.getini("junit_duration_report_reruns"),
junit_family,
config.getini("junit_log_passing_tests_reruns"),
)
config.pluginmanager.register(config.stash[xml_key])
def pytest_unconfigure(config: Config) -> None:
xml = config.stash.get(xml_key, None)
if xml:
del config.stash[xml_key]
config.pluginmanager.unregister(xml)
class _NodeReporterReruns(_NodeReporter):
def _prepare_content(self, content: str, header: str) -> str:
return content
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
if content == "":
return
tag = ET.Element(jheader)
tag.text = bin_xml_escape(content)
self.append(tag)
class LogXMLReruns(LogXML):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
reporter._add_simple("skipped", "xfail-marked test passes unexpectedly")
else:
assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr(
report.longrepr, "reprcrash", None
)
if reprcrash is not None:
message = reprcrash.message
else:
message = str(report.longrepr)
message = bin_xml_escape(message)
reporter._add_simple("rerun", message, str(report.longrepr))
def pytest_runtest_logreport(self, report: TestReport) -> None:
super().pytest_runtest_logreport(report)
if report.outcome == "rerun":
reporter = self._opentestcase(report)
self.append_rerun(reporter, report)
if report.outcome == "skipped":
if isinstance(report.longrepr, tuple):
fspath, lineno, reason = report.longrepr
reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}"
report.longrepr = (fspath, lineno, reason)
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns:
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
# Local hack to handle xdist report order.
workernode = getattr(report, "node", None)
key = nodeid, workernode
if key in self.node_reporters:
# TODO: breaks for --dist=each
return self.node_reporters[key]
reporter = _NodeReporterReruns(nodeid, self)
self.node_reporters[key] = reporter
self.node_reporters_ordered.append(reporter)
return reporter

View File

@ -1443,7 +1443,7 @@ class TestMathBits(TestCase):
)
# input strides and size may have been altered due to the result of an inplace op
def test_inplace_view(func, input, rs, input_size, input_strides):
def check_inplace_view(func, input, rs, input_size, input_strides):
if func is None:
return
# TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out
@ -1470,7 +1470,7 @@ class TestTagsMode(TorchDispatchMode):
old_size = args[0].size()
old_stride = args[0].stride()
rs = func(*args, **kwargs)
test_inplace_view(func, args[0], rs, old_size, old_stride)
check_inplace_view(func, args[0], rs, old_size, old_stride)
else:
rs = func(*args, **kwargs)
return rs
@ -1492,7 +1492,7 @@ class TestTags(TestCase):
# TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
aten_name = op.aten_name if op.aten_name is not None else op.name
opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
test_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
class TestRefsOpsInfo(TestCase):

View File

@ -4,7 +4,7 @@ import sys
import xml.etree.ElementTree as ET
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple
from tools.stats.upload_stats_lib import (
download_gha_artifacts,
@ -14,6 +14,15 @@ from tools.stats.upload_stats_lib import (
)
def get_job_id(report: Path) -> int:
# [Job id in artifacts]
# Retrieve the job id from the report path. In our GHA workflows, we append
# the job id to the end of the report name, so `report` looks like:
# unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
# and we want to get `5596745227` out of it.
return int(report.parts[0].rpartition("_")[2])
def parse_xml_report(
tag: str,
report: Path,
@ -22,12 +31,8 @@ def parse_xml_report(
) -> List[Dict[str, Any]]:
"""Convert a test report xml file into a JSON-serializable list of test cases."""
print(f"Parsing {tag}s for test report: {report}")
# [Job id in artifacts]
# Retrieve the job id from the report path. In our GHA workflows, we append
# the job id to the end of the report name, so `report` looks like:
# unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
# and we want to get `5596745227` out of it.
job_id = int(report.parts[0].rpartition("_")[2])
job_id = get_job_id(report)
print(f"Found job id: {job_id}")
root = ET.parse(report)
@ -112,7 +117,22 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]:
return ret
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, Any]]:
def get_pytest_parallel_times() -> Dict[Any, Any]:
pytest_parallel_times = {}
for report in Path(".").glob("**/python-pytest/**/*.xml"):
invoking_file = report.parent.name
root = ET.parse(report)
assert len(list(root.iter("testsuite"))) == 1
for test_suite in root.iter("testsuite"):
pytest_parallel_times[
(invoking_file, get_job_id(report))
] = test_suite.attrib["time"]
return pytest_parallel_times
def get_tests(
workflow_run_id: int, workflow_run_attempt: int
) -> Tuple[List[Dict[str, Any]], Dict[Any, Any]]:
with TemporaryDirectory() as temp_dir:
print("Using temporary directory:", temp_dir)
os.chdir(temp_dir)
@ -142,7 +162,44 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str,
)
)
return test_cases
pytest_parallel_times = get_pytest_parallel_times()
return test_cases, pytest_parallel_times
def get_invoking_file_times(
test_case_summaries: List[Dict[str, Any]], pytest_parallel_times: Dict[Any, Any]
) -> List[Dict[str, Any]]:
def get_key(summary: Dict[str, Any]) -> Any:
return (
summary["invoking_file"],
summary["job_id"],
)
def init_value(summary: Dict[str, Any]) -> Any:
return {
"job_id": summary["job_id"],
"workflow_id": summary["workflow_id"],
"workflow_run_attempt": summary["workflow_run_attempt"],
"invoking_file": summary["invoking_file"],
"time": 0.0,
}
ret = {}
for summary in test_case_summaries:
key = get_key(summary)
if key not in ret:
ret[key] = init_value(summary)
ret[key]["time"] += summary["time"]
for key, val in ret.items():
# when running in parallel in pytest, adding the test times will not give the correct
# time used to run the file, which will make the sharding incorrect, so if the test is
# run in parallel, we take the time reported by the testsuite
if key in pytest_parallel_times:
val["time"] = pytest_parallel_times[key]
return list(ret.values())
def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
@ -220,18 +277,32 @@ if __name__ == "__main__":
help="Head branch of the workflow",
)
args = parser.parse_args()
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
test_cases, pytest_parallel_times = get_tests(
args.workflow_run_id, args.workflow_run_attempt
)
# Flush stdout so that any errors in rockset upload show up last in the logs.
sys.stdout.flush()
# For PRs, only upload a summary of test_runs. This helps lower the
# volume of writes we do to Rockset.
test_case_summary = summarize_test_cases(test_cases)
invoking_file_times = get_invoking_file_times(
test_case_summary, pytest_parallel_times
)
upload_to_s3(
args.workflow_run_id,
args.workflow_run_attempt,
"test_run_summary",
summarize_test_cases(test_cases),
test_case_summary,
)
upload_to_s3(
args.workflow_run_id,
args.workflow_run_attempt,
"invoking_file_times",
invoking_file_times,
)
if args.head_branch == "master":

View File

@ -1,10 +1,10 @@
import os
import unittest
IN_CI = os.environ.get("CI")
from tools.stats.upload_test_stats import get_tests, summarize_test_cases
IN_CI = os.environ.get("CI")
class TestUploadTestStats(unittest.TestCase):
@unittest.skipIf(
@ -13,7 +13,7 @@ class TestUploadTestStats(unittest.TestCase):
)
def test_existing_job(self) -> None:
"""Run on a known-good job and make sure we don't error and get basically okay reults."""
test_cases = get_tests(2561394934, 1)
test_cases, _ = get_tests(2561394934, 1)
self.assertEqual(len(test_cases), 609873)
summary = summarize_test_cases(test_cases)
self.assertEqual(len(summary), 5068)

View File

@ -75,6 +75,8 @@ from torch.onnx import (register_custom_op_symbolic,
unregister_custom_op_symbolic)
torch.backends.disable_global_flags()
PYTEST_FILES = ["test_ops", "test_ops_gradients", "test_ops_jit"]
FILE_SCHEMA = "file://"
if sys.platform == 'win32':
FILE_SCHEMA = "file:///"
@ -91,9 +93,6 @@ MAX_NUM_RETRIES = 3
DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
slow_tests_dict: Optional[Dict[str, Any]] = None
disabled_tests_dict: Optional[Dict[str, Any]] = None
NATIVE_DEVICES = ('cpu', 'cuda', 'meta')
@ -590,20 +589,33 @@ def lint_test_case_extension(suite):
succeed = False
return succeed
def sanitize_pytest_xml(xml_file: str):
# pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml
# consider somehow modifying the XML logger in conftest to do this instead
import xml.etree.ElementTree as ET
tree = ET.parse(xml_file)
for testcase in tree.iter('testcase'):
full_classname = testcase.attrib['classname']
regex_result = re.search(r"^test\.(.*)\.([^\.]*)$", full_classname)
classname = regex_result.group(2)
file = regex_result.group(1).replace('.', "/")
testcase.set('classname', classname)
testcase.set('file', f"{file}.py")
tree.write(xml_file)
def run_tests(argv=UNITTEST_ARGS):
# import test files.
if IMPORT_SLOW_TESTS:
if os.path.exists(IMPORT_SLOW_TESTS):
global slow_tests_dict
with open(IMPORT_SLOW_TESTS, 'r') as fp:
slow_tests_dict = json.load(fp)
# use env vars so pytest-xdist subprocesses can still access them
os.environ['SLOW_TESTS_DICT'] = fp.read()
else:
print(f'[WARNING] slow test file provided but not found: {IMPORT_SLOW_TESTS}')
if IMPORT_DISABLED_TESTS:
if os.path.exists(IMPORT_DISABLED_TESTS):
global disabled_tests_dict
with open(IMPORT_DISABLED_TESTS, 'r') as fp:
disabled_tests_dict = json.load(fp)
os.environ['DISABLED_TESTS_DICT'] = fp.read()
else:
print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}')
# Determine the test launch mechanism
@ -682,10 +694,32 @@ def run_tests(argv=UNITTEST_ARGS):
test_filename = sanitize_test_filename(inspect.getfile(sys._getframe(1)))
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
test_report_path = os.path.join(test_report_path, test_filename)
if test_filename in PYTEST_FILES and not IS_SANDCASTLE and not (
"cuda" in os.environ["BUILD_ENVIRONMENT"] and "linux" in os.environ["BUILD_ENVIRONMENT"]
):
# exclude linux cuda tests because we run into memory issues when running in parallel
import pytest
os.environ["NO_COLOR"] = "1"
os.environ["USING_PYTEST"] = "1"
pytest_report_path = test_report_path.replace('python-unittest', 'python-pytest')
os.makedirs(pytest_report_path, exist_ok=True)
# part of our xml parsing looks for grandparent folder names
pytest_report_path = os.path.join(pytest_report_path, f"{test_filename}.xml")
print(f'Test results will be stored in {pytest_report_path}')
# mac slower on 4 proc than 3
num_procs = 3 if "macos" in os.environ["BUILD_ENVIRONMENT"] else 4
exit_code = pytest.main(args=[inspect.getfile(sys._getframe(1)), f'-n={num_procs}', '-vv', '-x',
'--reruns=2', '-rfEsX', f'--junit-xml-reruns={pytest_report_path}'])
del os.environ["USING_PYTEST"]
sanitize_pytest_xml(f'{pytest_report_path}')
# exitcode of 5 means no tests were found, which happens since some test configs don't
# run tests from certain files
exit(0 if exit_code == 5 else exit_code)
else:
os.makedirs(test_report_path, exist_ok=True)
verbose = '--verbose' in argv or '-v' in argv
if verbose:
print('Test results will be stored in {}'.format(test_report_path))
print(f'Test results will be stored in {test_report_path}')
unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
output=test_report_path,
verbosity=2 if verbose else 1,
@ -1504,13 +1538,16 @@ def remove_device_and_dtype_suffixes(test_name: str) -> str:
def check_if_enable(test: unittest.TestCase):
test_suite = str(test.__class__).split('\'')[1]
if "USING_PYTEST" in os.environ:
test_suite = f"__main__.{test_suite.split('.')[1]}"
raw_test_name = f'{test._testMethodName} ({test_suite})'
if slow_tests_dict is not None and raw_test_name in slow_tests_dict:
if raw_test_name in json.loads(os.environ.get("SLOW_TESTS_DICT", "{}")):
getattr(test, test._testMethodName).__dict__['slow_test'] = True
if not TEST_WITH_SLOW:
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName)
if not IS_SANDCASTLE and disabled_tests_dict is not None:
if not IS_SANDCASTLE and "DISABLED_TESTS_DICT" in os.environ:
disabled_tests_dict = json.loads(os.environ["DISABLED_TESTS_DICT"])
for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
disable_test_parts = disabled_test.split()
if len(disable_test_parts) > 1: