diff --git a/test/distributed/_composable/test_replicate_training.py b/test/distributed/_composable/test_replicate_training.py index d9a633ad9922..c8a827d33c12 100644 --- a/test/distributed/_composable/test_replicate_training.py +++ b/test/distributed/_composable/test_replicate_training.py @@ -1,11 +1,16 @@ # Owner(s): ["oncall: distributed"] +import copy +from collections.abc import Iterable import torch +import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.replicate_with_fsdp import replicate +from torch.distributed.fsdp import FSDPModule +from torch.distributed.tensor import DTensor from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTestMultiThread, get_devtype +from torch.testing._internal.common_fsdp import FSDPTestMultiThread, get_devtype, MLP from torch.testing._internal.common_utils import run_tests @@ -49,5 +54,120 @@ class TestReplicateForwardInputs(FSDPTestMultiThread): model(x, ys) +class TestReplicateRegisteredParams(FSDPTestMultiThread): + @property + def world_size(self) -> int: + return 4 + + @skip_if_lt_x_gpu(1) + def test_param_registration_after_forward(self): + """Tests the parameter registration after forward.""" + device = torch.device(device_type.type, 0) + # Single Replicate group + for reshard_after_forward in (True, False, None): + torch.manual_seed(42) + model = MLP(3, device) + # Since seed is per process, not per thread, we broadcast to ensure + # the same parameters across ranks + for param in model.parameters(): + dist.broadcast(param, src=0) + ref_model = copy.deepcopy(model) + replicate(model, reshard_after_forward=reshard_after_forward) # root only + inp = torch.randn((2, 3), device=device_type.type) + self._assert_dtensor_params(model.parameters()) + self._assert_same_params(model.parameters(), ref_model.parameters()) + model(inp) + if reshard_after_forward: + self._assert_dtensor_params(model.parameters()) + else: + self._assert_tensor_params(model.parameters()) + self._assert_same_params(model.parameters(), ref_model.parameters()) + model.reshard() # however, we can manually reshard + self._assert_dtensor_params(model.parameters()) + self._assert_same_params(model.parameters(), ref_model.parameters()) + + # Multiple Replicate groups + for reshard_after_forward in (True, False, None): + torch.manual_seed(42) + model = nn.Sequential(MLP(3, device), MLP(3, device)) + for param in model.parameters(): + dist.broadcast(param, src=0) + ref_model = copy.deepcopy(model) + replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward) + replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward) + replicate(model, reshard_after_forward=reshard_after_forward) + + self._assert_dtensor_params(model.parameters()) + self._assert_same_params(model.parameters(), ref_model.parameters()) + model(inp) + non_root_params = list(model[0].in_proj.parameters()) + list( + model[0].out_proj.parameters() + ) + root_params = list(set(model.parameters()) - set(non_root_params)) + if reshard_after_forward is None: + self._assert_dtensor_params(non_root_params) + self._assert_tensor_params(root_params) + elif reshard_after_forward: + self._assert_dtensor_params(non_root_params) + self._assert_dtensor_params(root_params) + else: + self._assert_tensor_params(non_root_params) + self._assert_tensor_params(root_params) + self._assert_same_params(model.parameters(), ref_model.parameters()) + for module in model.modules(): + if isinstance(module, FSDPModule): + module.reshard() # however, we can manually reshard + self._assert_dtensor_params(model.parameters()) + self._assert_same_params(model.parameters(), ref_model.parameters()) + + @skip_if_lt_x_gpu(1) + def test_param_registration_after_backward(self): + """Tests the parameter registration after backward.""" + device = torch.device(device_type.type, 0) + # Single Replicate group + for reshard_after_forward in (True, False): + model = MLP(8, device) + replicate(model, reshard_after_forward=reshard_after_forward) # root only + inp = torch.randn((2, 8), device=device_type.type) + self._assert_dtensor_params(model.parameters()) + model(inp).sum().backward() + self._assert_dtensor_params(model.parameters()) + + # Multiple Replicate groups + for reshard_after_forward in (True, False): + model = MLP(8, device) + replicate(model.in_proj, reshard_after_forward=reshard_after_forward) + replicate(model.out_proj, reshard_after_forward=reshard_after_forward) + replicate(model, reshard_after_forward=reshard_after_forward) + self._assert_dtensor_params(model.parameters()) + model(inp).sum().backward() + self._assert_dtensor_params(model.parameters()) + + def _assert_tensor_params(self, params: Iterable[nn.Parameter]): + # need to iterate over the list multiple times + params = list(params) + self.assertGreater(len(params), 0) + for param in params: + self.assertNotIsInstance(param, DTensor) + self.assertIsInstance(param, torch.Tensor) + + def _assert_dtensor_params(self, params: Iterable[nn.Parameter]): + params = list(params) + self.assertGreater(len(params), 0) + for param in params: + self.assertIsInstance(param, DTensor) + + def _assert_same_params( + self, params: Iterable[nn.Parameter], ref_params: Iterable[nn.Parameter] + ): + params, ref_params = list(params), list(ref_params) + self.assertEqual(len(params), len(ref_params)) + for param, ref_param in zip(params, ref_params): + if isinstance(param, DTensor): + param = param.full_tensor() + self.assertEqual(param.shape, ref_param.shape) + self.assertEqual(param, ref_param) + + if __name__ == "__main__": run_tests()