mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Companion with https://github.com/pytorch/test-infra/pull/4424 Uses the file rating generated by the test infra PR to re order tests. For each test file, sum the file ratings from the changed files in the PR, and put the tests in order of sum. A lot of tests are probably going to end up as "prioritized" since it takes anything with a rating > 0 right now. Sharding is done twice, once on the prioritized tests, and once on the general/non prioritized tests. Prioritized tests have an order, so they should be sharded according to that order, while general tests don't have an order and are sharded by test time, which should result in more balanced shards. I'll change the metric name before I merge, i want to quarantine my testing stuff from actual results Pull Request resolved: https://github.com/pytorch/pytorch/pull/106347 Approved by: https://github.com/ZainRizvi
362 lines
12 KiB
Python
362 lines
12 KiB
Python
import heapq
|
|
import json
|
|
import math
|
|
import os
|
|
import subprocess
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple
|
|
from warnings import warn
|
|
|
|
from tools.shared.logging_utils import duration_to_str, pluralize
|
|
from tools.stats.export_test_times import TEST_FILE_RATINGS_FILE
|
|
|
|
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
|
from tools.stats.upload_stats_lib import emit_metric
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
|
|
|
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
|
|
|
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
|
|
# to ensure that sharding is consistent, NUM_PROCS is the actual number of procs
|
|
# used to run tests. If they are not equal, the only consequence should be
|
|
# unequal shards.
|
|
IS_ROCM = os.path.exists("/opt/rocm")
|
|
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
|
|
NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2
|
|
THRESHOLD = 60 * 10 # 10 minutes
|
|
|
|
# See Note [ROCm parallel CI testing]
|
|
# Special logic for ROCm GHA runners to query number of GPUs available.
|
|
# torch.version.hip was not available to check if this was a ROCm self-hosted runner.
|
|
# Must check for ROCm runner in another way. We look for /opt/rocm directory.
|
|
if IS_ROCM and not IS_MEM_LEAK_CHECK:
|
|
try:
|
|
# This is the same logic used in GHA health check, see .github/templates/common.yml.j2
|
|
lines = (
|
|
subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n")
|
|
)
|
|
count = 0
|
|
for line in lines:
|
|
if " gfx" in line:
|
|
count += 1
|
|
assert count > 0 # there must be at least 1 GPU
|
|
# Limiting to 8 GPUs(PROCS)
|
|
NUM_PROCS = 8 if count > 8 else count
|
|
except subprocess.CalledProcessError as e:
|
|
# The safe default for ROCm GHA runners is to run tests serially.
|
|
NUM_PROCS = 1
|
|
|
|
|
|
class ShardedTest(NamedTuple):
|
|
name: str
|
|
shard: int
|
|
num_shards: int
|
|
time: Optional[float] # In seconds
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.name} {self.shard}/{self.num_shards}"
|
|
|
|
def get_time(self) -> float:
|
|
return self.time or 0
|
|
|
|
|
|
class ShardJob:
|
|
def __init__(self) -> None:
|
|
self.serial: List[ShardedTest] = []
|
|
self.parallel: List[ShardedTest] = []
|
|
|
|
def get_total_time(self) -> float:
|
|
procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)]
|
|
for test in self.parallel:
|
|
min_index = procs.index(min(procs))
|
|
procs[min_index] += test.get_time()
|
|
time = max(procs) + sum(test.get_time() for test in self.serial)
|
|
return time
|
|
|
|
def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]:
|
|
return (self.get_total_time(), self.serial + self.parallel)
|
|
|
|
|
|
def get_with_pytest_shard(
|
|
tests: List[str], test_file_times: Dict[str, float]
|
|
) -> List[ShardedTest]:
|
|
sharded_tests: List[ShardedTest] = []
|
|
for test in tests:
|
|
duration = test_file_times.get(test, None)
|
|
if duration and duration > THRESHOLD:
|
|
num_shards = math.ceil(duration / THRESHOLD)
|
|
for i in range(num_shards):
|
|
sharded_tests.append(
|
|
ShardedTest(test, i + 1, num_shards, duration / num_shards)
|
|
)
|
|
else:
|
|
sharded_tests.append(ShardedTest(test, 1, 1, duration))
|
|
return sharded_tests
|
|
|
|
|
|
def calculate_shards(
|
|
num_shards: int,
|
|
tests: List[str],
|
|
test_file_times: Dict[str, float],
|
|
must_serial: Optional[Callable[[str], bool]] = None,
|
|
sort_by_time: bool = True,
|
|
) -> List[Tuple[float, List[ShardedTest]]]:
|
|
must_serial = must_serial or (lambda x: True)
|
|
|
|
known_tests = tests
|
|
unknown_tests = []
|
|
|
|
if sort_by_time:
|
|
known_tests = [x for x in tests if x in test_file_times]
|
|
unknown_tests = [x for x in tests if x not in known_tests]
|
|
|
|
known_tests = get_with_pytest_shard(known_tests, test_file_times)
|
|
|
|
if sort_by_time:
|
|
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
|
|
|
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
|
|
for test in known_tests:
|
|
if must_serial(test.name):
|
|
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
|
min_sharded_job.serial.append(test)
|
|
else:
|
|
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
|
min_sharded_job.parallel.append(test)
|
|
|
|
# Round robin the unknown jobs starting with the smallest shard
|
|
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
|
|
for unknown_test in unknown_tests:
|
|
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
|
|
index = (index + 1) % num_shards
|
|
return [job.convert_to_tuple() for job in sharded_jobs]
|
|
|
|
|
|
def _query_changed_files() -> List[str]:
|
|
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
|
|
merge_base = (
|
|
subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
|
|
.decode()
|
|
.strip()
|
|
)
|
|
|
|
head = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
|
|
|
|
base_commit = merge_base
|
|
if base_commit == head:
|
|
# We are on the default branch, so check for changes since the last commit
|
|
base_commit = "HEAD^"
|
|
|
|
proc = subprocess.run(
|
|
["git", "diff", "--name-only", base_commit, "HEAD"], capture_output=True
|
|
)
|
|
|
|
if proc.returncode != 0:
|
|
raise RuntimeError("Unable to get changed files")
|
|
|
|
lines = proc.stdout.decode().strip().split("\n")
|
|
lines = [line.strip() for line in lines]
|
|
return lines
|
|
|
|
|
|
def _get_previously_failing_tests() -> Set[str]:
|
|
PYTEST_FAILED_TESTS_CACHE_FILE_PATH = Path(".pytest_cache/v/cache/lastfailed")
|
|
|
|
if not PYTEST_FAILED_TESTS_CACHE_FILE_PATH.exists():
|
|
warn(
|
|
f"No pytorch cache found at {PYTEST_FAILED_TESTS_CACHE_FILE_PATH.absolute()}"
|
|
)
|
|
return set()
|
|
|
|
with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH) as f:
|
|
last_failed_tests = json.load(f)
|
|
|
|
prioritized_tests = _parse_prev_failing_test_files(last_failed_tests)
|
|
return _python_test_file_to_test_name(prioritized_tests)
|
|
|
|
|
|
def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[str]:
|
|
prioritized_tests = set()
|
|
|
|
# The keys are formatted as "test_file.py::test_class::test_method[params]"
|
|
# We just need the test_file part
|
|
for test in last_failed_tests:
|
|
parts = test.split("::")
|
|
if len(parts) > 1:
|
|
test_file = parts[0]
|
|
prioritized_tests.add(test_file)
|
|
|
|
return prioritized_tests
|
|
|
|
|
|
def _get_modified_tests() -> Set[str]:
|
|
try:
|
|
changed_files = _query_changed_files()
|
|
except Exception as e:
|
|
warn(f"Can't query changed test files due to {e}")
|
|
# If unable to get changed files from git, quit without doing any sorting
|
|
return set()
|
|
|
|
return _python_test_file_to_test_name(set(changed_files))
|
|
|
|
|
|
def _python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
|
|
prefix = f"test{os.path.sep}"
|
|
valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")}
|
|
valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests}
|
|
|
|
return valid_tests
|
|
|
|
|
|
class PoolTimes:
|
|
def __init__(self, num_procs: int) -> None:
|
|
self.pool_times = [0.0 for _ in range(num_procs)]
|
|
self.serial_times = 0.0
|
|
|
|
def next_test_start_time(self, serial: bool) -> float:
|
|
if serial:
|
|
# Serial tests are run after all parallel tests complete
|
|
return max(self.pool_times) + self.serial_times
|
|
|
|
return self.pool_times[0]
|
|
|
|
def schedule_test(self, test: ShardedTest, serial: bool) -> None:
|
|
if serial:
|
|
self.serial_times += test.get_time()
|
|
else:
|
|
# pool_times[0] is always the thread with the least amount of time scheduled
|
|
heapq.heappushpop(self.pool_times, self.pool_times[0] + test.get_time())
|
|
|
|
|
|
def log_time_savings(
|
|
selected_tests: List[ShardedTest],
|
|
prioritized_tests: List[ShardedTest],
|
|
is_serial_test_fn: Callable[[str], bool],
|
|
num_procs: int = NUM_PROCS, # make this customizable for testing
|
|
) -> float:
|
|
# The tests will be run in [num_procs] parallel threads, so we assume each test
|
|
# is allocated to the thread that'll free up first.
|
|
# This isn't an exact match (since other factors could change which thread
|
|
# pool a test gets scheduled on) but it's a good approximation.
|
|
|
|
# Simulates the scheduled tests on each thread pool
|
|
default_pool = PoolTimes(num_procs) # originally scheduled run
|
|
prioritized_pool = PoolTimes(num_procs) # run for prioritized tests
|
|
max_time_savings_sec = 0.0
|
|
|
|
# It's easier to look up prioritized tests by name
|
|
prioritized_test_names = {test.name for test in prioritized_tests}
|
|
|
|
for test in selected_tests:
|
|
serial = is_serial_test_fn(test.name)
|
|
if test.name in prioritized_test_names:
|
|
# Successive tests will always have a greater time savings
|
|
max_time_savings_sec = default_pool.next_test_start_time(
|
|
serial
|
|
) - prioritized_pool.next_test_start_time(serial)
|
|
|
|
# "schedule" this test on the prioritized pool to get time savings for future prioritized tests
|
|
prioritized_pool.schedule_test(test, serial)
|
|
|
|
# always schedule on the default pool to know what the unprioritized timeline would've looked like
|
|
default_pool.schedule_test(test, serial)
|
|
|
|
print(
|
|
f"Prioritized tests will run about {duration_to_str(max_time_savings_sec)} sooner than they would've otherwise"
|
|
)
|
|
|
|
emit_metric(
|
|
"test_reordering_time_savings",
|
|
{
|
|
"time_savings_sec": max_time_savings_sec,
|
|
},
|
|
)
|
|
|
|
# Return value used by tests
|
|
return max_time_savings_sec
|
|
|
|
|
|
def _get_file_rating_tests() -> List[str]:
|
|
path = REPO_ROOT / TEST_FILE_RATINGS_FILE
|
|
if not os.path.exists(path):
|
|
print(f"could not find path {path}")
|
|
return []
|
|
with open(path) as f:
|
|
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
|
|
try:
|
|
changed_files = _query_changed_files()
|
|
except Exception as e:
|
|
warn(f"Can't query changed test files due to {e}")
|
|
return []
|
|
ratings: Dict[str, float] = defaultdict(float)
|
|
for file in changed_files:
|
|
for test_file, score in test_file_ratings.get(file, {}).items():
|
|
ratings[test_file] += score
|
|
prioritize = sorted(ratings, key=lambda x: ratings[x])
|
|
return prioritize
|
|
|
|
|
|
def get_reordered_tests(
|
|
tests: List[str],
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""
|
|
Get the reordered test filename list based on github PR history or git changed file.
|
|
We prioritize running test files that were changed.
|
|
"""
|
|
prioritized_tests: List[str] = []
|
|
|
|
def add_tests(tests_to_add: List[str], test_group_description: str) -> None:
|
|
if not tests_to_add:
|
|
return
|
|
|
|
print(f"{test_group_description}:")
|
|
for test in tests_to_add:
|
|
if test in tests:
|
|
print(f" {test}")
|
|
if test not in prioritized_tests:
|
|
prioritized_tests.append(test)
|
|
|
|
add_tests(
|
|
sorted(_get_previously_failing_tests()),
|
|
"If run, these tests will prioritized because they previously failed",
|
|
)
|
|
|
|
add_tests(
|
|
sorted(_get_modified_tests()),
|
|
"If run, these tests will be prioritized because they were modified",
|
|
)
|
|
|
|
add_tests(
|
|
_get_file_rating_tests(),
|
|
"If run, these tests will be preioritized for an experiment in TD",
|
|
)
|
|
|
|
prioritized_tests = [x for x in prioritized_tests if x in tests]
|
|
the_rest = [x for x in tests if x not in prioritized_tests]
|
|
|
|
if prioritized_tests:
|
|
test_cnt_str = pluralize(len(tests), "test")
|
|
print(
|
|
f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}"
|
|
)
|
|
|
|
emit_metric(
|
|
"test_reordering_prioritized_tests",
|
|
{
|
|
"prioritized_test_cnt": len(prioritized_tests),
|
|
"total_test_cnt": len(tests),
|
|
"prioritized_tests": prioritized_tests,
|
|
"remaining_tests": the_rest,
|
|
},
|
|
)
|
|
|
|
return (prioritized_tests, the_rest)
|
|
|
|
|
|
def get_test_case_configs(dirpath: str) -> None:
|
|
get_slow_tests(dirpath=dirpath)
|
|
get_disabled_tests(dirpath=dirpath)
|