Reordering tests experiment (#106347)

Companion with https://github.com/pytorch/test-infra/pull/4424

Uses the file rating generated by the test infra PR to re order tests.  For each test file, sum the file ratings from the changed files in the PR, and put the tests in order of sum.

A lot of tests are probably going to end up as "prioritized" since it takes anything with a rating > 0 right now.

Sharding is done twice, once on the prioritized tests, and once on the general/non prioritized tests.  Prioritized tests have an order, so they should be sharded according to that order, while general tests don't have an order and are sharded by test time, which should result in more balanced shards.

I'll change the metric name before I merge, i want to quarantine my testing stuff from actual results

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106347
Approved by: https://github.com/ZainRizvi
This commit is contained in:
Catherine Lee
2023-08-16 18:23:09 +00:00
committed by PyTorch MergeBot
parent 884c03d240
commit f16be5e0d4
14 changed files with 210 additions and 160 deletions

View File

@ -128,6 +128,7 @@ python -c "import os, glob; os.system('python -mpip install --no-index --no-deps
:: export test times so that potential sharded tests that'll branch off this build will use consistent data
python tools/stats/export_test_times.py
copy /Y ".pytorch-test-times.json" "%PYTORCH_FINAL_PACKAGE_DIR%"
copy /Y ".pytorch-test-file-ratings.json" "%PYTORCH_FINAL_PACKAGE_DIR%"
:: Also save build/.ninja_log as an artifact
copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\"

View File

@ -2,6 +2,7 @@ call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
echo Copying over test times file
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%"
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%"
pushd test

View File

@ -23,6 +23,7 @@ if "%SHARD_NUMBER%" == "1" (
echo Copying over test times file
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%"
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%"
echo Run nn tests
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose

2
.circleci/config.yml generated
View File

@ -652,7 +652,7 @@ jobs:
- run:
name: Archive artifacts into zip
command: |
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
cp artifacts.zip /Users/distiller/workspace
- persist_to_workspace:

View File

@ -177,7 +177,7 @@
- run:
name: Archive artifacts into zip
command: |
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
cp artifacts.zip /Users/distiller/workspace
- persist_to_workspace:

View File

@ -170,7 +170,7 @@ jobs:
- name: Archive artifacts into zip
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
run: |
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json .pytorch-test-file-ratings.json
- name: Store PyTorch Build Artifacts on S3
uses: seemethere/upload-artifact-s3@v5

View File

@ -182,7 +182,7 @@ jobs:
- name: Archive artifacts into zip
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
run: |
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
- name: Store PyTorch Build Artifacts on GHA
uses: actions/upload-artifact@v3

1
.gitignore vendored
View File

@ -19,6 +19,7 @@ coverage.xml
**/.pytorch-disabled-tests.json
**/.pytorch-slow-tests.json
**/.pytorch-test-times.json
**/.pytorch-test-file-ratings.json
*/*.pyc
*/*.so*
*/**/__pycache__

View File

@ -11,9 +11,10 @@ import signal
import subprocess
import sys
import tempfile
import time
from datetime import datetime
from distutils.version import LooseVersion
from typing import Any, cast, Dict, List, Optional, Union
from typing import Any, cast, Dict, List, NamedTuple, Optional, Union
import pkg_resources
@ -40,11 +41,11 @@ try:
# using tools/ to optimize test run.
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.export_test_times import TEST_TIMES_FILE
from tools.stats.upload_stats_lib import emit_metric
from tools.testing.test_selections import (
calculate_shards,
get_reordered_tests,
get_test_case_configs,
log_time_savings,
NUM_PROCS,
ShardedTest,
THRESHOLD,
@ -1278,7 +1279,9 @@ def exclude_tests(
return selected_tests
def must_serial(file: str) -> bool:
def must_serial(file: Union[str, ShardedTest]) -> bool:
if isinstance(file, ShardedTest):
file = file.name
return (
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
@ -1408,20 +1411,10 @@ def get_selected_tests(options) -> List[ShardedTest]:
)
selected_tests = [parse_test_module(x) for x in selected_tests]
return selected_tests
# sharding
which_shard, num_shards = 1, 1
if options.shard:
assert len(options.shard) == 2, "Unexpected shard format"
assert min(options.shard) > 0, "Shards must be positive numbers"
which_shard, num_shards = options.shard
assert (
which_shard <= num_shards
), "Selected shard must be less than or equal to total number of shards"
assert num_shards <= len(
selected_tests
), f"Number of shards must be less than {len(selected_tests)}"
def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
# Download previous test times to make sharding decisions
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
if os.path.exists(path):
@ -1434,14 +1427,35 @@ def get_selected_tests(options) -> List[ShardedTest]:
print(
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
)
return {}
else:
print("Found test time stats from artifacts")
return test_file_times[test_config]
def do_sharding(
options,
selected_tests: List[str],
test_file_times: Dict[str, float],
sort_by_time: bool = True,
) -> List[ShardedTest]:
which_shard, num_shards = 1, 1
if options.shard:
assert len(options.shard) == 2, "Unexpected shard format"
assert min(options.shard) > 0, "Shards must be positive numbers"
which_shard, num_shards = options.shard
assert (
which_shard <= num_shards
), "Selected shard must be less than or equal to total number of shards"
if HAVE_TEST_SELECTION_TOOLS:
# Do sharding
test_file_times_config = test_file_times.get(test_config, {})
shards = calculate_shards(
num_shards, selected_tests, test_file_times_config, must_serial=must_serial
num_shards,
selected_tests,
test_file_times,
must_serial=must_serial,
sort_by_time=sort_by_time,
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard
@ -1449,9 +1463,14 @@ def get_selected_tests(options) -> List[ShardedTest]:
return selected_tests
class TestFailure(NamedTuple):
test: str
message: str
def run_test_module(
test: Union[ShardedTest, str], test_directory: str, options
) -> Optional[str]:
) -> Optional[TestFailure]:
maybe_set_hip_visible_devies()
# Printing the date here can help diagnose which tests are slow
@ -1472,39 +1491,24 @@ def run_test_module(
# return code -N, where N is the signal number.
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
message += f" Received signal: {signal_name}"
return message
return TestFailure(test, message)
def run_tests(
selected_tests: List[ShardedTest], test_directory: str, options, group_name: str
selected_tests: List[ShardedTest],
test_directory: str,
options,
failures: List[TestFailure],
) -> None:
failure_messages = []
if len(selected_tests) == 0:
print_to_stderr(f"No tests in group `{group_name}`")
return failure_messages
return
# parallel = in parallel with other files
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
selected_tests_parallel = [
x
for x in selected_tests
if not must_serial(x.name if isinstance(x, ShardedTest) else x)
]
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
selected_tests_serial = [
x for x in selected_tests if x not in selected_tests_parallel
]
print(f"TEST GROUP: {group_name}")
print_to_stderr(
"parallel (file granularity) tests :\n {}".format(
"\n".join(str(x) for x in selected_tests_parallel)
)
)
print_to_stderr(
"serial (file granularity) tests:\n {}".format(
"\n ".join(str(x) for x in selected_tests_serial)
)
)
# See Note [ROCm parallel CI testing]
pool = get_context("spawn").Pool(
@ -1523,15 +1527,15 @@ def run_tests(
# Take the conftest file from the test directory
shutil.copy(os.path.join(test_directory, "conftest.py"), cpp_conftest_file)
def handle_error_messages(err_message):
if err_message is None:
def handle_error_messages(failure: Optional[TestFailure]):
if failure is None:
return False
failure_messages.append(err_message)
print_to_stderr(err_message)
failures.append(failure)
print_to_stderr(failure.message)
return True
def parallel_test_completion_callback(err_message):
test_failed = handle_error_messages(err_message)
def parallel_test_completion_callback(failure):
test_failed = handle_error_messages(failure)
if (
test_failed
and not options.continue_through_error
@ -1557,10 +1561,10 @@ def run_tests(
if (
not options.continue_through_error
and not RERUN_DISABLED_TESTS
and len(failure_messages) != 0
and len(failures) != 0
):
raise RuntimeError(
"\n".join(failure_messages)
"\n".join(x.message for x in failures)
+ "\n\nTip: You can keep running tests even on failure by "
"passing --keep-going to run_test.py.\n"
"If running on CI, add the 'keep-going' label to "
@ -1571,20 +1575,20 @@ def run_tests(
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
options_clone.pytest = True
err_message = run_test_module(test, test_directory, options_clone)
test_failed = handle_error_messages(err_message)
failure = run_test_module(test, test_directory, options_clone)
test_failed = handle_error_messages(failure)
if (
test_failed
and not options.continue_through_error
and not RERUN_DISABLED_TESTS
):
raise RuntimeError(err_message)
raise RuntimeError(failure.message)
finally:
pool.terminate()
pool.join()
return failure_messages
return
def check_pip_packages() -> None:
@ -1611,30 +1615,47 @@ def main():
test_directory = str(REPO_ROOT / "test")
selected_tests = get_selected_tests(options)
if options.verbose:
print_to_stderr(
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
)
if options.dry_run:
return
if options.coverage and not PYTORCH_COLLECT_COVERAGE:
shell(["coverage", "erase"])
prioritized_tests = []
remaining_tests = selected_tests
general_tests = selected_tests
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
(prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests)
log_time_savings(
selected_tests,
prioritized_tests,
is_serial_test_fn=must_serial,
num_procs=NUM_PROCS,
)
# downloading test cases configuration to local environment
get_test_case_configs(dirpath=test_directory)
(prioritized_tests, general_tests) = get_reordered_tests(general_tests)
metrics_dict = {
"prioritized_tests": prioritized_tests,
"general_tests": general_tests,
"cpp": options.cpp,
}
test_times_dict = download_test_times(TEST_TIMES_FILE)
prioritized_tests = do_sharding(
options, prioritized_tests, test_times_dict, sort_by_time=False
)
general_tests = do_sharding(options, general_tests, test_times_dict)
if options.verbose:
def print_tests(category, tests):
tests_str = "\n ".join(str(x) for x in tests)
print_to_stderr(f"{category} tests:\n {tests_str}")
print_tests(
"Prioritized parallel", [x for x in prioritized_tests if not must_serial(x)]
)
print_tests(
"Prioritized serial", [x for x in prioritized_tests if must_serial(x)]
)
print_tests(
"General parallel", [x for x in general_tests if not must_serial(x)]
)
print_tests("General serial", [x for x in general_tests if must_serial(x)])
if options.dry_run:
return
if options.dynamo:
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
@ -1646,17 +1667,17 @@ def main():
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
failure_messages = []
prioritized_failures: List[TestFailure] = []
general_failures: List[TestFailure] = []
start_time = time.time()
# First run the prioritized tests, then the remaining tests.
try:
failure_messages = run_tests(
prioritized_tests, test_directory, options, "Prioritized tests"
)
failure_messages += run_tests(
remaining_tests, test_directory, options, "General tests"
)
run_tests(prioritized_tests, test_directory, options, prioritized_failures)
metrics_dict["prioritized_failures"] = [x.test for x in prioritized_failures]
metrics_dict["general_start_time"] = time.time() - start_time
run_tests(general_tests, test_directory, options, general_failures)
metrics_dict["general_end_time"] = time.time() - start_time
metrics_dict["all_failures"] = [x.test for x in general_failures]
finally:
if options.coverage:
@ -1671,8 +1692,12 @@ def main():
if not PYTORCH_COLLECT_COVERAGE:
cov.html_report()
if len(failure_messages) != 0:
for err in failure_messages:
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
emit_metric("td_experiment_1", metrics_dict)
all_failures = prioritized_failures + general_failures
if len(all_failures) != 0:
for _, err in all_failures:
print_to_stderr(err)
# A disabled test is expected to fail, so there is no need to report a failure here

View File

@ -3,14 +3,16 @@ import sys
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
sys.path.append(str(REPO_ROOT))
from tools.stats.import_test_stats import get_test_times
from tools.stats.import_test_stats import get_test_file_ratings, get_test_times
TEST_TIMES_FILE = ".pytorch-test-times.json"
TEST_FILE_RATINGS_FILE = ".pytorch-test-file-ratings.json"
def main() -> None:
print(f"Exporting test times from test-infra to {TEST_TIMES_FILE}")
get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE)
get_test_file_ratings(str(REPO_ROOT), filename=TEST_FILE_RATINGS_FILE)
if __name__ == "__main__":

View File

@ -20,6 +20,7 @@ IGNORE_DISABLED_ISSUES: List[str] = get_disabled_issues()
SLOW_TESTS_FILE = ".pytorch-slow-tests.json"
DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json"
FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
@ -116,3 +117,12 @@ def get_disabled_tests(
except Exception:
print("Couldn't download test skip set, leaving all tests enabled...")
return {}
def get_test_file_ratings(dirpath: str, filename: str) -> Optional[Dict[str, Any]]:
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/file_test_rating.json"
try:
return fetch_and_cache(dirpath, filename, url, lambda x: x)
except Exception:
print("Couldn't download test file ratings file, not reordering...")
return {}

View File

@ -263,7 +263,7 @@ class EnvVarMetric:
value = os.environ.get(self.env_var)
if value is None and self.required:
raise ValueError(
f"Missing {self.name}. Please set the {self.env_var}"
f"Missing {self.name}. Please set the {self.env_var} "
"environment variable to pass in this value."
)
if self.type_conversion_fn:

View File

@ -394,28 +394,24 @@ class TestParsePrevTests(unittest.TestCase):
"tools.testing.test_selections._get_modified_tests",
return_value={"test2", "test4"},
)
@mock.patch(
"tools.testing.test_selections._get_file_rating_tests", return_value=["test1"]
)
def test_get_reordered_tests(
self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any
self,
mock_get_prev_failing_tests: Any,
mock_get_modified_tests: Any,
mock_get_file_rating_tests: Any,
) -> None:
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=600.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=500.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=400.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=300.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=200.0),
]
tests = ["test1", "test2", "test3", "test4", "test5"]
expected_prioritized_tests = {"test4", "test2"}
expected_remaining_tests = {"test1", "test3", "test5"}
expected_prioritized_tests = ["test4", "test2", "test1"]
expected_remaining_tests = {"test3", "test5"}
prioritized_tests, remaining_tests = get_reordered_tests(tests)
# Just want to check the names of the tests
prioritized_tests_name = {test.name for test in prioritized_tests}
remaining_tests_name = {test.name for test in remaining_tests}
self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
self.assertSetEqual(expected_remaining_tests, remaining_tests_name)
self.assertListEqual(expected_prioritized_tests, prioritized_tests)
self.assertSetEqual(expected_remaining_tests, set(remaining_tests))
def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None:
tests = [

View File

@ -3,16 +3,20 @@ import json
import math
import os
import subprocess
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple
from warnings import warn
from tools.shared.logging_utils import duration_to_str, pluralize
from tools.stats.export_test_times import TEST_FILE_RATINGS_FILE
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
from tools.stats.upload_stats_lib import emit_metric
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
@ -81,8 +85,8 @@ def get_with_pytest_shard(
) -> List[ShardedTest]:
sharded_tests: List[ShardedTest] = []
for test in tests:
duration = test_file_times[test]
if duration > THRESHOLD:
duration = test_file_times.get(test, None)
if duration and duration > THRESHOLD:
num_shards = math.ceil(duration / THRESHOLD)
for i in range(num_shards):
sharded_tests.append(
@ -98,20 +102,24 @@ def calculate_shards(
tests: List[str],
test_file_times: Dict[str, float],
must_serial: Optional[Callable[[str], bool]] = None,
sort_by_time: bool = True,
) -> List[Tuple[float, List[ShardedTest]]]:
must_serial = must_serial or (lambda x: True)
known_tests = [x for x in tests if x in test_file_times]
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
known_tests = tests
unknown_tests = []
sorted_tests = sorted(
get_with_pytest_shard(known_tests, test_file_times),
key=lambda j: j.get_time(),
reverse=True,
)
if sort_by_time:
known_tests = [x for x in tests if x in test_file_times]
unknown_tests = [x for x in tests if x not in known_tests]
known_tests = get_with_pytest_shard(known_tests, test_file_times)
if sort_by_time:
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
for test in sorted_tests:
for test in known_tests:
if must_serial(test.name):
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.serial.append(test)
@ -127,7 +135,7 @@ def calculate_shards(
return [job.convert_to_tuple() for job in sharded_jobs]
def _query_changed_test_files() -> List[str]:
def _query_changed_files() -> List[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
merge_base = (
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
@ -186,7 +194,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st
def _get_modified_tests() -> Set[str]:
try:
changed_files = _query_changed_test_files()
changed_files = _query_changed_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
# If unable to get changed files from git, quit without doing any sorting
@ -271,76 +279,81 @@ def log_time_savings(
return max_time_savings_sec
def _get_file_rating_tests() -> List[str]:
path = REPO_ROOT / TEST_FILE_RATINGS_FILE
if not os.path.exists(path):
print(f"could not find path {path}")
return []
with open(path) as f:
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
try:
changed_files = _query_changed_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
return []
ratings: Dict[str, float] = defaultdict(float)
for file in changed_files:
for test_file, score in test_file_ratings.get(file, {}).items():
ratings[test_file] += score
prioritize = sorted(ratings, key=lambda x: ratings[x])
return prioritize
def get_reordered_tests(
tests: List[ShardedTest],
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
tests: List[str],
) -> Tuple[List[str], List[str]]:
"""
Get the reordered test filename list based on github PR history or git changed file.
We prioritize running test files that were changed.
"""
prioritized_tests: List[str] = []
def print_tests(tests: Set[str], test_group_description: str) -> None:
if not tests:
def add_tests(tests_to_add: List[str], test_group_description: str) -> None:
if not tests_to_add:
return
print(f"{test_group_description}:")
for test in tests:
print(f" {test}")
for test in tests_to_add:
if test in tests:
print(f" {test}")
if test not in prioritized_tests:
prioritized_tests.append(test)
prioritized_tests: Set[str] = set()
pri_test = _get_previously_failing_tests()
print_tests(
pri_test, "If run, these tests will prioritized because they previously failed"
add_tests(
sorted(_get_previously_failing_tests()),
"If run, these tests will prioritized because they previously failed",
)
prioritized_tests |= pri_test
pri_test |= _get_modified_tests()
print_tests(
pri_test, "If run, these tests will be prioritized because they were modified"
add_tests(
sorted(_get_modified_tests()),
"If run, these tests will be prioritized because they were modified",
)
prioritized_tests |= pri_test
bring_to_front = []
the_rest = []
add_tests(
_get_file_rating_tests(),
"If run, these tests will be preioritized for an experiment in TD",
)
for test in tests:
if test.name in prioritized_tests:
bring_to_front.append(test)
else:
the_rest.append(test)
prioritized_tests = [x for x in prioritized_tests if x in tests]
the_rest = [x for x in tests if x not in prioritized_tests]
if len(tests) != len(bring_to_front) + len(the_rest):
print(
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
)
return ([], tests)
prioritized_test_names = []
remaining_test_names = []
if bring_to_front:
if prioritized_tests:
test_cnt_str = pluralize(len(tests), "test")
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
prioritized_test_names = [t.name for t in bring_to_front]
print(f"Prioritized: {prioritized_test_names}")
remaining_test_names = [t.name for t in the_rest]
print(f"The Rest: {remaining_test_names}")
else:
print("Didn't find any tests to prioritize")
print(
f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
)
emit_metric(
"test_reordering_prioritized_tests",
{
"prioritized_test_cnt": len(bring_to_front),
"prioritized_test_cnt": len(prioritized_tests),
"total_test_cnt": len(tests),
"prioritized_tests": prioritized_test_names,
"remaining_tests": remaining_test_names,
"prioritized_tests": prioritized_tests,
"remaining_tests": the_rest,
},
)
return (bring_to_front, the_rest)
return (prioritized_tests, the_rest)
def get_test_case_configs(dirpath: str) -> None: