[FSDP] Fix exec order validation for diff ignored modules across ranks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79533

Approved by: https://github.com/rohan-varma
This commit is contained in:
Andrew Gu
2022-06-15 20:59:05 +00:00
committed by PyTorch MergeBot
parent 3064982fb8
commit 18fcd4826f
2 changed files with 79 additions and 30 deletions

View File

@ -29,15 +29,22 @@ class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = torch.nn.Linear(3, 5)
self.layer1 = torch.nn.Sequential(
torch.nn.Linear(5, 5),
layer1_modules = [
torch.nn.Linear(5, 4),
torch.nn.Linear(4, 4),
)
self.layer2 = torch.nn.Linear(4, 1)
torch.nn.Linear(4, 4),
]
self.layer1 = torch.nn.Sequential(*layer1_modules)
self.layer2 = torch.nn.Linear(4, 2)
self.layer3 = torch.nn.Linear(2, 2)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.layer2(self.layer1(self.layer0(x)))
z = self.relu(self.layer0(x))
z = self.relu(self.layer1(z))
z = self.relu(self.layer2(z))
z = self.relu(self.layer3(z))
return z
def get_input(self, device):
return (torch.randn((8, 3)).to(device),)
@ -48,7 +55,36 @@ class Model(torch.nn.Module):
def run_backward(self, loss):
loss.backward()
class IgnoredModule(torch.nn.Module):
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))
def forward(self, x):
return x @ self.weight
class ModelWithIgnoredModules(Model):
"""Adds a variable number of :class:`IgnoredModule` to ``self.layer1``."""
def __init__(self, num_ignored: int) -> None:
assert num_ignored >= 0
super().__init__()
layer1_modules = [torch.nn.Linear(5, 4), torch.nn.Linear(4, 4)] + \
[IgnoredModule(4, 4) for _ in range(num_ignored)] + \
[torch.nn.Linear(4, 4)]
self.layer1 = torch.nn.Sequential(*layer1_modules)
class TestFSDPIgnoredModules(FSDPTest):
def _train_model(self, model, optim, num_iters, device=torch.device("cuda")):
for _ in range(num_iters):
inp = model.module.get_input(device)
output = model(*inp)
loss = model.module.get_loss(inp, output).to(device)
model.module.run_backward(loss)
optim.step()
@skip_if_lt_x_gpu(2)
def test_ignored_modules_transformer(self):
"""Tests that ignored modules' parameters are not flattened for a
@ -56,7 +92,9 @@ class TestFSDPIgnoredModules(FSDPTest):
# Initialize an FSDP-wrapped transformer model that has FSDP ignore
# the `nn.Transformer` module's parameters
group = dist.distributed_c10d._get_default_group()
wrapped_model = self._get_wrapped_model(group, ignore_modules=True)
wrapped_model = self._get_wrapped_model(
group, cuda_first=True, ignore_modules=True,
)
# Check that the wrapped model's flattened parameter does not include
# the ignored transformer module's parameters
nonwrapped_model = self._get_nonwrapped_model(group)
@ -69,21 +107,15 @@ class TestFSDPIgnoredModules(FSDPTest):
flat_param_numel = wrapped_model.params[0].numel()
self.assertEqual(flat_param_numel, nonignored_numel)
# Check that we can run a few iterations
device = torch.device("cuda")
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
for _ in range(3):
inp = wrapped_model.module.get_input(device)
output = wrapped_model(*inp)
loss = wrapped_model.module.get_loss(inp, output).to(device)
wrapped_model.module.run_backward(loss)
optim.step()
self._train_model(wrapped_model, optim, 3)
@skip_if_lt_x_gpu(2)
def test_ignored_modules_nested(self):
"""Tests that passing a module with nested FSDP modules does not
error and still ignores non-FSDP modules' parameters."""
# Initialize an FSDP-wrapped nested model that first wraps the nested
# sequential's middle linear layer (`layer1[1]`) and then wraps the
# sequential's second linear layer (`layer1[1]`) and then wraps the
# overall model while ignoring the nested sequential (`layer1`)
model = Model().cuda()
model.layer1[1] = FSDP(model.layer1[1])
@ -100,20 +132,14 @@ class TestFSDPIgnoredModules(FSDPTest):
flat_param_numel = wrapped_model.params[0].numel()
self.assertEqual(flat_param_numel, nonignored_numel)
# Check that we can run a few iterations
device = torch.device("cuda")
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
for _ in range(3):
inp = wrapped_model.get_input(device)
output = wrapped_model(*inp)
loss = wrapped_model.get_loss(inp, output).to(device)
wrapped_model.run_backward(loss)
optim.step()
self._train_model(wrapped_model, optim, 3)
@skip_if_lt_x_gpu(2)
def test_ignored_modules_invalid(self):
"""Tests that passing an FSDP module as an ignored module or the
top-level module itself errors."""
model = Model()
model = Model().cuda()
model.layer1 = FSDP(model.layer1)
# Passing an FSDP module as an ignored module should error
with self.assertRaises(
@ -123,12 +149,31 @@ class TestFSDPIgnoredModules(FSDPTest):
FSDP(model, ignored_modules=[model.layer1])
with self.assertWarnsRegex(
expected_warning=UserWarning,
expected_regex="Trying to ignore the top-level module passed into the FSDP "
"constructor itself will result in all parameters being ignored "
"and is not supported",
expected_regex="Trying to ignore the top-level module passed into "
"the FSDP constructor itself will result in all parameters being "
"ignored and is not supported",
):
FSDP(model, ignored_modules=[model])
@skip_if_lt_x_gpu(2)
def test_diff_ignored_modules_across_ranks(self):
"""Tests ignoring different modules across ranks."""
# To exercise different `FlatParameter` enumerations across ranks,
# we wrap `layer3` with FSDP, where `layer3` is registered as a module
# after `layer1`, which has the variable number of ignored modules
model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
layer1_ignored_modules = [
m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
]
model.layer1 = FSDP(model.layer1, ignored_modules=layer1_ignored_modules)
model.layer3 = FSDP(model.layer3)
model_ignored_modules = [
m for m in model.modules() if isinstance(m, IgnoredModule)
]
wrapped_model = FSDP(model, ignored_modules=model_ignored_modules)
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
self._train_model(wrapped_model, optim, 3)
instantiate_parametrized_tests(TestFSDPIgnoredModules)

View File

@ -368,14 +368,18 @@ class _ExecOrderData():
def init(self, root_module: "FullyShardedDataParallel"):
assert root_module._is_root, "This data structure should only be " \
"initialized on an FSDP root module"
# Save `root_modules.parameters()` to `_all_flat_params` instead of
# re-materializing each time to avoid the result depending on the
# calling context (e.g. when some parameters have been rebuilt)
self._all_flat_params = list(root_module.parameters())
# Save all `FlatParameter`s in `root_module`'s hierarchy to
# `_all_flat_params` instead of re-materializing each time to avoid the
# result depending on the calling context (e.g. when some parameters
# have been rebuilt)
self._all_flat_params = [
param for param in root_module.parameters()
if isinstance(param, FlatParameter)
]
self._param_to_unflat_param_names = cast(
Dict[FlatParameter, List[str]],
_get_param_to_unflat_param_names(root_module)
) # `root_module.parameters()` should only contain `FlatParameter`s
)
def get_param_index(self, param: FlatParameter) -> int:
"""Returns a unique non-negative parameter index for ``param`` if it is