mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Test Reordering: Run previously failing tests first (#101123)
Makes the CI prioritize running any test files that had a failing test in a previous iteration of the given PR. A follow up to https://github.com/pytorch/pytorch/pull/100522 which makes the `.pytest_cache` available to use here A concrete example: 1. Person A pushes a new commit and creates a PR. 2. 2 hours later, test_im_now_broken.py fails 3. Person A attempts to fix the test, but the test is actually still broken 4. The CI, seeing that test_im_now_broken.py had failed on a previous run, will now prioritize running that test first. Instead of waiting another 2 hours to get a signal, Person A only needs to wait ~15 minutes (which is how long it takes for tests to start running) # Testing I modified a file to make the tests invoking it fail and triggered CI twice with this failure. First run: https://github.com/pytorch/pytorch/actions/runs/4963943209/jobs/8883800811 Test step took 1h 9m to run Second run: https://github.com/pytorch/pytorch/actions/runs/4965016776/jobs/8885657992 Test step failed within 2m 27s Pull Request resolved: https://github.com/pytorch/pytorch/pull/101123 Approved by: https://github.com/malfet, https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
b5ed606a8b
commit
b1474019a4
19
tools/shared/logging_utils.py
Normal file
19
tools/shared/logging_utils.py
Normal file
@ -0,0 +1,19 @@
|
||||
def pluralize(count: int, singular_word: str, plural_word: str = "") -> str:
|
||||
if count == 1:
|
||||
return f"{count} {singular_word}"
|
||||
|
||||
if not plural_word:
|
||||
plural_word = f"{singular_word}s"
|
||||
|
||||
return f"{count} {plural_word}"
|
||||
|
||||
|
||||
def duration_to_str(seconds: float) -> str:
|
||||
if seconds < 0.00001:
|
||||
return "0s"
|
||||
elif seconds < 60:
|
||||
return f"{seconds:.1f}s"
|
||||
elif seconds < 3600:
|
||||
return f"{seconds / 60:.1f}m"
|
||||
else:
|
||||
return f"{seconds / 3600:.1f}h"
|
@ -1,15 +1,24 @@
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from unittest import mock
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
|
||||
from tools.testing.test_selections import (
|
||||
_get_previously_failing_tests,
|
||||
calculate_shards,
|
||||
get_reordered_tests,
|
||||
ShardedTest,
|
||||
THRESHOLD,
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
print("Can't import required modules, exiting")
|
||||
exit(1)
|
||||
@ -328,5 +337,81 @@ class TestCalculateShards(unittest.TestCase):
|
||||
self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])
|
||||
|
||||
|
||||
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
|
||||
file_object = io.StringIO()
|
||||
json.dump(contents, file_object)
|
||||
file_object.seek(0)
|
||||
return file_object
|
||||
|
||||
|
||||
class TestParsePrevTests(unittest.TestCase):
|
||||
@mock.patch("pathlib.Path.exists", return_value=False)
|
||||
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
|
||||
expected_failing_test_files: Set[str] = set()
|
||||
|
||||
found_tests = _get_previously_failing_tests()
|
||||
|
||||
self.assertSetEqual(expected_failing_test_files, found_tests)
|
||||
|
||||
@mock.patch("pathlib.Path.exists", return_value=True)
|
||||
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
|
||||
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
|
||||
expected_failing_test_files: Set[str] = set()
|
||||
|
||||
found_tests = _get_previously_failing_tests()
|
||||
|
||||
self.assertSetEqual(expected_failing_test_files, found_tests)
|
||||
mock_open.assert_called()
|
||||
|
||||
lastfailed_with_multiple_tests_per_file = {
|
||||
"test/test_car.py::TestCar::test_num[17]": True,
|
||||
"test/test_car.py::TestBar::test_num[25]": True,
|
||||
"test/test_far.py::TestFar::test_fun_copy[17]": True,
|
||||
"test/test_bar.py::TestBar::test_fun_copy[25]": True,
|
||||
}
|
||||
|
||||
@mock.patch("pathlib.Path.exists", return_value=True)
|
||||
@mock.patch(
|
||||
"builtins.open",
|
||||
return_value=mocked_file(lastfailed_with_multiple_tests_per_file),
|
||||
)
|
||||
def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> None:
|
||||
expected_failing_test_files = {"test_car", "test_bar", "test_far"}
|
||||
found_tests = _get_previously_failing_tests()
|
||||
|
||||
self.assertSetEqual(expected_failing_test_files, found_tests)
|
||||
|
||||
@mock.patch(
|
||||
"tools.testing.test_selections._get_previously_failing_tests",
|
||||
return_value={"test4"},
|
||||
)
|
||||
@mock.patch(
|
||||
"tools.testing.test_selections._get_modified_tests",
|
||||
return_value={"test2", "test4"},
|
||||
)
|
||||
def test_get_reordered_tests(
|
||||
self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any
|
||||
) -> None:
|
||||
tests = [
|
||||
ShardedTest(name="test1", shard=1, num_shards=2, time=600.0),
|
||||
ShardedTest(name="test2", shard=1, num_shards=2, time=500.0),
|
||||
ShardedTest(name="test3", shard=1, num_shards=2, time=400.0),
|
||||
ShardedTest(name="test4", shard=1, num_shards=2, time=300.0),
|
||||
ShardedTest(name="test5", shard=1, num_shards=2, time=200.0),
|
||||
]
|
||||
|
||||
expected_prioritized_tests = {"test4", "test2"}
|
||||
expected_remaining_tests = {"test1", "test3", "test5"}
|
||||
|
||||
prioritized_tests, remaining_tests = get_reordered_tests(tests)
|
||||
|
||||
# Just want to check the names of the tests
|
||||
prioritized_tests_name = {test.name for test in prioritized_tests}
|
||||
remaining_tests_name = {test.name for test in remaining_tests}
|
||||
|
||||
self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
|
||||
self.assertSetEqual(expected_remaining_tests, remaining_tests_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,8 +1,13 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
from warnings import warn
|
||||
|
||||
from tools.shared.logging_utils import duration_to_str, pluralize
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
||||
|
||||
@ -37,7 +42,7 @@ class ShardedTest(NamedTuple):
|
||||
name: str
|
||||
shard: int
|
||||
num_shards: int
|
||||
time: Optional[float]
|
||||
time: Optional[float] # In seconds
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} {self.shard}/{self.num_shards}"
|
||||
@ -133,6 +138,55 @@ def _query_changed_test_files() -> List[str]:
|
||||
return lines
|
||||
|
||||
|
||||
def _get_previously_failing_tests() -> Set[str]:
|
||||
PYTEST_FAILED_TESTS_CACHE_FILE_PATH = Path(".pytest_cache/v/cache/lastfailed")
|
||||
|
||||
if not PYTEST_FAILED_TESTS_CACHE_FILE_PATH.exists():
|
||||
warn(
|
||||
f"No pytorch cache found at {PYTEST_FAILED_TESTS_CACHE_FILE_PATH.absolute()}"
|
||||
)
|
||||
return set()
|
||||
|
||||
with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH, "r") as f:
|
||||
last_failed_tests = json.load(f)
|
||||
|
||||
prioritized_tests = _parse_prev_failing_test_files(last_failed_tests)
|
||||
return _python_test_file_to_test_name(prioritized_tests)
|
||||
|
||||
|
||||
def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[str]:
|
||||
prioritized_tests = set()
|
||||
|
||||
# The keys are formatted as "test_file.py::test_class::test_method[params]"
|
||||
# We just need the test_file part
|
||||
for test in last_failed_tests:
|
||||
parts = test.split("::")
|
||||
if len(parts) > 1:
|
||||
test_file = parts[0]
|
||||
prioritized_tests.add(test_file)
|
||||
|
||||
return prioritized_tests
|
||||
|
||||
|
||||
def _get_modified_tests() -> Set[str]:
|
||||
try:
|
||||
changed_files = _query_changed_test_files()
|
||||
except Exception as e:
|
||||
warn(f"Can't query changed test files due to {e}")
|
||||
# If unable to get changed files from git, quit without doing any sorting
|
||||
return set()
|
||||
|
||||
return _python_test_file_to_test_name(set(changed_files))
|
||||
|
||||
|
||||
def _python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
|
||||
prefix = f"test{os.path.sep}"
|
||||
valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")}
|
||||
valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests}
|
||||
|
||||
return valid_tests
|
||||
|
||||
|
||||
def get_reordered_tests(
|
||||
tests: List[ShardedTest],
|
||||
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
|
||||
@ -140,43 +194,66 @@ def get_reordered_tests(
|
||||
Get the reordered test filename list based on github PR history or git changed file.
|
||||
We prioritize running test files that were changed.
|
||||
"""
|
||||
prioritized_tests: List[str] = []
|
||||
if len(prioritized_tests) == 0:
|
||||
try:
|
||||
changed_files = _query_changed_test_files()
|
||||
except Exception:
|
||||
# If unable to get changed files from git, quit without doing any sorting
|
||||
return ([], tests)
|
||||
|
||||
prefix = f"test{os.path.sep}"
|
||||
prioritized_tests = [
|
||||
f for f in changed_files if f.startswith(prefix) and f.endswith(".py")
|
||||
]
|
||||
prioritized_tests = [f[len(prefix) :] for f in prioritized_tests]
|
||||
prioritized_tests = [f[: -len(".py")] for f in prioritized_tests]
|
||||
print("Prioritized test from test file changes.")
|
||||
def print_tests(tests: Set[str], test_group_description: str) -> None:
|
||||
if not tests:
|
||||
return
|
||||
|
||||
print(f"{test_group_description}:")
|
||||
for test in tests:
|
||||
print(f" {test}")
|
||||
|
||||
prioritized_tests: Set[str] = set()
|
||||
|
||||
pri_test = _get_previously_failing_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will prioritized because they previously failed"
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
pri_test |= _get_modified_tests()
|
||||
print_tests(
|
||||
pri_test, "If run, these tests will be prioritized because they were modified"
|
||||
)
|
||||
prioritized_tests |= pri_test
|
||||
|
||||
bring_to_front = []
|
||||
the_rest = []
|
||||
|
||||
test_time_for_regular_tests_so_far = 0.0
|
||||
# how much sooner did we run prioritized tests compared to a naive ordering
|
||||
time_savings_sec = 0.0
|
||||
|
||||
for test in tests:
|
||||
if test.name in prioritized_tests:
|
||||
bring_to_front.append(test)
|
||||
# Calculate approx time saved by reordering
|
||||
time_savings_sec = test_time_for_regular_tests_so_far
|
||||
else:
|
||||
the_rest.append(test)
|
||||
if len(tests) == len(bring_to_front) + len(the_rest):
|
||||
print(
|
||||
f"reordering tests for PR:\n"
|
||||
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
|
||||
)
|
||||
return (bring_to_front, the_rest)
|
||||
else:
|
||||
test_time_for_regular_tests_so_far += test.get_time()
|
||||
|
||||
if len(tests) != len(bring_to_front) + len(the_rest):
|
||||
print(
|
||||
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
|
||||
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
|
||||
)
|
||||
return ([], tests)
|
||||
|
||||
# TODO: Would be great to upload these stats to RDS/Rockset!
|
||||
test_cnt_str = pluralize(len(tests), "test")
|
||||
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
|
||||
print(
|
||||
f"Prioritized tests estimated to run up to {duration_to_str(time_savings_sec)} sooner than they would've otherwise"
|
||||
)
|
||||
|
||||
prioritized_test_names = [t.name for t in bring_to_front]
|
||||
print(f"Prioritized: {prioritized_test_names}")
|
||||
remaining_test_names = [t.name for t in the_rest]
|
||||
print(f"The Rest: {remaining_test_names}")
|
||||
|
||||
return (bring_to_front, the_rest)
|
||||
|
||||
|
||||
def get_test_case_configs(dirpath: str) -> None:
|
||||
get_slow_tests(dirpath=dirpath)
|
||||
|
Reference in New Issue
Block a user