mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Changes sharding to attempt to put all serial tests on as few shards as possible. Parallel tests are then distributed across all shards, with most of which likely ending up on the non serial shards
Example: 8 minutes of serial tests, 20 minutes of parallel tests, 2 proc per machine, 6 machines
-> 8 + 20/2 = 18 total minutes of tests
-> 18 / 6 machines = 3 min per machine
-> all serial tests should fit on 3 machines (3min, 3 min, 2min)
-> majority of parallel tests should go on last 4 machines, one of which is shared with the serial tests
Move serial tests to run first
If I want to move to a purely numbers based sharding, this ensures that parallel tests are run with parallel tests as much as possible instead of interleaving serial + parallel tests, which decreases effectiveness of parallelization, while also ensuring that test reordering is still mostly effective.
See 73e816ee80
for example logs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119078
Approved by: https://github.com/huydhn
406 lines
16 KiB
Python
406 lines
16 KiB
Python
import pathlib
|
|
import random
|
|
import sys
|
|
import unittest
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Tuple
|
|
|
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
|
try:
|
|
# using tools/ to optimize test run.
|
|
sys.path.append(str(REPO_ROOT))
|
|
from tools.testing.test_run import ShardedTest, TestRun
|
|
from tools.testing.test_selections import calculate_shards, THRESHOLD
|
|
except ModuleNotFoundError:
|
|
print("Can't import required modules, exiting")
|
|
sys.exit(1)
|
|
|
|
|
|
def gen_class_times(test_times: Dict[str, float]) -> Dict[str, Dict[str, float]]:
|
|
return {k: {"class1": v} for k, v in test_times.items()}
|
|
|
|
|
|
class TestCalculateShards(unittest.TestCase):
|
|
tests: List[TestRun] = [
|
|
TestRun("super_long_test"),
|
|
TestRun("long_test1"),
|
|
TestRun("long_test2"),
|
|
TestRun("normal_test1"),
|
|
TestRun("normal_test2"),
|
|
TestRun("normal_test3"),
|
|
TestRun("short_test1"),
|
|
TestRun("short_test2"),
|
|
TestRun("short_test3"),
|
|
TestRun("short_test4"),
|
|
TestRun("short_test5"),
|
|
]
|
|
|
|
test_times: Dict[str, float] = {
|
|
"super_long_test": 55,
|
|
"long_test1": 22,
|
|
"long_test2": 18,
|
|
"normal_test1": 9,
|
|
"normal_test2": 7,
|
|
"normal_test3": 5,
|
|
"short_test1": 1,
|
|
"short_test2": 0.6,
|
|
"short_test3": 0.4,
|
|
"short_test4": 0.3,
|
|
"short_test5": 0.01,
|
|
}
|
|
|
|
test_class_times: Dict[str, Dict[str, float]] = {
|
|
"super_long_test": {"class1": 55},
|
|
"long_test1": {"class1": 1, "class2": 21},
|
|
"long_test2": {"class1": 10, "class2": 8},
|
|
"normal_test1": {"class1": 9},
|
|
"normal_test2": {"class1": 7},
|
|
"normal_test3": {"class1": 5},
|
|
"short_test1": {"class1": 1},
|
|
"short_test2": {"class1": 0.6},
|
|
"short_test3": {"class1": 0.4},
|
|
"short_test4": {"class1": 0.3},
|
|
"short_test5": {"class1": 0.01},
|
|
}
|
|
|
|
def assert_shards_equal(
|
|
self,
|
|
expected_shards: List[Tuple[float, List[ShardedTest]]],
|
|
actual_shards: List[Tuple[float, List[ShardedTest]]],
|
|
) -> None:
|
|
for expected, actual in zip(expected_shards, actual_shards):
|
|
self.assertAlmostEqual(expected[0], actual[0])
|
|
self.assertListEqual(expected[1], actual[1])
|
|
|
|
def test_calculate_2_shards_with_complete_test_times(self) -> None:
|
|
expected_shards = [
|
|
(
|
|
60.0,
|
|
[
|
|
ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55),
|
|
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
|
|
],
|
|
),
|
|
(
|
|
58.31,
|
|
[
|
|
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
|
ShardedTest(test="long_test2", shard=1, num_shards=1, time=18),
|
|
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
|
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
|
|
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
|
ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
|
|
ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
|
|
ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
|
|
ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
|
|
],
|
|
),
|
|
]
|
|
self.assert_shards_equal(
|
|
expected_shards,
|
|
calculate_shards(2, self.tests, self.test_times, self.test_class_times),
|
|
)
|
|
|
|
def test_calculate_1_shard_with_complete_test_times(self) -> None:
|
|
tests = self.tests.copy()
|
|
class_test1 = TestRun("long_test1", excluded=["class2"])
|
|
class_test2 = TestRun("long_test1", included=["class2"])
|
|
tests.append(class_test1)
|
|
tests.append(class_test2)
|
|
|
|
expected_shards = [
|
|
(
|
|
140.31,
|
|
[
|
|
ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55),
|
|
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
|
ShardedTest(class_test2, shard=1, num_shards=1, time=21),
|
|
ShardedTest(test="long_test2", shard=1, num_shards=1, time=18),
|
|
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
|
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
|
|
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
|
|
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
|
ShardedTest(class_test1, shard=1, num_shards=1, time=1),
|
|
ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
|
|
ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
|
|
ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
|
|
ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
|
|
],
|
|
)
|
|
]
|
|
self.assert_shards_equal(
|
|
expected_shards,
|
|
calculate_shards(1, tests, self.test_times, self.test_class_times),
|
|
)
|
|
|
|
def test_calculate_5_shards_with_complete_test_times(self) -> None:
|
|
expected_shards = [
|
|
(
|
|
55.0,
|
|
[ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55)],
|
|
),
|
|
(22.0, [ShardedTest(test="long_test1", shard=1, num_shards=1, time=22)]),
|
|
(18.0, [ShardedTest(test="long_test2", shard=1, num_shards=1, time=18)]),
|
|
(
|
|
11.31,
|
|
[
|
|
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
|
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
|
ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
|
|
ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
|
|
ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
|
|
ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
|
|
],
|
|
),
|
|
(
|
|
12.0,
|
|
[
|
|
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
|
|
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
|
|
],
|
|
),
|
|
]
|
|
self.assert_shards_equal(
|
|
expected_shards,
|
|
calculate_shards(5, self.tests, self.test_times, self.test_class_times),
|
|
)
|
|
|
|
def test_calculate_2_shards_with_incomplete_test_times(self) -> None:
|
|
incomplete_test_times = {
|
|
k: v for k, v in self.test_times.items() if "test1" in k
|
|
}
|
|
expected_shards = [
|
|
(
|
|
22.0,
|
|
[
|
|
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
|
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
|
|
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
|
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
|
|
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
|
],
|
|
),
|
|
(
|
|
10.0,
|
|
[
|
|
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
|
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
|
ShardedTest(
|
|
test="super_long_test", shard=1, num_shards=1, time=None
|
|
),
|
|
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
|
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
|
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
|
|
],
|
|
),
|
|
]
|
|
self.assert_shards_equal(
|
|
expected_shards,
|
|
calculate_shards(
|
|
2,
|
|
self.tests,
|
|
incomplete_test_times,
|
|
gen_class_times(incomplete_test_times),
|
|
),
|
|
)
|
|
|
|
def test_calculate_5_shards_with_incomplete_test_times(self) -> None:
|
|
incomplete_test_times = {
|
|
k: v for k, v in self.test_times.items() if "test1" in k
|
|
}
|
|
expected_shards = [
|
|
(
|
|
22.0,
|
|
[
|
|
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
|
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
|
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
|
],
|
|
),
|
|
(
|
|
9.0,
|
|
[
|
|
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
|
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
|
],
|
|
),
|
|
(
|
|
1.0,
|
|
[
|
|
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
|
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
|
],
|
|
),
|
|
(
|
|
0.0,
|
|
[
|
|
ShardedTest(
|
|
test="super_long_test", shard=1, num_shards=1, time=None
|
|
),
|
|
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
|
|
],
|
|
),
|
|
(
|
|
0.0,
|
|
[
|
|
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
|
|
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
|
|
],
|
|
),
|
|
]
|
|
self.assert_shards_equal(
|
|
expected_shards,
|
|
calculate_shards(
|
|
5,
|
|
self.tests,
|
|
incomplete_test_times,
|
|
gen_class_times(incomplete_test_times),
|
|
),
|
|
)
|
|
|
|
def test_split_shards(self) -> None:
|
|
test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
|
|
expected_shards = [
|
|
(600.0, [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD)]),
|
|
(600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
|
|
]
|
|
self.assert_shards_equal(
|
|
expected_shards,
|
|
calculate_shards(
|
|
2,
|
|
[TestRun(t) for t in test_times.keys()],
|
|
test_times,
|
|
gen_class_times(test_times),
|
|
),
|
|
)
|
|
|
|
test_times = {"test1": THRESHOLD * 4, "test2": THRESHOLD * 2.5}
|
|
expected_shards = [
|
|
(
|
|
2200.0,
|
|
[
|
|
ShardedTest(test="test1", shard=1, num_shards=4, time=600.0),
|
|
ShardedTest(test="test1", shard=3, num_shards=4, time=600.0),
|
|
ShardedTest(test="test2", shard=1, num_shards=3, time=500.0),
|
|
ShardedTest(test="test2", shard=3, num_shards=3, time=500.0),
|
|
],
|
|
),
|
|
(
|
|
1700.0,
|
|
[
|
|
ShardedTest(test="test1", shard=2, num_shards=4, time=600.0),
|
|
ShardedTest(test="test1", shard=4, num_shards=4, time=600.0),
|
|
ShardedTest(test="test2", shard=2, num_shards=3, time=500.0),
|
|
],
|
|
),
|
|
]
|
|
self.assert_shards_equal(
|
|
expected_shards,
|
|
calculate_shards(
|
|
2,
|
|
[TestRun(t) for t in test_times.keys()],
|
|
test_times,
|
|
gen_class_times(test_times),
|
|
),
|
|
)
|
|
|
|
test_times = {"test1": THRESHOLD / 2, "test2": THRESHOLD}
|
|
expected_shards = [
|
|
(600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
|
|
(
|
|
300.0,
|
|
[ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD / 2)],
|
|
),
|
|
]
|
|
self.assert_shards_equal(
|
|
expected_shards,
|
|
calculate_shards(
|
|
2,
|
|
[TestRun(t) for t in test_times.keys()],
|
|
test_times,
|
|
gen_class_times(test_times),
|
|
),
|
|
)
|
|
|
|
def test_zero_tests(self) -> None:
|
|
self.assertListEqual([(0.0, []), (0.0, [])], calculate_shards(2, [], {}, None))
|
|
|
|
def test_split_shards_random(self) -> None:
|
|
random.seed(120)
|
|
for _ in range(100):
|
|
num_shards = random.randint(1, 10)
|
|
num_tests = random.randint(1, 100)
|
|
random_times: Dict[str, float] = {
|
|
str(i): random.randint(0, THRESHOLD * 10) for i in range(num_tests)
|
|
}
|
|
serial = [str(i) for i in range(num_tests) if random.randint(0, 1) == 0]
|
|
|
|
shards = calculate_shards(
|
|
num_shards,
|
|
[TestRun(t) for t in random_times.keys()],
|
|
random_times,
|
|
None,
|
|
must_serial=lambda x: x in serial,
|
|
sort_by_time=random.randint(0, 1) == 0,
|
|
)
|
|
|
|
times = [x[0] for x in shards]
|
|
max_diff = max(times) - min(times)
|
|
self.assertTrue(max_diff <= THRESHOLD)
|
|
|
|
all_sharded_tests: Dict[str, List[ShardedTest]] = defaultdict(list)
|
|
for _, sharded_tests in shards:
|
|
for sharded_test in sharded_tests:
|
|
all_sharded_tests[sharded_test.name].append(sharded_test)
|
|
|
|
# Check that all test files are represented in the shards
|
|
self.assertListEqual(
|
|
sorted(random_times.keys()), sorted(all_sharded_tests.keys())
|
|
)
|
|
# Check that for each test file, the pytest shards' times adds up to
|
|
# original and all shards are present
|
|
for test, sharded_tests in all_sharded_tests.items():
|
|
self.assertAlmostEqual(
|
|
random_times[test], sum(x.time or 0 for x in sharded_tests)
|
|
)
|
|
self.assertListEqual(
|
|
list(range(sharded_tests[0].num_shards)),
|
|
sorted(x.shard - 1 for x in sharded_tests),
|
|
)
|
|
|
|
def test_calculate_2_shards_against_optimal_shards(self) -> None:
|
|
random.seed(120)
|
|
for _ in range(100):
|
|
random_times = {k.test_file: random.random() * 10 for k in self.tests}
|
|
# all test times except first two
|
|
rest_of_tests = [
|
|
i
|
|
for k, i in random_times.items()
|
|
if k != "super_long_test" and k != "long_test1"
|
|
]
|
|
sum_of_rest = sum(rest_of_tests)
|
|
random_times["super_long_test"] = max(sum_of_rest / 2, *rest_of_tests)
|
|
random_times["long_test1"] = sum_of_rest - random_times["super_long_test"]
|
|
# An optimal sharding would look like the below, but we don't need to compute this for the test:
|
|
# optimal_shards = [
|
|
# (sum_of_rest, ['super_long_test', 'long_test1']),
|
|
# (sum_of_rest, [i for i in self.tests if i != 'super_long_test' and i != 'long_test1']),
|
|
# ]
|
|
calculated_shards = calculate_shards(
|
|
2, self.tests, random_times, gen_class_times(random_times)
|
|
)
|
|
max_shard_time = max(calculated_shards[0][0], calculated_shards[1][0])
|
|
if sum_of_rest != 0:
|
|
# The calculated shard should not have a ratio worse than 7/6 for num_shards = 2
|
|
self.assertGreaterEqual(7.0 / 6.0, max_shard_time / sum_of_rest)
|
|
sorted_tests = sorted([t.test_file for t in self.tests])
|
|
sorted_shard_tests = sorted(
|
|
calculated_shards[0][1] + calculated_shards[1][1]
|
|
)
|
|
# All the tests should be represented by some shard
|
|
self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|