mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding run-specified-test-cases option in run_test.py (#59487)
Summary: The run-specified-test-cases option would allow us to specify a list of test cases to run by having a CSV with minimally two columns: test_filename and test_case_name. This PR also adds .json to some files we use for better clarity. Usage: `python test/run_test.py --run-specified-test-cases <csv_file>` where the csv file can look like: ``` test_filename,test_case_name,test_total_time,windows_only_failure_sha_count,total_sha_count,windows_failure_count,linux_failure_count,windows_total_count,linux_total_count test_cuda,test_cudnn_multiple_threads_same_device,8068.8409659525,46,3768,53,0,2181,6750 test_utils,test_load_standalone,8308.8062920459,14,4630,65,0,2718,8729 test_ops,test_forward_mode_AD_acosh_cuda_complex128,91.652619369806,11,1971,26,1,1197,3825 test_ops,test_forward_mode_AD_acos_cuda_complex128,91.825633094915,11,1971,26,1,1197,3825 test_profiler,test_source,60.93786725749,9,4656,21,3,2742,8805 test_profiler,test_profiler_tracing,203.09352795241,9,4662,21,3,2737,8807 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/59487 Test Plan: Without specifying the option, everything should be as they were before. Running `python test/run_test.py --run-specified-test-cases windows_smoke_tests.csv` resulted in this paste P420276949 (you can see internally). A snippet looks like: ``` (pytorch) janeyx@janeyx-mbp pytorch % python test/run_test.py --run-specified-test-cases windows_smoke_tests.csv Loading specified test cases to run from windows_smoke_tests.csv. Processed 28 test cases. Running test_cpp_extensions_jit ... [2021-06-04 17:24:41.213644] Executing ['/Users/janeyx/miniconda3/envs/pytorch/bin/python', 'test_cpp_extensions_jit.py', '-k', 'test_jit_cuda_archflags'] ... [2021-06-04 17:24:41.213781] s ---------------------------------------------------------------------- Ran 1 test in 0.000s OK (skipped=1) ... ``` With pytest, an example executable would be: `Running test_dataloader ... [2021-06-04 17:37:57.643039] Executing ['/Users/janeyx/miniconda3/envs/pytorch/bin/python', '-m', 'pytest', 'test_dataloader.py', '-v', '-k', 'test_segfault or test_timeout'] ... [2021-06-04 17:37:57.643327]` Reviewed By: samestep Differential Revision: D28913223 Pulled By: janeyx99 fbshipit-source-id: 0d1f9910973426b8756815c697b483160517b127
This commit is contained in:
committed by
Facebook GitHub Bot
parent
caf76c2445
commit
24432eaa29
4
.gitignore
vendored
4
.gitignore
vendored
@ -15,8 +15,8 @@ coverage.xml
|
||||
.hypothesis
|
||||
.mypy_cache
|
||||
/.extracted_scripts/
|
||||
**/.pytorch-test-times
|
||||
**/.pytorch-slow-tests
|
||||
**/.pytorch-test-times.json
|
||||
**/.pytorch-slow-tests.json
|
||||
*/*.pyc
|
||||
*/*.so*
|
||||
*/**/__pycache__
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import csv
|
||||
from datetime import datetime
|
||||
import json
|
||||
import modulefinder
|
||||
@ -339,7 +340,7 @@ TARGET_DET_LIST = [
|
||||
]
|
||||
|
||||
# the JSON file to store the S3 test stats
|
||||
TEST_TIMES_FILE = '.pytorch-test-times'
|
||||
TEST_TIMES_FILE = '.pytorch-test-times.json'
|
||||
|
||||
# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST
|
||||
SLOW_TEST_THRESHOLD = 300
|
||||
@ -389,6 +390,21 @@ JIT_EXECUTOR_TESTS = [
|
||||
'test_jit_fuser_legacy',
|
||||
]
|
||||
|
||||
# Dictionary matching test modules (in TESTS) to lists of test cases (within that test_module) that would be run when
|
||||
# options.run_specified_test_cases is enabled.
|
||||
# For example:
|
||||
# {
|
||||
# "test_nn": ["test_doubletensor_avg_pool3d", "test_share_memory", "test_hook_requires_grad"],
|
||||
# ...
|
||||
# }
|
||||
# For test_nn.py, we would ONLY run test_doubletensor_avg_pool3d, test_share_memory, and test_hook_requires_grad.
|
||||
SPECIFIED_TEST_CASES_DICT: Dict[str, List[str]] = {}
|
||||
|
||||
# The file from which the SPECIFIED_TEST_CASES_DICT will be filled, a CSV of test cases that would be run when
|
||||
# options.run_specified_test_cases is enabled.
|
||||
SPECIFIED_TEST_CASES_FILE: str = '.pytorch_specified_test_cases.csv'
|
||||
|
||||
|
||||
def print_to_stderr(message):
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
@ -515,6 +531,23 @@ def get_slow_tests_based_on_S3() -> List[str]:
|
||||
return slow_tests
|
||||
|
||||
|
||||
def get_test_case_args(test_module, using_pytest) -> List[str]:
|
||||
if test_module not in SPECIFIED_TEST_CASES_DICT:
|
||||
sys.exit(f'Warning! Test module {test_module} is not found in the specified tests dict. This should never'
|
||||
'happen as we make a check for that before entering this function.')
|
||||
args = []
|
||||
|
||||
if using_pytest:
|
||||
args.append('-k')
|
||||
args.append(' or '.join(SPECIFIED_TEST_CASES_DICT[test_module]))
|
||||
else:
|
||||
for test in SPECIFIED_TEST_CASES_DICT[test_module]:
|
||||
args.append('-k')
|
||||
args.append(test)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_executable_command(options, allow_pytest, disable_coverage=False):
|
||||
if options.coverage and not disable_coverage:
|
||||
executable = ['coverage', 'run', '--parallel-mode', '--source=torch']
|
||||
@ -542,10 +575,6 @@ def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unit
|
||||
if options.pytest:
|
||||
unittest_args = [arg if arg != '-f' else '-x' for arg in unittest_args]
|
||||
|
||||
# Can't call `python -m unittest test_*` here because it doesn't run code
|
||||
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
|
||||
argv = [test_module + '.py'] + unittest_args
|
||||
|
||||
# Multiprocessing related tests cannot run with coverage.
|
||||
# Tracking issue: https://github.com/pytorch/pytorch/issues/50661
|
||||
disable_coverage = sys.platform == 'win32' and test_module in WINDOWS_COVERAGE_BLOCKLIST
|
||||
@ -554,6 +583,15 @@ def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unit
|
||||
executable = get_executable_command(options, allow_pytest=not extra_unittest_args,
|
||||
disable_coverage=disable_coverage)
|
||||
|
||||
# The following logic for running specified tests will only run for non-distributed tests, as those are dispatched
|
||||
# to test_distributed and not run_test (this function)
|
||||
if options.run_specified_test_cases:
|
||||
unittest_args.extend(get_test_case_args(test_module, 'pytest' in executable))
|
||||
|
||||
# Can't call `python -m unittest test_*` here because it doesn't run code
|
||||
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
|
||||
argv = [test_module + '.py'] + unittest_args
|
||||
|
||||
command = (launcher_cmd or []) + executable + argv
|
||||
print_to_stderr('Executing {} ... [{}]'.format(command, datetime.now()))
|
||||
return shell(command, test_directory)
|
||||
@ -729,8 +767,7 @@ def parse_args():
|
||||
default=TESTS,
|
||||
metavar='TESTS',
|
||||
help='select a set of tests to include (defaults to ALL tests).'
|
||||
' tests can be specified with module name, module.TestClass'
|
||||
' or module.TestClass.test_method')
|
||||
' tests must be a part of the TESTS list defined in run_test.py')
|
||||
parser.add_argument(
|
||||
'-x',
|
||||
'--exclude',
|
||||
@ -796,6 +833,14 @@ def parse_args():
|
||||
action='store_true',
|
||||
help='exclude tests that are run for a specific jit config'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--run-specified-test-cases',
|
||||
nargs='?',
|
||||
type=str,
|
||||
const=SPECIFIED_TEST_CASES_FILE,
|
||||
help='runs specified test cases from previous OSS CI stats from a file, format CSV',
|
||||
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -866,6 +911,10 @@ def get_selected_tests(options):
|
||||
if options.exclude_jit_executor:
|
||||
options.exclude.extend(JIT_EXECUTOR_TESTS)
|
||||
|
||||
if options.run_specified_test_cases:
|
||||
# Filter out any unspecified test modules.
|
||||
selected_tests = [t for t in selected_tests if t in SPECIFIED_TEST_CASES_DICT]
|
||||
|
||||
selected_tests = exclude_tests(options.exclude, selected_tests)
|
||||
|
||||
if sys.platform == 'win32' and not options.ignore_win_blocklist:
|
||||
@ -1053,6 +1102,33 @@ def export_S3_test_times(test_times_filename: str, test_times: Dict[str, float])
|
||||
file.write('\n')
|
||||
|
||||
|
||||
def load_specified_test_cases(filename: str) -> None:
|
||||
if not os.path.exists(filename):
|
||||
print(f'Could not find specified tests file: {filename}. Proceeding with default behavior.')
|
||||
return
|
||||
|
||||
# The below encoding is utf-8-sig because utf-8 doesn't properly handle the byte-order-mark character
|
||||
with open(filename, mode='r', encoding="utf-8-sig") as csv_file:
|
||||
csv_reader = csv.DictReader(csv_file)
|
||||
line_count = 0
|
||||
global SPECIFIED_TEST_CASES_DICT
|
||||
for row in csv_reader:
|
||||
line_count += 1
|
||||
if line_count == 1:
|
||||
if 'test_filename' not in row or 'test_case_name' not in row:
|
||||
print('Data is missing necessary columns for test specification. Proceeding with default behavior.')
|
||||
return
|
||||
test_filename = row['test_filename']
|
||||
test_case_name = row['test_case_name']
|
||||
if test_filename not in TESTS:
|
||||
print(f'Specified test_filename {test_filename} not found in TESTS. Skipping.')
|
||||
continue
|
||||
if test_filename not in SPECIFIED_TEST_CASES_DICT:
|
||||
SPECIFIED_TEST_CASES_DICT[test_filename] = []
|
||||
SPECIFIED_TEST_CASES_DICT[test_filename].append(test_case_name)
|
||||
print(f'Processed {line_count} test cases.')
|
||||
|
||||
|
||||
def query_changed_test_files() -> List[str]:
|
||||
cmd = ["git", "diff", "--name-only", "origin/master", "HEAD"]
|
||||
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
@ -1104,6 +1180,11 @@ def main():
|
||||
export_S3_test_times(test_times_filename, pull_job_times_from_S3())
|
||||
return
|
||||
|
||||
specified_test_cases_filename = options.run_specified_test_cases
|
||||
if specified_test_cases_filename:
|
||||
print(f'Loading specified test cases to run from {specified_test_cases_filename}.')
|
||||
load_specified_test_cases(specified_test_cases_filename)
|
||||
|
||||
test_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
selected_tests = get_selected_tests(options)
|
||||
|
||||
|
@ -858,7 +858,7 @@ def check_slow_test_from_stats(test):
|
||||
if slow_tests_dict is None:
|
||||
if not IS_SANDCASTLE and os.getenv("PYTORCH_RUN_DISABLED_TESTS", "0") != "1":
|
||||
url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json"
|
||||
slow_tests_dict = fetch_and_cache(".pytorch-slow-tests", url)
|
||||
slow_tests_dict = fetch_and_cache(".pytorch-slow-tests.json", url)
|
||||
else:
|
||||
slow_tests_dict = {}
|
||||
test_suite = str(test.__class__).split('\'')[1]
|
||||
@ -1043,7 +1043,6 @@ class TestCase(expecttest.TestCase):
|
||||
result.stop()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
check_slow_test_from_stats(self)
|
||||
if TEST_SKIP_FAST:
|
||||
if not getattr(self, self._testMethodName).__dict__.get('slow_test', False):
|
||||
|
Reference in New Issue
Block a user