pytest to run test_ops, test_ops_gradients, test_ops_jit in non linux cuda environments (#79898)

This PR uses pytest to run test_ops, test_ops_gradients, and test_ops_jit in parallel in non linux cuda environments to decrease TTS.  I am excluding linux cuda because running in parallel results in errors due to running out of memory

Notes:
* update hypothesis version for compatability with pytest
* use rerun-failures to rerun tests (similar to flaky tests, although these test files generally don't have flaky tests)
  * reruns are denoted by a rerun tag in the xml.  Failed reruns also have the failure tag.  Successes (meaning that the test is flaky) do not have the failure tag.
* see https://docs.google.com/spreadsheets/d/1aO0Rbg3y3ch7ghipt63PG2KNEUppl9a5b18Hmv2CZ4E/edit#gid=602543594 for info on speedup (or slowdown in the case of slow tests)
  * expecting windows tests to decrease by 60 minutes total
* slow test infra is expected to stay the same - verified by running pytest and unittest on the same job and check the number of skipped/run tests
* test reports to s3 changed - add entirely new table to keep track of invoking_file times
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79898
Approved by: https://github.com/malfet, https://github.com/janeyx99
This commit is contained in:
Catherine Lee
2022-07-19 19:50:57 +00:00
committed by PyTorch MergeBot
parent 9f9dd4f072
commit 06a0cfc0ea
9 changed files with 304 additions and 37 deletions

View File

@ -41,7 +41,7 @@ flatbuffers==2.0
#Pinned versions: #Pinned versions:
#test that import: #test that import:
hypothesis==4.53.2 hypothesis==5.35.1
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
#Description: advanced library for generating parametrized tests #Description: advanced library for generating parametrized tests
#Pinned versions: 3.44.6, 4.53.2 #Pinned versions: 3.44.6, 4.53.2
@ -143,6 +143,16 @@ pytest
#Pinned versions: #Pinned versions:
#test that import: test_typing.py, test_cpp_extensions_aot.py, run_test.py #test that import: test_typing.py, test_cpp_extensions_aot.py, run_test.py
pytest-xdist
#Description: plugin for running pytest in parallel
#Pinned versions:
#test that import:
pytest-rerunfailures
#Description: plugin for rerunning tests in pytest
#Pinned versions:
#test that import:
#pytest-benchmark #pytest-benchmark
#Description: fixture for benchmarking code #Description: fixture for benchmarking code
#Pinned versions: 3.2.3 #Pinned versions: 3.2.3

View File

@ -10,7 +10,9 @@ pip install -q hypothesis "expecttest==0.1.3" "librosa>=0.6.2" "numba<=0.49.1" p
# TODO move this to docker # TODO move this to docker
# Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014 # Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
pip install "unittest-xml-reporting<=3.2.0,>=2.0.0" \ pip install "unittest-xml-reporting<=3.2.0,>=2.0.0" \
pytest pytest \
pytest-xdist \
pytest-rerunfailures
if [ -z "${CI}" ]; then if [ -z "${CI}" ]; then
rm -rf "${WORKSPACE_DIR}"/miniconda3/lib/python3.6/site-packages/torch* rm -rf "${WORKSPACE_DIR}"/miniconda3/lib/python3.6/site-packages/torch*

View File

@ -36,7 +36,7 @@ popd
======= =======
:: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014 :: Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014
pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pip install "ninja==1.10.0.post1" future "hypothesis==5.35.1" "expecttest==0.1.3" "librosa>=0.6.2" "scipy==1.6.3" psutil pillow "unittest-xml-reporting<=3.2.0,>=2.0.0" pytest pytest-xdist pytest-rerunfailures
if errorlevel 1 exit /b if errorlevel 1 exit /b
if not errorlevel 0 exit /b if not errorlevel 0 exit /b

View File

@ -10,3 +10,4 @@ addopts =
-Wd -Wd
testpaths = testpaths =
test test
junit_logging_reruns = all

146
test/conftest.py Normal file
View File

@ -0,0 +1,146 @@
from _pytest.junitxml import LogXML, _NodeReporter, bin_xml_escape
from _pytest.terminal import _get_raw_skip_reason
from _pytest.stash import StashKey
from _pytest.reports import TestReport
from _pytest.config.argparsing import Parser
from _pytest.config import filename_arg
from _pytest.config import Config
from _pytest._code.code import ReprFileLocation
from typing import Union
from typing import Optional
import xml.etree.ElementTree as ET
import functools
# a lot of this file is copied from _pytest.junitxml and modified to get rerun info
xml_key = StashKey["LogXMLReruns"]()
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group.addoption(
"--junit-xml-reruns",
action="store",
dest="xmlpath_reruns",
metavar="path",
type=functools.partial(filename_arg, optname="--junit-xml-reruns"),
default=None,
help="create junit-xml style report file at given path.",
)
group.addoption(
"--junit-prefix-reruns",
action="store",
metavar="str",
default=None,
help="prepend prefix to classnames in junit-xml output",
)
parser.addini(
"junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest"
)
parser.addini(
"junit_logging_reruns",
"Write captured log messages to JUnit report: "
"one of no|log|system-out|system-err|out-err|all",
default="no",
)
parser.addini(
"junit_log_passing_tests_reruns",
"Capture log information for passing tests to JUnit report: ",
type="bool",
default=True,
)
parser.addini(
"junit_duration_report_reruns",
"Duration time to report: one of total|call",
default="total",
)
parser.addini(
"junit_family_reruns",
"Emit XML for schema: one of legacy|xunit1|xunit2",
default="xunit2",
)
def pytest_configure(config: Config) -> None:
xmlpath = config.option.xmlpath_reruns
# Prevent opening xmllog on worker nodes (xdist).
if xmlpath and not hasattr(config, "workerinput"):
junit_family = config.getini("junit_family_reruns")
config.stash[xml_key] = LogXMLReruns(
xmlpath,
config.option.junitprefix,
config.getini("junit_suite_name_reruns"),
config.getini("junit_logging_reruns"),
config.getini("junit_duration_report_reruns"),
junit_family,
config.getini("junit_log_passing_tests_reruns"),
)
config.pluginmanager.register(config.stash[xml_key])
def pytest_unconfigure(config: Config) -> None:
xml = config.stash.get(xml_key, None)
if xml:
del config.stash[xml_key]
config.pluginmanager.unregister(xml)
class _NodeReporterReruns(_NodeReporter):
def _prepare_content(self, content: str, header: str) -> str:
return content
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
if content == "":
return
tag = ET.Element(jheader)
tag.text = bin_xml_escape(content)
self.append(tag)
class LogXMLReruns(LogXML):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
reporter._add_simple("skipped", "xfail-marked test passes unexpectedly")
else:
assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr(
report.longrepr, "reprcrash", None
)
if reprcrash is not None:
message = reprcrash.message
else:
message = str(report.longrepr)
message = bin_xml_escape(message)
reporter._add_simple("rerun", message, str(report.longrepr))
def pytest_runtest_logreport(self, report: TestReport) -> None:
super().pytest_runtest_logreport(report)
if report.outcome == "rerun":
reporter = self._opentestcase(report)
self.append_rerun(reporter, report)
if report.outcome == "skipped":
if isinstance(report.longrepr, tuple):
fspath, lineno, reason = report.longrepr
reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}"
report.longrepr = (fspath, lineno, reason)
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns:
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
# Local hack to handle xdist report order.
workernode = getattr(report, "node", None)
key = nodeid, workernode
if key in self.node_reporters:
# TODO: breaks for --dist=each
return self.node_reporters[key]
reporter = _NodeReporterReruns(nodeid, self)
self.node_reporters[key] = reporter
self.node_reporters_ordered.append(reporter)
return reporter

View File

@ -1443,7 +1443,7 @@ class TestMathBits(TestCase):
) )
# input strides and size may have been altered due to the result of an inplace op # input strides and size may have been altered due to the result of an inplace op
def test_inplace_view(func, input, rs, input_size, input_strides): def check_inplace_view(func, input, rs, input_size, input_strides):
if func is None: if func is None:
return return
# TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out
@ -1470,7 +1470,7 @@ class TestTagsMode(TorchDispatchMode):
old_size = args[0].size() old_size = args[0].size()
old_stride = args[0].stride() old_stride = args[0].stride()
rs = func(*args, **kwargs) rs = func(*args, **kwargs)
test_inplace_view(func, args[0], rs, old_size, old_stride) check_inplace_view(func, args[0], rs, old_size, old_stride)
else: else:
rs = func(*args, **kwargs) rs = func(*args, **kwargs)
return rs return rs
@ -1492,7 +1492,7 @@ class TestTags(TestCase):
# TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761 # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
aten_name = op.aten_name if op.aten_name is not None else op.name aten_name = op.aten_name if op.aten_name is not None else op.name
opoverloadpacket = getattr(torch.ops.aten, aten_name, None) opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
test_inplace_view(opoverloadpacket, input, rs, old_size, old_stride) check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
class TestRefsOpsInfo(TestCase): class TestRefsOpsInfo(TestCase):

View File

@ -4,7 +4,7 @@ import sys
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Dict, List from typing import Any, Dict, List, Tuple
from tools.stats.upload_stats_lib import ( from tools.stats.upload_stats_lib import (
download_gha_artifacts, download_gha_artifacts,
@ -14,6 +14,15 @@ from tools.stats.upload_stats_lib import (
) )
def get_job_id(report: Path) -> int:
# [Job id in artifacts]
# Retrieve the job id from the report path. In our GHA workflows, we append
# the job id to the end of the report name, so `report` looks like:
# unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
# and we want to get `5596745227` out of it.
return int(report.parts[0].rpartition("_")[2])
def parse_xml_report( def parse_xml_report(
tag: str, tag: str,
report: Path, report: Path,
@ -22,12 +31,8 @@ def parse_xml_report(
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Convert a test report xml file into a JSON-serializable list of test cases.""" """Convert a test report xml file into a JSON-serializable list of test cases."""
print(f"Parsing {tag}s for test report: {report}") print(f"Parsing {tag}s for test report: {report}")
# [Job id in artifacts]
# Retrieve the job id from the report path. In our GHA workflows, we append job_id = get_job_id(report)
# the job id to the end of the report name, so `report` looks like:
# unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
# and we want to get `5596745227` out of it.
job_id = int(report.parts[0].rpartition("_")[2])
print(f"Found job id: {job_id}") print(f"Found job id: {job_id}")
root = ET.parse(report) root = ET.parse(report)
@ -112,7 +117,22 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]:
return ret return ret
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, Any]]: def get_pytest_parallel_times() -> Dict[Any, Any]:
pytest_parallel_times = {}
for report in Path(".").glob("**/python-pytest/**/*.xml"):
invoking_file = report.parent.name
root = ET.parse(report)
assert len(list(root.iter("testsuite"))) == 1
for test_suite in root.iter("testsuite"):
pytest_parallel_times[
(invoking_file, get_job_id(report))
] = test_suite.attrib["time"]
return pytest_parallel_times
def get_tests(
workflow_run_id: int, workflow_run_attempt: int
) -> Tuple[List[Dict[str, Any]], Dict[Any, Any]]:
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:
print("Using temporary directory:", temp_dir) print("Using temporary directory:", temp_dir)
os.chdir(temp_dir) os.chdir(temp_dir)
@ -142,7 +162,44 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str,
) )
) )
return test_cases pytest_parallel_times = get_pytest_parallel_times()
return test_cases, pytest_parallel_times
def get_invoking_file_times(
test_case_summaries: List[Dict[str, Any]], pytest_parallel_times: Dict[Any, Any]
) -> List[Dict[str, Any]]:
def get_key(summary: Dict[str, Any]) -> Any:
return (
summary["invoking_file"],
summary["job_id"],
)
def init_value(summary: Dict[str, Any]) -> Any:
return {
"job_id": summary["job_id"],
"workflow_id": summary["workflow_id"],
"workflow_run_attempt": summary["workflow_run_attempt"],
"invoking_file": summary["invoking_file"],
"time": 0.0,
}
ret = {}
for summary in test_case_summaries:
key = get_key(summary)
if key not in ret:
ret[key] = init_value(summary)
ret[key]["time"] += summary["time"]
for key, val in ret.items():
# when running in parallel in pytest, adding the test times will not give the correct
# time used to run the file, which will make the sharding incorrect, so if the test is
# run in parallel, we take the time reported by the testsuite
if key in pytest_parallel_times:
val["time"] = pytest_parallel_times[key]
return list(ret.values())
def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
@ -220,18 +277,32 @@ if __name__ == "__main__":
help="Head branch of the workflow", help="Head branch of the workflow",
) )
args = parser.parse_args() args = parser.parse_args()
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt) test_cases, pytest_parallel_times = get_tests(
args.workflow_run_id, args.workflow_run_attempt
)
# Flush stdout so that any errors in rockset upload show up last in the logs. # Flush stdout so that any errors in rockset upload show up last in the logs.
sys.stdout.flush() sys.stdout.flush()
# For PRs, only upload a summary of test_runs. This helps lower the # For PRs, only upload a summary of test_runs. This helps lower the
# volume of writes we do to Rockset. # volume of writes we do to Rockset.
test_case_summary = summarize_test_cases(test_cases)
invoking_file_times = get_invoking_file_times(
test_case_summary, pytest_parallel_times
)
upload_to_s3( upload_to_s3(
args.workflow_run_id, args.workflow_run_id,
args.workflow_run_attempt, args.workflow_run_attempt,
"test_run_summary", "test_run_summary",
summarize_test_cases(test_cases), test_case_summary,
)
upload_to_s3(
args.workflow_run_id,
args.workflow_run_attempt,
"invoking_file_times",
invoking_file_times,
) )
if args.head_branch == "master": if args.head_branch == "master":

View File

@ -1,10 +1,10 @@
import os import os
import unittest import unittest
IN_CI = os.environ.get("CI")
from tools.stats.upload_test_stats import get_tests, summarize_test_cases from tools.stats.upload_test_stats import get_tests, summarize_test_cases
IN_CI = os.environ.get("CI")
class TestUploadTestStats(unittest.TestCase): class TestUploadTestStats(unittest.TestCase):
@unittest.skipIf( @unittest.skipIf(
@ -13,7 +13,7 @@ class TestUploadTestStats(unittest.TestCase):
) )
def test_existing_job(self) -> None: def test_existing_job(self) -> None:
"""Run on a known-good job and make sure we don't error and get basically okay reults.""" """Run on a known-good job and make sure we don't error and get basically okay reults."""
test_cases = get_tests(2561394934, 1) test_cases, _ = get_tests(2561394934, 1)
self.assertEqual(len(test_cases), 609873) self.assertEqual(len(test_cases), 609873)
summary = summarize_test_cases(test_cases) summary = summarize_test_cases(test_cases)
self.assertEqual(len(summary), 5068) self.assertEqual(len(summary), 5068)

View File

@ -75,6 +75,8 @@ from torch.onnx import (register_custom_op_symbolic,
unregister_custom_op_symbolic) unregister_custom_op_symbolic)
torch.backends.disable_global_flags() torch.backends.disable_global_flags()
PYTEST_FILES = ["test_ops", "test_ops_gradients", "test_ops_jit"]
FILE_SCHEMA = "file://" FILE_SCHEMA = "file://"
if sys.platform == 'win32': if sys.platform == 'win32':
FILE_SCHEMA = "file:///" FILE_SCHEMA = "file:///"
@ -91,9 +93,6 @@ MAX_NUM_RETRIES = 3
DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json' DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
SLOW_TESTS_FILE = '.pytorch-slow-tests.json' SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
slow_tests_dict: Optional[Dict[str, Any]] = None
disabled_tests_dict: Optional[Dict[str, Any]] = None
NATIVE_DEVICES = ('cpu', 'cuda', 'meta') NATIVE_DEVICES = ('cpu', 'cuda', 'meta')
@ -590,20 +589,33 @@ def lint_test_case_extension(suite):
succeed = False succeed = False
return succeed return succeed
def sanitize_pytest_xml(xml_file: str):
# pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml
# consider somehow modifying the XML logger in conftest to do this instead
import xml.etree.ElementTree as ET
tree = ET.parse(xml_file)
for testcase in tree.iter('testcase'):
full_classname = testcase.attrib['classname']
regex_result = re.search(r"^test\.(.*)\.([^\.]*)$", full_classname)
classname = regex_result.group(2)
file = regex_result.group(1).replace('.', "/")
testcase.set('classname', classname)
testcase.set('file', f"{file}.py")
tree.write(xml_file)
def run_tests(argv=UNITTEST_ARGS): def run_tests(argv=UNITTEST_ARGS):
# import test files. # import test files.
if IMPORT_SLOW_TESTS: if IMPORT_SLOW_TESTS:
if os.path.exists(IMPORT_SLOW_TESTS): if os.path.exists(IMPORT_SLOW_TESTS):
global slow_tests_dict
with open(IMPORT_SLOW_TESTS, 'r') as fp: with open(IMPORT_SLOW_TESTS, 'r') as fp:
slow_tests_dict = json.load(fp) # use env vars so pytest-xdist subprocesses can still access them
os.environ['SLOW_TESTS_DICT'] = fp.read()
else: else:
print(f'[WARNING] slow test file provided but not found: {IMPORT_SLOW_TESTS}') print(f'[WARNING] slow test file provided but not found: {IMPORT_SLOW_TESTS}')
if IMPORT_DISABLED_TESTS: if IMPORT_DISABLED_TESTS:
if os.path.exists(IMPORT_DISABLED_TESTS): if os.path.exists(IMPORT_DISABLED_TESTS):
global disabled_tests_dict
with open(IMPORT_DISABLED_TESTS, 'r') as fp: with open(IMPORT_DISABLED_TESTS, 'r') as fp:
disabled_tests_dict = json.load(fp) os.environ['DISABLED_TESTS_DICT'] = fp.read()
else: else:
print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}') print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}')
# Determine the test launch mechanism # Determine the test launch mechanism
@ -682,10 +694,32 @@ def run_tests(argv=UNITTEST_ARGS):
test_filename = sanitize_test_filename(inspect.getfile(sys._getframe(1))) test_filename = sanitize_test_filename(inspect.getfile(sys._getframe(1)))
test_report_path = TEST_SAVE_XML + LOG_SUFFIX test_report_path = TEST_SAVE_XML + LOG_SUFFIX
test_report_path = os.path.join(test_report_path, test_filename) test_report_path = os.path.join(test_report_path, test_filename)
if test_filename in PYTEST_FILES and not IS_SANDCASTLE and not (
"cuda" in os.environ["BUILD_ENVIRONMENT"] and "linux" in os.environ["BUILD_ENVIRONMENT"]
):
# exclude linux cuda tests because we run into memory issues when running in parallel
import pytest
os.environ["NO_COLOR"] = "1"
os.environ["USING_PYTEST"] = "1"
pytest_report_path = test_report_path.replace('python-unittest', 'python-pytest')
os.makedirs(pytest_report_path, exist_ok=True)
# part of our xml parsing looks for grandparent folder names
pytest_report_path = os.path.join(pytest_report_path, f"{test_filename}.xml")
print(f'Test results will be stored in {pytest_report_path}')
# mac slower on 4 proc than 3
num_procs = 3 if "macos" in os.environ["BUILD_ENVIRONMENT"] else 4
exit_code = pytest.main(args=[inspect.getfile(sys._getframe(1)), f'-n={num_procs}', '-vv', '-x',
'--reruns=2', '-rfEsX', f'--junit-xml-reruns={pytest_report_path}'])
del os.environ["USING_PYTEST"]
sanitize_pytest_xml(f'{pytest_report_path}')
# exitcode of 5 means no tests were found, which happens since some test configs don't
# run tests from certain files
exit(0 if exit_code == 5 else exit_code)
else:
os.makedirs(test_report_path, exist_ok=True) os.makedirs(test_report_path, exist_ok=True)
verbose = '--verbose' in argv or '-v' in argv verbose = '--verbose' in argv or '-v' in argv
if verbose: if verbose:
print('Test results will be stored in {}'.format(test_report_path)) print(f'Test results will be stored in {test_report_path}')
unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner( unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
output=test_report_path, output=test_report_path,
verbosity=2 if verbose else 1, verbosity=2 if verbose else 1,
@ -1504,13 +1538,16 @@ def remove_device_and_dtype_suffixes(test_name: str) -> str:
def check_if_enable(test: unittest.TestCase): def check_if_enable(test: unittest.TestCase):
test_suite = str(test.__class__).split('\'')[1] test_suite = str(test.__class__).split('\'')[1]
if "USING_PYTEST" in os.environ:
test_suite = f"__main__.{test_suite.split('.')[1]}"
raw_test_name = f'{test._testMethodName} ({test_suite})' raw_test_name = f'{test._testMethodName} ({test_suite})'
if slow_tests_dict is not None and raw_test_name in slow_tests_dict: if raw_test_name in json.loads(os.environ.get("SLOW_TESTS_DICT", "{}")):
getattr(test, test._testMethodName).__dict__['slow_test'] = True getattr(test, test._testMethodName).__dict__['slow_test'] = True
if not TEST_WITH_SLOW: if not TEST_WITH_SLOW:
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test") raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName) sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName)
if not IS_SANDCASTLE and disabled_tests_dict is not None: if not IS_SANDCASTLE and "DISABLED_TESTS_DICT" in os.environ:
disabled_tests_dict = json.loads(os.environ["DISABLED_TESTS_DICT"])
for disabled_test, (issue_url, platforms) in disabled_tests_dict.items(): for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
disable_test_parts = disabled_test.split() disable_test_parts = disabled_test.split()
if len(disable_test_parts) > 1: if len(disable_test_parts) > 1: