Files
pytorch/test/distributed/test_composability.py
Howard Huang 2beead7523 [PP] move FSDP reduce scatters to end of step (#165106)
Move FSDP reduce scatters to the end of the PP step. The reduce scatter compute stream sync blocks the other stages from executing their backwards leading to bubbles. There should be a way to execute these RS earlier, but doing this for now as a quick fix.

<img width="1056" height="463" alt="image" src="https://github.com/user-attachments/assets/b945dd55-8ab1-4acc-b862-c6e2e476b834" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165106
Approved by: https://github.com/weifengpy
ghstack dependencies: #164976
2025-10-12 13:28:02 +00:00

536 lines
19 KiB
Python

# Owner(s): ["oncall: distributed"]
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.fsdp._fully_shard._fsdp_param import ShardedState
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import (
_Action,
_ComputationType,
_PipelineScheduleRuntime,
PipelineScheduleSingle,
Schedule1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS,
)
from torch.distributed.tensor import DTensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_WITH_ROCM,
)
device_type = "cuda"
# MLP Layer
class MLPModule(torch.nn.Module):
def __init__(self, d_hid: int):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)
self.init_weights()
def init_weights(self):
# ensure a proper init otherwise gradient tests will be more likely to get zero grad values
torch.nn.init.kaiming_uniform_(
self.net1.weight, mode="fan_in", nonlinearity="relu"
)
torch.nn.init.kaiming_uniform_(
self.net2.weight, mode="fan_in", nonlinearity="relu"
)
def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x
class MLPModuleEven(torch.nn.Module):
def __init__(self, d_hid: int):
super().__init__()
self.net1 = nn.Linear(d_hid, d_hid)
self.net2 = nn.Linear(d_hid, d_hid)
self.net3 = nn.Linear(d_hid, d_hid * 2)
self.init_weights()
def init_weights(self):
torch.nn.init.kaiming_uniform_(
self.net1.weight, mode="fan_in", nonlinearity="relu"
)
torch.nn.init.kaiming_uniform_(
self.net2.weight, mode="fan_in", nonlinearity="relu"
)
torch.nn.init.kaiming_uniform_(
self.net3.weight, mode="fan_in", nonlinearity="relu"
)
def forward(self, x):
x = F.relu(self.net1(x))
x = F.relu(self.net2(x))
x = F.relu(self.net3(x))
return x
def loss_fn(y, target, scale=1e-4):
# Scale the loss to simulate a small learning rate and avoid exploding grads
return torch.nn.functional.cross_entropy(y, target) * scale
class ComposabilityTest(MultiProcContinuousTest):
@classmethod
def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
def _rand_microbatches(self, dp_mesh, num_microbatches, dim, dtype=torch.float32):
full = [
torch.rand((num_microbatches, dim), device=self.device, dtype=dtype)
for _ in range(dp_mesh.size())
]
local = full[dp_mesh.get_local_rank()]
local_mb = [[local[i].reshape((1, dim))] for i in range(num_microbatches)]
return full, local, local_mb
# build a pipeline stage
def _build_pp_stage(
self, pp_group, full_model, total_layers, apply_dp, stage_idx, num_stages
):
# divide the model (e.g. 8 layers) by the number of stages
layers_per_stage = total_layers // num_stages
assert layers_per_stage * num_stages == total_layers
# return offset so validation code can match partial layer back to orig model
offset = stage_idx * layers_per_stage
partial_model = nn.Sequential(
*full_model[offset : (stage_idx + 1) * layers_per_stage]
)
partial_model.to(self.device)
dp_model = apply_dp(partial_model)
stage = PipelineStage(
dp_model,
stage_idx,
num_stages,
self.device,
group=pp_group,
)
return stage, offset
def _build_pp_schedule(
self,
ScheduleClass,
num_microbatches,
pp_group,
full_model,
total_layers,
apply_dp,
loss_fn,
scale_grads=True,
):
if issubclass(ScheduleClass, PipelineScheduleSingle):
pipeline_stage, offset = self._build_pp_stage(
pp_group,
full_model,
total_layers,
apply_dp,
pp_group.rank(),
pp_group.size(),
)
partial_models = [pipeline_stage.submod]
offsets = [offset]
pipeline_schedule = ScheduleClass(
pipeline_stage,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
scale_grads=scale_grads,
)
else:
n_virtual = 2
num_stages = pp_group.size() * n_virtual
stages = []
offsets = []
for i in range(n_virtual):
stage, offset = self._build_pp_stage(
pp_group,
full_model,
total_layers,
apply_dp,
pp_group.rank() + n_virtual * i,
num_stages,
)
stages.append(stage)
offsets.append(offset)
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
pipeline_schedule = ScheduleClass(
stages,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
scale_grads=scale_grads,
)
return pipeline_schedule, partial_models, offsets
@requires_nccl()
@skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
@parametrize(
"ScheduleClass",
[
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
],
)
def test_pp_ddp(self, ScheduleClass):
if ScheduleClass == ScheduleInterleavedZeroBubble:
# TODO: DDP + InterleavedZeroBubble is not currently supported due to issue with DDP reducer not triggering
# https://github.com/pytorch/pytorch/issues/144530
return
torch.get_device_module(device_type).set_device(self.device)
mesh_shape = (self.world_size // 2, 2)
mesh_dim_names = ("dp", "pp")
device_mesh = init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
pp_group = device_mesh["pp"].get_group()
dp_mesh = device_mesh["dp"]
# create "entire model"
total_layers = 8
num_microbatches = 8
dim = 10
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
ref_model = nn.Sequential(*copy.deepcopy(full_model))
ref_model.to(self.device)
# Prepare inputs
inputs, input_local, _ = self._rand_microbatches(dp_mesh, num_microbatches, dim)
targets, target_local, _ = self._rand_microbatches(
dp_mesh, num_microbatches, dim
)
def apply_dp(partial_model):
return DDP(partial_model, process_group=dp_mesh.get_group())
# Build pipeline stages, apply data parallelism and attach to a schedule
pipeline_schedule, partial_models, offsets = self._build_pp_schedule(
ScheduleClass,
num_microbatches,
pp_group,
full_model,
total_layers,
apply_dp,
loss_fn,
)
# Run the pipeline
if pp_group.rank() == 0:
pipeline_schedule.step(input_local)
else:
pipeline_schedule.step(target=target_local)
# Ref model runs on 2 different inputs, accumulating grads across them.
# this ensures that we detect if the DDP all-reduce becomes a no-op.
for sim_dp_rank in range(dp_mesh.size()):
loss_fn(ref_model(inputs[sim_dp_rank]), targets[sim_dp_rank]).backward()
ref_model.to(torch.float32)
for p in ref_model.parameters():
p.grad = p.grad.to(torch.float32)
p.grad /= dp_mesh.size()
# Validate that whichever weights we have locally match that part of our local/full ref model
ref_parameters = dict(ref_model.named_parameters())
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")[
1:
] # remove the DDP module. prefix (FSDP2 doesn't have one)
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
torch.testing.assert_close(p.grad, ref_p.grad)
@requires_nccl()
@skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
@parametrize("dp_type", ["FSDP", "FSDP_MP"])
@parametrize(
"ScheduleClass",
[
Schedule1F1B,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ScheduleInterleavedZeroBubble,
],
)
def test_pp_fsdp(self, dp_type, ScheduleClass):
if TEST_WITH_ROCM:
return
torch.get_device_module(device_type).set_device(self.device)
mesh_shape = (self.world_size // 2, 2)
mesh_dim_names = ("dp", "pp")
device_mesh = init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
pp_group = device_mesh["pp"].get_group()
dp_mesh = device_mesh["dp"]
# fsdp_mixed-precision dtype
mp_dtype = torch.bfloat16 if dp_type == "FSDP_MP" else torch.float32
# create "entire model"
total_layers = 8
num_microbatches = 8
dim = 10
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
ref_model = nn.Sequential(*copy.deepcopy(full_model))
ref_model.to(self.device)
if dp_type == "FSDP_MP":
ref_model.to(dtype=mp_dtype)
# Prepare inputs
inputs, input_local, _ = self._rand_microbatches(
dp_mesh, num_microbatches, dim, dtype=mp_dtype
)
targets, target_local, _ = self._rand_microbatches(
dp_mesh, num_microbatches, dim, dtype=mp_dtype
)
# Apply FSDP to stage module
def apply_dp(partial_model):
mp_policy = MixedPrecisionPolicy(
param_dtype=mp_dtype,
reduce_dtype=torch.float32,
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer in partial_model.children():
fully_shard(
layer,
**fsdp_config,
reshard_after_forward=False,
)
return fully_shard(partial_model, **fsdp_config)
# Build pipeline stages, apply data parallelism and attach to a schedule
pipeline_schedule, partial_models, offsets = self._build_pp_schedule(
ScheduleClass,
num_microbatches,
pp_group,
full_model,
total_layers,
apply_dp,
loss_fn,
)
# Run the pipeline
if pp_group.rank() == 0:
pipeline_schedule.step(input_local)
else:
pipeline_schedule.step(target=target_local)
for m in partial_models:
for p in m.parameters():
assert p.grad is not None
# introduce a race condition for FSDP's reduce-scatter which could corrupt gradients if pipelining
# does not properly synchronize with FSDP
p.grad.div_(2.0)
p.grad.mul_(2.0)
# Ref model runs on 2 different inputs, accumulating grads across them.
# this ensures that we detect if the FSDP reduce becomes a no-op.
# (in fsdp case, we use one of these inputs on each DP rank)
for sim_dp_rank in range(dp_mesh.size()):
loss_fn(ref_model(inputs[sim_dp_rank]), targets[sim_dp_rank]).backward()
ref_model.to(torch.float32)
for p in ref_model.parameters():
p.grad = p.grad.to(torch.float32)
p.grad /= dp_mesh.size()
# Validate that whichever weights we have locally match that part of our local/full ref model
# (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler)
ref_parameters = dict(ref_model.named_parameters())
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
self.assertTrue(isinstance(p.grad, DTensor))
torch.testing.assert_close(
p.grad.full_tensor(), ref_p.grad, atol=5e-5, rtol=2e-2
)
@requires_nccl()
@skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
@parametrize("dp_type", ["FSDP", "FSDP_MP"])
def test_pp_fsdp_unshard_reshard_runtime(self, dp_type):
"""Test FSDP UNSHARD/RESHARD functionality using _PipelineScheduleRuntime with custom schedules."""
if TEST_WITH_ROCM:
return
torch.get_device_module(device_type).set_device(self.device)
mesh_shape = (self.world_size, 1)
mesh_dim_names = ("dp", "pp")
device_mesh = init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
pp_group = device_mesh["pp"].get_group()
dp_mesh = device_mesh["dp"]
# fsdp_mixed-precision dtype
mp_dtype = torch.bfloat16 if dp_type == "FSDP_MP" else torch.float32
total_layers = 4
dim = 10
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
def apply_dp(partial_model):
mp_policy = MixedPrecisionPolicy(
param_dtype=mp_dtype,
reduce_dtype=torch.float32,
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer in partial_model.children():
fully_shard(
layer,
**fsdp_config,
reshard_after_forward=False,
)
return fully_shard(partial_model, **fsdp_config)
# Build pipeline stages
num_stages = pp_group.size()
layers_per_stage = total_layers // num_stages
stage_idx = pp_group.rank()
offset = stage_idx * layers_per_stage
partial_model = nn.Sequential(
*full_model[offset : (stage_idx + 1) * layers_per_stage]
)
partial_model.to(self.device)
fsdp_model = apply_dp(partial_model)
distributed_state = fully_shard.state(fsdp_model)
distributed_state._lazy_init()
stage = PipelineStage(
fsdp_model,
stage_idx,
num_stages,
self.device,
group=pp_group,
)
# Helper function to check FSDP sharding state
def check_fsdp_unsharded_state(module, expected_unsharded=False):
"""Check if FSDP parameters are in expected sharding state."""
distributed_state = fully_shard.state(module)
unsharded_count = 0
total_fsdp_params = 0
for state in distributed_state._state_ctx.all_states:
if state._fsdp_param_group:
group = state._fsdp_param_group
for fsdp_param in group.fsdp_params:
total_fsdp_params += 1
if fsdp_param.sharded_state == ShardedState.UNSHARDED:
unsharded_count += 1
if expected_unsharded:
self.assertEqual(
unsharded_count,
total_fsdp_params,
f"Expected all {total_fsdp_params} FSDP parameters to be unsharded, "
f"but only {unsharded_count} are unsharded",
)
else:
self.assertEqual(
unsharded_count,
0,
f"Expected all FSDP parameters to be sharded, "
f"but {unsharded_count} out of {total_fsdp_params} are unsharded",
)
return total_fsdp_params > 0 # Return whether we found any FSDP parameters
# Test initial state - should be sharded
has_fsdp = check_fsdp_unsharded_state(stage.submod, expected_unsharded=False)
if not has_fsdp:
self.skipTest("No FSDP parameters found in the model")
def create_schedule(computation_types, microbatch_index=None):
schedule = {
0: [
_Action(
stage_index=0, # stage 0 (the only stage)
computation_type=comp_type,
microbatch_index=microbatch_index
if comp_type == _ComputationType.FORWARD
else None,
)
for comp_type in computation_types
]
}
return schedule
unshard_schedule = create_schedule(
[
_ComputationType.UNSHARD,
_ComputationType.FORWARD,
],
microbatch_index=0,
)
unshard_reshard_schedule = create_schedule(
[
_ComputationType.UNSHARD,
_ComputationType.FORWARD,
_ComputationType.RESHARD,
],
microbatch_index=0,
)
# Test 1: Run UNSHARD + RESHARD schedule
runtime = _PipelineScheduleRuntime(
[stage], n_microbatches=1, loss_fn=None, scale_grads=False
)
runtime.pipeline_order_with_comms = unshard_reshard_schedule
dummy_input = torch.randn(1, dim, device=self.device, dtype=mp_dtype)
runtime.step(dummy_input)
# Verify parameters are now sharded again
check_fsdp_unsharded_state(stage.submod, expected_unsharded=False)
# Test 2: Run UNSHARD only schedule
runtime.pipeline_order_with_comms = unshard_schedule
runtime.step(dummy_input)
# Verify parameters are still sharded
check_fsdp_unsharded_state(stage.submod, expected_unsharded=False)
instantiate_parametrized_tests(ComposabilityTest)
if __name__ == "__main__":
run_tests()