[DCP] Fixes the stateless optimizer issue of distributed state_dict (#135535)

Some optimizers don't have states that can cause get_state_dict/set_state_dict behave incorrectly. This PR fixes the issues.

fixes: https://github.com/pytorch/pytorch/issues/133415

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135535
Approved by: https://github.com/wz337
This commit is contained in:
Chien-Chin Huang
2024-09-09 15:46:24 -07:00
committed by PyTorch MergeBot
parent 7ec17b49cf
commit 1d9fefff19
3 changed files with 54 additions and 40 deletions

View File

@ -86,18 +86,19 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
model, optim, copy_optim, dist_model, dist_optim = init_model_optim()
# Train 10 steps.
_dist_optim = [dist_optim] if not isinstance(dist_optim, list) else dist_optim
for i in range(10):
optim.zero_grad()
for d_optim in _dist_optim:
d_optim.zero_grad()
batch = torch.rand(8, 100, device="cuda")
model(batch).sum().backward()
optim.step()
dist_model(batch).sum().backward()
if not isinstance(dist_optim, list):
dist_optim.step()
dist_optim.zero_grad()
else:
for _dist_optim in dist_optim:
_dist_optim.zero_grad()
optim.zero_grad()
optim.step()
for d_optim in _dist_optim:
d_optim.step()
# Get the state_dict, and compare the result
msd = model.state_dict()
@ -176,8 +177,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
device_mesh = init_device_mesh("cuda", (self.world_size,))
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3, foreach=True)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3, foreach=True)
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True)
if wrapping:
strategy = set(wrapping)
else:
@ -204,7 +205,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
if compile_model:
dist_model = torch.compile(dist_model)
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3, foreach=True)
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4, foreach=True)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim)
@ -218,7 +219,11 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
"use_composable": [True, False],
"use_dtensor": [True, False],
"wrapping": [(), (nn.Linear, UnitModule)],
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
"optimizer_class": [
torch.optim.Adam,
torch.optim.AdamW,
torch.optim.SGD,
],
},
self._test_fsdp,
)
@ -248,10 +253,10 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(
orig_model.parameters(), lr=1e-3, foreach=foreach
orig_model.parameters(), lr=1e-4, foreach=foreach
)
copy_optim = optimizer_class(
orig_model.parameters(), lr=1e-3, foreach=foreach
orig_model.parameters(), lr=1e-4, foreach=foreach
)
dist_model = FSDP2(
@ -262,7 +267,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
if compile_model:
dist_model = torch.compile(dist_model)
dist_optim = optimizer_class(
dist_model.parameters(), lr=1e-3, foreach=foreach
dist_model.parameters(), lr=1e-4, foreach=foreach
)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
@ -284,13 +289,13 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
if use_composable:
dist_model = replicate(copy.deepcopy(orig_model))
else:
dist_model = DDP(copy.deepcopy(orig_model))
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim)
@ -301,7 +306,11 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self.run_subtests(
{
"use_composable": [True, False],
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
"optimizer_class": [
torch.optim.Adam,
torch.optim.AdamW,
torch.optim.SGD,
],
},
self._test_ddp,
)
@ -320,8 +329,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
orig_model.u1.parameters(), orig_model.u2.parameters()
):
param.requires_grad = False
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
dist_model = copy.deepcopy(orig_model)
if use_composable:
replicate(dist_model.l)
@ -336,13 +345,13 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
)
if optim_in_backward:
_apply_optimizer_in_backward(
optimizer_class, dist_model.parameters(), {"lr": 1e-3}
optimizer_class, dist_model.parameters(), {"lr": 1e-4}
)
dist_optim = [
p._in_backward_optimizers[0] for p in dist_model.parameters()
]
else:
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim, test_frozen)
@ -395,10 +404,10 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
model_copy = copy.deepcopy(orig_model)
optim_copy = optimizer_class(model_copy.parameters(), lr=1e-3)
optim_copy = optimizer_class(model_copy.parameters(), lr=1e-4)
return orig_model, orig_optim, copy_optim, model_copy, optim_copy
self._test_save_load(init_model_optim)
@ -445,7 +454,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
device_mesh=device_mesh,
)
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4)
mst, ost = get_state_dict(
dist_model,
@ -887,10 +896,10 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def init_model_optim():
device_mesh = init_device_mesh("cuda", (self.world_size,))
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3)
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-4)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim)
@ -958,7 +967,7 @@ class TestNoComm(MultiProcessTestCase):
@skip_if_lt_x_gpu(1)
def test_no_dist(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
optim = torch.optim.AdamW(model.parameters(), lr=1e-4)
self.assertFalse(dist.is_initialized())
msd = get_model_state_dict(

View File

@ -591,17 +591,17 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None:
# The optimizer state is initialized.
return
# There are some stateless optimizers like SGD. These optimizer will
# not return in the above condition. So if gradients exist, we should also
# return. If gradients do not exist, the following initialization should
# not disturb SGD because the gradients and lr are both zero.
for param_group in optim.param_groups:
for param in param_group[_PARAMS]:
if param.grad is not None:
raise RuntimeError(
"state_dict can only be used if the optimizer "
"states are initialized (usually after one step() with "
"gradients) or gradients are None. For the later case, "
"state_dict will fake the gradients as zero "
"to initialize the optimizer states. However, the "
"gradients are not None."
)
return
for param_group in optim.param_groups:
for param in param_group[_PARAMS]:
if param.requires_grad:
param.grad = torch.zeros_like(param)

View File

@ -43,7 +43,12 @@ class VerifyStateDictMixin:
dist_param = dist_msd.get(fqn, None)
if not options.ignore_frozen_params:
self.assertIsNotNone(dist_param, f"{fqn=}")
self._compare_tensor(param, dist_param, offload_to_cpu)
try:
self._compare_tensor(param, dist_param, offload_to_cpu)
except AssertionError as e:
raise AssertionError(
f"{fqn} has mismatched value {param} {dist_param}"
) from e
elif dist_param is None:
self.assertFalse(param.requires_grad, f"{fqn=}")