Test Reordering: Run previously failing tests first (#101123)

Makes the CI prioritize running any test files that had a failing test in a previous iteration of the given PR.

A follow up to https://github.com/pytorch/pytorch/pull/100522 which makes the `.pytest_cache` available to use here

A concrete example:
1. Person A pushes a new commit and creates a PR.
2. 2 hours later, test_im_now_broken.py fails
3. Person A attempts to fix the test, but the test is actually still broken
4. The CI, seeing that test_im_now_broken.py had failed on a previous run, will now prioritize running that test first. Instead of waiting another 2 hours to get a signal, Person A only needs to wait ~15 minutes (which is how long it takes for tests to start running)

# Testing
I modified a file to make the tests invoking it fail and triggered CI twice with this failure.

First run: https://github.com/pytorch/pytorch/actions/runs/4963943209/jobs/8883800811
Test step took 1h 9m to run

Second run: https://github.com/pytorch/pytorch/actions/runs/4965016776/jobs/8885657992
Test step failed within 2m 27s

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101123
Approved by: https://github.com/malfet, https://github.com/huydhn
This commit is contained in:
Zain Rizvi
2023-05-16 19:57:48 +00:00
committed by PyTorch MergeBot
parent b5ed606a8b
commit b1474019a4
3 changed files with 206 additions and 25 deletions

View File

@ -0,0 +1,19 @@
def pluralize(count: int, singular_word: str, plural_word: str = "") -> str:
if count == 1:
return f"{count} {singular_word}"
if not plural_word:
plural_word = f"{singular_word}s"
return f"{count} {plural_word}"
def duration_to_str(seconds: float) -> str:
if seconds < 0.00001:
return "0s"
elif seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
return f"{seconds / 60:.1f}m"
else:
return f"{seconds / 3600:.1f}h"

View File

@ -1,15 +1,24 @@
import io
import json
import pathlib
import random
import sys
import unittest
from collections import defaultdict
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Set, Tuple
from unittest import mock
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
from tools.testing.test_selections import (
_get_previously_failing_tests,
calculate_shards,
get_reordered_tests,
ShardedTest,
THRESHOLD,
)
except ModuleNotFoundError:
print("Can't import required modules, exiting")
exit(1)
@ -328,5 +337,81 @@ class TestCalculateShards(unittest.TestCase):
self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
file_object = io.StringIO()
json.dump(contents, file_object)
file_object.seek(0)
return file_object
class TestParsePrevTests(unittest.TestCase):
@mock.patch("pathlib.Path.exists", return_value=False)
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
expected_failing_test_files: Set[str] = set()
found_tests = _get_previously_failing_tests()
self.assertSetEqual(expected_failing_test_files, found_tests)
@mock.patch("pathlib.Path.exists", return_value=True)
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
expected_failing_test_files: Set[str] = set()
found_tests = _get_previously_failing_tests()
self.assertSetEqual(expected_failing_test_files, found_tests)
mock_open.assert_called()
lastfailed_with_multiple_tests_per_file = {
"test/test_car.py::TestCar::test_num[17]": True,
"test/test_car.py::TestBar::test_num[25]": True,
"test/test_far.py::TestFar::test_fun_copy[17]": True,
"test/test_bar.py::TestBar::test_fun_copy[25]": True,
}
@mock.patch("pathlib.Path.exists", return_value=True)
@mock.patch(
"builtins.open",
return_value=mocked_file(lastfailed_with_multiple_tests_per_file),
)
def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> None:
expected_failing_test_files = {"test_car", "test_bar", "test_far"}
found_tests = _get_previously_failing_tests()
self.assertSetEqual(expected_failing_test_files, found_tests)
@mock.patch(
"tools.testing.test_selections._get_previously_failing_tests",
return_value={"test4"},
)
@mock.patch(
"tools.testing.test_selections._get_modified_tests",
return_value={"test2", "test4"},
)
def test_get_reordered_tests(
self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any
) -> None:
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=600.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=500.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=400.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=300.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=200.0),
]
expected_prioritized_tests = {"test4", "test2"}
expected_remaining_tests = {"test1", "test3", "test5"}
prioritized_tests, remaining_tests = get_reordered_tests(tests)
# Just want to check the names of the tests
prioritized_tests_name = {test.name for test in prioritized_tests}
remaining_tests_name = {test.name for test in remaining_tests}
self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
self.assertSetEqual(expected_remaining_tests, remaining_tests_name)
if __name__ == "__main__":
unittest.main()

View File

@ -1,8 +1,13 @@
import json
import math
import os
import subprocess
from pathlib import Path
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from warnings import warn
from tools.shared.logging_utils import duration_to_str, pluralize
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
@ -37,7 +42,7 @@ class ShardedTest(NamedTuple):
name: str
shard: int
num_shards: int
time: Optional[float]
time: Optional[float] # In seconds
def __str__(self) -> str:
return f"{self.name} {self.shard}/{self.num_shards}"
@ -133,6 +138,55 @@ def _query_changed_test_files() -> List[str]:
return lines
def _get_previously_failing_tests() -> Set[str]:
PYTEST_FAILED_TESTS_CACHE_FILE_PATH = Path(".pytest_cache/v/cache/lastfailed")
if not PYTEST_FAILED_TESTS_CACHE_FILE_PATH.exists():
warn(
f"No pytorch cache found at {PYTEST_FAILED_TESTS_CACHE_FILE_PATH.absolute()}"
)
return set()
with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH, "r") as f:
last_failed_tests = json.load(f)
prioritized_tests = _parse_prev_failing_test_files(last_failed_tests)
return _python_test_file_to_test_name(prioritized_tests)
def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[str]:
prioritized_tests = set()
# The keys are formatted as "test_file.py::test_class::test_method[params]"
# We just need the test_file part
for test in last_failed_tests:
parts = test.split("::")
if len(parts) > 1:
test_file = parts[0]
prioritized_tests.add(test_file)
return prioritized_tests
def _get_modified_tests() -> Set[str]:
try:
changed_files = _query_changed_test_files()
except Exception as e:
warn(f"Can't query changed test files due to {e}")
# If unable to get changed files from git, quit without doing any sorting
return set()
return _python_test_file_to_test_name(set(changed_files))
def _python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
prefix = f"test{os.path.sep}"
valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")}
valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests}
return valid_tests
def get_reordered_tests(
tests: List[ShardedTest],
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
@ -140,43 +194,66 @@ def get_reordered_tests(
Get the reordered test filename list based on github PR history or git changed file.
We prioritize running test files that were changed.
"""
prioritized_tests: List[str] = []
if len(prioritized_tests) == 0:
try:
changed_files = _query_changed_test_files()
except Exception:
# If unable to get changed files from git, quit without doing any sorting
return ([], tests)
prefix = f"test{os.path.sep}"
prioritized_tests = [
f for f in changed_files if f.startswith(prefix) and f.endswith(".py")
]
prioritized_tests = [f[len(prefix) :] for f in prioritized_tests]
prioritized_tests = [f[: -len(".py")] for f in prioritized_tests]
print("Prioritized test from test file changes.")
def print_tests(tests: Set[str], test_group_description: str) -> None:
if not tests:
return
print(f"{test_group_description}:")
for test in tests:
print(f" {test}")
prioritized_tests: Set[str] = set()
pri_test = _get_previously_failing_tests()
print_tests(
pri_test, "If run, these tests will prioritized because they previously failed"
)
prioritized_tests |= pri_test
pri_test |= _get_modified_tests()
print_tests(
pri_test, "If run, these tests will be prioritized because they were modified"
)
prioritized_tests |= pri_test
bring_to_front = []
the_rest = []
test_time_for_regular_tests_so_far = 0.0
# how much sooner did we run prioritized tests compared to a naive ordering
time_savings_sec = 0.0
for test in tests:
if test.name in prioritized_tests:
bring_to_front.append(test)
# Calculate approx time saved by reordering
time_savings_sec = test_time_for_regular_tests_so_far
else:
the_rest.append(test)
if len(tests) == len(bring_to_front) + len(the_rest):
print(
f"reordering tests for PR:\n"
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
)
return (bring_to_front, the_rest)
else:
test_time_for_regular_tests_so_far += test.get_time()
if len(tests) != len(bring_to_front) + len(the_rest):
print(
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
)
return ([], tests)
# TODO: Would be great to upload these stats to RDS/Rockset!
test_cnt_str = pluralize(len(tests), "test")
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
print(
f"Prioritized tests estimated to run up to {duration_to_str(time_savings_sec)} sooner than they would've otherwise"
)
prioritized_test_names = [t.name for t in bring_to_front]
print(f"Prioritized: {prioritized_test_names}")
remaining_test_names = [t.name for t in the_rest]
print(f"The Rest: {remaining_test_names}")
return (bring_to_front, the_rest)
def get_test_case_configs(dirpath: str) -> None:
get_slow_tests(dirpath=dirpath)