Port existing heuristics to TD framework (#107071)

This PR looks big, but it's mostly just refactorings with a bit of dead code deletion. Exceptions are:
- Some metric emissions were changed to comply with the new TD format
- Some logging changes
- We now run tests in three batches (highly_relevant, probably_relevant, unranked_relevance) instead of the previous two (prioritized and general)

Refactorings done:
- Moves all test reordering code to the new TD framework
- Refactors run_test.py to cleanly support multiple levels of test priorities
- Deletes some dead code that was originally written for logging
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107071
Approved by: https://github.com/clee2000, https://github.com/huydhn
This commit is contained in:
Zain Rizvi
2023-08-22 20:04:05 -05:00
committed by PyTorch MergeBot
parent d7f943ec82
commit 36399d067a
11 changed files with 431 additions and 497 deletions

View File

@ -1,25 +1,15 @@
import io
import json
import pathlib
import random
import sys
import unittest
from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple
from unittest import mock
from typing import Dict, List, Tuple
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))
from tools.testing.test_selections import (
_get_previously_failing_tests,
calculate_shards,
get_reordered_tests,
log_time_savings,
ShardedTest,
THRESHOLD,
)
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
except ModuleNotFoundError:
print("Can't import required modules, exiting")
exit(1)
@ -338,173 +328,5 @@ class TestCalculateShards(unittest.TestCase):
self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
file_object = io.StringIO()
json.dump(contents, file_object)
file_object.seek(0)
return file_object
def never_serial(test_name: str) -> bool:
return False
class TestParsePrevTests(unittest.TestCase):
@mock.patch("pathlib.Path.exists", return_value=False)
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
expected_failing_test_files: Set[str] = set()
found_tests = _get_previously_failing_tests()
self.assertSetEqual(expected_failing_test_files, found_tests)
@mock.patch("pathlib.Path.exists", return_value=True)
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
expected_failing_test_files: Set[str] = set()
found_tests = _get_previously_failing_tests()
self.assertSetEqual(expected_failing_test_files, found_tests)
mock_open.assert_called()
lastfailed_with_multiple_tests_per_file = {
"test/test_car.py::TestCar::test_num[17]": True,
"test/test_car.py::TestBar::test_num[25]": True,
"test/test_far.py::TestFar::test_fun_copy[17]": True,
"test/test_bar.py::TestBar::test_fun_copy[25]": True,
}
@mock.patch("pathlib.Path.exists", return_value=True)
@mock.patch(
"builtins.open",
return_value=mocked_file(lastfailed_with_multiple_tests_per_file),
)
def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> None:
expected_failing_test_files = {"test_car", "test_bar", "test_far"}
found_tests = _get_previously_failing_tests()
self.assertSetEqual(expected_failing_test_files, found_tests)
@mock.patch(
"tools.testing.test_selections._get_previously_failing_tests",
return_value={"test4"},
)
@mock.patch(
"tools.testing.test_selections._get_modified_tests",
return_value={"test2", "test4"},
)
@mock.patch(
"tools.testing.test_selections._get_file_rating_tests", return_value=["test1"]
)
def test_get_reordered_tests(
self,
mock_get_prev_failing_tests: Any,
mock_get_modified_tests: Any,
mock_get_file_rating_tests: Any,
) -> None:
tests = ["test1", "test2", "test3", "test4", "test5"]
expected_prioritized_tests = ["test4", "test2", "test1"]
expected_remaining_tests = {"test3", "test5"}
prioritized_tests, remaining_tests = get_reordered_tests(tests)
self.assertListEqual(expected_prioritized_tests, prioritized_tests)
self.assertSetEqual(expected_remaining_tests, set(remaining_tests))
def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None:
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=7.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=5.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=4.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=2.0),
ShardedTest(name="test6", shard=1, num_shards=2, time=1.0),
]
prioritized_tests = [
test for test in tests if test.name in ["test4", "test5", "test8"]
]
expected_time_savings = 9.0
time_savings = log_time_savings(
tests, prioritized_tests, is_serial_test_fn=never_serial, num_procs=2
)
self.assertEqual(
time_savings, expected_time_savings, "Received an unexpected time savings"
)
def test_compute_prioritization_time_savings_with_multiple_threads_and_many_prioritized_tests(
self,
) -> None:
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=4.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=2.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=4.0),
ShardedTest(name="test6", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test7", shard=1, num_shards=2, time=5.0),
]
prioritized_tests = [
test for test in tests if test.name in ["test2", "test3", "test7"]
]
# Drawing out the math here since this is a complicated example
# Logic for original execution assuming 2 procs
# Test | Proc 1 | Proc 2
# test1 | 4 |
# test2 | | 3
# test3 | | 2
# test4 | 3 |
# test5 | | 4
# test6 | 3 |
# test7 | | 5 <- starts at time 9 ( 3 + 2 + 4)
# Logic for new execution's prioritized pool:
# Test | Proc 1 | Proc 2
# test3 | 2 |
# test4 | | 3
# test7 | 5 | <- now starts at time 2
# Time savings = 9 - 2 = 7
expected_time_savings = 7.0
time_savings = log_time_savings(
tests, prioritized_tests, is_serial_test_fn=never_serial, num_procs=2
)
self.assertEqual(
time_savings, expected_time_savings, "Received an unexpected time savings"
)
pass
def test_compute_prioritization_time_savings_with_serialized_test(self) -> None:
tests = [
ShardedTest(name="test1", shard=1, num_shards=2, time=7.0),
ShardedTest(name="test2", shard=1, num_shards=2, time=5.0),
ShardedTest(name="test3", shard=1, num_shards=2, time=4.0),
ShardedTest(name="test4", shard=1, num_shards=2, time=3.0),
ShardedTest(name="test5", shard=1, num_shards=2, time=2.0),
ShardedTest(name="test6", shard=1, num_shards=2, time=1.0),
]
prioritized_tests = [test for test in tests if test.name in ["test3", "test6"]]
def serialized(test: str) -> bool:
return test in ["test4", "test6"]
expected_time_savings = 8.0
time_savings = log_time_savings(
tests, prioritized_tests, is_serial_test_fn=serialized, num_procs=2
)
self.assertEqual(
time_savings, expected_time_savings, "Received an unexpected time savings"
)
pass
if __name__ == "__main__":
unittest.main()