Files
pytorch/tools/testing/test_selections.py
Zain Rizvi 95f191a248 Always run prioritized tests first, even if they're expected to run serially (#100748)
Today, we prioritize running test files that were edited in the user's PR, with the idea being to run them before we run any other test.

Except, if the modified test is supposed to run serially, then we still end up running it after all the parallelized tests have finished running.

This PR fixes that to _always_ run the prioritized tests before the regular tests, regardless of if the test is supposed to run serially or in parallel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100748
Approved by: https://github.com/huydhn
2023-05-08 20:23:46 +00:00

178 lines
6.1 KiB
Python

import math
import os
import subprocess
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 2
THRESHOLD = 60 * 10 # 10 minutes
# See Note [ROCm parallel CI testing]
# Special logic for ROCm GHA runners to query number of GPUs available.
# torch.version.hip was not available to check if this was a ROCm self-hosted runner.
# Must check for ROCm runner in another way. We look for /opt/rocm directory.
if os.path.exists("/opt/rocm") and not IS_MEM_LEAK_CHECK:
try:
# This is the same logic used in GHA health check, see .github/templates/common.yml.j2
lines = (
subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n")
)
count = 0
for line in lines:
if " gfx" in line:
count += 1
assert count > 0 # there must be at least 1 GPU
# Limiting to 8 GPUs(PROCS)
NUM_PROCS = 8 if count > 8 else count
except subprocess.CalledProcessError as e:
# The safe default for ROCm GHA runners is to run tests serially.
NUM_PROCS = 1
class ShardedTest(NamedTuple):
name: str
shard: int
num_shards: int
time: Optional[float]
def __str__(self) -> str:
return f"{self.name} {self.shard}/{self.num_shards}"
def get_time(self) -> float:
return self.time or 0
class ShardJob:
def __init__(self) -> None:
self.serial: List[ShardedTest] = []
self.parallel: List[ShardedTest] = []
def get_total_time(self) -> float:
procs = [0.0 for _ in range(NUM_PROCS)]
for test in self.parallel:
min_index = procs.index(min(procs))
procs[min_index] += test.get_time()
time = max(procs) + sum(test.get_time() for test in self.serial)
return time
def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]:
return (self.get_total_time(), self.serial + self.parallel)
def get_with_pytest_shard(
tests: List[str], test_file_times: Dict[str, float]
) -> List[ShardedTest]:
sharded_tests: List[ShardedTest] = []
for test in tests:
duration = test_file_times[test]
if duration > THRESHOLD:
num_shards = math.ceil(duration / THRESHOLD)
for i in range(num_shards):
sharded_tests.append(
ShardedTest(test, i + 1, num_shards, duration / num_shards)
)
else:
sharded_tests.append(ShardedTest(test, 1, 1, duration))
return sharded_tests
def calculate_shards(
num_shards: int,
tests: List[str],
test_file_times: Dict[str, float],
must_serial: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[float, List[ShardedTest]]]:
must_serial = must_serial or (lambda x: True)
known_tests = [x for x in tests if x in test_file_times]
unknown_tests: List[str] = [x for x in tests if x not in known_tests]
sorted_tests = sorted(
get_with_pytest_shard(known_tests, test_file_times),
key=lambda j: j.get_time(),
reverse=True,
)
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
for test in sorted_tests:
if must_serial(test.name):
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.serial.append(test)
else:
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.parallel.append(test)
# Round robin the unknown jobs starting with the smallest shard
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
for unknown_test in unknown_tests:
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
index = (index + 1) % num_shards
return [job.convert_to_tuple() for job in sharded_jobs]
def _query_changed_test_files() -> List[str]:
default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
cmd = ["git", "diff", "--name-only", default_branch, "HEAD"]
proc = subprocess.run(cmd, capture_output=True)
if proc.returncode != 0:
raise RuntimeError("Unable to get changed files")
lines = proc.stdout.decode().strip().split("\n")
lines = [line.strip() for line in lines]
return lines
def get_reordered_tests(
tests: List[ShardedTest],
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
"""
Get the reordered test filename list based on github PR history or git changed file.
We prioritize running test files that were changed.
"""
prioritized_tests: List[str] = []
if len(prioritized_tests) == 0:
try:
changed_files = _query_changed_test_files()
except Exception:
# If unable to get changed files from git, quit without doing any sorting
return ([], tests)
prefix = f"test{os.path.sep}"
prioritized_tests = [
f for f in changed_files if f.startswith(prefix) and f.endswith(".py")
]
prioritized_tests = [f[len(prefix) :] for f in prioritized_tests]
prioritized_tests = [f[: -len(".py")] for f in prioritized_tests]
print("Prioritized test from test file changes.")
bring_to_front = []
the_rest = []
for test in tests:
if test.name in prioritized_tests:
bring_to_front.append(test)
else:
the_rest.append(test)
if len(tests) == len(bring_to_front) + len(the_rest):
print(
f"reordering tests for PR:\n"
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
)
return (bring_to_front, the_rest)
else:
print(
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
)
return ([], tests)
def get_test_case_configs(dirpath: str) -> None:
get_slow_tests(dirpath=dirpath)
get_disabled_tests(dirpath=dirpath)