Alternate sharding (#119078)

Changes sharding to attempt to put all serial tests on as few shards as possible.  Parallel tests are then distributed across all shards, with most of which likely ending up on the non serial shards

Example: 8 minutes of serial tests, 20 minutes of parallel tests, 2 proc per machine, 6 machines
-> 8 + 20/2 = 18 total minutes of tests
-> 18 / 6 machines = 3 min per machine
-> all serial tests should fit on 3 machines (3min, 3 min, 2min)
-> majority of parallel tests should go on last 4 machines, one of which is shared with the serial tests

Move serial tests to run first

If I want to move to a purely numbers based sharding, this ensures that parallel tests are run with parallel tests as much as possible instead of interleaving serial + parallel tests, which decreases effectiveness of parallelization, while also ensuring that test reordering is still mostly effective.

See 73e816ee80 for example logs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119078
Approved by: https://github.com/huydhn
This commit is contained in:
Catherine Lee
2024-02-21 16:40:27 +00:00
committed by PyTorch MergeBot
parent a24cba35b0
commit cfddfce0d3
3 changed files with 197 additions and 100 deletions

View File

@ -322,6 +322,9 @@ class TestCalculateShards(unittest.TestCase):
),
)
def test_zero_tests(self) -> None:
self.assertListEqual([(0.0, []), (0.0, [])], calculate_shards(2, [], {}, None))
def test_split_shards_random(self) -> None:
random.seed(120)
for _ in range(100):
@ -330,27 +333,32 @@ class TestCalculateShards(unittest.TestCase):
random_times: Dict[str, float] = {
str(i): random.randint(0, THRESHOLD * 10) for i in range(num_tests)
}
serial = [str(i) for i in range(num_tests) if random.randint(0, 1) == 0]
shards = calculate_shards(
num_shards,
[TestRun(t) for t in random_times.keys()],
random_times,
gen_class_times(random_times),
None,
must_serial=lambda x: x in serial,
sort_by_time=random.randint(0, 1) == 0,
)
times = [x[0] for x in shards]
max_diff = max(times) - min(times)
self.assertTrue(max_diff <= THRESHOLD)
all_sharded_tests = defaultdict(list)
for time, sharded_tests in shards:
self.assertEqual(time, sum(x.time for x in sharded_tests))
all_sharded_tests: Dict[str, List[ShardedTest]] = defaultdict(list)
for _, sharded_tests in shards:
for sharded_test in sharded_tests:
all_sharded_tests[sharded_test.name].append(sharded_test)
# Check that all test files are represented in the shards
self.assertListEqual(
sorted(random_times.keys()), sorted(all_sharded_tests.keys())
)
# Check that for each test file, the pytest shards' times adds up to
# original and all shards are present
for test, sharded_tests in all_sharded_tests.items():
self.assertAlmostEqual(
random_times[test], sum(x.time or 0 for x in sharded_tests)