[Test][2D] Turn on 2D state_dict tests for uneven sharding (#124255)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124255
Approved by: https://github.com/wanchaol
This commit is contained in:
wz337
2024-04-17 20:45:30 +00:00
committed by PyTorch MergeBot
parent 93e249969b
commit cdc855af97

View File

@ -61,8 +61,6 @@ class SimpleModel(nn.Module):
return torch.rand(4, 5, device="cuda")
# TODO: Temporarily disabled tests related SimpleModelUneven due to size mismatch problem.
# TODO: Let's change back the tests after corresponding fixes are made.
class SimpleModelUneven(nn.Module):
def __init__(self):
super().__init__()
@ -246,9 +244,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
# TODO: See the TODO item for SimpleModelUneven.
# @parametrize("is_even_sharded_model", [True, False])
@parametrize("is_even_sharded_model", [True])
@parametrize("is_even_sharded_model", [True, False])
def test_2d_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -302,9 +298,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
# TODO: See the TODO item for SimpleModelUneven.
# @parametrize("is_even_sharded_model", [True, False])
@parametrize("is_even_sharded_model", [True])
@parametrize("is_even_sharded_model", [True, False])
def test_2d_load_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -357,9 +351,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
# TODO: See the TODO item for SimpleModelUneven.
# @parametrize("is_even_sharded_model", [True, False])
@parametrize("is_even_sharded_model", [True])
@parametrize("is_even_sharded_model", [True, False])
def test_2d_optim_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven