mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Reland] Adding run specified tests option to run_test.py (#59649)
Summary: Reland of https://github.com/pytorch/pytorch/issues/59487 Pull Request resolved: https://github.com/pytorch/pytorch/pull/59649 Reviewed By: samestep Differential Revision: D28970751 Pulled By: janeyx99 fbshipit-source-id: 6e28d4dcfdab8a49da4b6a02c57516b08bacd7b5
This commit is contained in:
committed by
Facebook GitHub Bot
parent
51884c6479
commit
97dfc7e300
4
.gitignore
vendored
4
.gitignore
vendored
@ -15,8 +15,8 @@ coverage.xml
|
||||
.hypothesis
|
||||
.mypy_cache
|
||||
/.extracted_scripts/
|
||||
**/.pytorch-test-times
|
||||
**/.pytorch-slow-tests
|
||||
**/.pytorch-test-times.json
|
||||
**/.pytorch-slow-tests.json
|
||||
*/*.pyc
|
||||
*/*.so*
|
||||
*/**/__pycache__
|
||||
|
@ -123,6 +123,6 @@ python setup.py install --cmake && sccache --show-stats && (
|
||||
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\caffe2 && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||
|
||||
:: export test times so that potential sharded tests that'll branch off this build will use consistent data
|
||||
python test/run_test.py --export-past-test-times %PYTORCH_FINAL_PACKAGE_DIR%/.pytorch-test-times
|
||||
python test/run_test.py --export-past-test-times %PYTORCH_FINAL_PACKAGE_DIR%/.pytorch-test-times.json
|
||||
)
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
|
||||
|
||||
echo Copying over test times file
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times" "%TEST_DIR_WIN%"
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%"
|
||||
|
||||
pushd test
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
|
||||
|
||||
echo Copying over test times file
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times" "%TEST_DIR_WIN%"
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%"
|
||||
|
||||
pushd test
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
|
||||
|
||||
echo Copying over test times file
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times" "%TEST_DIR_WIN%"
|
||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%"
|
||||
|
||||
cd test && python run_test.py --exclude-jit-executor --shard 2 2 --verbose --determine-from="%1" && cd ..
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import csv
|
||||
from datetime import datetime
|
||||
import json
|
||||
import modulefinder
|
||||
@ -339,7 +340,7 @@ TARGET_DET_LIST = [
|
||||
]
|
||||
|
||||
# the JSON file to store the S3 test stats
|
||||
TEST_TIMES_FILE = '.pytorch-test-times'
|
||||
TEST_TIMES_FILE = '.pytorch-test-times.json'
|
||||
|
||||
# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST
|
||||
SLOW_TEST_THRESHOLD = 300
|
||||
@ -389,6 +390,21 @@ JIT_EXECUTOR_TESTS = [
|
||||
'test_jit_fuser_legacy',
|
||||
]
|
||||
|
||||
# Dictionary matching test modules (in TESTS) to lists of test cases (within that test_module) that would be run when
|
||||
# options.run_specified_test_cases is enabled.
|
||||
# For example:
|
||||
# {
|
||||
# "test_nn": ["test_doubletensor_avg_pool3d", "test_share_memory", "test_hook_requires_grad"],
|
||||
# ...
|
||||
# }
|
||||
# For test_nn.py, we would ONLY run test_doubletensor_avg_pool3d, test_share_memory, and test_hook_requires_grad.
|
||||
SPECIFIED_TEST_CASES_DICT: Dict[str, List[str]] = {}
|
||||
|
||||
# The file from which the SPECIFIED_TEST_CASES_DICT will be filled, a CSV of test cases that would be run when
|
||||
# options.run_specified_test_cases is enabled.
|
||||
SPECIFIED_TEST_CASES_FILE: str = '.pytorch_specified_test_cases.csv'
|
||||
|
||||
|
||||
def print_to_stderr(message):
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
@ -515,6 +531,23 @@ def get_slow_tests_based_on_S3() -> List[str]:
|
||||
return slow_tests
|
||||
|
||||
|
||||
def get_test_case_args(test_module, using_pytest) -> List[str]:
|
||||
if test_module not in SPECIFIED_TEST_CASES_DICT:
|
||||
sys.exit(f'Warning! Test module {test_module} is not found in the specified tests dict. This should never'
|
||||
'happen as we make a check for that before entering this function.')
|
||||
args = []
|
||||
|
||||
if using_pytest:
|
||||
args.append('-k')
|
||||
args.append(' or '.join(SPECIFIED_TEST_CASES_DICT[test_module]))
|
||||
else:
|
||||
for test in SPECIFIED_TEST_CASES_DICT[test_module]:
|
||||
args.append('-k')
|
||||
args.append(test)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_executable_command(options, allow_pytest, disable_coverage=False):
|
||||
if options.coverage and not disable_coverage:
|
||||
executable = ['coverage', 'run', '--parallel-mode', '--source=torch']
|
||||
@ -542,10 +575,6 @@ def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unit
|
||||
if options.pytest:
|
||||
unittest_args = [arg if arg != '-f' else '-x' for arg in unittest_args]
|
||||
|
||||
# Can't call `python -m unittest test_*` here because it doesn't run code
|
||||
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
|
||||
argv = [test_module + '.py'] + unittest_args
|
||||
|
||||
# Multiprocessing related tests cannot run with coverage.
|
||||
# Tracking issue: https://github.com/pytorch/pytorch/issues/50661
|
||||
disable_coverage = sys.platform == 'win32' and test_module in WINDOWS_COVERAGE_BLOCKLIST
|
||||
@ -554,6 +583,15 @@ def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unit
|
||||
executable = get_executable_command(options, allow_pytest=not extra_unittest_args,
|
||||
disable_coverage=disable_coverage)
|
||||
|
||||
# The following logic for running specified tests will only run for non-distributed tests, as those are dispatched
|
||||
# to test_distributed and not run_test (this function)
|
||||
if options.run_specified_test_cases:
|
||||
unittest_args.extend(get_test_case_args(test_module, 'pytest' in executable))
|
||||
|
||||
# Can't call `python -m unittest test_*` here because it doesn't run code
|
||||
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
|
||||
argv = [test_module + '.py'] + unittest_args
|
||||
|
||||
command = (launcher_cmd or []) + executable + argv
|
||||
print_to_stderr('Executing {} ... [{}]'.format(command, datetime.now()))
|
||||
return shell(command, test_directory)
|
||||
@ -729,8 +767,7 @@ def parse_args():
|
||||
default=TESTS,
|
||||
metavar='TESTS',
|
||||
help='select a set of tests to include (defaults to ALL tests).'
|
||||
' tests can be specified with module name, module.TestClass'
|
||||
' or module.TestClass.test_method')
|
||||
' tests must be a part of the TESTS list defined in run_test.py')
|
||||
parser.add_argument(
|
||||
'-x',
|
||||
'--exclude',
|
||||
@ -796,6 +833,14 @@ def parse_args():
|
||||
action='store_true',
|
||||
help='exclude tests that are run for a specific jit config'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--run-specified-test-cases',
|
||||
nargs='?',
|
||||
type=str,
|
||||
const=SPECIFIED_TEST_CASES_FILE,
|
||||
help='runs specified test cases from previous OSS CI stats from a file, format CSV',
|
||||
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -866,6 +911,10 @@ def get_selected_tests(options):
|
||||
if options.exclude_jit_executor:
|
||||
options.exclude.extend(JIT_EXECUTOR_TESTS)
|
||||
|
||||
if options.run_specified_test_cases:
|
||||
# Filter out any unspecified test modules.
|
||||
selected_tests = [t for t in selected_tests if t in SPECIFIED_TEST_CASES_DICT]
|
||||
|
||||
selected_tests = exclude_tests(options.exclude, selected_tests)
|
||||
|
||||
if sys.platform == 'win32' and not options.ignore_win_blocklist:
|
||||
@ -1053,6 +1102,33 @@ def export_S3_test_times(test_times_filename: str, test_times: Dict[str, float])
|
||||
file.write('\n')
|
||||
|
||||
|
||||
def load_specified_test_cases(filename: str) -> None:
|
||||
if not os.path.exists(filename):
|
||||
print(f'Could not find specified tests file: {filename}. Proceeding with default behavior.')
|
||||
return
|
||||
|
||||
# The below encoding is utf-8-sig because utf-8 doesn't properly handle the byte-order-mark character
|
||||
with open(filename, mode='r', encoding="utf-8-sig") as csv_file:
|
||||
csv_reader = csv.DictReader(csv_file)
|
||||
line_count = 0
|
||||
global SPECIFIED_TEST_CASES_DICT
|
||||
for row in csv_reader:
|
||||
line_count += 1
|
||||
if line_count == 1:
|
||||
if 'test_filename' not in row or 'test_case_name' not in row:
|
||||
print('Data is missing necessary columns for test specification. Proceeding with default behavior.')
|
||||
return
|
||||
test_filename = row['test_filename']
|
||||
test_case_name = row['test_case_name']
|
||||
if test_filename not in TESTS:
|
||||
print(f'Specified test_filename {test_filename} not found in TESTS. Skipping.')
|
||||
continue
|
||||
if test_filename not in SPECIFIED_TEST_CASES_DICT:
|
||||
SPECIFIED_TEST_CASES_DICT[test_filename] = []
|
||||
SPECIFIED_TEST_CASES_DICT[test_filename].append(test_case_name)
|
||||
print(f'Processed {line_count} test cases.')
|
||||
|
||||
|
||||
def query_changed_test_files() -> List[str]:
|
||||
cmd = ["git", "diff", "--name-only", "origin/master", "HEAD"]
|
||||
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
@ -1104,6 +1180,11 @@ def main():
|
||||
export_S3_test_times(test_times_filename, pull_job_times_from_S3())
|
||||
return
|
||||
|
||||
specified_test_cases_filename = options.run_specified_test_cases
|
||||
if specified_test_cases_filename:
|
||||
print(f'Loading specified test cases to run from {specified_test_cases_filename}.')
|
||||
load_specified_test_cases(specified_test_cases_filename)
|
||||
|
||||
test_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
selected_tests = get_selected_tests(options)
|
||||
|
||||
|
@ -8,7 +8,7 @@ from collections import defaultdict
|
||||
from tools.stats_utils.s3_stat_parser import get_previous_reports_for_branch, Report, Version2Report
|
||||
from typing import cast, DefaultDict, Dict, List
|
||||
|
||||
SLOW_TESTS_FILE = '.pytorch-slow-tests'
|
||||
SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
|
||||
SLOW_TEST_CASE_THRESHOLD_SEC = 60.0
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ def parse_args() -> argparse.Namespace:
|
||||
type=str,
|
||||
default=SLOW_TESTS_FILE,
|
||||
const=SLOW_TESTS_FILE,
|
||||
help='Specify a file path to dump slow test times from previous S3 stats. Default file path: .pytorch-slow-tests',
|
||||
help='Specify a file path to dump slow test times from previous S3 stats. Default file path: .pytorch-slow-tests.json',
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -859,7 +859,7 @@ def check_slow_test_from_stats(test):
|
||||
if slow_tests_dict is None:
|
||||
if not IS_SANDCASTLE and os.getenv("PYTORCH_RUN_DISABLED_TESTS", "0") != "1":
|
||||
url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json"
|
||||
slow_tests_dict = fetch_and_cache(".pytorch-slow-tests", url)
|
||||
slow_tests_dict = fetch_and_cache(".pytorch-slow-tests.json", url)
|
||||
else:
|
||||
slow_tests_dict = {}
|
||||
test_suite = str(test.__class__).split('\'')[1]
|
||||
@ -1044,7 +1044,6 @@ class TestCase(expecttest.TestCase):
|
||||
result.stop()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
check_slow_test_from_stats(self)
|
||||
if TEST_SKIP_FAST:
|
||||
if not getattr(self, self._testMethodName).__dict__.get('slow_test', False):
|
||||
|
Reference in New Issue
Block a user