mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port existing heuristics to TD framework (#107071)
This PR looks big, but it's mostly just refactorings with a bit of dead code deletion. Exceptions are: - Some metric emissions were changed to comply with the new TD format - Some logging changes - We now run tests in three batches (highly_relevant, probably_relevant, unranked_relevance) instead of the previous two (prioritized and general) Refactorings done: - Moves all test reordering code to the new TD framework - Refactors run_test.py to cleanly support multiple levels of test priorities - Deletes some dead code that was originally written for logging Pull Request resolved: https://github.com/pytorch/pytorch/pull/107071 Approved by: https://github.com/clee2000, https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
d7f943ec82
commit
36399d067a
@ -1,19 +1,11 @@
|
||||
import heapq
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
from warnings import warn
|
||||
|
||||
from tools.shared.logging_utils import duration_to_str, pluralize
|
||||
from tools.stats.export_test_times import TEST_FILE_RATINGS_FILE
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
||||
from tools.stats.upload_metrics import emit_metric
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
@ -135,227 +127,6 @@ def calculate_shards(
|
||||
return [job.convert_to_tuple() for job in sharded_jobs]
|
||||
|
||||
|
||||
def _query_changed_files() -> List[str]:
|
||||
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
||||
merge_base = (
|
||||
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
|
||||
head = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
|
||||
|
||||
base_commit = merge_base
|
||||
if base_commit == head:
|
||||
# We are on the default branch, so check for changes since the last commit
|
||||
base_commit = "HEAD^"
|
||||
|
||||
proc = subprocess.run(
|
||||
["git", "diff", "--name-only", base_commit, "HEAD"], capture_output=True
|
||||
)
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError("Unable to get changed files")
|
||||
|
||||
lines = proc.stdout.decode().strip().split("\n")
|
||||
lines = [line.strip() for line in lines]
|
||||
return lines
|
||||
|
||||
|
||||
def _get_previously_failing_tests() -> Set[str]:
|
||||
PYTEST_FAILED_TESTS_CACHE_FILE_PATH = Path(".pytest_cache/v/cache/lastfailed")
|
||||
|
||||
if not PYTEST_FAILED_TESTS_CACHE_FILE_PATH.exists():
|
||||
warn(
|
||||
f"No pytorch cache found at {PYTEST_FAILED_TESTS_CACHE_FILE_PATH.absolute()}"
|
||||
)
|
||||
return set()
|
||||
|
||||
with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH) as f:
|
||||
last_failed_tests = json.load(f)
|
||||
|
||||
prioritized_tests = _parse_prev_failing_test_files(last_failed_tests)
|
||||
return _python_test_file_to_test_name(prioritized_tests)
|
||||
|
||||
|
||||
def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[str]:
|
||||
prioritized_tests = set()
|
||||
|
||||
# The keys are formatted as "test_file.py::test_class::test_method[params]"
|
||||
# We just need the test_file part
|
||||
for test in last_failed_tests:
|
||||
parts = test.split("::")
|
||||
if len(parts) > 1:
|
||||
test_file = parts[0]
|
||||
prioritized_tests.add(test_file)
|
||||
|
||||
return prioritized_tests
|
||||
|
||||
|
||||
def _get_modified_tests() -> Set[str]:
|
||||
try:
|
||||
changed_files = _query_changed_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
# If unable to get changed files from git, quit without doing any sorting
|
||||
return set()
|
||||
|
||||
return _python_test_file_to_test_name(set(changed_files))
|
||||
|
||||
|
||||
def _python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
|
||||
prefix = f"test{os.path.sep}"
|
||||
valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")}
|
||||
valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests}
|
||||
|
||||
return valid_tests
|
||||
|
||||
|
||||
class PoolTimes:
|
||||
def __init__(self, num_procs: int) -> None:
|
||||
self.pool_times = [0.0 for _ in range(num_procs)]
|
||||
self.serial_times = 0.0
|
||||
|
||||
def next_test_start_time(self, serial: bool) -> float:
|
||||
if serial:
|
||||
# Serial tests are run after all parallel tests complete
|
||||
return max(self.pool_times) + self.serial_times
|
||||
|
||||
return self.pool_times[0]
|
||||
|
||||
def schedule_test(self, test: ShardedTest, serial: bool) -> None:
|
||||
if serial:
|
||||
self.serial_times += test.get_time()
|
||||
else:
|
||||
# pool_times[0] is always the thread with the least amount of time scheduled
|
||||
heapq.heappushpop(self.pool_times, self.pool_times[0] + test.get_time())
|
||||
|
||||
|
||||
def log_time_savings(
|
||||
selected_tests: List[ShardedTest],
|
||||
prioritized_tests: List[ShardedTest],
|
||||
is_serial_test_fn: Callable[[str], bool],
|
||||
num_procs: int = NUM_PROCS, # make this customizable for testing
|
||||
) -> float:
|
||||
# The tests will be run in [num_procs] parallel threads, so we assume each test
|
||||
# is allocated to the thread that'll free up first.
|
||||
# This isn't an exact match (since other factors could change which thread
|
||||
# pool a test gets scheduled on) but it's a good approximation.
|
||||
|
||||
# Simulates the scheduled tests on each thread pool
|
||||
default_pool = PoolTimes(num_procs) # originally scheduled run
|
||||
prioritized_pool = PoolTimes(num_procs) # run for prioritized tests
|
||||
max_time_savings_sec = 0.0
|
||||
|
||||
# It's easier to look up prioritized tests by name
|
||||
prioritized_test_names = {test.name for test in prioritized_tests}
|
||||
|
||||
for test in selected_tests:
|
||||
serial = is_serial_test_fn(test.name)
|
||||
if test.name in prioritized_test_names:
|
||||
# Successive tests will always have a greater time savings
|
||||
max_time_savings_sec = default_pool.next_test_start_time(
|
||||
serial
|
||||
) - prioritized_pool.next_test_start_time(serial)
|
||||
|
||||
# "schedule" this test on the prioritized pool to get time savings for future prioritized tests
|
||||
prioritized_pool.schedule_test(test, serial)
|
||||
|
||||
# always schedule on the default pool to know what the unprioritized timeline would've looked like
|
||||
default_pool.schedule_test(test, serial)
|
||||
|
||||
print(
|
||||
f"Prioritized tests will run about {duration_to_str(max_time_savings_sec)} sooner than they would've otherwise"
|
||||
)
|
||||
|
||||
emit_metric(
|
||||
"test_reordering_time_savings",
|
||||
{
|
||||
"time_savings_sec": max_time_savings_sec,
|
||||
},
|
||||
)
|
||||
|
||||
# Return value used by tests
|
||||
return max_time_savings_sec
|
||||
|
||||
|
||||
def _get_file_rating_tests() -> List[str]:
|
||||
path = REPO_ROOT / TEST_FILE_RATINGS_FILE
|
||||
if not os.path.exists(path):
|
||||
print(f"could not find path {path}")
|
||||
return []
|
||||
with open(path) as f:
|
||||
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
|
||||
try:
|
||||
changed_files = _query_changed_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
return []
|
||||
ratings: Dict[str, float] = defaultdict(float)
|
||||
for file in changed_files:
|
||||
for test_file, score in test_file_ratings.get(file, {}).items():
|
||||
ratings[test_file] += score
|
||||
prioritize = sorted(ratings, key=lambda x: ratings[x])
|
||||
return prioritize
|
||||
|
||||
|
||||
def get_reordered_tests(
|
||||
tests: List[str],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Get the reordered test filename list based on github PR history or git changed file.
|
||||
We prioritize running test files that were changed.
|
||||
"""
|
||||
prioritized_tests: List[str] = []
|
||||
|
||||
def add_tests(tests_to_add: List[str], test_group_description: str) -> None:
|
||||
if not tests_to_add:
|
||||
return
|
||||
|
||||
print(f"{test_group_description}:")
|
||||
for test in tests_to_add:
|
||||
if test in tests:
|
||||
print(f" {test}")
|
||||
if test not in prioritized_tests:
|
||||
prioritized_tests.append(test)
|
||||
|
||||
add_tests(
|
||||
sorted(_get_previously_failing_tests()),
|
||||
"If run, these tests will prioritized because they previously failed",
|
||||
)
|
||||
|
||||
add_tests(
|
||||
sorted(_get_modified_tests()),
|
||||
"If run, these tests will be prioritized because they were modified",
|
||||
)
|
||||
|
||||
add_tests(
|
||||
_get_file_rating_tests(),
|
||||
"If run, these tests will be preioritized for an experiment in TD",
|
||||
)
|
||||
|
||||
prioritized_tests = [x for x in prioritized_tests if x in tests]
|
||||
the_rest = [x for x in tests if x not in prioritized_tests]
|
||||
|
||||
if prioritized_tests:
|
||||
test_cnt_str = pluralize(len(tests), "test")
|
||||
print(
|
||||
f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
|
||||
)
|
||||
|
||||
emit_metric(
|
||||
"test_reordering_prioritized_tests",
|
||||
{
|
||||
"prioritized_test_cnt": len(prioritized_tests),
|
||||
"total_test_cnt": len(tests),
|
||||
"prioritized_tests": prioritized_tests,
|
||||
"remaining_tests": the_rest,
|
||||
},
|
||||
)
|
||||
|
||||
return (prioritized_tests, the_rest)
|
||||
|
||||
|
||||
def get_test_case_configs(dirpath: str) -> None:
|
||||
get_slow_tests(dirpath=dirpath)
|
||||
get_disabled_tests(dirpath=dirpath)
|
||||
|
Reference in New Issue
Block a user