Files
pytorch/test/distributed/_composable/test_replicate_with_fsdp.py
Prachi Gupta 22650c89fb [ROCm] Update skip_if_lt_x_gpu to work with MultiProcContinuous class (#167281)
- Since MultiProcContinuous class spawns one process per GPU and runs UT in each of the processes, we need to ensure we are propagating the exit code associated with skip all the way to the main worker thread that spawned all the child processes.
- This commit also updates several UTs that are meant for 4 GPUs but incorrectly calls skip_if_lt_x_gpu with 2 as an input. Examples:
    - test_replicate_with_fsdp.py
    - test_dtensor_resharding.py
    - test_state_dict.py
    - test_functional_api.py: Fix typo. multi-accelerator doesn't exit, replaced with multi-gpu
    - test_op_strategy.py: world_size was hardcoded
    - test_math_ops.py: UT written for 4 GPU, so skipping for anything less
    - test_schedule_multiproc.py: All UTs in this suite are required to run on 2+ GPUs, therefore, adding skips if less than 4 GPUs are supplied

Fixes https://github.com/pytorch/pytorch/issues/166875

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167281
Approved by: https://github.com/jeffdaily
2025-11-07 18:11:48 +00:00

312 lines
9.6 KiB
Python

# Owner(s): ["oncall: distributed"]
import copy
import functools
import os
from copy import deepcopy
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._composable.contract import _get_registry
from torch.distributed._composable.replicate_with_fsdp import (
_get_managed_modules,
replicate,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import Replicate, Shard
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
run_subtests,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_fsdp import check_sharded_parity, MLPStack
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.fc2 = nn.Linear(2, 2)
self.fc3 = nn.Linear(2, 2)
def forward(self, x):
return self.fc3(self.fc2(self.fc1(x)))
class ReplicateTest(MultiProcessTestCase):
@property
def world_size(self) -> int:
return 4
def init_replicate_tp_mesh(self) -> DeviceMesh:
# Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP
replicate_size = 2
return init_device_mesh(
"cuda",
(replicate_size, 1, self.world_size // replicate_size),
mesh_dim_names=("replicate", "shard", "tp"),
)
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def _init_pg(self):
# Set the device explicitly before initializing the process group
torch.cuda.set_device(self.rank % self.world_size)
dist.init_process_group(
backend="nccl",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
@skip_if_lt_x_gpu(4)
def test_replicate_transformer(self):
"""
This tests that replicate works on a transformer model with fully_shard and replicate layers
"""
self._init_pg()
run_subtests(
self,
{
"sharding_strategy": ["replicate", "fully_shard"],
},
self._test_replicate_transformer,
)
def _composable_api_module_check(self, module, sharding_strategy):
if sharding_strategy == "replicate":
self.assertTrue("replicate" in _get_registry(module))
else:
self.assertTrue("fully_shard" in _get_registry(module))
def _test_replicate_transformer(self, sharding_strategy):
model_args = ModelArgs()
model = Transformer(model_args)
replicate_model = deepcopy(model)
for i, layer in enumerate(replicate_model.layers):
if i % 2 == 0:
replicate(layer)
elif i % 2 == 1:
fully_shard(layer)
if sharding_strategy == "replicate":
replicate_model = replicate(replicate_model)
else:
replicate_model = fully_shard(replicate_model)
self._composable_api_module_check(replicate_model, sharding_strategy)
for i, layer in enumerate(replicate_model.layers):
if i % 2 == 0:
self.assertTrue("replicate" in _get_registry(layer))
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Replicate(),))
elif i % 2 == 1:
self.assertTrue("fully_shard" in _get_registry(layer))
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Shard(dim=0),))
@skip_if_lt_x_gpu(4)
def test_replicate_transformer_managed_modules(self):
"""
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
which means there are 49 submodules. We apply replicate on the first layer and fully shard on the second layer,
each consisting of 14 submodules, leaving 21 remaining submodules. I have shown below how there are this many submodules
1. Transformer Module
2. tok_embeddings
3. pos_embeddings
4. dropout
5. layers
6. norm
7. output
In the layers we have Transformer Blocks
1. Transformer Block
2. attention_norm
3. Attention
4. resid_dropout
5. wq
6. wk
7. wv
8. wo
9. ffn_norm
10. Feed_forward
11. w1
12. gelu
13. w2
14. resid_dropout
"""
self._init_pg()
model_args = ModelArgs()
model_args.n_layers = 3
model = Transformer(model_args)
replicate_model = deepcopy(model)
self.assertEqual(len(_get_managed_modules((replicate_model,))), 49)
for i, layer in enumerate(replicate_model.layers):
if i % 3 == 0:
replicate(layer)
elif i % 3 == 1:
fully_shard(layer)
replicate_model = replicate(replicate_model)
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
@skip_if_lt_x_gpu(4)
def test_replicate_tp_device_mesh(self):
"""
This tests that a user can pass in a device mesh to replicate a module
"""
self._init_pg()
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
model = Net().to(device)
replicate_model = deepcopy(model)
layers = [
replicate_model.fc1,
replicate_model.fc2,
replicate_model.fc3,
]
global_mesh = self.init_replicate_tp_mesh()
replicate_mesh = global_mesh["replicate"]
for layer in layers:
replicate(layer, mesh=replicate_mesh)
for parameter in layer.parameters():
self.assertEqual(parameter.device_mesh.shape, (2,))
self.assertEqual(parameter.placements, (Replicate(),))
@skip_if_lt_x_gpu(4)
def test_train_replicate_fsdp(self):
"""
Tests that replicate_model has the same behavior as original model when training
"""
self._init_pg()
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
model = Net().to(device)
replicate_model = deepcopy(model)
layers = [
replicate_model.fc1,
replicate_model.fc2,
replicate_model.fc3,
]
for layer in layers:
replicate(layer)
replicate_model = replicate(replicate_model)
optim = torch.optim.Adam(model.parameters(), lr=0.01)
replicate_optim = torch.optim.Adam(replicate_model.parameters(), lr=0.01)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn(2, 2, device=device)
for _ in range(10):
loss = model(inp).sum()
loss.backward()
for param in model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
replicate_loss = replicate_model(inp).sum()
replicate_loss.backward()
optim.step()
replicate_optim.step()
optim.zero_grad()
replicate_optim.zero_grad()
self.assertEqual(replicate_loss, loss)
check_sharded_parity(self, model, replicate_model)
@skip_if_lt_x_gpu(4)
def test_train_parity_2d_mlp(self):
"""
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training
"""
self._init_pg()
global_mesh = self.init_replicate_tp_mesh()
run_subtests(
self,
{
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 16, 17],
},
functools.partial(self._test_train_parity_2d_mlp, global_mesh),
)
def _test_train_parity_2d_mlp(
self,
global_mesh: DeviceMesh,
use_activation_checkpointing: bool,
mlp_dim: int,
):
replicate_shard_mesh, tp_mesh = (
global_mesh["replicate", "shard"],
global_mesh["tp"],
)
replicate_mesh = global_mesh["replicate"]
replicate_pg = replicate_mesh.get_group() # used for `replicate()`
torch.manual_seed(42)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, mesh=replicate_mesh)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
model.parallelize(
tp_mesh,
replicate_shard_mesh,
use_activation_checkpointing,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
torch.manual_seed(42 + replicate_pg.rank() + 1)
device = torch.device("cuda")
for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device)
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
losses[-1].backward()
_optim.step()
self.assertEqual(losses[0], losses[1])
if __name__ == "__main__":
run_tests()