Allow sharding for distributed tests

Addresses my mistake introduced in https://github.com/pytorch/pytorch/pull/76536#issuecomment-1112657429

Also allows for sharding 1 in run_test.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76570
Approved by: https://github.com/jeffdaily, https://github.com/seemethere
This commit is contained in:
Jane Xu
2022-04-29 03:55:07 +00:00
committed by PyTorch MergeBot
parent fe1968dea0
commit 0708630d9f
3 changed files with 36 additions and 6 deletions

View File

@ -65,6 +65,29 @@ class TestCalculateShards(unittest.TestCase):
expected_shards, calculate_shards(2, self.tests, self.test_times)
)
def test_calculate_1_shard_with_complete_test_times(self) -> None:
expected_shards = [
(
118.31,
[
"super_long_test",
"long_test1",
"long_test2",
"normal_test1",
"normal_test2",
"normal_test3",
"short_test1",
"short_test2",
"short_test3",
"short_test4",
"short_test5",
],
),
]
self.assert_shards_equal(
expected_shards, calculate_shards(1, self.tests, self.test_times)
)
def test_calculate_5_shards_with_complete_test_times(self) -> None:
expected_shards = [
(55.0, ["super_long_test"]),