mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DCP] Fixes the stateless optimizer issue of distributed state_dict (#135535)
Some optimizers don't have states that can cause get_state_dict/set_state_dict behave incorrectly. This PR fixes the issues. fixes: https://github.com/pytorch/pytorch/issues/133415 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135535 Approved by: https://github.com/wz337
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ec17b49cf
commit
1d9fefff19
@ -86,18 +86,19 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
model, optim, copy_optim, dist_model, dist_optim = init_model_optim()
|
||||
|
||||
# Train 10 steps.
|
||||
_dist_optim = [dist_optim] if not isinstance(dist_optim, list) else dist_optim
|
||||
for i in range(10):
|
||||
optim.zero_grad()
|
||||
for d_optim in _dist_optim:
|
||||
d_optim.zero_grad()
|
||||
|
||||
batch = torch.rand(8, 100, device="cuda")
|
||||
model(batch).sum().backward()
|
||||
optim.step()
|
||||
dist_model(batch).sum().backward()
|
||||
if not isinstance(dist_optim, list):
|
||||
dist_optim.step()
|
||||
dist_optim.zero_grad()
|
||||
else:
|
||||
for _dist_optim in dist_optim:
|
||||
_dist_optim.zero_grad()
|
||||
optim.zero_grad()
|
||||
|
||||
optim.step()
|
||||
for d_optim in _dist_optim:
|
||||
d_optim.step()
|
||||
|
||||
# Get the state_dict, and compare the result
|
||||
msd = model.state_dict()
|
||||
@ -176,8 +177,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3, foreach=True)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3, foreach=True)
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True)
|
||||
if wrapping:
|
||||
strategy = set(wrapping)
|
||||
else:
|
||||
@ -204,7 +205,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
if compile_model:
|
||||
dist_model = torch.compile(dist_model)
|
||||
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3, foreach=True)
|
||||
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4, foreach=True)
|
||||
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
|
||||
|
||||
self._test_save_load(init_model_optim)
|
||||
@ -218,7 +219,11 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
"use_composable": [True, False],
|
||||
"use_dtensor": [True, False],
|
||||
"wrapping": [(), (nn.Linear, UnitModule)],
|
||||
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
|
||||
"optimizer_class": [
|
||||
torch.optim.Adam,
|
||||
torch.optim.AdamW,
|
||||
torch.optim.SGD,
|
||||
],
|
||||
},
|
||||
self._test_fsdp,
|
||||
)
|
||||
@ -248,10 +253,10 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
def init_model_optim():
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_optim = optimizer_class(
|
||||
orig_model.parameters(), lr=1e-3, foreach=foreach
|
||||
orig_model.parameters(), lr=1e-4, foreach=foreach
|
||||
)
|
||||
copy_optim = optimizer_class(
|
||||
orig_model.parameters(), lr=1e-3, foreach=foreach
|
||||
orig_model.parameters(), lr=1e-4, foreach=foreach
|
||||
)
|
||||
|
||||
dist_model = FSDP2(
|
||||
@ -262,7 +267,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
if compile_model:
|
||||
dist_model = torch.compile(dist_model)
|
||||
dist_optim = optimizer_class(
|
||||
dist_model.parameters(), lr=1e-3, foreach=foreach
|
||||
dist_model.parameters(), lr=1e-4, foreach=foreach
|
||||
)
|
||||
|
||||
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
|
||||
@ -284,13 +289,13 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None:
|
||||
def init_model_optim():
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
if use_composable:
|
||||
dist_model = replicate(copy.deepcopy(orig_model))
|
||||
else:
|
||||
dist_model = DDP(copy.deepcopy(orig_model))
|
||||
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
|
||||
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4)
|
||||
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
|
||||
|
||||
self._test_save_load(init_model_optim)
|
||||
@ -301,7 +306,11 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
self.run_subtests(
|
||||
{
|
||||
"use_composable": [True, False],
|
||||
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
|
||||
"optimizer_class": [
|
||||
torch.optim.Adam,
|
||||
torch.optim.AdamW,
|
||||
torch.optim.SGD,
|
||||
],
|
||||
},
|
||||
self._test_ddp,
|
||||
)
|
||||
@ -320,8 +329,8 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
orig_model.u1.parameters(), orig_model.u2.parameters()
|
||||
):
|
||||
param.requires_grad = False
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
dist_model = copy.deepcopy(orig_model)
|
||||
if use_composable:
|
||||
replicate(dist_model.l)
|
||||
@ -336,13 +345,13 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
)
|
||||
if optim_in_backward:
|
||||
_apply_optimizer_in_backward(
|
||||
optimizer_class, dist_model.parameters(), {"lr": 1e-3}
|
||||
optimizer_class, dist_model.parameters(), {"lr": 1e-4}
|
||||
)
|
||||
dist_optim = [
|
||||
p._in_backward_optimizers[0] for p in dist_model.parameters()
|
||||
]
|
||||
else:
|
||||
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
|
||||
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4)
|
||||
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
|
||||
|
||||
self._test_save_load(init_model_optim, test_frozen)
|
||||
@ -395,10 +404,10 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None:
|
||||
def init_model_optim():
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
model_copy = copy.deepcopy(orig_model)
|
||||
optim_copy = optimizer_class(model_copy.parameters(), lr=1e-3)
|
||||
optim_copy = optimizer_class(model_copy.parameters(), lr=1e-4)
|
||||
return orig_model, orig_optim, copy_optim, model_copy, optim_copy
|
||||
|
||||
self._test_save_load(init_model_optim)
|
||||
@ -445,7 +454,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
|
||||
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4)
|
||||
|
||||
mst, ost = get_state_dict(
|
||||
dist_model,
|
||||
@ -887,10 +896,10 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
def init_model_optim():
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
|
||||
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
|
||||
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
|
||||
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
|
||||
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
|
||||
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
|
||||
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3)
|
||||
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-4)
|
||||
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
|
||||
|
||||
self._test_save_load(init_model_optim)
|
||||
@ -958,7 +967,7 @@ class TestNoComm(MultiProcessTestCase):
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_no_dist(self) -> None:
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
|
||||
self.assertFalse(dist.is_initialized())
|
||||
msd = get_model_state_dict(
|
||||
|
@ -591,17 +591,17 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None:
|
||||
# The optimizer state is initialized.
|
||||
return
|
||||
|
||||
# There are some stateless optimizers like SGD. These optimizer will
|
||||
# not return in the above condition. So if gradients exist, we should also
|
||||
# return. If gradients do not exist, the following initialization should
|
||||
# not disturb SGD because the gradients and lr are both zero.
|
||||
for param_group in optim.param_groups:
|
||||
for param in param_group[_PARAMS]:
|
||||
if param.grad is not None:
|
||||
raise RuntimeError(
|
||||
"state_dict can only be used if the optimizer "
|
||||
"states are initialized (usually after one step() with "
|
||||
"gradients) or gradients are None. For the later case, "
|
||||
"state_dict will fake the gradients as zero "
|
||||
"to initialize the optimizer states. However, the "
|
||||
"gradients are not None."
|
||||
)
|
||||
return
|
||||
|
||||
for param_group in optim.param_groups:
|
||||
for param in param_group[_PARAMS]:
|
||||
if param.requires_grad:
|
||||
param.grad = torch.zeros_like(param)
|
||||
|
||||
|
@ -43,7 +43,12 @@ class VerifyStateDictMixin:
|
||||
dist_param = dist_msd.get(fqn, None)
|
||||
if not options.ignore_frozen_params:
|
||||
self.assertIsNotNone(dist_param, f"{fqn=}")
|
||||
self._compare_tensor(param, dist_param, offload_to_cpu)
|
||||
try:
|
||||
self._compare_tensor(param, dist_param, offload_to_cpu)
|
||||
except AssertionError as e:
|
||||
raise AssertionError(
|
||||
f"{fqn} has mismatched value {param} {dist_param}"
|
||||
) from e
|
||||
elif dist_param is None:
|
||||
self.assertFalse(param.requires_grad, f"{fqn=}")
|
||||
|
||||
|
Reference in New Issue
Block a user