mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Reordering tests experiment (#106347)"
This reverts commit 7dfab082be9eaeeee95c7b0363e59c824c6a9009. Reverted https://github.com/pytorch/pytorch/pull/106347 on behalf of https://github.com/clee2000 due to probably broke sharding ([comment](https://github.com/pytorch/pytorch/pull/106347#issuecomment-1675542738))
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,7 +19,6 @@ coverage.xml
|
||||
**/.pytorch-disabled-tests.json
|
||||
**/.pytorch-slow-tests.json
|
||||
**/.pytorch-test-times.json
|
||||
**/.pytorch-test-file-ratings.json
|
||||
*/*.pyc
|
||||
*/*.so*
|
||||
*/**/__pycache__
|
||||
|
185
test/run_test.py
185
test/run_test.py
@ -11,10 +11,9 @@ import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Any, cast, Dict, List, NamedTuple, Optional, Union
|
||||
from typing import Any, cast, Dict, List, Optional, Union
|
||||
|
||||
import pkg_resources
|
||||
|
||||
@ -41,11 +40,11 @@ try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from tools.stats.export_test_times import TEST_TIMES_FILE
|
||||
from tools.stats.upload_stats_lib import emit_metric
|
||||
from tools.testing.test_selections import (
|
||||
calculate_shards,
|
||||
get_reordered_tests,
|
||||
get_test_case_configs,
|
||||
log_time_savings,
|
||||
NUM_PROCS,
|
||||
ShardedTest,
|
||||
THRESHOLD,
|
||||
@ -1279,9 +1278,7 @@ def exclude_tests(
|
||||
return selected_tests
|
||||
|
||||
|
||||
def must_serial(file: Union[str, ShardedTest]) -> bool:
|
||||
if isinstance(file, ShardedTest):
|
||||
file = file.name
|
||||
def must_serial(file: str) -> bool:
|
||||
return (
|
||||
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
|
||||
or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
|
||||
@ -1411,10 +1408,20 @@ def get_selected_tests(options) -> List[ShardedTest]:
|
||||
)
|
||||
|
||||
selected_tests = [parse_test_module(x) for x in selected_tests]
|
||||
return selected_tests
|
||||
|
||||
# sharding
|
||||
which_shard, num_shards = 1, 1
|
||||
if options.shard:
|
||||
assert len(options.shard) == 2, "Unexpected shard format"
|
||||
assert min(options.shard) > 0, "Shards must be positive numbers"
|
||||
which_shard, num_shards = options.shard
|
||||
assert (
|
||||
which_shard <= num_shards
|
||||
), "Selected shard must be less than or equal to total number of shards"
|
||||
assert num_shards <= len(
|
||||
selected_tests
|
||||
), f"Number of shards must be less than {len(selected_tests)}"
|
||||
|
||||
def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
|
||||
# Download previous test times to make sharding decisions
|
||||
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
|
||||
if os.path.exists(path):
|
||||
@ -1427,35 +1434,14 @@ def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
|
||||
print(
|
||||
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
|
||||
)
|
||||
return {}
|
||||
else:
|
||||
print("Found test time stats from artifacts")
|
||||
return test_file_times[test_config]
|
||||
|
||||
|
||||
def do_sharding(
|
||||
options,
|
||||
selected_tests: List[str],
|
||||
test_file_times: Dict[str, float],
|
||||
sort_by_time: bool = True,
|
||||
) -> List[ShardedTest]:
|
||||
which_shard, num_shards = 1, 1
|
||||
if options.shard:
|
||||
assert len(options.shard) == 2, "Unexpected shard format"
|
||||
assert min(options.shard) > 0, "Shards must be positive numbers"
|
||||
which_shard, num_shards = options.shard
|
||||
assert (
|
||||
which_shard <= num_shards
|
||||
), "Selected shard must be less than or equal to total number of shards"
|
||||
|
||||
if HAVE_TEST_SELECTION_TOOLS:
|
||||
# Do sharding
|
||||
test_file_times_config = test_file_times.get(test_config, {})
|
||||
shards = calculate_shards(
|
||||
num_shards,
|
||||
selected_tests,
|
||||
test_file_times,
|
||||
must_serial=must_serial,
|
||||
sort_by_time=sort_by_time,
|
||||
num_shards, selected_tests, test_file_times_config, must_serial=must_serial
|
||||
)
|
||||
_, tests_from_shard = shards[which_shard - 1]
|
||||
selected_tests = tests_from_shard
|
||||
@ -1463,14 +1449,9 @@ def do_sharding(
|
||||
return selected_tests
|
||||
|
||||
|
||||
class TestFailure(NamedTuple):
|
||||
test: str
|
||||
message: str
|
||||
|
||||
|
||||
def run_test_module(
|
||||
test: Union[ShardedTest, str], test_directory: str, options
|
||||
) -> Optional[TestFailure]:
|
||||
) -> Optional[str]:
|
||||
maybe_set_hip_visible_devies()
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
@ -1491,24 +1472,39 @@ def run_test_module(
|
||||
# return code -N, where N is the signal number.
|
||||
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
|
||||
message += f" Received signal: {signal_name}"
|
||||
return TestFailure(test, message)
|
||||
return message
|
||||
|
||||
|
||||
def run_tests(
|
||||
selected_tests: List[ShardedTest],
|
||||
test_directory: str,
|
||||
options,
|
||||
failures: List[TestFailure],
|
||||
selected_tests: List[ShardedTest], test_directory: str, options, group_name: str
|
||||
) -> None:
|
||||
failure_messages = []
|
||||
|
||||
if len(selected_tests) == 0:
|
||||
return
|
||||
print_to_stderr(f"No tests in group `{group_name}`")
|
||||
return failure_messages
|
||||
|
||||
# parallel = in parallel with other files
|
||||
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
|
||||
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
|
||||
selected_tests_parallel = [
|
||||
x
|
||||
for x in selected_tests
|
||||
if not must_serial(x.name if isinstance(x, ShardedTest) else x)
|
||||
]
|
||||
selected_tests_serial = [
|
||||
x for x in selected_tests if x not in selected_tests_parallel
|
||||
]
|
||||
print(f"TEST GROUP: {group_name}")
|
||||
print_to_stderr(
|
||||
"parallel (file granularity) tests :\n {}".format(
|
||||
"\n".join(str(x) for x in selected_tests_parallel)
|
||||
)
|
||||
)
|
||||
print_to_stderr(
|
||||
"serial (file granularity) tests:\n {}".format(
|
||||
"\n ".join(str(x) for x in selected_tests_serial)
|
||||
)
|
||||
)
|
||||
|
||||
# See Note [ROCm parallel CI testing]
|
||||
pool = get_context("spawn").Pool(
|
||||
@ -1527,15 +1523,15 @@ def run_tests(
|
||||
# Take the conftest file from the test directory
|
||||
shutil.copy(os.path.join(test_directory, "conftest.py"), cpp_conftest_file)
|
||||
|
||||
def handle_error_messages(failure: Optional[TestFailure]):
|
||||
if failure is None:
|
||||
def handle_error_messages(err_message):
|
||||
if err_message is None:
|
||||
return False
|
||||
failures.append(failure)
|
||||
print_to_stderr(failure.message)
|
||||
failure_messages.append(err_message)
|
||||
print_to_stderr(err_message)
|
||||
return True
|
||||
|
||||
def parallel_test_completion_callback(failure):
|
||||
test_failed = handle_error_messages(failure)
|
||||
def parallel_test_completion_callback(err_message):
|
||||
test_failed = handle_error_messages(err_message)
|
||||
if (
|
||||
test_failed
|
||||
and not options.continue_through_error
|
||||
@ -1561,10 +1557,10 @@ def run_tests(
|
||||
if (
|
||||
not options.continue_through_error
|
||||
and not RERUN_DISABLED_TESTS
|
||||
and len(failures) != 0
|
||||
and len(failure_messages) != 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
"\n".join(x.message for x in failures)
|
||||
"\n".join(failure_messages)
|
||||
+ "\n\nTip: You can keep running tests even on failure by "
|
||||
"passing --keep-going to run_test.py.\n"
|
||||
"If running on CI, add the 'keep-going' label to "
|
||||
@ -1575,20 +1571,20 @@ def run_tests(
|
||||
options_clone = copy.deepcopy(options)
|
||||
if can_run_in_pytest(test):
|
||||
options_clone.pytest = True
|
||||
failure = run_test_module(test, test_directory, options_clone)
|
||||
test_failed = handle_error_messages(failure)
|
||||
err_message = run_test_module(test, test_directory, options_clone)
|
||||
test_failed = handle_error_messages(err_message)
|
||||
if (
|
||||
test_failed
|
||||
and not options.continue_through_error
|
||||
and not RERUN_DISABLED_TESTS
|
||||
):
|
||||
raise RuntimeError(failure.message)
|
||||
raise RuntimeError(err_message)
|
||||
|
||||
finally:
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
return
|
||||
return failure_messages
|
||||
|
||||
|
||||
def check_pip_packages() -> None:
|
||||
@ -1615,47 +1611,30 @@ def main():
|
||||
test_directory = str(REPO_ROOT / "test")
|
||||
selected_tests = get_selected_tests(options)
|
||||
|
||||
if options.verbose:
|
||||
print_to_stderr(
|
||||
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
|
||||
)
|
||||
|
||||
if options.dry_run:
|
||||
return
|
||||
|
||||
if options.coverage and not PYTORCH_COLLECT_COVERAGE:
|
||||
shell(["coverage", "erase"])
|
||||
|
||||
prioritized_tests = []
|
||||
general_tests = selected_tests
|
||||
remaining_tests = selected_tests
|
||||
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
|
||||
(prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests)
|
||||
log_time_savings(
|
||||
selected_tests,
|
||||
prioritized_tests,
|
||||
is_serial_test_fn=must_serial,
|
||||
num_procs=NUM_PROCS,
|
||||
)
|
||||
|
||||
# downloading test cases configuration to local environment
|
||||
get_test_case_configs(dirpath=test_directory)
|
||||
(prioritized_tests, general_tests) = get_reordered_tests(general_tests)
|
||||
|
||||
metrics_dict = {
|
||||
"prioritized_tests": prioritized_tests,
|
||||
"general_tests": general_tests,
|
||||
"cpp": options.cpp,
|
||||
}
|
||||
|
||||
test_times_dict = download_test_times(TEST_TIMES_FILE)
|
||||
prioritized_tests = do_sharding(
|
||||
options, prioritized_tests, test_times_dict, sort_by_time=False
|
||||
)
|
||||
general_tests = do_sharding(options, general_tests, test_times_dict)
|
||||
|
||||
if options.verbose:
|
||||
|
||||
def print_tests(category, tests):
|
||||
tests_str = "\n ".join(str(x) for x in tests)
|
||||
print_to_stderr(f"{category} tests:\n {tests_str}")
|
||||
|
||||
print_tests(
|
||||
"Prioritized parallel", [x for x in prioritized_tests if not must_serial(x)]
|
||||
)
|
||||
print_tests(
|
||||
"Prioritized serial", [x for x in prioritized_tests if must_serial(x)]
|
||||
)
|
||||
print_tests(
|
||||
"General parallel", [x for x in general_tests if not must_serial(x)]
|
||||
)
|
||||
print_tests("General serial", [x for x in general_tests if must_serial(x)])
|
||||
|
||||
if options.dry_run:
|
||||
return
|
||||
|
||||
if options.dynamo:
|
||||
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
|
||||
@ -1667,17 +1646,17 @@ def main():
|
||||
|
||||
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
|
||||
|
||||
prioritized_failures: List[TestFailure] = []
|
||||
general_failures: List[TestFailure] = []
|
||||
start_time = time.time()
|
||||
failure_messages = []
|
||||
|
||||
# First run the prioritized tests, then the remaining tests.
|
||||
try:
|
||||
run_tests(prioritized_tests, test_directory, options, prioritized_failures)
|
||||
metrics_dict["prioritized_failures"] = [x.test for x in prioritized_failures]
|
||||
metrics_dict["general_start_time"] = time.time() - start_time
|
||||
run_tests(general_tests, test_directory, options, general_failures)
|
||||
metrics_dict["general_end_time"] = time.time() - start_time
|
||||
metrics_dict["all_failures"] = [x.test for x in general_failures]
|
||||
failure_messages = run_tests(
|
||||
prioritized_tests, test_directory, options, "Prioritized tests"
|
||||
)
|
||||
|
||||
failure_messages += run_tests(
|
||||
remaining_tests, test_directory, options, "General tests"
|
||||
)
|
||||
|
||||
finally:
|
||||
if options.coverage:
|
||||
@ -1692,12 +1671,8 @@ def main():
|
||||
if not PYTORCH_COLLECT_COVERAGE:
|
||||
cov.html_report()
|
||||
|
||||
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
|
||||
emit_metric("td_experiment_1", metrics_dict)
|
||||
|
||||
all_failures = prioritized_failures + general_failures
|
||||
if len(all_failures) != 0:
|
||||
for _, err in all_failures:
|
||||
if len(failure_messages) != 0:
|
||||
for err in failure_messages:
|
||||
print_to_stderr(err)
|
||||
|
||||
# A disabled test is expected to fail, so there is no need to report a failure here
|
||||
|
@ -19,8 +19,6 @@ IGNORE_DISABLED_ISSUES: List[str] = get_disabled_issues()
|
||||
|
||||
SLOW_TESTS_FILE = ".pytorch-slow-tests.json"
|
||||
DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json"
|
||||
TEST_FILE_RATINGS_FILE = ".pytorch-test-file-ratings.json"
|
||||
|
||||
|
||||
FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
|
||||
|
||||
@ -118,14 +116,3 @@ def get_disabled_tests(
|
||||
except Exception:
|
||||
print("Couldn't download test skip set, leaving all tests enabled...")
|
||||
return {}
|
||||
|
||||
|
||||
def get_test_file_ratings(
|
||||
dirpath: str, filename: str = TEST_FILE_RATINGS_FILE
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/file_test_rating.json"
|
||||
try:
|
||||
return fetch_and_cache(dirpath, filename, url, lambda x: x)
|
||||
except Exception:
|
||||
print("Couldn't download test file ratings file, not reordering...")
|
||||
return {}
|
||||
|
@ -249,7 +249,7 @@ class EnvVarMetric:
|
||||
value = os.environ.get(self.env_var)
|
||||
if value is None and self.required:
|
||||
raise ValueError(
|
||||
f"Missing {self.name}. Please set the {self.env_var} "
|
||||
f"Missing {self.name}. Please set the {self.env_var}"
|
||||
"environment variable to pass in this value."
|
||||
)
|
||||
if self.type_conversion_fn:
|
||||
|
@ -394,24 +394,28 @@ class TestParsePrevTests(unittest.TestCase):
|
||||
"tools.testing.test_selections._get_modified_tests",
|
||||
return_value={"test2", "test4"},
|
||||
)
|
||||
@mock.patch(
|
||||
"tools.testing.test_selections._get_file_rating_tests", return_value=["test1"]
|
||||
)
|
||||
def test_get_reordered_tests(
|
||||
self,
|
||||
mock_get_prev_failing_tests: Any,
|
||||
mock_get_modified_tests: Any,
|
||||
mock_get_file_rating_tests: Any,
|
||||
self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any
|
||||
) -> None:
|
||||
tests = ["test1", "test2", "test3", "test4", "test5"]
|
||||
tests = [
|
||||
ShardedTest(name="test1", shard=1, num_shards=2, time=600.0),
|
||||
ShardedTest(name="test2", shard=1, num_shards=2, time=500.0),
|
||||
ShardedTest(name="test3", shard=1, num_shards=2, time=400.0),
|
||||
ShardedTest(name="test4", shard=1, num_shards=2, time=300.0),
|
||||
ShardedTest(name="test5", shard=1, num_shards=2, time=200.0),
|
||||
]
|
||||
|
||||
expected_prioritized_tests = ["test4", "test2", "test1"]
|
||||
expected_remaining_tests = {"test3", "test5"}
|
||||
expected_prioritized_tests = {"test4", "test2"}
|
||||
expected_remaining_tests = {"test1", "test3", "test5"}
|
||||
|
||||
prioritized_tests, remaining_tests = get_reordered_tests(tests)
|
||||
|
||||
self.assertListEqual(expected_prioritized_tests, prioritized_tests)
|
||||
self.assertSetEqual(expected_remaining_tests, set(remaining_tests))
|
||||
# Just want to check the names of the tests
|
||||
prioritized_tests_name = {test.name for test in prioritized_tests}
|
||||
remaining_tests_name = {test.name for test in remaining_tests}
|
||||
|
||||
self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
|
||||
self.assertSetEqual(expected_remaining_tests, remaining_tests_name)
|
||||
|
||||
def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None:
|
||||
tests = [
|
||||
|
@ -3,24 +3,16 @@ import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
from warnings import warn
|
||||
|
||||
from tools.shared.logging_utils import duration_to_str, pluralize
|
||||
|
||||
from tools.stats.import_test_stats import (
|
||||
get_disabled_tests,
|
||||
get_slow_tests,
|
||||
get_test_file_ratings,
|
||||
TEST_FILE_RATINGS_FILE,
|
||||
)
|
||||
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
||||
from tools.stats.upload_stats_lib import emit_metric
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
||||
|
||||
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
|
||||
@ -89,8 +81,8 @@ def get_with_pytest_shard(
|
||||
) -> List[ShardedTest]:
|
||||
sharded_tests: List[ShardedTest] = []
|
||||
for test in tests:
|
||||
duration = test_file_times.get(test, None)
|
||||
if duration and duration > THRESHOLD:
|
||||
duration = test_file_times[test]
|
||||
if duration > THRESHOLD:
|
||||
num_shards = math.ceil(duration / THRESHOLD)
|
||||
for i in range(num_shards):
|
||||
sharded_tests.append(
|
||||
@ -106,24 +98,20 @@ def calculate_shards(
|
||||
tests: List[str],
|
||||
test_file_times: Dict[str, float],
|
||||
must_serial: Optional[Callable[[str], bool]] = None,
|
||||
sort_by_time: bool = True,
|
||||
) -> List[Tuple[float, List[ShardedTest]]]:
|
||||
must_serial = must_serial or (lambda x: True)
|
||||
|
||||
known_tests = tests
|
||||
unknown_tests = []
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests = [x for x in tests if x not in known_tests]
|
||||
|
||||
known_tests = get_with_pytest_shard(known_tests, test_file_times)
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
||||
sorted_tests = sorted(
|
||||
get_with_pytest_shard(known_tests, test_file_times),
|
||||
key=lambda j: j.get_time(),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
|
||||
for test in known_tests:
|
||||
for test in sorted_tests:
|
||||
if must_serial(test.name):
|
||||
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
||||
min_sharded_job.serial.append(test)
|
||||
@ -139,7 +127,7 @@ def calculate_shards(
|
||||
return [job.convert_to_tuple() for job in sharded_jobs]
|
||||
|
||||
|
||||
def _query_changed_files() -> List[str]:
|
||||
def _query_changed_test_files() -> List[str]:
|
||||
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
||||
merge_base = (
|
||||
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
|
||||
@ -198,7 +186,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st
|
||||
|
||||
def _get_modified_tests() -> Set[str]:
|
||||
try:
|
||||
changed_files = _query_changed_files()
|
||||
changed_files = _query_changed_test_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
# If unable to get changed files from git, quit without doing any sorting
|
||||
@ -283,86 +271,78 @@ def log_time_savings(
|
||||
return max_time_savings_sec
|
||||
|
||||
|
||||
def _get_file_rating_tests() -> List[str]:
|
||||
path = REPO_ROOT / "test" / TEST_FILE_RATINGS_FILE
|
||||
if not os.path.exists(path):
|
||||
print(f"could not find path {path}")
|
||||
return []
|
||||
with open(path) as f:
|
||||
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
|
||||
try:
|
||||
changed_files = _query_changed_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
return []
|
||||
ratings: Dict[str, float] = defaultdict(float)
|
||||
for file in changed_files:
|
||||
for test_file, score in test_file_ratings.get(file, {}).items():
|
||||
ratings[test_file] += score
|
||||
prioritize = sorted(ratings, key=lambda x: ratings[x])
|
||||
return prioritize
|
||||
|
||||
|
||||
def get_reordered_tests(
|
||||
tests: List[str],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
tests: List[ShardedTest],
|
||||
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
|
||||
"""
|
||||
Get the reordered test filename list based on github PR history or git changed file.
|
||||
We prioritize running test files that were changed.
|
||||
"""
|
||||
prioritized_tests: List[str] = []
|
||||
|
||||
def add_tests(
|
||||
tests_to_add: Union[List[str], Set[str]], test_group_description: str
|
||||
) -> None:
|
||||
if not tests_to_add:
|
||||
def print_tests(tests: Set[str], test_group_description: str) -> None:
|
||||
if not tests:
|
||||
return
|
||||
|
||||
print(f"{test_group_description}:")
|
||||
for test in tests_to_add:
|
||||
if test in tests:
|
||||
print(f" {test}")
|
||||
if test not in prioritized_tests:
|
||||
prioritized_tests.append(test)
|
||||
for test in tests:
|
||||
print(f" {test}")
|
||||
|
||||
add_tests(
|
||||
_get_previously_failing_tests(),
|
||||
"If run, these tests will prioritized because they previously failed",
|
||||
prioritized_tests: Set[str] = set()
|
||||
|
||||
pri_test = _get_previously_failing_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will prioritized because they previously failed"
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
add_tests(
|
||||
_get_modified_tests(),
|
||||
"If run, these tests will be prioritized because they were modified",
|
||||
pri_test |= _get_modified_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will be prioritized because they were modified"
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
add_tests(
|
||||
_get_file_rating_tests(),
|
||||
"If run, these tests will be preioritized for an experiment in TD",
|
||||
)
|
||||
bring_to_front = []
|
||||
the_rest = []
|
||||
|
||||
prioritized_tests = [x for x in prioritized_tests if x in tests]
|
||||
the_rest = [x for x in tests if x not in prioritized_tests]
|
||||
for test in tests:
|
||||
if test.name in prioritized_tests:
|
||||
bring_to_front.append(test)
|
||||
else:
|
||||
the_rest.append(test)
|
||||
|
||||
if prioritized_tests:
|
||||
test_cnt_str = pluralize(len(tests), "test")
|
||||
if len(tests) != len(bring_to_front) + len(the_rest):
|
||||
print(
|
||||
f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
|
||||
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
|
||||
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
|
||||
)
|
||||
return ([], tests)
|
||||
|
||||
prioritized_test_names = []
|
||||
remaining_test_names = []
|
||||
if bring_to_front:
|
||||
test_cnt_str = pluralize(len(tests), "test")
|
||||
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
|
||||
|
||||
prioritized_test_names = [t.name for t in bring_to_front]
|
||||
print(f"Prioritized: {prioritized_test_names}")
|
||||
remaining_test_names = [t.name for t in the_rest]
|
||||
print(f"The Rest: {remaining_test_names}")
|
||||
else:
|
||||
print("Didn't find any tests to prioritize")
|
||||
|
||||
emit_metric(
|
||||
"test_reordering_prioritized_tests",
|
||||
{
|
||||
"prioritized_test_cnt": len(prioritized_tests),
|
||||
"prioritized_test_cnt": len(bring_to_front),
|
||||
"total_test_cnt": len(tests),
|
||||
"prioritized_tests": prioritized_tests,
|
||||
"remaining_tests": the_rest,
|
||||
"prioritized_tests": prioritized_test_names,
|
||||
"remaining_tests": remaining_test_names,
|
||||
},
|
||||
)
|
||||
|
||||
return (prioritized_tests, the_rest)
|
||||
return (bring_to_front, the_rest)
|
||||
|
||||
|
||||
def get_test_case_configs(dirpath: str) -> None:
|
||||
get_slow_tests(dirpath=dirpath)
|
||||
get_disabled_tests(dirpath=dirpath)
|
||||
get_test_file_ratings(dirpath=dirpath)
|
||||
|
Reference in New Issue
Block a user