mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR looks big, but it's mostly just refactorings with a bit of dead code deletion. Exceptions are: - Some metric emissions were changed to comply with the new TD format - Some logging changes - We now run tests in three batches (highly_relevant, probably_relevant, unranked_relevance) instead of the previous two (prioritized and general) Refactorings done: - Moves all test reordering code to the new TD framework - Refactors run_test.py to cleanly support multiple levels of test priorities - Deletes some dead code that was originally written for logging Pull Request resolved: https://github.com/pytorch/pytorch/pull/107071 Approved by: https://github.com/clee2000, https://github.com/huydhn
133 lines
4.7 KiB
Python
133 lines
4.7 KiB
Python
import math
|
|
import os
|
|
import subprocess
|
|
from pathlib import Path
|
|
|
|
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
|
|
|
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
|
|
|
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
|
|
|
# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
|
|
# to ensure that sharding is consistent, NUM_PROCS is the actual number of procs
|
|
# used to run tests. If they are not equal, the only consequence should be
|
|
# unequal shards.
|
|
IS_ROCM = os.path.exists("/opt/rocm")
|
|
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
|
|
NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2
|
|
THRESHOLD = 60 * 10 # 10 minutes
|
|
|
|
# See Note [ROCm parallel CI testing]
|
|
# Special logic for ROCm GHA runners to query number of GPUs available.
|
|
# torch.version.hip was not available to check if this was a ROCm self-hosted runner.
|
|
# Must check for ROCm runner in another way. We look for /opt/rocm directory.
|
|
if IS_ROCM and not IS_MEM_LEAK_CHECK:
|
|
try:
|
|
# This is the same logic used in GHA health check, see .github/templates/common.yml.j2
|
|
lines = (
|
|
subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n")
|
|
)
|
|
count = 0
|
|
for line in lines:
|
|
if " gfx" in line:
|
|
count += 1
|
|
assert count > 0 # there must be at least 1 GPU
|
|
# Limiting to 8 GPUs(PROCS)
|
|
NUM_PROCS = 8 if count > 8 else count
|
|
except subprocess.CalledProcessError as e:
|
|
# The safe default for ROCm GHA runners is to run tests serially.
|
|
NUM_PROCS = 1
|
|
|
|
|
|
class ShardedTest(NamedTuple):
|
|
name: str
|
|
shard: int
|
|
num_shards: int
|
|
time: Optional[float] # In seconds
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.name} {self.shard}/{self.num_shards}"
|
|
|
|
def get_time(self) -> float:
|
|
return self.time or 0
|
|
|
|
|
|
class ShardJob:
|
|
def __init__(self) -> None:
|
|
self.serial: List[ShardedTest] = []
|
|
self.parallel: List[ShardedTest] = []
|
|
|
|
def get_total_time(self) -> float:
|
|
procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)]
|
|
for test in self.parallel:
|
|
min_index = procs.index(min(procs))
|
|
procs[min_index] += test.get_time()
|
|
time = max(procs) + sum(test.get_time() for test in self.serial)
|
|
return time
|
|
|
|
def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]:
|
|
return (self.get_total_time(), self.serial + self.parallel)
|
|
|
|
|
|
def get_with_pytest_shard(
|
|
tests: List[str], test_file_times: Dict[str, float]
|
|
) -> List[ShardedTest]:
|
|
sharded_tests: List[ShardedTest] = []
|
|
for test in tests:
|
|
duration = test_file_times.get(test, None)
|
|
if duration and duration > THRESHOLD:
|
|
num_shards = math.ceil(duration / THRESHOLD)
|
|
for i in range(num_shards):
|
|
sharded_tests.append(
|
|
ShardedTest(test, i + 1, num_shards, duration / num_shards)
|
|
)
|
|
else:
|
|
sharded_tests.append(ShardedTest(test, 1, 1, duration))
|
|
return sharded_tests
|
|
|
|
|
|
def calculate_shards(
|
|
num_shards: int,
|
|
tests: List[str],
|
|
test_file_times: Dict[str, float],
|
|
must_serial: Optional[Callable[[str], bool]] = None,
|
|
sort_by_time: bool = True,
|
|
) -> List[Tuple[float, List[ShardedTest]]]:
|
|
must_serial = must_serial or (lambda x: True)
|
|
|
|
known_tests = tests
|
|
unknown_tests = []
|
|
|
|
if sort_by_time:
|
|
known_tests = [x for x in tests if x in test_file_times]
|
|
unknown_tests = [x for x in tests if x not in known_tests]
|
|
|
|
known_tests = get_with_pytest_shard(known_tests, test_file_times)
|
|
|
|
if sort_by_time:
|
|
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
|
|
|
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
|
|
for test in known_tests:
|
|
if must_serial(test.name):
|
|
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
|
min_sharded_job.serial.append(test)
|
|
else:
|
|
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
|
min_sharded_job.parallel.append(test)
|
|
|
|
# Round robin the unknown jobs starting with the smallest shard
|
|
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
|
|
for unknown_test in unknown_tests:
|
|
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
|
|
index = (index + 1) % num_shards
|
|
return [job.convert_to_tuple() for job in sharded_jobs]
|
|
|
|
|
|
def get_test_case_configs(dirpath: str) -> None:
|
|
get_slow_tests(dirpath=dirpath)
|
|
get_disabled_tests(dirpath=dirpath)
|