Reordering tests experiment (#106347)

Companion with https://github.com/pytorch/test-infra/pull/4424

Uses the file rating generated by the test infra PR to re order tests.  For each test file, sum the file ratings from the changed files in the PR, and put the tests in order of sum.

A lot of tests are probably going to end up as "prioritized" since it takes anything with a rating > 0 right now.

Sharding is done twice, once on the prioritized tests, and once on the general/non prioritized tests.  Prioritized tests have an order, so they should be sharded according to that order, while general tests don't have an order and are sharded by test time, which should result in more balanced shards.

I'll change the metric name before I merge, i want to quarantine my testing stuff from actual results

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106347
Approved by: https://github.com/ZainRizvi
This commit is contained in:
Catherine Lee
2023-08-09 20:11:09 +00:00
committed by PyTorch MergeBot
parent a44c072c89
commit 7dfab082be
6 changed files with 211 additions and 156 deletions

View File

@ -3,16 +3,24 @@ import json
import math
import os
import subprocess
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple, Union
from warnings import warn
from tools.shared.logging_utils import duration_to_str, pluralize
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
from tools.stats.import_test_stats import (
get_disabled_tests,
get_slow_tests,
get_test_file_ratings,
TEST_FILE_RATINGS_FILE,
)
from tools.stats.upload_stats_lib import emit_metric
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
@ -81,8 +89,8 @@ def get_with_pytest_shard(
) -> List[ShardedTest]:
sharded_tests: List[ShardedTest] = []
for test in tests:
duration = test_file_times[test]
if duration > THRESHOLD:
duration = test_file_times.get(test, None)
if duration and duration > THRESHOLD:
num_shards = math.ceil(duration / THRESHOLD)
for i in range(num_shards):
sharded_tests.append(
@ -98,20 +106,24 @@ def calculate_shards(
tests: List[str],
test_file_times: Dict[str, float],
must_serial: Optional[Callable[[str], bool]] = None,
sort_by_time: bool = True,
) -> List[Tuple[float, List[ShardedTest]]]:
must_serial = must_serial or (lambda x: True)
known_tests = [x for x in tests if x in test_file_times]
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
known_tests = tests
unknown_tests = []
sorted_tests = sorted(
get_with_pytest_shard(known_tests, test_file_times),
key=lambda j: j.get_time(),
reverse=True,
)
if sort_by_time:
known_tests = [x for x in tests if x in test_file_times]
unknown_tests = [x for x in tests if x not in known_tests]
known_tests = get_with_pytest_shard(known_tests, test_file_times)
if sort_by_time:
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
for test in sorted_tests:
for test in known_tests:
if must_serial(test.name):
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.serial.append(test)
@ -127,7 +139,7 @@ def calculate_shards(
return [job.convert_to_tuple() for job in sharded_jobs]
def _query_changed_test_files() -> List[str]:
def _query_changed_files() -> List[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
merge_base = (
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
@ -186,7 +198,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st
def _get_modified_tests() -> Set[str]:
try:
changed_files = _query_changed_test_files()
changed_files = _query_changed_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
# If unable to get changed files from git, quit without doing any sorting
@ -271,78 +283,86 @@ def log_time_savings(
return max_time_savings_sec
def _get_file_rating_tests() -> List[str]:
path = REPO_ROOT / "test" / TEST_FILE_RATINGS_FILE
if not os.path.exists(path):
print(f"could not find path {path}")
return []
with open(path) as f:
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
try:
changed_files = _query_changed_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
return []
ratings: Dict[str, float] = defaultdict(float)
for file in changed_files:
for test_file, score in test_file_ratings.get(file, {}).items():
ratings[test_file] += score
prioritize = sorted(ratings, key=lambda x: ratings[x])
return prioritize
def get_reordered_tests(
tests: List[ShardedTest],
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
tests: List[str],
) -> Tuple[List[str], List[str]]:
"""
Get the reordered test filename list based on github PR history or git changed file.
We prioritize running test files that were changed.
"""
prioritized_tests: List[str] = []
def print_tests(tests: Set[str], test_group_description: str) -> None:
if not tests:
def add_tests(
tests_to_add: Union[List[str], Set[str]], test_group_description: str
) -> None:
if not tests_to_add:
return
print(f"{test_group_description}:")
for test in tests:
print(f" {test}")
for test in tests_to_add:
if test in tests:
print(f" {test}")
if test not in prioritized_tests:
prioritized_tests.append(test)
prioritized_tests: Set[str] = set()
pri_test = _get_previously_failing_tests()
print_tests(
pri_test, "If run, these tests will prioritized because they previously failed"
add_tests(
_get_previously_failing_tests(),
"If run, these tests will prioritized because they previously failed",
)
prioritized_tests |= pri_test
pri_test |= _get_modified_tests()
print_tests(
pri_test, "If run, these tests will be prioritized because they were modified"
add_tests(
_get_modified_tests(),
"If run, these tests will be prioritized because they were modified",
)
prioritized_tests |= pri_test
bring_to_front = []
the_rest = []
add_tests(
_get_file_rating_tests(),
"If run, these tests will be preioritized for an experiment in TD",
)
for test in tests:
if test.name in prioritized_tests:
bring_to_front.append(test)
else:
the_rest.append(test)
prioritized_tests = [x for x in prioritized_tests if x in tests]
the_rest = [x for x in tests if x not in prioritized_tests]
if len(tests) != len(bring_to_front) + len(the_rest):
print(
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
)
return ([], tests)
prioritized_test_names = []
remaining_test_names = []
if bring_to_front:
if prioritized_tests:
test_cnt_str = pluralize(len(tests), "test")
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
prioritized_test_names = [t.name for t in bring_to_front]
print(f"Prioritized: {prioritized_test_names}")
remaining_test_names = [t.name for t in the_rest]
print(f"The Rest: {remaining_test_names}")
else:
print("Didn't find any tests to prioritize")
print(
f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
)
emit_metric(
"test_reordering_prioritized_tests",
{
"prioritized_test_cnt": len(bring_to_front),
"prioritized_test_cnt": len(prioritized_tests),
"total_test_cnt": len(tests),
"prioritized_tests": prioritized_test_names,
"remaining_tests": remaining_test_names,
"prioritized_tests": prioritized_tests,
"remaining_tests": the_rest,
},
)
return (bring_to_front, the_rest)
return (prioritized_tests, the_rest)
def get_test_case_configs(dirpath: str) -> None:
get_slow_tests(dirpath=dirpath)
get_disabled_tests(dirpath=dirpath)
get_test_file_ratings(dirpath=dirpath)