mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow sharding for distributed tests
Addresses my mistake introduced in https://github.com/pytorch/pytorch/pull/76536#issuecomment-1112657429 Also allows for sharding 1 in run_test.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/76570 Approved by: https://github.com/jeffdaily, https://github.com/seemethere
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe1968dea0
commit
0708630d9f
@ -65,6 +65,29 @@ class TestCalculateShards(unittest.TestCase):
|
||||
expected_shards, calculate_shards(2, self.tests, self.test_times)
|
||||
)
|
||||
|
||||
def test_calculate_1_shard_with_complete_test_times(self) -> None:
|
||||
expected_shards = [
|
||||
(
|
||||
118.31,
|
||||
[
|
||||
"super_long_test",
|
||||
"long_test1",
|
||||
"long_test2",
|
||||
"normal_test1",
|
||||
"normal_test2",
|
||||
"normal_test3",
|
||||
"short_test1",
|
||||
"short_test2",
|
||||
"short_test3",
|
||||
"short_test4",
|
||||
"short_test5",
|
||||
],
|
||||
),
|
||||
]
|
||||
self.assert_shards_equal(
|
||||
expected_shards, calculate_shards(1, self.tests, self.test_times)
|
||||
)
|
||||
|
||||
def test_calculate_5_shards_with_complete_test_times(self) -> None:
|
||||
expected_shards = [
|
||||
(55.0, ["super_long_test"]),
|
||||
|
Reference in New Issue
Block a user