mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
pytest to run test_ops, test_ops_gradients, test_ops_jit in non linux cuda environments (#79898)
This PR uses pytest to run test_ops, test_ops_gradients, and test_ops_jit in parallel in non linux cuda environments to decrease TTS. I am excluding linux cuda because running in parallel results in errors due to running out of memory Notes: * update hypothesis version for compatability with pytest * use rerun-failures to rerun tests (similar to flaky tests, although these test files generally don't have flaky tests) * reruns are denoted by a rerun tag in the xml. Failed reruns also have the failure tag. Successes (meaning that the test is flaky) do not have the failure tag. * see https://docs.google.com/spreadsheets/d/1aO0Rbg3y3ch7ghipt63PG2KNEUppl9a5b18Hmv2CZ4E/edit#gid=602543594 for info on speedup (or slowdown in the case of slow tests) * expecting windows tests to decrease by 60 minutes total * slow test infra is expected to stay the same - verified by running pytest and unittest on the same job and check the number of skipped/run tests * test reports to s3 changed - add entirely new table to keep track of invoking_file times Pull Request resolved: https://github.com/pytorch/pytorch/pull/79898 Approved by: https://github.com/malfet, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
9f9dd4f072
commit
06a0cfc0ea
@ -41,7 +41,7 @@ flatbuffers==2.0
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
hypothesis==4.53.2
|
||||
hypothesis==5.35.1
|
||||
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
|
||||
#Description: advanced library for generating parametrized tests
|
||||
#Pinned versions: 3.44.6, 4.53.2
|
||||
@ -143,6 +143,16 @@ pytest
|
||||
#Pinned versions:
|
||||
#test that import: test_typing.py, test_cpp_extensions_aot.py, run_test.py
|
||||
|
||||
pytest-xdist
|
||||
#Description: plugin for running pytest in parallel
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
pytest-rerunfailures
|
||||
#Description: plugin for rerunning tests in pytest
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
#pytest-benchmark
|
||||
#Description: fixture for benchmarking code
|
||||
#Pinned versions: 3.2.3
|
||||
|
@ -10,7 +10,9 @@ pip install -q hypothesis "expecttest==0.1.3" "librosa>=0.6.2" "numba<=0.49.1" p
|
||||
# TODO move this to docker
|
||||
# Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
|
||||
pip install "unittest-xml-reporting<=3.2.0,>=2.0.0" \
|
||||
pytest
|
||||
pytest \
|
||||
pytest-xdist \
|
||||
pytest-rerunfailures
|
||||
|
||||
if [ -z "${CI}" ]; then
|
||||
rm -rf "${WORKSPACE_DIR}"/miniconda3/lib/python3.6/site-packages/torch*
|
||||
|
@ -36,7 +36,7 @@ popd
|
||||
=======
|
||||
:: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
|
||||
|
||||
pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest
|
||||
pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-rerunfailures
|
||||
if errorlevel 1 exit /b
|
||||
if not errorlevel 0 exit /b
|
||||
|
||||
|
@ -10,3 +10,4 @@ addopts =
|
||||
-Wd
|
||||
testpaths =
|
||||
test
|
||||
junit_logging_reruns = all
|
||||
|
146
test/conftest.py
Normal file
146
test/conftest.py
Normal file
@ -0,0 +1,146 @@
|
||||
from _pytest.junitxml import LogXML, _NodeReporter, bin_xml_escape
|
||||
from _pytest.terminal import _get_raw_skip_reason
|
||||
from _pytest.stash import StashKey
|
||||
from _pytest.reports import TestReport
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.config import filename_arg
|
||||
from _pytest.config import Config
|
||||
from _pytest._code.code import ReprFileLocation
|
||||
from typing import Union
|
||||
from typing import Optional
|
||||
import xml.etree.ElementTree as ET
|
||||
import functools
|
||||
|
||||
# a lot of this file is copied from _pytest.junitxml and modified to get rerun info
|
||||
|
||||
xml_key = StashKey["LogXMLReruns"]()
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
group = parser.getgroup("terminal reporting")
|
||||
group.addoption(
|
||||
"--junit-xml-reruns",
|
||||
action="store",
|
||||
dest="xmlpath_reruns",
|
||||
metavar="path",
|
||||
type=functools.partial(filename_arg, optname="--junit-xml-reruns"),
|
||||
default=None,
|
||||
help="create junit-xml style report file at given path.",
|
||||
)
|
||||
group.addoption(
|
||||
"--junit-prefix-reruns",
|
||||
action="store",
|
||||
metavar="str",
|
||||
default=None,
|
||||
help="prepend prefix to classnames in junit-xml output",
|
||||
)
|
||||
parser.addini(
|
||||
"junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest"
|
||||
)
|
||||
parser.addini(
|
||||
"junit_logging_reruns",
|
||||
"Write captured log messages to JUnit report: "
|
||||
"one of no|log|system-out|system-err|out-err|all",
|
||||
default="no",
|
||||
)
|
||||
parser.addini(
|
||||
"junit_log_passing_tests_reruns",
|
||||
"Capture log information for passing tests to JUnit report: ",
|
||||
type="bool",
|
||||
default=True,
|
||||
)
|
||||
parser.addini(
|
||||
"junit_duration_report_reruns",
|
||||
"Duration time to report: one of total|call",
|
||||
default="total",
|
||||
)
|
||||
parser.addini(
|
||||
"junit_family_reruns",
|
||||
"Emit XML for schema: one of legacy|xunit1|xunit2",
|
||||
default="xunit2",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: Config) -> None:
|
||||
xmlpath = config.option.xmlpath_reruns
|
||||
# Prevent opening xmllog on worker nodes (xdist).
|
||||
if xmlpath and not hasattr(config, "workerinput"):
|
||||
junit_family = config.getini("junit_family_reruns")
|
||||
config.stash[xml_key] = LogXMLReruns(
|
||||
xmlpath,
|
||||
config.option.junitprefix,
|
||||
config.getini("junit_suite_name_reruns"),
|
||||
config.getini("junit_logging_reruns"),
|
||||
config.getini("junit_duration_report_reruns"),
|
||||
junit_family,
|
||||
config.getini("junit_log_passing_tests_reruns"),
|
||||
)
|
||||
config.pluginmanager.register(config.stash[xml_key])
|
||||
|
||||
|
||||
def pytest_unconfigure(config: Config) -> None:
|
||||
xml = config.stash.get(xml_key, None)
|
||||
if xml:
|
||||
del config.stash[xml_key]
|
||||
config.pluginmanager.unregister(xml)
|
||||
|
||||
|
||||
class _NodeReporterReruns(_NodeReporter):
|
||||
def _prepare_content(self, content: str, header: str) -> str:
|
||||
return content
|
||||
|
||||
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
|
||||
if content == "":
|
||||
return
|
||||
tag = ET.Element(jheader)
|
||||
tag.text = bin_xml_escape(content)
|
||||
self.append(tag)
|
||||
|
||||
|
||||
class LogXMLReruns(LogXML):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None:
|
||||
if hasattr(report, "wasxfail"):
|
||||
reporter._add_simple("skipped", "xfail-marked test passes unexpectedly")
|
||||
else:
|
||||
assert report.longrepr is not None
|
||||
reprcrash: Optional[ReprFileLocation] = getattr(
|
||||
report.longrepr, "reprcrash", None
|
||||
)
|
||||
if reprcrash is not None:
|
||||
message = reprcrash.message
|
||||
else:
|
||||
message = str(report.longrepr)
|
||||
message = bin_xml_escape(message)
|
||||
reporter._add_simple("rerun", message, str(report.longrepr))
|
||||
|
||||
def pytest_runtest_logreport(self, report: TestReport) -> None:
|
||||
super().pytest_runtest_logreport(report)
|
||||
if report.outcome == "rerun":
|
||||
reporter = self._opentestcase(report)
|
||||
self.append_rerun(reporter, report)
|
||||
if report.outcome == "skipped":
|
||||
if isinstance(report.longrepr, tuple):
|
||||
fspath, lineno, reason = report.longrepr
|
||||
reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}"
|
||||
report.longrepr = (fspath, lineno, reason)
|
||||
|
||||
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns:
|
||||
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
|
||||
# Local hack to handle xdist report order.
|
||||
workernode = getattr(report, "node", None)
|
||||
|
||||
key = nodeid, workernode
|
||||
|
||||
if key in self.node_reporters:
|
||||
# TODO: breaks for --dist=each
|
||||
return self.node_reporters[key]
|
||||
|
||||
reporter = _NodeReporterReruns(nodeid, self)
|
||||
|
||||
self.node_reporters[key] = reporter
|
||||
self.node_reporters_ordered.append(reporter)
|
||||
|
||||
return reporter
|
@ -1443,7 +1443,7 @@ class TestMathBits(TestCase):
|
||||
)
|
||||
|
||||
# input strides and size may have been altered due to the result of an inplace op
|
||||
def test_inplace_view(func, input, rs, input_size, input_strides):
|
||||
def check_inplace_view(func, input, rs, input_size, input_strides):
|
||||
if func is None:
|
||||
return
|
||||
# TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out
|
||||
@ -1470,7 +1470,7 @@ class TestTagsMode(TorchDispatchMode):
|
||||
old_size = args[0].size()
|
||||
old_stride = args[0].stride()
|
||||
rs = func(*args, **kwargs)
|
||||
test_inplace_view(func, args[0], rs, old_size, old_stride)
|
||||
check_inplace_view(func, args[0], rs, old_size, old_stride)
|
||||
else:
|
||||
rs = func(*args, **kwargs)
|
||||
return rs
|
||||
@ -1492,7 +1492,7 @@ class TestTags(TestCase):
|
||||
# TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
|
||||
aten_name = op.aten_name if op.aten_name is not None else op.name
|
||||
opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
|
||||
test_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
|
||||
check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
|
||||
|
||||
|
||||
class TestRefsOpsInfo(TestCase):
|
||||
|
@ -4,7 +4,7 @@ import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from tools.stats.upload_stats_lib import (
|
||||
download_gha_artifacts,
|
||||
@ -14,6 +14,15 @@ from tools.stats.upload_stats_lib import (
|
||||
)
|
||||
|
||||
|
||||
def get_job_id(report: Path) -> int:
|
||||
# [Job id in artifacts]
|
||||
# Retrieve the job id from the report path. In our GHA workflows, we append
|
||||
# the job id to the end of the report name, so `report` looks like:
|
||||
# unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
|
||||
# and we want to get `5596745227` out of it.
|
||||
return int(report.parts[0].rpartition("_")[2])
|
||||
|
||||
|
||||
def parse_xml_report(
|
||||
tag: str,
|
||||
report: Path,
|
||||
@ -22,12 +31,8 @@ def parse_xml_report(
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert a test report xml file into a JSON-serializable list of test cases."""
|
||||
print(f"Parsing {tag}s for test report: {report}")
|
||||
# [Job id in artifacts]
|
||||
# Retrieve the job id from the report path. In our GHA workflows, we append
|
||||
# the job id to the end of the report name, so `report` looks like:
|
||||
# unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
|
||||
# and we want to get `5596745227` out of it.
|
||||
job_id = int(report.parts[0].rpartition("_")[2])
|
||||
|
||||
job_id = get_job_id(report)
|
||||
print(f"Found job id: {job_id}")
|
||||
|
||||
root = ET.parse(report)
|
||||
@ -112,7 +117,22 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]:
|
||||
return ret
|
||||
|
||||
|
||||
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, Any]]:
|
||||
def get_pytest_parallel_times() -> Dict[Any, Any]:
|
||||
pytest_parallel_times = {}
|
||||
for report in Path(".").glob("**/python-pytest/**/*.xml"):
|
||||
invoking_file = report.parent.name
|
||||
root = ET.parse(report)
|
||||
assert len(list(root.iter("testsuite"))) == 1
|
||||
for test_suite in root.iter("testsuite"):
|
||||
pytest_parallel_times[
|
||||
(invoking_file, get_job_id(report))
|
||||
] = test_suite.attrib["time"]
|
||||
return pytest_parallel_times
|
||||
|
||||
|
||||
def get_tests(
|
||||
workflow_run_id: int, workflow_run_attempt: int
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[Any, Any]]:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
print("Using temporary directory:", temp_dir)
|
||||
os.chdir(temp_dir)
|
||||
@ -142,7 +162,44 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str,
|
||||
)
|
||||
)
|
||||
|
||||
return test_cases
|
||||
pytest_parallel_times = get_pytest_parallel_times()
|
||||
|
||||
return test_cases, pytest_parallel_times
|
||||
|
||||
|
||||
def get_invoking_file_times(
|
||||
test_case_summaries: List[Dict[str, Any]], pytest_parallel_times: Dict[Any, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
def get_key(summary: Dict[str, Any]) -> Any:
|
||||
return (
|
||||
summary["invoking_file"],
|
||||
summary["job_id"],
|
||||
)
|
||||
|
||||
def init_value(summary: Dict[str, Any]) -> Any:
|
||||
return {
|
||||
"job_id": summary["job_id"],
|
||||
"workflow_id": summary["workflow_id"],
|
||||
"workflow_run_attempt": summary["workflow_run_attempt"],
|
||||
"invoking_file": summary["invoking_file"],
|
||||
"time": 0.0,
|
||||
}
|
||||
|
||||
ret = {}
|
||||
for summary in test_case_summaries:
|
||||
key = get_key(summary)
|
||||
if key not in ret:
|
||||
ret[key] = init_value(summary)
|
||||
ret[key]["time"] += summary["time"]
|
||||
|
||||
for key, val in ret.items():
|
||||
# when running in parallel in pytest, adding the test times will not give the correct
|
||||
# time used to run the file, which will make the sharding incorrect, so if the test is
|
||||
# run in parallel, we take the time reported by the testsuite
|
||||
if key in pytest_parallel_times:
|
||||
val["time"] = pytest_parallel_times[key]
|
||||
|
||||
return list(ret.values())
|
||||
|
||||
|
||||
def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
@ -220,18 +277,32 @@ if __name__ == "__main__":
|
||||
help="Head branch of the workflow",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
|
||||
test_cases, pytest_parallel_times = get_tests(
|
||||
args.workflow_run_id, args.workflow_run_attempt
|
||||
)
|
||||
|
||||
# Flush stdout so that any errors in rockset upload show up last in the logs.
|
||||
sys.stdout.flush()
|
||||
|
||||
# For PRs, only upload a summary of test_runs. This helps lower the
|
||||
# volume of writes we do to Rockset.
|
||||
test_case_summary = summarize_test_cases(test_cases)
|
||||
invoking_file_times = get_invoking_file_times(
|
||||
test_case_summary, pytest_parallel_times
|
||||
)
|
||||
|
||||
upload_to_s3(
|
||||
args.workflow_run_id,
|
||||
args.workflow_run_attempt,
|
||||
"test_run_summary",
|
||||
summarize_test_cases(test_cases),
|
||||
test_case_summary,
|
||||
)
|
||||
|
||||
upload_to_s3(
|
||||
args.workflow_run_id,
|
||||
args.workflow_run_attempt,
|
||||
"invoking_file_times",
|
||||
invoking_file_times,
|
||||
)
|
||||
|
||||
if args.head_branch == "master":
|
||||
|
@ -1,10 +1,10 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
IN_CI = os.environ.get("CI")
|
||||
|
||||
from tools.stats.upload_test_stats import get_tests, summarize_test_cases
|
||||
|
||||
IN_CI = os.environ.get("CI")
|
||||
|
||||
|
||||
class TestUploadTestStats(unittest.TestCase):
|
||||
@unittest.skipIf(
|
||||
@ -13,7 +13,7 @@ class TestUploadTestStats(unittest.TestCase):
|
||||
)
|
||||
def test_existing_job(self) -> None:
|
||||
"""Run on a known-good job and make sure we don't error and get basically okay reults."""
|
||||
test_cases = get_tests(2561394934, 1)
|
||||
test_cases, _ = get_tests(2561394934, 1)
|
||||
self.assertEqual(len(test_cases), 609873)
|
||||
summary = summarize_test_cases(test_cases)
|
||||
self.assertEqual(len(summary), 5068)
|
||||
|
@ -75,6 +75,8 @@ from torch.onnx import (register_custom_op_symbolic,
|
||||
unregister_custom_op_symbolic)
|
||||
torch.backends.disable_global_flags()
|
||||
|
||||
PYTEST_FILES = ["test_ops", "test_ops_gradients", "test_ops_jit"]
|
||||
|
||||
FILE_SCHEMA = "file://"
|
||||
if sys.platform == 'win32':
|
||||
FILE_SCHEMA = "file:///"
|
||||
@ -91,9 +93,6 @@ MAX_NUM_RETRIES = 3
|
||||
DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
|
||||
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
|
||||
|
||||
slow_tests_dict: Optional[Dict[str, Any]] = None
|
||||
disabled_tests_dict: Optional[Dict[str, Any]] = None
|
||||
|
||||
NATIVE_DEVICES = ('cpu', 'cuda', 'meta')
|
||||
|
||||
|
||||
@ -590,20 +589,33 @@ def lint_test_case_extension(suite):
|
||||
succeed = False
|
||||
return succeed
|
||||
|
||||
def sanitize_pytest_xml(xml_file: str):
|
||||
# pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml
|
||||
# consider somehow modifying the XML logger in conftest to do this instead
|
||||
import xml.etree.ElementTree as ET
|
||||
tree = ET.parse(xml_file)
|
||||
for testcase in tree.iter('testcase'):
|
||||
full_classname = testcase.attrib['classname']
|
||||
regex_result = re.search(r"^test\.(.*)\.([^\.]*)$", full_classname)
|
||||
classname = regex_result.group(2)
|
||||
file = regex_result.group(1).replace('.', "/")
|
||||
testcase.set('classname', classname)
|
||||
testcase.set('file', f"{file}.py")
|
||||
tree.write(xml_file)
|
||||
|
||||
def run_tests(argv=UNITTEST_ARGS):
|
||||
# import test files.
|
||||
if IMPORT_SLOW_TESTS:
|
||||
if os.path.exists(IMPORT_SLOW_TESTS):
|
||||
global slow_tests_dict
|
||||
with open(IMPORT_SLOW_TESTS, 'r') as fp:
|
||||
slow_tests_dict = json.load(fp)
|
||||
# use env vars so pytest-xdist subprocesses can still access them
|
||||
os.environ['SLOW_TESTS_DICT'] = fp.read()
|
||||
else:
|
||||
print(f'[WARNING] slow test file provided but not found: {IMPORT_SLOW_TESTS}')
|
||||
if IMPORT_DISABLED_TESTS:
|
||||
if os.path.exists(IMPORT_DISABLED_TESTS):
|
||||
global disabled_tests_dict
|
||||
with open(IMPORT_DISABLED_TESTS, 'r') as fp:
|
||||
disabled_tests_dict = json.load(fp)
|
||||
os.environ['DISABLED_TESTS_DICT'] = fp.read()
|
||||
else:
|
||||
print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}')
|
||||
# Determine the test launch mechanism
|
||||
@ -682,10 +694,32 @@ def run_tests(argv=UNITTEST_ARGS):
|
||||
test_filename = sanitize_test_filename(inspect.getfile(sys._getframe(1)))
|
||||
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
||||
test_report_path = os.path.join(test_report_path, test_filename)
|
||||
if test_filename in PYTEST_FILES and not IS_SANDCASTLE and not (
|
||||
"cuda" in os.environ["BUILD_ENVIRONMENT"] and "linux" in os.environ["BUILD_ENVIRONMENT"]
|
||||
):
|
||||
# exclude linux cuda tests because we run into memory issues when running in parallel
|
||||
import pytest
|
||||
os.environ["NO_COLOR"] = "1"
|
||||
os.environ["USING_PYTEST"] = "1"
|
||||
pytest_report_path = test_report_path.replace('python-unittest', 'python-pytest')
|
||||
os.makedirs(pytest_report_path, exist_ok=True)
|
||||
# part of our xml parsing looks for grandparent folder names
|
||||
pytest_report_path = os.path.join(pytest_report_path, f"{test_filename}.xml")
|
||||
print(f'Test results will be stored in {pytest_report_path}')
|
||||
# mac slower on 4 proc than 3
|
||||
num_procs = 3 if "macos" in os.environ["BUILD_ENVIRONMENT"] else 4
|
||||
exit_code = pytest.main(args=[inspect.getfile(sys._getframe(1)), f'-n={num_procs}', '-vv', '-x',
|
||||
'--reruns=2', '-rfEsX', f'--junit-xml-reruns={pytest_report_path}'])
|
||||
del os.environ["USING_PYTEST"]
|
||||
sanitize_pytest_xml(f'{pytest_report_path}')
|
||||
# exitcode of 5 means no tests were found, which happens since some test configs don't
|
||||
# run tests from certain files
|
||||
exit(0 if exit_code == 5 else exit_code)
|
||||
else:
|
||||
os.makedirs(test_report_path, exist_ok=True)
|
||||
verbose = '--verbose' in argv or '-v' in argv
|
||||
if verbose:
|
||||
print('Test results will be stored in {}'.format(test_report_path))
|
||||
print(f'Test results will be stored in {test_report_path}')
|
||||
unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
|
||||
output=test_report_path,
|
||||
verbosity=2 if verbose else 1,
|
||||
@ -1504,13 +1538,16 @@ def remove_device_and_dtype_suffixes(test_name: str) -> str:
|
||||
|
||||
def check_if_enable(test: unittest.TestCase):
|
||||
test_suite = str(test.__class__).split('\'')[1]
|
||||
if "USING_PYTEST" in os.environ:
|
||||
test_suite = f"__main__.{test_suite.split('.')[1]}"
|
||||
raw_test_name = f'{test._testMethodName} ({test_suite})'
|
||||
if slow_tests_dict is not None and raw_test_name in slow_tests_dict:
|
||||
if raw_test_name in json.loads(os.environ.get("SLOW_TESTS_DICT", "{}")):
|
||||
getattr(test, test._testMethodName).__dict__['slow_test'] = True
|
||||
if not TEST_WITH_SLOW:
|
||||
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
|
||||
sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName)
|
||||
if not IS_SANDCASTLE and disabled_tests_dict is not None:
|
||||
if not IS_SANDCASTLE and "DISABLED_TESTS_DICT" in os.environ:
|
||||
disabled_tests_dict = json.loads(os.environ["DISABLED_TESTS_DICT"])
|
||||
for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
|
||||
disable_test_parts = disabled_test.split()
|
||||
if len(disable_test_parts) > 1:
|
||||
|
Reference in New Issue
Block a user