mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FSDP] Fix exec order validation for diff ignored modules across ranks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79533 Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
3064982fb8
commit
18fcd4826f
@ -29,15 +29,22 @@ class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer0 = torch.nn.Linear(3, 5)
|
||||
self.layer1 = torch.nn.Sequential(
|
||||
torch.nn.Linear(5, 5),
|
||||
layer1_modules = [
|
||||
torch.nn.Linear(5, 4),
|
||||
torch.nn.Linear(4, 4),
|
||||
)
|
||||
self.layer2 = torch.nn.Linear(4, 1)
|
||||
torch.nn.Linear(4, 4),
|
||||
]
|
||||
self.layer1 = torch.nn.Sequential(*layer1_modules)
|
||||
self.layer2 = torch.nn.Linear(4, 2)
|
||||
self.layer3 = torch.nn.Linear(2, 2)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer2(self.layer1(self.layer0(x)))
|
||||
z = self.relu(self.layer0(x))
|
||||
z = self.relu(self.layer1(z))
|
||||
z = self.relu(self.layer2(z))
|
||||
z = self.relu(self.layer3(z))
|
||||
return z
|
||||
|
||||
def get_input(self, device):
|
||||
return (torch.randn((8, 3)).to(device),)
|
||||
@ -48,7 +55,36 @@ class Model(torch.nn.Module):
|
||||
def run_backward(self, loss):
|
||||
loss.backward()
|
||||
|
||||
|
||||
class IgnoredModule(torch.nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))
|
||||
|
||||
def forward(self, x):
|
||||
return x @ self.weight
|
||||
|
||||
|
||||
class ModelWithIgnoredModules(Model):
|
||||
"""Adds a variable number of :class:`IgnoredModule` to ``self.layer1``."""
|
||||
def __init__(self, num_ignored: int) -> None:
|
||||
assert num_ignored >= 0
|
||||
super().__init__()
|
||||
layer1_modules = [torch.nn.Linear(5, 4), torch.nn.Linear(4, 4)] + \
|
||||
[IgnoredModule(4, 4) for _ in range(num_ignored)] + \
|
||||
[torch.nn.Linear(4, 4)]
|
||||
self.layer1 = torch.nn.Sequential(*layer1_modules)
|
||||
|
||||
|
||||
class TestFSDPIgnoredModules(FSDPTest):
|
||||
def _train_model(self, model, optim, num_iters, device=torch.device("cuda")):
|
||||
for _ in range(num_iters):
|
||||
inp = model.module.get_input(device)
|
||||
output = model(*inp)
|
||||
loss = model.module.get_loss(inp, output).to(device)
|
||||
model.module.run_backward(loss)
|
||||
optim.step()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ignored_modules_transformer(self):
|
||||
"""Tests that ignored modules' parameters are not flattened for a
|
||||
@ -56,7 +92,9 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||
# Initialize an FSDP-wrapped transformer model that has FSDP ignore
|
||||
# the `nn.Transformer` module's parameters
|
||||
group = dist.distributed_c10d._get_default_group()
|
||||
wrapped_model = self._get_wrapped_model(group, ignore_modules=True)
|
||||
wrapped_model = self._get_wrapped_model(
|
||||
group, cuda_first=True, ignore_modules=True,
|
||||
)
|
||||
# Check that the wrapped model's flattened parameter does not include
|
||||
# the ignored transformer module's parameters
|
||||
nonwrapped_model = self._get_nonwrapped_model(group)
|
||||
@ -69,21 +107,15 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||
flat_param_numel = wrapped_model.params[0].numel()
|
||||
self.assertEqual(flat_param_numel, nonignored_numel)
|
||||
# Check that we can run a few iterations
|
||||
device = torch.device("cuda")
|
||||
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
|
||||
for _ in range(3):
|
||||
inp = wrapped_model.module.get_input(device)
|
||||
output = wrapped_model(*inp)
|
||||
loss = wrapped_model.module.get_loss(inp, output).to(device)
|
||||
wrapped_model.module.run_backward(loss)
|
||||
optim.step()
|
||||
self._train_model(wrapped_model, optim, 3)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ignored_modules_nested(self):
|
||||
"""Tests that passing a module with nested FSDP modules does not
|
||||
error and still ignores non-FSDP modules' parameters."""
|
||||
# Initialize an FSDP-wrapped nested model that first wraps the nested
|
||||
# sequential's middle linear layer (`layer1[1]`) and then wraps the
|
||||
# sequential's second linear layer (`layer1[1]`) and then wraps the
|
||||
# overall model while ignoring the nested sequential (`layer1`)
|
||||
model = Model().cuda()
|
||||
model.layer1[1] = FSDP(model.layer1[1])
|
||||
@ -100,20 +132,14 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||
flat_param_numel = wrapped_model.params[0].numel()
|
||||
self.assertEqual(flat_param_numel, nonignored_numel)
|
||||
# Check that we can run a few iterations
|
||||
device = torch.device("cuda")
|
||||
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
|
||||
for _ in range(3):
|
||||
inp = wrapped_model.get_input(device)
|
||||
output = wrapped_model(*inp)
|
||||
loss = wrapped_model.get_loss(inp, output).to(device)
|
||||
wrapped_model.run_backward(loss)
|
||||
optim.step()
|
||||
self._train_model(wrapped_model, optim, 3)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ignored_modules_invalid(self):
|
||||
"""Tests that passing an FSDP module as an ignored module or the
|
||||
top-level module itself errors."""
|
||||
model = Model()
|
||||
model = Model().cuda()
|
||||
model.layer1 = FSDP(model.layer1)
|
||||
# Passing an FSDP module as an ignored module should error
|
||||
with self.assertRaises(
|
||||
@ -123,12 +149,31 @@ class TestFSDPIgnoredModules(FSDPTest):
|
||||
FSDP(model, ignored_modules=[model.layer1])
|
||||
with self.assertWarnsRegex(
|
||||
expected_warning=UserWarning,
|
||||
expected_regex="Trying to ignore the top-level module passed into the FSDP "
|
||||
"constructor itself will result in all parameters being ignored "
|
||||
"and is not supported",
|
||||
expected_regex="Trying to ignore the top-level module passed into "
|
||||
"the FSDP constructor itself will result in all parameters being "
|
||||
"ignored and is not supported",
|
||||
):
|
||||
FSDP(model, ignored_modules=[model])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_diff_ignored_modules_across_ranks(self):
|
||||
"""Tests ignoring different modules across ranks."""
|
||||
# To exercise different `FlatParameter` enumerations across ranks,
|
||||
# we wrap `layer3` with FSDP, where `layer3` is registered as a module
|
||||
# after `layer1`, which has the variable number of ignored modules
|
||||
model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
|
||||
layer1_ignored_modules = [
|
||||
m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
|
||||
]
|
||||
model.layer1 = FSDP(model.layer1, ignored_modules=layer1_ignored_modules)
|
||||
model.layer3 = FSDP(model.layer3)
|
||||
model_ignored_modules = [
|
||||
m for m in model.modules() if isinstance(m, IgnoredModule)
|
||||
]
|
||||
wrapped_model = FSDP(model, ignored_modules=model_ignored_modules)
|
||||
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
|
||||
self._train_model(wrapped_model, optim, 3)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestFSDPIgnoredModules)
|
||||
|
||||
|
@ -368,14 +368,18 @@ class _ExecOrderData():
|
||||
def init(self, root_module: "FullyShardedDataParallel"):
|
||||
assert root_module._is_root, "This data structure should only be " \
|
||||
"initialized on an FSDP root module"
|
||||
# Save `root_modules.parameters()` to `_all_flat_params` instead of
|
||||
# re-materializing each time to avoid the result depending on the
|
||||
# calling context (e.g. when some parameters have been rebuilt)
|
||||
self._all_flat_params = list(root_module.parameters())
|
||||
# Save all `FlatParameter`s in `root_module`'s hierarchy to
|
||||
# `_all_flat_params` instead of re-materializing each time to avoid the
|
||||
# result depending on the calling context (e.g. when some parameters
|
||||
# have been rebuilt)
|
||||
self._all_flat_params = [
|
||||
param for param in root_module.parameters()
|
||||
if isinstance(param, FlatParameter)
|
||||
]
|
||||
self._param_to_unflat_param_names = cast(
|
||||
Dict[FlatParameter, List[str]],
|
||||
_get_param_to_unflat_param_names(root_module)
|
||||
) # `root_module.parameters()` should only contain `FlatParameter`s
|
||||
)
|
||||
|
||||
def get_param_index(self, param: FlatParameter) -> int:
|
||||
"""Returns a unique non-negative parameter index for ``param`` if it is
|
||||
|
Reference in New Issue
Block a user