Compare commits

...

6 Commits

Author SHA1 Message Date
b3f6c3090b Update on "[FSDP][Replicate] got rid of reshard_after_forward and updated test cases"
**Summary:** I have gotten of reshard_after_forward and shard_placement as inputs for replicate as there will be no sharding. I have also updated all the necessary tests. 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-28 16:54:27 -07:00
ab2665d3cc [FSDP][Replicate] got rid of reshard_after_forward and updated test cases
[ghstack-poisoned]
2025-10-28 16:43:09 -07:00
d7d2e8731b [FSDP][Replicate] added two replicate overload declarations and changed device_mesh to mesh
[ghstack-poisoned]
2025-10-28 15:20:25 -07:00
84cad15e82 Update on "[FSDP][Replicate] final version integrating 1D device mesh replicate into fsdp"
**Summary:** I have created a new composable replicate api that's integrated into FSDP's codebase with minimal changes. The key changes I made are when we use DDPMeshInfo, we use Replicate placements, prevent initial sharding of parameters, set worldsize to 1 to skip allgathers and reducescatter. 

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py
2. pytest test_pp_composability.py
3. pytest test_replicate_with_fsdp.py




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-28 13:49:30 -07:00
f4a9ac120c Update on "[FSDP][Replicate] final version integrating 1D device mesh replicate into fsdp"
**Summary:** I have created a new composable replicate api that's integrated into FSDP's codebase with minimal changes. The key changes I made are when we use DDPMeshInfo, we use Replicate placements, prevent initial sharding of parameters, set worldsize to 1 to skip allgathers and reducescatter. 

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py
2. pytest test_pp_composability.py
3. pytest test_replicate_with_fsdp.py




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-28 13:16:52 -07:00
bdd80556c7 [FSDP][Replicate] final version integrating 1D device mesh replicate into fsdp
[ghstack-poisoned]
2025-10-28 11:47:54 -07:00
8 changed files with 236 additions and 280 deletions

View File

@ -392,11 +392,11 @@ class ComposabilityTest(MultiProcessTestCase):
replicate_size = self.world_size // (pp_size)
device_mesh = init_device_mesh(
device_type,
mesh_shape=(replicate_size, 1, pp_size),
mesh_dim_names=("replicate", "shard", "pp"),
mesh_shape=(replicate_size, pp_size),
mesh_dim_names=("replicate", "pp"),
)
torch.manual_seed(42)
dp_mesh = device_mesh["replicate", "shard"]
dp_mesh = device_mesh["replicate"]
pp_mesh = device_mesh["pp"]
pp_group = device_mesh["pp"].get_group()
@ -416,15 +416,13 @@ class ComposabilityTest(MultiProcessTestCase):
param_dtype=MixedPrecisionParam,
reduce_dtype=torch.float32,
)
replicate_config = {"mp_policy": mp_policy}
replicate_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id in range(len(partial_model)):
replicate(
partial_model[layer_id],
device_mesh=dp_mesh,
**replicate_config,
reshard_after_forward=False,
)
dp_model = replicate(partial_model, device_mesh=dp_mesh, **replicate_config)
dp_model = replicate(partial_model, **replicate_config)
return dp_model
# Apply same precision to reference model (without replicate)
@ -582,11 +580,11 @@ class ComposabilityTest(MultiProcessTestCase):
replicate_size = self.world_size // (pp_size)
device_mesh = init_device_mesh(
device_type,
mesh_shape=(replicate_size, 1, pp_size),
mesh_dim_names=("replicate", "shard", "pp"),
mesh_shape=(replicate_size, pp_size),
mesh_dim_names=("replicate", "pp"),
)
torch.manual_seed(42)
dp_mesh = device_mesh["replicate", "shard"]
dp_mesh = device_mesh["replicate"]
pp_mesh = device_mesh["pp"]
pp_group = device_mesh["pp"].get_group()
dp_group = device_mesh["replicate"].get_group()
@ -648,10 +646,9 @@ class ComposabilityTest(MultiProcessTestCase):
for layer_id in range(len(partial_model)):
replicate(
partial_model[layer_id],
device_mesh=dp_mesh,
reshard_after_forward=False,
mesh=dp_mesh,
)
dp_model = replicate(partial_model, device_mesh=dp_mesh)
dp_model = replicate(partial_model, mesh=dp_mesh)
return dp_model
def pipelined_models_parameters(start_layer, model):

View File

@ -3,7 +3,7 @@
import copy
import dataclasses
import functools
from typing import Optional, Union
from typing import Optional
import torch
import torch.distributed as dist
@ -14,7 +14,6 @@ from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
_get_gradient_divide_factors,
)
from torch.distributed.tensor import Shard
from torch.testing._internal.common_distributed import (
requires_nccl_version,
SaveForwardInputsModel,
@ -46,35 +45,20 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
def _init_models_and_optims(
self,
reshard_after_forward: Union[bool, int],
param_dtype: Optional[torch.dtype],
reduce_dtype: Optional[torch.dtype],
use_shard_placement_fn,
):
torch.manual_seed(42)
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = -1
largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
assert largest_dim >= 0, f"{param.shape}"
return Shard(largest_dim)
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype
)
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
replicate_fn = functools.partial(
replicate,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
shard_placement_fn=shard_placement_fn,
)
for mlp in model:
replicate_fn(mlp)
@ -82,27 +66,13 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
return ref_model, ref_optim, model, optim
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
use_shard_placement_fn_vals = [False]
if self.world_size == 2:
# For world size >2, gradient elements get reduced in different
# orders for the baseline vs. dim-1 sharding, leading to numeric
# differences for bf16 reduction, so only test world size 2.
use_shard_placement_fn_vals.append(True)
return use_shard_placement_fn_vals
@skipIfRocmVersionLessThan((7, 0))
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_compute_dtype(self):
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"param_dtype": [torch.bfloat16, torch.float16],
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_compute_dtype,
)
@ -110,14 +80,10 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
def _test_compute_dtype(
self,
param_dtype: torch.dtype,
reshard_after_forward: Union[bool, int],
use_shard_placement_fn: bool,
):
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=None,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
@ -175,39 +141,14 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_reduce_dtype(self):
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": [False, True],
},
self._test_reduce_dtype_fp32_reduce,
)
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_reduce_dtype_bf16_reduce,
)
self._test_reduce_dtype_fp32_reduce()
self._test_reduce_dtype_bf16_reduce()
def _test_reduce_dtype_fp32_reduce(
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
):
if (
self.world_size > 2
and isinstance(reshard_after_forward, int)
and use_shard_placement_fn
):
return
def _test_reduce_dtype_fp32_reduce(self):
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
@ -249,14 +190,12 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
check_sharded_parity(self, ref_model, model)
def _test_reduce_dtype_bf16_reduce(
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
self,
):
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
group = dist.distributed_c10d._get_default_group()
orig_reduce_scatter = dist.reduce_scatter_tensor
@ -321,12 +260,8 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for mlp in model:
replicate(
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
replicate(
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
replicate(mlp, mp_policy=mp_policy)
replicate(model, mp_policy=mp_policy)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
orig_reduce_scatter = dist.reduce_scatter_tensor

View File

@ -108,84 +108,70 @@ class TestReplicateRegisteredParams(FSDPTestMultiThread):
"""Tests the parameter registration after forward."""
device = torch.device(device_type.type, 0)
# Single Replicate group
for reshard_after_forward in (True, False, None):
torch.manual_seed(42)
model = MLP(3, device)
# Since seed is per process, not per thread, we broadcast to ensure
# the same parameters across ranks
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 3), device=device_type.type)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
if reshard_after_forward:
self._assert_dtensor_params(model.parameters())
else:
self._assert_tensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
torch.manual_seed(42)
model = MLP(3, device)
# Since seed is per process, not per thread, we broadcast to ensure
# the same parameters across ranks
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model) # root only
inp = torch.randn((2, 3), device=device_type.type)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
self._assert_tensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
# Multiple Replicate groups
for reshard_after_forward in (True, False, None):
torch.manual_seed(42)
model = nn.Sequential(MLP(3, device), MLP(3, device))
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
torch.manual_seed(42)
model = nn.Sequential(MLP(3, device), MLP(3, device))
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model[0].in_proj)
replicate(model[0].out_proj)
replicate(model)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
non_root_params = list(model[0].in_proj.parameters()) + list(
model[0].out_proj.parameters()
)
root_params = list(set(model.parameters()) - set(non_root_params))
if reshard_after_forward is None:
self._assert_dtensor_params(non_root_params)
self._assert_tensor_params(root_params)
elif reshard_after_forward:
self._assert_dtensor_params(non_root_params)
self._assert_dtensor_params(root_params)
else:
self._assert_tensor_params(non_root_params)
self._assert_tensor_params(root_params)
self._assert_same_params(model.parameters(), ref_model.parameters())
for module in model.modules():
if isinstance(module, FSDPModule):
module.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
non_root_params = list(model[0].in_proj.parameters()) + list(
model[0].out_proj.parameters()
)
root_params = list(set(model.parameters()) - set(non_root_params))
self._assert_tensor_params(non_root_params)
self._assert_tensor_params(root_params)
self._assert_same_params(model.parameters(), ref_model.parameters())
for module in model.modules():
if isinstance(module, FSDPModule):
module.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
@skip_if_lt_x_gpu(1)
def test_param_registration_after_backward(self):
"""Tests the parameter registration after backward."""
device = torch.device(device_type.type, 0)
# Single Replicate group
for reshard_after_forward in (True, False):
model = MLP(8, device)
replicate(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 8), device=device_type.type)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
model = MLP(8, device)
replicate(model) # root only
inp = torch.randn((2, 8), device=device_type.type)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
# Multiple Replicate groups
for reshard_after_forward in (True, False):
model = MLP(8, device)
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
model = MLP(8, device)
replicate(model.in_proj)
replicate(model.out_proj)
replicate(model)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
# need to iterate over the list multiple times
@ -287,14 +273,11 @@ class TestReplicate1DTrainingCore(FSDPTest):
[(7, 15), (15, 3)],
[(16, 17), (17, 8)],
],
"use_shard_placement_fn": [False],
},
self._test_train_parity_single_group,
)
def _test_train_parity_single_group(
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
):
def _test_train_parity_single_group(self, lin_shapes: list[tuple[int, int]]):
torch.manual_seed(42)
model = nn.Sequential(
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
@ -333,7 +316,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
"""
self.run_subtests(
{
"reshard_after_forward": [True, False],
"test_device_type": [device_type.type],
"offload_policy": [OffloadPolicy()],
"delay_after_forward": [False, True],
@ -354,7 +336,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
"""
self.run_subtests(
{
"reshard_after_forward": [True], # save CI time
"offload_policy": [
CPUOffloadPolicy(pin_memory=True),
CPUOffloadPolicy(pin_memory=False),
@ -371,7 +352,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
def _test_train_parity_multi_group(
self,
reshard_after_forward: Union[bool, int],
offload_policy: OffloadPolicy,
test_device_type: str,
delay_after_forward: bool,
@ -405,13 +385,12 @@ class TestReplicate1DTrainingCore(FSDPTest):
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
mesh = init_device_mesh(
test_device_type,
(self.world_size, 1),
mesh_dim_names=("replicate", "shard"),
(self.world_size,),
mesh_dim_names=("replicate",),
)
fully_shard_fn = functools.partial(
replicate,
device_mesh=mesh,
reshard_after_forward=reshard_after_forward,
mesh=mesh,
offload_policy=offload_policy,
)
for module in model.modules():
@ -527,12 +506,10 @@ class TestReplicate1DTrainingCore(FSDPTest):
Tests parity when running a module that participates multiple
times in forward.
"""
self.run_subtests(
{"reshard_after_forward": [True, False]},
self._test_multi_forward_module,
)
def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]):
self._test_multi_forward_module()
def _test_multi_forward_module(self):
class MultiForwardModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
@ -687,7 +664,6 @@ class TestReplicateTrainingCompose(FSDPTest):
"""
self.run_subtests(
{
"reshard_after_forward": [True, False],
"checkpoint_impl": ["composable", "utils", "wrapper"],
"module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"],
"test_device_type": [device_type.type],
@ -697,7 +673,6 @@ class TestReplicateTrainingCompose(FSDPTest):
def _test_train_parity_with_activation_checkpointing(
self,
reshard_after_forward: Union[bool, int],
checkpoint_impl: str,
module_grouping: str,
test_device_type: str,
@ -740,12 +715,11 @@ class TestReplicateTrainingCompose(FSDPTest):
# Apply Replicate
device_mesh = init_device_mesh(
test_device_type,
(self.world_size, 1),
mesh_dim_names=("replicate", "shard"),
(self.world_size,),
mesh_dim_names=("replicate",),
)
fsdp_kwargs = {
"reshard_after_forward": reshard_after_forward,
"device_mesh": device_mesh,
"mesh": device_mesh,
}
if module_grouping == "mem_eff":
assert model_args.n_layers == 3
@ -809,7 +783,6 @@ class TestReplicateSharedParams(FSDPTest):
def test_train_parity_with_shared_params(self):
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
},
self._test_train_shared_params,
@ -817,7 +790,6 @@ class TestReplicateSharedParams(FSDPTest):
def _test_train_shared_params(
self,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
):
torch.manual_seed(42)
@ -830,8 +802,8 @@ class TestReplicateSharedParams(FSDPTest):
if isinstance(module, TransformerBlock):
if use_activation_checkpointing:
checkpoint(module)
replicate(module, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
replicate(module)
replicate(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.manual_seed(42 + self.rank + 1)
@ -868,11 +840,11 @@ class TestReplicateGradientAccumulation(FSDPTest):
with/without resharding after backward.
"""
shard_size, replicate_size = 1, self.world_size
replicate_size = self.world_size
meshes = init_device_mesh(
device_type.type,
(replicate_size, shard_size),
mesh_dim_names=("replicate", "shard"),
(replicate_size,),
mesh_dim_names=("replicate",),
)
self.run_subtests(
{
@ -928,8 +900,7 @@ class TestReplicateGradientAccumulation(FSDPTest):
ref_model = copy.deepcopy(model).to(device_type)
replicate_fn = functools.partial(
replicate,
device_mesh=mesh,
reshard_after_forward=reshard_after_forward,
mesh=mesh,
offload_policy=offload_policy,
)
for mlp in model[1:]:
@ -1040,8 +1011,8 @@ class TestReplicateGradientAccumulation(FSDPTest):
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
replicate(module, reshard_after_forward=False)
replicate(model, reshard_after_forward=False)
replicate(module)
replicate(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
num_microbatches = 3
@ -1145,8 +1116,8 @@ class TestReplicateTPTraining(FSDPTest):
def init_global_mesh(self) -> DeviceMesh:
return init_device_mesh(
device_type.type,
(2, 1, 2),
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
(2, 2),
mesh_dim_names=("dp_replicate", "tp"),
)
@skip_if_lt_x_gpu(8)
@ -1154,7 +1125,6 @@ class TestReplicateTPTraining(FSDPTest):
global_mesh = self.init_global_mesh()
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 5, 16, 17],
"foreach": [False],
@ -1165,12 +1135,11 @@ class TestReplicateTPTraining(FSDPTest):
def _test_replicate_tp(
self,
global_mesh: DeviceMesh,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
mlp_dim: int,
foreach: bool,
):
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
dp_mesh, tp_mesh = global_mesh["dp_replicate"], global_mesh["tp"]
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
torch.manual_seed(42)
@ -1197,8 +1166,8 @@ class TestReplicateTPTraining(FSDPTest):
continue
if use_activation_checkpointing:
checkpoint(module)
replicate(module, device_mesh=dp_mesh)
replicate(model, device_mesh=dp_mesh)
replicate(module, mesh=dp_mesh)
replicate(model, mesh=dp_mesh)
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
# strided-sharded layers.
@ -1229,11 +1198,9 @@ class TestReplicateTPTraining(FSDPTest):
for _, p in model.named_parameters():
self.assertIsInstance(p, DTensor)
self.assertEqual(p.device_mesh.ndim, 3)
self.assertEqual(len(p.placements), 3)
self.assertEqual(
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
)
self.assertEqual(p.device_mesh.ndim, 2)
self.assertEqual(len(p.placements), 2)
self.assertEqual(p.device_mesh.mesh_dim_names, ("dp_replicate", "tp"))
if __name__ == "__main__":

View File

@ -120,7 +120,7 @@ class ReplicateTest(MultiProcessTestCase):
if i % 2 == 0:
self.assertTrue("replicate" in _get_registry(layer))
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
self.assertEqual(parameter.placements, (Replicate(),))
elif i % 2 == 1:
self.assertTrue("fully_shard" in _get_registry(layer))
for parameter in layer.parameters():
@ -197,14 +197,14 @@ class ReplicateTest(MultiProcessTestCase):
]
global_mesh = self.init_replicate_tp_mesh()
replicate_mesh = global_mesh["replicate", "shard"]
replicate_mesh = global_mesh["replicate"]
for layer in layers:
replicate(layer, device_mesh=replicate_mesh)
replicate(layer, mesh=replicate_mesh)
for parameter in layer.parameters():
self.assertEqual(parameter.device_mesh.shape, (2, 1))
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
self.assertEqual(parameter.device_mesh.shape, (2,))
self.assertEqual(parameter.placements, (Replicate(),))
@skip_if_lt_x_gpu(2)
def test_train_replicate_fsdp(self):
@ -263,7 +263,6 @@ class ReplicateTest(MultiProcessTestCase):
run_subtests(
self,
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 16, 17],
},
@ -273,7 +272,6 @@ class ReplicateTest(MultiProcessTestCase):
def _test_train_parity_2d_mlp(
self,
global_mesh: DeviceMesh,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
mlp_dim: int,
):
@ -287,13 +285,12 @@ class ReplicateTest(MultiProcessTestCase):
torch.manual_seed(42)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, device_mesh=replicate_shard_mesh)
replicate(ref_model, mesh=replicate_mesh)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
model.parallelize(
tp_mesh,
replicate_shard_mesh,
use_activation_checkpointing,
reshard_after_forward=reshard_after_forward,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import Optional, TYPE_CHECKING, Union
from typing import Optional, overload, Union
import torch
import torch.distributed as dist
@ -14,13 +14,12 @@ from torch.distributed.fsdp._fully_shard._fsdp_api import (
OffloadPolicy,
)
from torch.distributed.fsdp._fully_shard._fsdp_common import (
DDPMeshInfo,
detect_compiled_autograd,
HSDPMeshInfo,
)
from torch.distributed.fsdp._fully_shard._fsdp_init import (
_get_device_from_mesh,
_get_managed_states,
_get_post_forward_mesh_info,
_init_default_fully_shard_mesh,
_move_states_to_device,
)
@ -39,12 +38,6 @@ from torch.distributed.utils import _get_root_modules
from .contract import _get_registry, contract
if TYPE_CHECKING:
from collections.abc import Callable
from torch.distributed.tensor import Shard
cls_to_replicate_cls: dict[type, type] = {}
_ROOT_MODULE_PREFIX = ""
@ -95,7 +88,7 @@ class _ReplicateState(FSDPState):
modules: tuple[nn.Module, ...],
device: torch.device,
mp_policy: MixedPrecisionPolicy,
auto_reshard_after_forward: bool,
auto_reshard_after_forward: bool = False,
) -> None:
for module in modules:
_insert_module_state(module, self)
@ -171,8 +164,6 @@ def replicate_impl(
mesh: DeviceMesh,
*,
device_id: Optional[Union[int, torch.device]] = None,
reshard_after_forward: Optional[Union[bool, int]] = None,
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
offload_policy: OffloadPolicy = OffloadPolicy(),
ignored_params: Optional[set[nn.Parameter]] = None,
@ -184,30 +175,25 @@ def replicate_impl(
)
mesh = mesh or _init_default_fully_shard_mesh()
if mesh.ndim != 2:
raise ValueError(f"replicate expects a 2D DeviceMesh but got {mesh}")
if mesh.ndim != 1:
raise ValueError(f"replicate expects a 1D DeviceMesh but got {mesh}")
else:
if mesh.mesh_dim_names is None:
raise AssertionError(
"Please init the 2D mesh for HSDP with mesh_dim_names specified"
)
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0)
device = _get_device_from_mesh(mesh)
auto_reshard_after_forward = reshard_after_forward is None
# If the user does not provide ``reshard_after_forward``, we set it to True.
# During lazy_init, we identify which module is the root and override its value to False
post_forward_mesh_info = _get_post_forward_mesh_info(
reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type]
mesh_info,
)
post_forward_mesh_info = None
arg_module = module
modules = (
(module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
)
state = replicate.state(modules[0]) # type: ignore[attr-defined] # see [1]
state.init(modules, device, mp_policy, auto_reshard_after_forward)
state.init(modules, device, mp_policy)
managed_modules = _get_managed_modules(modules, ignored_params)
params, buffers = _get_managed_states(managed_modules, ignored_params)
@ -217,10 +203,10 @@ def replicate_impl(
state._fsdp_param_group = FSDPParamGroup(
params,
modules,
mesh_info,
mesh_info, # type: ignore[arg-type]
post_forward_mesh_info,
device,
shard_placement_fn,
None,
mp_policy,
offload_policy,
)
@ -237,11 +223,39 @@ def replicate_impl(
return arg_module
@contract(state_cls=_ReplicateState)
@overload
# pyrefly: ignore [inconsistent-overload]
def replicate(
module: nn.Module,
**kwargs,
) -> nn.Module:
*,
mesh: Optional[DeviceMesh] = ...,
mp_policy: MixedPrecisionPolicy = ...,
offload_policy: OffloadPolicy = ...,
ignored_params: Optional[set[nn.Parameter]] = ...,
) -> ReplicateModule: ...
@overload
# pyrefly: ignore [inconsistent-overload]
def replicate(
module: list[nn.Module],
*,
mesh: Optional[DeviceMesh] = ...,
mp_policy: MixedPrecisionPolicy = ...,
offload_policy: OffloadPolicy = ...,
ignored_params: Optional[set[nn.Parameter]] = ...,
) -> list[ReplicateModule]: ...
@contract(state_cls=_ReplicateState) # type: ignore[misc]
def replicate(
module: nn.Module,
*,
mesh: Optional[DeviceMesh] = None,
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
offload_policy: OffloadPolicy = OffloadPolicy(),
ignored_params: Optional[set[nn.Parameter]] = None,
):
r"""Replicates a module
Args:
@ -253,24 +267,21 @@ def replicate(
>>> replicate(module)
"""
if "device_id" in kwargs:
if not isinstance(kwargs["device_id"], (int, torch.device)):
raise RuntimeError(
"Expected device_id to be int or torch.device, "
f"but got {type(kwargs['device_id'])}"
)
if not is_composable_with_replicate(module):
raise RuntimeError(
"Cannot apply `replicate()` on a Module already managed by `fully_shard`"
)
device_mesh = kwargs.pop("device_mesh", None)
if device_mesh is None:
device_mesh = replicate_mesh()
if mesh is None:
mesh = replicate_mesh()
module = replicate_impl(module, mesh=device_mesh, **kwargs)
return module
return replicate_impl(
module,
mesh,
mp_policy=mp_policy,
offload_policy=offload_policy,
ignored_params=ignored_params,
)
class ReplicateModule(FSDPModule):
@ -341,8 +352,8 @@ def replicate_mesh():
device = torch._C._get_accelerator()
mesh = init_device_mesh(
device.type,
mesh_shape=(default_pg.size(), 1),
mesh_dim_names=("replicate", "shard"),
mesh_shape=(default_pg.size(),),
mesh_dim_names=("replicate",),
)
return mesh

View File

@ -492,7 +492,11 @@ def foreach_reduce(
force_sum_reduction_for_comms,
)
)
world_size = reduce_scatter_group.size()
if reduce_scatter_group is None:
world_size = 1
else:
world_size = reduce_scatter_group.size()
device_handle = _get_device_handle(device.type)
current_stream = device_handle.current_stream()
@ -547,7 +551,7 @@ def foreach_reduce(
reduce_output.copy_(reduce_scatter_input)
reduce_scatter_event = reduce_scatter_stream.record_event()
post_reduce_stream = reduce_scatter_stream
if all_reduce_group is not None: # HSDP
if all_reduce_group is not None: # HSDP or DDP/replicate
# Accumulations must run in the reduce-scatter stream
if not all_reduce_grads:
if partial_reduce_output is not None:
@ -690,7 +694,7 @@ def _get_all_gather_input_metadatas(
def _get_gradient_divide_factors(
reduce_scatter_group: dist.ProcessGroup,
reduce_scatter_group: Optional[dist.ProcessGroup],
all_reduce_group: Optional[dist.ProcessGroup],
reduce_dtype: torch.dtype,
device_type: str = "",
@ -709,8 +713,11 @@ def _get_gradient_divide_factors(
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
# use NCCL's built-in division to avoid separate div kernels
overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16)
if reduce_scatter_group is not None:
data_parallel_size = reduce_scatter_group.size()
else:
data_parallel_size = 1
data_parallel_size = reduce_scatter_group.size()
if all_reduce_group is not None:
data_parallel_size *= all_reduce_group.size()

View File

@ -11,6 +11,7 @@ import torch.nn as nn
from torch._prims_common import make_contiguous_strides_for
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp._fully_shard._fsdp_common import DDPMeshInfo
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.placement_types import _StridedShard, Placement
@ -306,22 +307,29 @@ class FSDPParam:
f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}."
)
self._spmd_placements: tuple[Placement, ...]
dp_shard_tp_placement = (
(
_StridedShard(shard_dim, split_factor=split_factor)
if split_factor > 1
else fsdp_placement
),
*self._tp_spec.placements,
)
if dp_mesh.ndim == 1: # FSDP
self._spmd_placements = dp_shard_tp_placement
else: # HSDP
if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP
dp_shard_tp_placement = (
(
_StridedShard(shard_dim, split_factor=split_factor)
if split_factor > 1
else fsdp_placement
),
*self._tp_spec.placements,
)
else: # DDP
dp_shard_tp_placement = (
(Replicate()),
*self._tp_spec.placements,
)
if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP
if self.mesh_info.replicate_mesh_dim != 0:
raise AssertionError(
f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}"
)
self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
else: # FSDP or DDP
self._spmd_placements = dp_shard_tp_placement
self._sharding_spec = DTensorSpec(
self._spmd_mesh,
self._spmd_placements,
@ -330,10 +338,12 @@ class FSDPParam:
param_data = cast(DTensor, param)._local_tensor
else:
self._spmd_mesh = self.mesh_info.mesh
if isinstance(self.mesh_info, HSDPMeshInfo):
if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP
self._spmd_placements = (Replicate(), fsdp_placement)
else:
elif isinstance(self.mesh_info, FSDPMeshInfo): # FSDP
self._spmd_placements = (fsdp_placement,)
elif isinstance(self.mesh_info, DDPMeshInfo): # DDP
self._spmd_placements = (Replicate(),)
self._sharding_spec = DTensorSpec(
self._spmd_mesh,
self._spmd_placements,
@ -351,8 +361,13 @@ class FSDPParam:
)
self._orig_size = param_data.size()
self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
shard_rank = self.mesh_info.shard_mesh_rank
shard_world_size = self.mesh_info.shard_mesh_size
if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP
shard_rank = self.mesh_info.shard_mesh_rank
shard_world_size = self.mesh_info.shard_mesh_size
else: # DDP
shard_rank = 0
shard_world_size = 1
if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0:
# If sharding on nonzero dim, require even sharding for now because
# the uneven sharding (1) requires extra copies before/after FSDP
@ -401,12 +416,20 @@ class FSDPParam:
if mesh_info is None:
raise AssertionError("Expected post_forward_mesh_info to not be None")
param_data = param._local_tensor if isinstance(param, DTensor) else param
chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
self.sharded_post_forward_size = _get_dim_chunked_size(
chunks[mesh_info.shard_mesh_rank],
param_data.size(),
dim=self.fsdp_placement.dim,
)
if isinstance(mesh_info, FSDPMeshInfo):
chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
self.sharded_post_forward_size = _get_dim_chunked_size(
chunks[mesh_info.shard_mesh_rank],
param_data.size(),
dim=self.fsdp_placement.dim,
)
else: # DDP
chunks = _chunk_with_empty(param_data, 1, dim=0)
self.sharded_post_forward_size = _get_dim_chunked_size(
chunks[0],
param_data.size(),
dim=self.fsdp_placement.dim,
)
self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for(
self.sharded_post_forward_size
)

View File

@ -29,6 +29,7 @@ from ._fsdp_collectives import (
)
from ._fsdp_common import (
compiled_autograd_enabled,
DDPMeshInfo,
FSDPMeshInfo,
HSDPMeshInfo,
is_bw,
@ -315,7 +316,10 @@ class FSDPParamGroup:
self._wait_all_gather_streams_on_event(self._reshard_after_forward_event)
self._reshard_after_forward_event = None
world_size = self._all_gather_process_group.size()
if isinstance(self.mesh_info, FSDPMeshInfo):
world_size = self._all_gather_process_group.size()
else:
world_size = 1
if world_size == 1:
# can't skip due to early return in wait_for_unshard if
# no self._all_gather_result
@ -356,7 +360,10 @@ class FSDPParamGroup:
if prev_all_gather_state := self.comm_ctx.all_gather_state:
self._wait_all_gather_streams_on_event(prev_all_gather_state.event)
self.comm_ctx.all_gather_state = None # free the all-gather result
world_size = self._all_gather_process_group.size()
if isinstance(self.mesh_info, FSDPMeshInfo):
world_size = self._all_gather_process_group.size()
else:
world_size = 1
if world_size == 1:
# directly initialize unsharded parameters from sharded parameters
@ -531,7 +538,11 @@ class FSDPParamGroup:
self.comm_ctx.reduce_scatter_state.event
)
self.comm_ctx.reduce_scatter_state = None
all_reduce_pg = self._all_reduce_process_group if self._is_hsdp else None
all_reduce_pg = (
self._all_reduce_process_group
if isinstance(self.mesh_info, DDPMeshInfo)
else None
)
all_reduce_stream: torch.cuda.Stream
if all_reduce_pg is None and self._all_reduce_hook_stream is not None:
# this means the native HSDP is not enabled,
@ -555,14 +566,22 @@ class FSDPParamGroup:
) = foreach_reduce(
fsdp_params_with_grad,
unsharded_grads,
self._reduce_scatter_process_group,
(
self._reduce_scatter_process_group
if isinstance(self.mesh_info, FSDPMeshInfo)
else None
),
self.comm_ctx.reduce_scatter_stream,
self._reduce_scatter_comm,
self._orig_dtype,
self._reduce_dtype,
self.device,
self.gradient_divide_factor,
self._all_reduce_process_group if self._is_hsdp else None,
(
self._all_reduce_process_group
if isinstance(self.mesh_info, DDPMeshInfo)
else None
),
all_reduce_stream,
self.all_reduce_grads,
self._partial_reduce_output,
@ -776,9 +795,9 @@ class FSDPParamGroup:
@property
def _all_reduce_process_group(self) -> dist.ProcessGroup:
if not isinstance(self.mesh_info, HSDPMeshInfo):
if not isinstance(self.mesh_info, DDPMeshInfo):
raise AssertionError(
f"Expected mesh_info to be HSDPMeshInfo, got {type(self.mesh_info)}"
f"Expected mesh_info to be DDPMeshInfo or HSDPMeshInfo, got {type(self.mesh_info)}"
)
return self.mesh_info.replicate_process_group