mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Test][2D] Turn on 2D state_dict tests for uneven sharding (#124255)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/124255 Approved by: https://github.com/wanchaol
This commit is contained in:
@ -61,8 +61,6 @@ class SimpleModel(nn.Module):
|
||||
return torch.rand(4, 5, device="cuda")
|
||||
|
||||
|
||||
# TODO: Temporarily disabled tests related SimpleModelUneven due to size mismatch problem.
|
||||
# TODO: Let's change back the tests after corresponding fixes are made.
|
||||
class SimpleModelUneven(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -246,9 +244,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
# TODO: See the TODO item for SimpleModelUneven.
|
||||
# @parametrize("is_even_sharded_model", [True, False])
|
||||
@parametrize("is_even_sharded_model", [True])
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
|
||||
@ -302,9 +298,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
# TODO: See the TODO item for SimpleModelUneven.
|
||||
# @parametrize("is_even_sharded_model", [True, False])
|
||||
@parametrize("is_even_sharded_model", [True])
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_load_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
|
||||
@ -357,9 +351,7 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
# TODO: See the TODO item for SimpleModelUneven.
|
||||
# @parametrize("is_even_sharded_model", [True, False])
|
||||
@parametrize("is_even_sharded_model", [True])
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_optim_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
|
||||
|
Reference in New Issue
Block a user