mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Reordering tests experiment (#106347)"
This reverts commit 7dfab082be9eaeeee95c7b0363e59c824c6a9009. Reverted https://github.com/pytorch/pytorch/pull/106347 on behalf of https://github.com/clee2000 due to probably broke sharding ([comment](https://github.com/pytorch/pytorch/pull/106347#issuecomment-1675542738))
This commit is contained in:
@ -3,24 +3,16 @@ import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
from warnings import warn
|
||||
|
||||
from tools.shared.logging_utils import duration_to_str, pluralize
|
||||
|
||||
from tools.stats.import_test_stats import (
|
||||
get_disabled_tests,
|
||||
get_slow_tests,
|
||||
get_test_file_ratings,
|
||||
TEST_FILE_RATINGS_FILE,
|
||||
)
|
||||
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
||||
from tools.stats.upload_stats_lib import emit_metric
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
||||
|
||||
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
|
||||
@ -89,8 +81,8 @@ def get_with_pytest_shard(
|
||||
) -> List[ShardedTest]:
|
||||
sharded_tests: List[ShardedTest] = []
|
||||
for test in tests:
|
||||
duration = test_file_times.get(test, None)
|
||||
if duration and duration > THRESHOLD:
|
||||
duration = test_file_times[test]
|
||||
if duration > THRESHOLD:
|
||||
num_shards = math.ceil(duration / THRESHOLD)
|
||||
for i in range(num_shards):
|
||||
sharded_tests.append(
|
||||
@ -106,24 +98,20 @@ def calculate_shards(
|
||||
tests: List[str],
|
||||
test_file_times: Dict[str, float],
|
||||
must_serial: Optional[Callable[[str], bool]] = None,
|
||||
sort_by_time: bool = True,
|
||||
) -> List[Tuple[float, List[ShardedTest]]]:
|
||||
must_serial = must_serial or (lambda x: True)
|
||||
|
||||
known_tests = tests
|
||||
unknown_tests = []
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = [x for x in tests if x in test_file_times]
|
||||
unknown_tests = [x for x in tests if x not in known_tests]
|
||||
|
||||
known_tests = get_with_pytest_shard(known_tests, test_file_times)
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
||||
sorted_tests = sorted(
|
||||
get_with_pytest_shard(known_tests, test_file_times),
|
||||
key=lambda j: j.get_time(),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
|
||||
for test in known_tests:
|
||||
for test in sorted_tests:
|
||||
if must_serial(test.name):
|
||||
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
||||
min_sharded_job.serial.append(test)
|
||||
@ -139,7 +127,7 @@ def calculate_shards(
|
||||
return [job.convert_to_tuple() for job in sharded_jobs]
|
||||
|
||||
|
||||
def _query_changed_files() -> List[str]:
|
||||
def _query_changed_test_files() -> List[str]:
|
||||
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
||||
merge_base = (
|
||||
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
|
||||
@ -198,7 +186,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st
|
||||
|
||||
def _get_modified_tests() -> Set[str]:
|
||||
try:
|
||||
changed_files = _query_changed_files()
|
||||
changed_files = _query_changed_test_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
# If unable to get changed files from git, quit without doing any sorting
|
||||
@ -283,86 +271,78 @@ def log_time_savings(
|
||||
return max_time_savings_sec
|
||||
|
||||
|
||||
def _get_file_rating_tests() -> List[str]:
|
||||
path = REPO_ROOT / "test" / TEST_FILE_RATINGS_FILE
|
||||
if not os.path.exists(path):
|
||||
print(f"could not find path {path}")
|
||||
return []
|
||||
with open(path) as f:
|
||||
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
|
||||
try:
|
||||
changed_files = _query_changed_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
return []
|
||||
ratings: Dict[str, float] = defaultdict(float)
|
||||
for file in changed_files:
|
||||
for test_file, score in test_file_ratings.get(file, {}).items():
|
||||
ratings[test_file] += score
|
||||
prioritize = sorted(ratings, key=lambda x: ratings[x])
|
||||
return prioritize
|
||||
|
||||
|
||||
def get_reordered_tests(
|
||||
tests: List[str],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
tests: List[ShardedTest],
|
||||
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
|
||||
"""
|
||||
Get the reordered test filename list based on github PR history or git changed file.
|
||||
We prioritize running test files that were changed.
|
||||
"""
|
||||
prioritized_tests: List[str] = []
|
||||
|
||||
def add_tests(
|
||||
tests_to_add: Union[List[str], Set[str]], test_group_description: str
|
||||
) -> None:
|
||||
if not tests_to_add:
|
||||
def print_tests(tests: Set[str], test_group_description: str) -> None:
|
||||
if not tests:
|
||||
return
|
||||
|
||||
print(f"{test_group_description}:")
|
||||
for test in tests_to_add:
|
||||
if test in tests:
|
||||
print(f" {test}")
|
||||
if test not in prioritized_tests:
|
||||
prioritized_tests.append(test)
|
||||
for test in tests:
|
||||
print(f" {test}")
|
||||
|
||||
add_tests(
|
||||
_get_previously_failing_tests(),
|
||||
"If run, these tests will prioritized because they previously failed",
|
||||
prioritized_tests: Set[str] = set()
|
||||
|
||||
pri_test = _get_previously_failing_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will prioritized because they previously failed"
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
add_tests(
|
||||
_get_modified_tests(),
|
||||
"If run, these tests will be prioritized because they were modified",
|
||||
pri_test |= _get_modified_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will be prioritized because they were modified"
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
add_tests(
|
||||
_get_file_rating_tests(),
|
||||
"If run, these tests will be preioritized for an experiment in TD",
|
||||
)
|
||||
bring_to_front = []
|
||||
the_rest = []
|
||||
|
||||
prioritized_tests = [x for x in prioritized_tests if x in tests]
|
||||
the_rest = [x for x in tests if x not in prioritized_tests]
|
||||
for test in tests:
|
||||
if test.name in prioritized_tests:
|
||||
bring_to_front.append(test)
|
||||
else:
|
||||
the_rest.append(test)
|
||||
|
||||
if prioritized_tests:
|
||||
test_cnt_str = pluralize(len(tests), "test")
|
||||
if len(tests) != len(bring_to_front) + len(the_rest):
|
||||
print(
|
||||
f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
|
||||
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
|
||||
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
|
||||
)
|
||||
return ([], tests)
|
||||
|
||||
prioritized_test_names = []
|
||||
remaining_test_names = []
|
||||
if bring_to_front:
|
||||
test_cnt_str = pluralize(len(tests), "test")
|
||||
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
|
||||
|
||||
prioritized_test_names = [t.name for t in bring_to_front]
|
||||
print(f"Prioritized: {prioritized_test_names}")
|
||||
remaining_test_names = [t.name for t in the_rest]
|
||||
print(f"The Rest: {remaining_test_names}")
|
||||
else:
|
||||
print("Didn't find any tests to prioritize")
|
||||
|
||||
emit_metric(
|
||||
"test_reordering_prioritized_tests",
|
||||
{
|
||||
"prioritized_test_cnt": len(prioritized_tests),
|
||||
"prioritized_test_cnt": len(bring_to_front),
|
||||
"total_test_cnt": len(tests),
|
||||
"prioritized_tests": prioritized_tests,
|
||||
"remaining_tests": the_rest,
|
||||
"prioritized_tests": prioritized_test_names,
|
||||
"remaining_tests": remaining_test_names,
|
||||
},
|
||||
)
|
||||
|
||||
return (prioritized_tests, the_rest)
|
||||
return (bring_to_front, the_rest)
|
||||
|
||||
|
||||
def get_test_case_configs(dirpath: str) -> None:
|
||||
get_slow_tests(dirpath=dirpath)
|
||||
get_disabled_tests(dirpath=dirpath)
|
||||
get_test_file_ratings(dirpath=dirpath)
|
||||
|
||||
Reference in New Issue
Block a user