mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reordering tests experiment (#106347)
Companion with https://github.com/pytorch/test-infra/pull/4424 Uses the file rating generated by the test infra PR to re order tests. For each test file, sum the file ratings from the changed files in the PR, and put the tests in order of sum. A lot of tests are probably going to end up as "prioritized" since it takes anything with a rating > 0 right now. Sharding is done twice, once on the prioritized tests, and once on the general/non prioritized tests. Prioritized tests have an order, so they should be sharded according to that order, while general tests don't have an order and are sharded by test time, which should result in more balanced shards. I'll change the metric name before I merge, i want to quarantine my testing stuff from actual results Pull Request resolved: https://github.com/pytorch/pytorch/pull/106347 Approved by: https://github.com/ZainRizvi
This commit is contained in:
committed by
PyTorch MergeBot
parent
a44c072c89
commit
7dfab082be
@ -3,16 +3,24 @@ import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||
from warnings import warn
|
||||
|
||||
from tools.shared.logging_utils import duration_to_str, pluralize
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
||||
from tools.stats.import_test_stats import (
|
||||
get_disabled_tests,
|
||||
get_slow_tests,
|
||||
get_test_file_ratings,
|
||||
TEST_FILE_RATINGS_FILE,
|
||||
)
|
||||
from tools.stats.upload_stats_lib import emit_metric
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
||||
|
||||
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
|
||||
@ -81,8 +89,8 @@ def get_with_pytest_shard(
|
||||
) -> List[ShardedTest]:
|
||||
sharded_tests: List[ShardedTest] = []
|
||||
for test in tests:
|
||||
duration = test_file_times[test]
|
||||
if duration > THRESHOLD:
|
||||
duration = test_file_times.get(test, None)
|
||||
if duration and duration > THRESHOLD:
|
||||
num_shards = math.ceil(duration / THRESHOLD)
|
||||
for i in range(num_shards):
|
||||
sharded_tests.append(
|
||||
@ -98,20 +106,24 @@ def calculate_shards(
|
||||
tests: List[str],
|
||||
test_file_times: Dict[str, float],
|
||||
must_serial: Optional[Callable[[str], bool]] = None,
|
||||
sort_by_time: bool = True,
|
||||
) -> List[Tuple[float, List[ShardedTest]]]:
|
||||
must_serial = must_serial or (lambda x: True)
|
||||
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
|
||||
known_tests = tests
|
||||
unknown_tests = []
|
||||
|
||||
sorted_tests = sorted(
|
||||
get_with_pytest_shard(known_tests, test_file_times),
|
||||
key=lambda j: j.get_time(),
|
||||
reverse=True,
|
||||
)
|
||||
if sort_by_time:
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests = [x for x in tests if x not in known_tests]
|
||||
|
||||
known_tests = get_with_pytest_shard(known_tests, test_file_times)
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
||||
|
||||
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
|
||||
for test in sorted_tests:
|
||||
for test in known_tests:
|
||||
if must_serial(test.name):
|
||||
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
||||
min_sharded_job.serial.append(test)
|
||||
@ -127,7 +139,7 @@ def calculate_shards(
|
||||
return [job.convert_to_tuple() for job in sharded_jobs]
|
||||
|
||||
|
||||
def _query_changed_test_files() -> List[str]:
|
||||
def _query_changed_files() -> List[str]:
|
||||
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
||||
merge_base = (
|
||||
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
|
||||
@ -186,7 +198,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st
|
||||
|
||||
def _get_modified_tests() -> Set[str]:
|
||||
try:
|
||||
changed_files = _query_changed_test_files()
|
||||
changed_files = _query_changed_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
# If unable to get changed files from git, quit without doing any sorting
|
||||
@ -271,78 +283,86 @@ def log_time_savings(
|
||||
return max_time_savings_sec
|
||||
|
||||
|
||||
def _get_file_rating_tests() -> List[str]:
|
||||
path = REPO_ROOT / "test" / TEST_FILE_RATINGS_FILE
|
||||
if not os.path.exists(path):
|
||||
print(f"could not find path {path}")
|
||||
return []
|
||||
with open(path) as f:
|
||||
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
|
||||
try:
|
||||
changed_files = _query_changed_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
return []
|
||||
ratings: Dict[str, float] = defaultdict(float)
|
||||
for file in changed_files:
|
||||
for test_file, score in test_file_ratings.get(file, {}).items():
|
||||
ratings[test_file] += score
|
||||
prioritize = sorted(ratings, key=lambda x: ratings[x])
|
||||
return prioritize
|
||||
|
||||
|
||||
def get_reordered_tests(
|
||||
tests: List[ShardedTest],
|
||||
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
|
||||
tests: List[str],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Get the reordered test filename list based on github PR history or git changed file.
|
||||
We prioritize running test files that were changed.
|
||||
"""
|
||||
prioritized_tests: List[str] = []
|
||||
|
||||
def print_tests(tests: Set[str], test_group_description: str) -> None:
|
||||
if not tests:
|
||||
def add_tests(
|
||||
tests_to_add: Union[List[str], Set[str]], test_group_description: str
|
||||
) -> None:
|
||||
if not tests_to_add:
|
||||
return
|
||||
|
||||
print(f"{test_group_description}:")
|
||||
for test in tests:
|
||||
print(f" {test}")
|
||||
for test in tests_to_add:
|
||||
if test in tests:
|
||||
print(f" {test}")
|
||||
if test not in prioritized_tests:
|
||||
prioritized_tests.append(test)
|
||||
|
||||
prioritized_tests: Set[str] = set()
|
||||
|
||||
pri_test = _get_previously_failing_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will prioritized because they previously failed"
|
||||
add_tests(
|
||||
_get_previously_failing_tests(),
|
||||
"If run, these tests will prioritized because they previously failed",
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
pri_test |= _get_modified_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will be prioritized because they were modified"
|
||||
add_tests(
|
||||
_get_modified_tests(),
|
||||
"If run, these tests will be prioritized because they were modified",
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
bring_to_front = []
|
||||
the_rest = []
|
||||
add_tests(
|
||||
_get_file_rating_tests(),
|
||||
"If run, these tests will be preioritized for an experiment in TD",
|
||||
)
|
||||
|
||||
for test in tests:
|
||||
if test.name in prioritized_tests:
|
||||
bring_to_front.append(test)
|
||||
else:
|
||||
the_rest.append(test)
|
||||
prioritized_tests = [x for x in prioritized_tests if x in tests]
|
||||
the_rest = [x for x in tests if x not in prioritized_tests]
|
||||
|
||||
if len(tests) != len(bring_to_front) + len(the_rest):
|
||||
print(
|
||||
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
|
||||
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
|
||||
)
|
||||
return ([], tests)
|
||||
|
||||
prioritized_test_names = []
|
||||
remaining_test_names = []
|
||||
if bring_to_front:
|
||||
if prioritized_tests:
|
||||
test_cnt_str = pluralize(len(tests), "test")
|
||||
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
|
||||
|
||||
prioritized_test_names = [t.name for t in bring_to_front]
|
||||
print(f"Prioritized: {prioritized_test_names}")
|
||||
remaining_test_names = [t.name for t in the_rest]
|
||||
print(f"The Rest: {remaining_test_names}")
|
||||
else:
|
||||
print("Didn't find any tests to prioritize")
|
||||
print(
|
||||
f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
|
||||
)
|
||||
|
||||
emit_metric(
|
||||
"test_reordering_prioritized_tests",
|
||||
{
|
||||
"prioritized_test_cnt": len(bring_to_front),
|
||||
"prioritized_test_cnt": len(prioritized_tests),
|
||||
"total_test_cnt": len(tests),
|
||||
"prioritized_tests": prioritized_test_names,
|
||||
"remaining_tests": remaining_test_names,
|
||||
"prioritized_tests": prioritized_tests,
|
||||
"remaining_tests": the_rest,
|
||||
},
|
||||
)
|
||||
|
||||
return (bring_to_front, the_rest)
|
||||
return (prioritized_tests, the_rest)
|
||||
|
||||
|
||||
def get_test_case_configs(dirpath: str) -> None:
|
||||
get_slow_tests(dirpath=dirpath)
|
||||
get_disabled_tests(dirpath=dirpath)
|
||||
get_test_file_ratings(dirpath=dirpath)
|
||||
|
Reference in New Issue
Block a user