mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port existing heuristics to TD framework (#107071)
This PR looks big, but it's mostly just refactorings with a bit of dead code deletion. Exceptions are: - Some metric emissions were changed to comply with the new TD format - Some logging changes - We now run tests in three batches (highly_relevant, probably_relevant, unranked_relevance) instead of the previous two (prioritized and general) Refactorings done: - Moves all test reordering code to the new TD framework - Refactors run_test.py to cleanly support multiple levels of test priorities - Deletes some dead code that was originally written for logging Pull Request resolved: https://github.com/pytorch/pytorch/pull/107071 Approved by: https://github.com/clee2000, https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
d7f943ec82
commit
36399d067a
@ -1,25 +1,15 @@
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from unittest import mock
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.testing.test_selections import (
|
||||
_get_previously_failing_tests,
|
||||
calculate_shards,
|
||||
get_reordered_tests,
|
||||
log_time_savings,
|
||||
ShardedTest,
|
||||
THRESHOLD,
|
||||
)
|
||||
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
|
||||
except ModuleNotFoundError:
|
||||
print("Can't import required modules, exiting")
|
||||
exit(1)
|
||||
@ -338,173 +328,5 @@ class TestCalculateShards(unittest.TestCase):
|
||||
self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])
|
||||
|
||||
|
||||
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
|
||||
file_object = io.StringIO()
|
||||
json.dump(contents, file_object)
|
||||
file_object.seek(0)
|
||||
return file_object
|
||||
|
||||
|
||||
def never_serial(test_name: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TestParsePrevTests(unittest.TestCase):
|
||||
@mock.patch("pathlib.Path.exists", return_value=False)
|
||||
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
|
||||
expected_failing_test_files: Set[str] = set()
|
||||
|
||||
found_tests = _get_previously_failing_tests()
|
||||
|
||||
self.assertSetEqual(expected_failing_test_files, found_tests)
|
||||
|
||||
@mock.patch("pathlib.Path.exists", return_value=True)
|
||||
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
|
||||
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
|
||||
expected_failing_test_files: Set[str] = set()
|
||||
|
||||
found_tests = _get_previously_failing_tests()
|
||||
|
||||
self.assertSetEqual(expected_failing_test_files, found_tests)
|
||||
mock_open.assert_called()
|
||||
|
||||
lastfailed_with_multiple_tests_per_file = {
|
||||
"test/test_car.py::TestCar::test_num[17]": True,
|
||||
"test/test_car.py::TestBar::test_num[25]": True,
|
||||
"test/test_far.py::TestFar::test_fun_copy[17]": True,
|
||||
"test/test_bar.py::TestBar::test_fun_copy[25]": True,
|
||||
}
|
||||
|
||||
@mock.patch("pathlib.Path.exists", return_value=True)
|
||||
@mock.patch(
|
||||
"builtins.open",
|
||||
return_value=mocked_file(lastfailed_with_multiple_tests_per_file),
|
||||
)
|
||||
def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> None:
|
||||
expected_failing_test_files = {"test_car", "test_bar", "test_far"}
|
||||
found_tests = _get_previously_failing_tests()
|
||||
|
||||
self.assertSetEqual(expected_failing_test_files, found_tests)
|
||||
|
||||
@mock.patch(
|
||||
"tools.testing.test_selections._get_previously_failing_tests",
|
||||
return_value={"test4"},
|
||||
)
|
||||
@mock.patch(
|
||||
"tools.testing.test_selections._get_modified_tests",
|
||||
return_value={"test2", "test4"},
|
||||
)
|
||||
@mock.patch(
|
||||
"tools.testing.test_selections._get_file_rating_tests", return_value=["test1"]
|
||||
)
|
||||
def test_get_reordered_tests(
|
||||
self,
|
||||
mock_get_prev_failing_tests: Any,
|
||||
mock_get_modified_tests: Any,
|
||||
mock_get_file_rating_tests: Any,
|
||||
) -> None:
|
||||
tests = ["test1", "test2", "test3", "test4", "test5"]
|
||||
|
||||
expected_prioritized_tests = ["test4", "test2", "test1"]
|
||||
expected_remaining_tests = {"test3", "test5"}
|
||||
|
||||
prioritized_tests, remaining_tests = get_reordered_tests(tests)
|
||||
|
||||
self.assertListEqual(expected_prioritized_tests, prioritized_tests)
|
||||
self.assertSetEqual(expected_remaining_tests, set(remaining_tests))
|
||||
|
||||
def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None:
|
||||
tests = [
|
||||
ShardedTest(name="test1", shard=1, num_shards=2, time=7.0),
|
||||
ShardedTest(name="test2", shard=1, num_shards=2, time=5.0),
|
||||
ShardedTest(name="test3", shard=1, num_shards=2, time=4.0),
|
||||
ShardedTest(name="test4", shard=1, num_shards=2, time=3.0),
|
||||
ShardedTest(name="test5", shard=1, num_shards=2, time=2.0),
|
||||
ShardedTest(name="test6", shard=1, num_shards=2, time=1.0),
|
||||
]
|
||||
prioritized_tests = [
|
||||
test for test in tests if test.name in ["test4", "test5", "test8"]
|
||||
]
|
||||
|
||||
expected_time_savings = 9.0
|
||||
|
||||
time_savings = log_time_savings(
|
||||
tests, prioritized_tests, is_serial_test_fn=never_serial, num_procs=2
|
||||
)
|
||||
self.assertEqual(
|
||||
time_savings, expected_time_savings, "Received an unexpected time savings"
|
||||
)
|
||||
|
||||
def test_compute_prioritization_time_savings_with_multiple_threads_and_many_prioritized_tests(
|
||||
self,
|
||||
) -> None:
|
||||
tests = [
|
||||
ShardedTest(name="test1", shard=1, num_shards=2, time=4.0),
|
||||
ShardedTest(name="test2", shard=1, num_shards=2, time=3.0),
|
||||
ShardedTest(name="test3", shard=1, num_shards=2, time=2.0),
|
||||
ShardedTest(name="test4", shard=1, num_shards=2, time=3.0),
|
||||
ShardedTest(name="test5", shard=1, num_shards=2, time=4.0),
|
||||
ShardedTest(name="test6", shard=1, num_shards=2, time=3.0),
|
||||
ShardedTest(name="test7", shard=1, num_shards=2, time=5.0),
|
||||
]
|
||||
prioritized_tests = [
|
||||
test for test in tests if test.name in ["test2", "test3", "test7"]
|
||||
]
|
||||
|
||||
# Drawing out the math here since this is a complicated example
|
||||
|
||||
# Logic for original execution assuming 2 procs
|
||||
# Test | Proc 1 | Proc 2
|
||||
# test1 | 4 |
|
||||
# test2 | | 3
|
||||
# test3 | | 2
|
||||
# test4 | 3 |
|
||||
# test5 | | 4
|
||||
# test6 | 3 |
|
||||
# test7 | | 5 <- starts at time 9 ( 3 + 2 + 4)
|
||||
|
||||
# Logic for new execution's prioritized pool:
|
||||
# Test | Proc 1 | Proc 2
|
||||
# test3 | 2 |
|
||||
# test4 | | 3
|
||||
# test7 | 5 | <- now starts at time 2
|
||||
|
||||
# Time savings = 9 - 2 = 7
|
||||
|
||||
expected_time_savings = 7.0
|
||||
|
||||
time_savings = log_time_savings(
|
||||
tests, prioritized_tests, is_serial_test_fn=never_serial, num_procs=2
|
||||
)
|
||||
self.assertEqual(
|
||||
time_savings, expected_time_savings, "Received an unexpected time savings"
|
||||
)
|
||||
pass
|
||||
|
||||
def test_compute_prioritization_time_savings_with_serialized_test(self) -> None:
|
||||
tests = [
|
||||
ShardedTest(name="test1", shard=1, num_shards=2, time=7.0),
|
||||
ShardedTest(name="test2", shard=1, num_shards=2, time=5.0),
|
||||
ShardedTest(name="test3", shard=1, num_shards=2, time=4.0),
|
||||
ShardedTest(name="test4", shard=1, num_shards=2, time=3.0),
|
||||
ShardedTest(name="test5", shard=1, num_shards=2, time=2.0),
|
||||
ShardedTest(name="test6", shard=1, num_shards=2, time=1.0),
|
||||
]
|
||||
prioritized_tests = [test for test in tests if test.name in ["test3", "test6"]]
|
||||
|
||||
def serialized(test: str) -> bool:
|
||||
return test in ["test4", "test6"]
|
||||
|
||||
expected_time_savings = 8.0
|
||||
|
||||
time_savings = log_time_savings(
|
||||
tests, prioritized_tests, is_serial_test_fn=serialized, num_procs=2
|
||||
)
|
||||
self.assertEqual(
|
||||
time_savings, expected_time_savings, "Received an unexpected time savings"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user