mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix round robin sharding (#121022)
Fix round robin sharding when there are no test times and sort_by_time=False Adds more tests to test_test_selections for sort_by_time=False Adds more checks to test_split_shards_random for serial/parallel ordering + ordering of tests Refactoring of dup code Tested locally by running `python test/run_test.py --shard 3 5` with no test times downloaded and checked that it wasn't an empty list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121022 Approved by: https://github.com/huydhn, https://github.com/osalpekar
This commit is contained in:
committed by
PyTorch MergeBot
parent
e2ac2dc13a
commit
6801595349
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
@ -72,6 +73,109 @@ class TestCalculateShards(unittest.TestCase):
|
||||
self.assertAlmostEqual(expected[0], actual[0])
|
||||
self.assertListEqual(expected[1], actual[1])
|
||||
|
||||
def test_no_times(self) -> None:
|
||||
# Check that round robin sharding is used when no times are provided
|
||||
expected_shards = [
|
||||
(
|
||||
0.0,
|
||||
[
|
||||
ShardedTest(
|
||||
test="super_long_test", shard=1, num_shards=1, time=None
|
||||
),
|
||||
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
[
|
||||
ShardedTest(test="long_test1", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards,
|
||||
calculate_shards(2, self.tests, {}, {}, sort_by_time=False),
|
||||
)
|
||||
|
||||
def test_some_times_with_not_sort_by_time(self) -> None:
|
||||
expected_shards = [
|
||||
(
|
||||
400.0,
|
||||
[
|
||||
ShardedTest(test="test_1", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="test_2", shard=1, num_shards=1, time=400),
|
||||
ShardedTest(test="test_5", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
300.0,
|
||||
[
|
||||
ShardedTest(test="test_3", shard=1, num_shards=1, time=300),
|
||||
ShardedTest(test="test_4", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards,
|
||||
calculate_shards(
|
||||
2,
|
||||
[
|
||||
TestRun("test_1"),
|
||||
TestRun("test_2"),
|
||||
TestRun("test_3"),
|
||||
TestRun("test_4"),
|
||||
TestRun("test_5"),
|
||||
],
|
||||
{"test_2": 400, "test_3": 300},
|
||||
{},
|
||||
sort_by_time=False,
|
||||
),
|
||||
)
|
||||
|
||||
def test_serial_parallel_interleaving(self) -> None:
|
||||
expected_shards = [
|
||||
(
|
||||
300.0,
|
||||
[
|
||||
ShardedTest(test="test_1", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="test_3", shard=1, num_shards=1, time=300),
|
||||
ShardedTest(test="test_4", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
400.0,
|
||||
[
|
||||
ShardedTest(test="test_2", shard=1, num_shards=1, time=400),
|
||||
ShardedTest(test="test_5", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards,
|
||||
calculate_shards(
|
||||
2,
|
||||
[
|
||||
TestRun("test_1"),
|
||||
TestRun("test_2"),
|
||||
TestRun("test_3"),
|
||||
TestRun("test_4"),
|
||||
TestRun("test_5"),
|
||||
],
|
||||
{"test_2": 400, "test_3": 300},
|
||||
{},
|
||||
must_serial=lambda x: x in ["test_1", "test_3"],
|
||||
sort_by_time=False,
|
||||
),
|
||||
)
|
||||
|
||||
def test_calculate_2_shards_with_complete_test_times(self) -> None:
|
||||
expected_shards = [
|
||||
(
|
||||
@ -174,10 +278,12 @@ class TestCalculateShards(unittest.TestCase):
|
||||
22.0,
|
||||
[
|
||||
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(
|
||||
test="super_long_test", shard=1, num_shards=1, time=None
|
||||
),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
@ -185,12 +291,10 @@ class TestCalculateShards(unittest.TestCase):
|
||||
[
|
||||
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(
|
||||
test="super_long_test", shard=1, num_shards=1, time=None
|
||||
),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
]
|
||||
@ -213,27 +317,6 @@ class TestCalculateShards(unittest.TestCase):
|
||||
22.0,
|
||||
[
|
||||
ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
9.0,
|
||||
[
|
||||
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
1.0,
|
||||
[
|
||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
[
|
||||
ShardedTest(
|
||||
test="super_long_test", shard=1, num_shards=1, time=None
|
||||
),
|
||||
@ -241,12 +324,33 @@ class TestCalculateShards(unittest.TestCase):
|
||||
],
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
9.0,
|
||||
[
|
||||
ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
|
||||
ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
1.0,
|
||||
[
|
||||
ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
|
||||
ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
|
||||
ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
[
|
||||
ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
[
|
||||
ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards,
|
||||
@ -330,23 +434,27 @@ class TestCalculateShards(unittest.TestCase):
|
||||
for _ in range(100):
|
||||
num_shards = random.randint(1, 10)
|
||||
num_tests = random.randint(1, 100)
|
||||
test_names = [str(i) for i in range(num_tests)]
|
||||
tests = [TestRun(x) for x in test_names]
|
||||
serial = [x for x in test_names if random.randint(0, 1) == 0]
|
||||
has_times = [x for x in test_names if random.randint(0, 1) == 0]
|
||||
random_times: Dict[str, float] = {
|
||||
str(i): random.randint(0, THRESHOLD * 10) for i in range(num_tests)
|
||||
i: random.randint(0, THRESHOLD * 10) for i in has_times
|
||||
}
|
||||
serial = [str(i) for i in range(num_tests) if random.randint(0, 1) == 0]
|
||||
sort_by_time = random.randint(0, 1) == 0
|
||||
|
||||
shards = calculate_shards(
|
||||
num_shards,
|
||||
[TestRun(t) for t in random_times.keys()],
|
||||
tests,
|
||||
random_times,
|
||||
None,
|
||||
must_serial=lambda x: x in serial,
|
||||
sort_by_time=random.randint(0, 1) == 0,
|
||||
sort_by_time=sort_by_time,
|
||||
)
|
||||
|
||||
times = [x[0] for x in shards]
|
||||
max_diff = max(times) - min(times)
|
||||
self.assertTrue(max_diff <= THRESHOLD)
|
||||
self.assertTrue(max_diff <= THRESHOLD + (num_tests - len(has_times)) * 60)
|
||||
|
||||
all_sharded_tests: Dict[str, List[ShardedTest]] = defaultdict(list)
|
||||
for _, sharded_tests in shards:
|
||||
@ -354,19 +462,56 @@ class TestCalculateShards(unittest.TestCase):
|
||||
all_sharded_tests[sharded_test.name].append(sharded_test)
|
||||
|
||||
# Check that all test files are represented in the shards
|
||||
self.assertListEqual(
|
||||
sorted(random_times.keys()), sorted(all_sharded_tests.keys())
|
||||
)
|
||||
self.assertListEqual(sorted(test_names), sorted(all_sharded_tests.keys()))
|
||||
# Check that for each test file, the pytest shards' times adds up to
|
||||
# original and all shards are present
|
||||
for test, sharded_tests in all_sharded_tests.items():
|
||||
self.assertAlmostEqual(
|
||||
random_times[test], sum(x.time or 0 for x in sharded_tests)
|
||||
)
|
||||
if random_times.get(test) is None:
|
||||
self.assertTrue(len(sharded_tests) == 1)
|
||||
self.assertTrue(sharded_tests[0].time is None)
|
||||
else:
|
||||
# x.time is not None because of the above check
|
||||
self.assertAlmostEqual(
|
||||
random_times[test], sum(x.time for x in sharded_tests) # type: ignore[misc]
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(range(sharded_tests[0].num_shards)),
|
||||
sorted(x.shard - 1 for x in sharded_tests),
|
||||
)
|
||||
# Check that sort_by_time is respected
|
||||
if sort_by_time:
|
||||
|
||||
def comparator(a: ShardedTest, b: ShardedTest) -> int:
|
||||
# serial comes first
|
||||
if a.name in serial and b.name not in serial:
|
||||
return -1
|
||||
if a.name not in serial and b.name in serial:
|
||||
return 1
|
||||
# known test times come first
|
||||
if a.time is not None and b.time is None:
|
||||
return -1
|
||||
if a.time is None and b.time is not None:
|
||||
return 1
|
||||
if a.time == b.time:
|
||||
return 0
|
||||
# not None due to the above checks
|
||||
return -1 if a.time > b.time else 1 # type: ignore[operator]
|
||||
|
||||
else:
|
||||
|
||||
def comparator(a: ShardedTest, b: ShardedTest) -> int:
|
||||
# serial comes first
|
||||
if a.name in serial and b.name not in serial:
|
||||
return -1
|
||||
if a.name not in serial and b.name in serial:
|
||||
return 1
|
||||
return test_names.index(a.name) - test_names.index(b.name)
|
||||
|
||||
for _, sharded_tests in shards:
|
||||
self.assertListEqual(
|
||||
sorted(sharded_tests, key=functools.cmp_to_key(comparator)),
|
||||
sharded_tests,
|
||||
)
|
||||
|
||||
def test_calculate_2_shards_against_optimal_shards(self) -> None:
|
||||
random.seed(120)
|
||||
|
Reference in New Issue
Block a user