mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 06:07:55 +08:00
Move FSDP reduce scatters to the end of the PP step. The reduce scatter compute stream sync blocks the other stages from executing their backwards leading to bubbles. There should be a way to execute these RS earlier, but doing this for now as a quick fix. <img width="1056" height="463" alt="image" src="https://github.com/user-attachments/assets/b945dd55-8ab1-4acc-b862-c6e2e476b834" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165106 Approved by: https://github.com/weifengpy ghstack dependencies: #164976
536 lines
19 KiB
Python
536 lines
19 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
import copy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
|
from torch.distributed.fsdp._fully_shard._fsdp_param import ShardedState
|
|
from torch.distributed.pipelining import PipelineStage
|
|
from torch.distributed.pipelining.schedules import (
|
|
_Action,
|
|
_ComputationType,
|
|
_PipelineScheduleRuntime,
|
|
PipelineScheduleSingle,
|
|
Schedule1F1B,
|
|
ScheduleGPipe,
|
|
ScheduleInterleaved1F1B,
|
|
ScheduleInterleavedZeroBubble,
|
|
ScheduleLoopedBFS,
|
|
)
|
|
from torch.distributed.tensor import DTensor
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcContinuousTest,
|
|
requires_nccl,
|
|
skip_if_lt_x_gpu,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
skip_but_pass_in_sandcastle_if,
|
|
TEST_WITH_ROCM,
|
|
)
|
|
|
|
|
|
device_type = "cuda"
|
|
|
|
|
|
# MLP Layer
|
|
class MLPModule(torch.nn.Module):
|
|
def __init__(self, d_hid: int):
|
|
super().__init__()
|
|
self.net1 = torch.nn.Linear(d_hid, d_hid)
|
|
self.relu = torch.nn.ReLU()
|
|
self.net2 = torch.nn.Linear(d_hid, d_hid)
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
# ensure a proper init otherwise gradient tests will be more likely to get zero grad values
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net1.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net2.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.net1(x)
|
|
x = self.relu(x)
|
|
x = self.net2(x)
|
|
return x
|
|
|
|
|
|
class MLPModuleEven(torch.nn.Module):
|
|
def __init__(self, d_hid: int):
|
|
super().__init__()
|
|
self.net1 = nn.Linear(d_hid, d_hid)
|
|
self.net2 = nn.Linear(d_hid, d_hid)
|
|
self.net3 = nn.Linear(d_hid, d_hid * 2)
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net1.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net2.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net3.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.net1(x))
|
|
x = F.relu(self.net2(x))
|
|
x = F.relu(self.net3(x))
|
|
return x
|
|
|
|
|
|
def loss_fn(y, target, scale=1e-4):
|
|
# Scale the loss to simulate a small learning rate and avoid exploding grads
|
|
return torch.nn.functional.cross_entropy(y, target) * scale
|
|
|
|
|
|
class ComposabilityTest(MultiProcContinuousTest):
|
|
@classmethod
|
|
def backend_str(cls) -> str:
|
|
# Testing with NCCL backend
|
|
return "nccl"
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
def _rand_microbatches(self, dp_mesh, num_microbatches, dim, dtype=torch.float32):
|
|
full = [
|
|
torch.rand((num_microbatches, dim), device=self.device, dtype=dtype)
|
|
for _ in range(dp_mesh.size())
|
|
]
|
|
local = full[dp_mesh.get_local_rank()]
|
|
local_mb = [[local[i].reshape((1, dim))] for i in range(num_microbatches)]
|
|
return full, local, local_mb
|
|
|
|
# build a pipeline stage
|
|
def _build_pp_stage(
|
|
self, pp_group, full_model, total_layers, apply_dp, stage_idx, num_stages
|
|
):
|
|
# divide the model (e.g. 8 layers) by the number of stages
|
|
layers_per_stage = total_layers // num_stages
|
|
assert layers_per_stage * num_stages == total_layers
|
|
# return offset so validation code can match partial layer back to orig model
|
|
offset = stage_idx * layers_per_stage
|
|
partial_model = nn.Sequential(
|
|
*full_model[offset : (stage_idx + 1) * layers_per_stage]
|
|
)
|
|
partial_model.to(self.device)
|
|
dp_model = apply_dp(partial_model)
|
|
stage = PipelineStage(
|
|
dp_model,
|
|
stage_idx,
|
|
num_stages,
|
|
self.device,
|
|
group=pp_group,
|
|
)
|
|
return stage, offset
|
|
|
|
def _build_pp_schedule(
|
|
self,
|
|
ScheduleClass,
|
|
num_microbatches,
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
loss_fn,
|
|
scale_grads=True,
|
|
):
|
|
if issubclass(ScheduleClass, PipelineScheduleSingle):
|
|
pipeline_stage, offset = self._build_pp_stage(
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
pp_group.rank(),
|
|
pp_group.size(),
|
|
)
|
|
|
|
partial_models = [pipeline_stage.submod]
|
|
offsets = [offset]
|
|
pipeline_schedule = ScheduleClass(
|
|
pipeline_stage,
|
|
n_microbatches=num_microbatches,
|
|
loss_fn=loss_fn,
|
|
scale_grads=scale_grads,
|
|
)
|
|
else:
|
|
n_virtual = 2
|
|
num_stages = pp_group.size() * n_virtual
|
|
stages = []
|
|
offsets = []
|
|
for i in range(n_virtual):
|
|
stage, offset = self._build_pp_stage(
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
pp_group.rank() + n_virtual * i,
|
|
num_stages,
|
|
)
|
|
stages.append(stage)
|
|
offsets.append(offset)
|
|
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
|
|
pipeline_schedule = ScheduleClass(
|
|
stages,
|
|
n_microbatches=num_microbatches,
|
|
loss_fn=loss_fn,
|
|
scale_grads=scale_grads,
|
|
)
|
|
return pipeline_schedule, partial_models, offsets
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(4)
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
|
|
@parametrize(
|
|
"ScheduleClass",
|
|
[
|
|
ScheduleGPipe,
|
|
ScheduleInterleaved1F1B,
|
|
ScheduleInterleavedZeroBubble,
|
|
],
|
|
)
|
|
def test_pp_ddp(self, ScheduleClass):
|
|
if ScheduleClass == ScheduleInterleavedZeroBubble:
|
|
# TODO: DDP + InterleavedZeroBubble is not currently supported due to issue with DDP reducer not triggering
|
|
# https://github.com/pytorch/pytorch/issues/144530
|
|
return
|
|
|
|
torch.get_device_module(device_type).set_device(self.device)
|
|
mesh_shape = (self.world_size // 2, 2)
|
|
mesh_dim_names = ("dp", "pp")
|
|
device_mesh = init_device_mesh(
|
|
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
pp_group = device_mesh["pp"].get_group()
|
|
dp_mesh = device_mesh["dp"]
|
|
|
|
# create "entire model"
|
|
total_layers = 8
|
|
num_microbatches = 8
|
|
dim = 10
|
|
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
|
|
ref_model = nn.Sequential(*copy.deepcopy(full_model))
|
|
ref_model.to(self.device)
|
|
|
|
# Prepare inputs
|
|
inputs, input_local, _ = self._rand_microbatches(dp_mesh, num_microbatches, dim)
|
|
targets, target_local, _ = self._rand_microbatches(
|
|
dp_mesh, num_microbatches, dim
|
|
)
|
|
|
|
def apply_dp(partial_model):
|
|
return DDP(partial_model, process_group=dp_mesh.get_group())
|
|
|
|
# Build pipeline stages, apply data parallelism and attach to a schedule
|
|
pipeline_schedule, partial_models, offsets = self._build_pp_schedule(
|
|
ScheduleClass,
|
|
num_microbatches,
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
loss_fn,
|
|
)
|
|
|
|
# Run the pipeline
|
|
if pp_group.rank() == 0:
|
|
pipeline_schedule.step(input_local)
|
|
else:
|
|
pipeline_schedule.step(target=target_local)
|
|
|
|
# Ref model runs on 2 different inputs, accumulating grads across them.
|
|
# this ensures that we detect if the DDP all-reduce becomes a no-op.
|
|
for sim_dp_rank in range(dp_mesh.size()):
|
|
loss_fn(ref_model(inputs[sim_dp_rank]), targets[sim_dp_rank]).backward()
|
|
ref_model.to(torch.float32)
|
|
for p in ref_model.parameters():
|
|
p.grad = p.grad.to(torch.float32)
|
|
p.grad /= dp_mesh.size()
|
|
|
|
# Validate that whichever weights we have locally match that part of our local/full ref model
|
|
ref_parameters = dict(ref_model.named_parameters())
|
|
for partial_model, offset in zip(partial_models, offsets):
|
|
for name, p in partial_model.named_parameters():
|
|
parts = name.split(".")[
|
|
1:
|
|
] # remove the DDP module. prefix (FSDP2 doesn't have one)
|
|
parts[0] = str(int(parts[0]) + offset)
|
|
name = ".".join(parts)
|
|
ref_p = ref_parameters[name]
|
|
torch.testing.assert_close(p.grad, ref_p.grad)
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(4)
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
|
|
@parametrize("dp_type", ["FSDP", "FSDP_MP"])
|
|
@parametrize(
|
|
"ScheduleClass",
|
|
[
|
|
Schedule1F1B,
|
|
ScheduleInterleaved1F1B,
|
|
ScheduleLoopedBFS,
|
|
ScheduleInterleavedZeroBubble,
|
|
],
|
|
)
|
|
def test_pp_fsdp(self, dp_type, ScheduleClass):
|
|
if TEST_WITH_ROCM:
|
|
return
|
|
|
|
torch.get_device_module(device_type).set_device(self.device)
|
|
mesh_shape = (self.world_size // 2, 2)
|
|
mesh_dim_names = ("dp", "pp")
|
|
device_mesh = init_device_mesh(
|
|
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
pp_group = device_mesh["pp"].get_group()
|
|
dp_mesh = device_mesh["dp"]
|
|
|
|
# fsdp_mixed-precision dtype
|
|
mp_dtype = torch.bfloat16 if dp_type == "FSDP_MP" else torch.float32
|
|
|
|
# create "entire model"
|
|
total_layers = 8
|
|
num_microbatches = 8
|
|
dim = 10
|
|
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
|
|
ref_model = nn.Sequential(*copy.deepcopy(full_model))
|
|
ref_model.to(self.device)
|
|
if dp_type == "FSDP_MP":
|
|
ref_model.to(dtype=mp_dtype)
|
|
|
|
# Prepare inputs
|
|
inputs, input_local, _ = self._rand_microbatches(
|
|
dp_mesh, num_microbatches, dim, dtype=mp_dtype
|
|
)
|
|
targets, target_local, _ = self._rand_microbatches(
|
|
dp_mesh, num_microbatches, dim, dtype=mp_dtype
|
|
)
|
|
|
|
# Apply FSDP to stage module
|
|
def apply_dp(partial_model):
|
|
mp_policy = MixedPrecisionPolicy(
|
|
param_dtype=mp_dtype,
|
|
reduce_dtype=torch.float32,
|
|
)
|
|
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
|
for layer in partial_model.children():
|
|
fully_shard(
|
|
layer,
|
|
**fsdp_config,
|
|
reshard_after_forward=False,
|
|
)
|
|
return fully_shard(partial_model, **fsdp_config)
|
|
|
|
# Build pipeline stages, apply data parallelism and attach to a schedule
|
|
pipeline_schedule, partial_models, offsets = self._build_pp_schedule(
|
|
ScheduleClass,
|
|
num_microbatches,
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
loss_fn,
|
|
)
|
|
|
|
# Run the pipeline
|
|
if pp_group.rank() == 0:
|
|
pipeline_schedule.step(input_local)
|
|
else:
|
|
pipeline_schedule.step(target=target_local)
|
|
for m in partial_models:
|
|
for p in m.parameters():
|
|
assert p.grad is not None
|
|
# introduce a race condition for FSDP's reduce-scatter which could corrupt gradients if pipelining
|
|
# does not properly synchronize with FSDP
|
|
p.grad.div_(2.0)
|
|
p.grad.mul_(2.0)
|
|
|
|
# Ref model runs on 2 different inputs, accumulating grads across them.
|
|
# this ensures that we detect if the FSDP reduce becomes a no-op.
|
|
# (in fsdp case, we use one of these inputs on each DP rank)
|
|
for sim_dp_rank in range(dp_mesh.size()):
|
|
loss_fn(ref_model(inputs[sim_dp_rank]), targets[sim_dp_rank]).backward()
|
|
ref_model.to(torch.float32)
|
|
for p in ref_model.parameters():
|
|
p.grad = p.grad.to(torch.float32)
|
|
p.grad /= dp_mesh.size()
|
|
|
|
# Validate that whichever weights we have locally match that part of our local/full ref model
|
|
# (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler)
|
|
ref_parameters = dict(ref_model.named_parameters())
|
|
for partial_model, offset in zip(partial_models, offsets):
|
|
for name, p in partial_model.named_parameters():
|
|
parts = name.split(".")
|
|
parts[0] = str(int(parts[0]) + offset)
|
|
name = ".".join(parts)
|
|
ref_p = ref_parameters[name]
|
|
self.assertTrue(isinstance(p.grad, DTensor))
|
|
torch.testing.assert_close(
|
|
p.grad.full_tensor(), ref_p.grad, atol=5e-5, rtol=2e-2
|
|
)
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(4)
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
|
|
@parametrize("dp_type", ["FSDP", "FSDP_MP"])
|
|
def test_pp_fsdp_unshard_reshard_runtime(self, dp_type):
|
|
"""Test FSDP UNSHARD/RESHARD functionality using _PipelineScheduleRuntime with custom schedules."""
|
|
if TEST_WITH_ROCM:
|
|
return
|
|
|
|
torch.get_device_module(device_type).set_device(self.device)
|
|
mesh_shape = (self.world_size, 1)
|
|
mesh_dim_names = ("dp", "pp")
|
|
device_mesh = init_device_mesh(
|
|
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
pp_group = device_mesh["pp"].get_group()
|
|
dp_mesh = device_mesh["dp"]
|
|
|
|
# fsdp_mixed-precision dtype
|
|
mp_dtype = torch.bfloat16 if dp_type == "FSDP_MP" else torch.float32
|
|
total_layers = 4
|
|
dim = 10
|
|
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
|
|
|
|
def apply_dp(partial_model):
|
|
mp_policy = MixedPrecisionPolicy(
|
|
param_dtype=mp_dtype,
|
|
reduce_dtype=torch.float32,
|
|
)
|
|
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
|
for layer in partial_model.children():
|
|
fully_shard(
|
|
layer,
|
|
**fsdp_config,
|
|
reshard_after_forward=False,
|
|
)
|
|
return fully_shard(partial_model, **fsdp_config)
|
|
|
|
# Build pipeline stages
|
|
num_stages = pp_group.size()
|
|
layers_per_stage = total_layers // num_stages
|
|
stage_idx = pp_group.rank()
|
|
offset = stage_idx * layers_per_stage
|
|
|
|
partial_model = nn.Sequential(
|
|
*full_model[offset : (stage_idx + 1) * layers_per_stage]
|
|
)
|
|
partial_model.to(self.device)
|
|
fsdp_model = apply_dp(partial_model)
|
|
distributed_state = fully_shard.state(fsdp_model)
|
|
distributed_state._lazy_init()
|
|
|
|
stage = PipelineStage(
|
|
fsdp_model,
|
|
stage_idx,
|
|
num_stages,
|
|
self.device,
|
|
group=pp_group,
|
|
)
|
|
|
|
# Helper function to check FSDP sharding state
|
|
def check_fsdp_unsharded_state(module, expected_unsharded=False):
|
|
"""Check if FSDP parameters are in expected sharding state."""
|
|
distributed_state = fully_shard.state(module)
|
|
unsharded_count = 0
|
|
total_fsdp_params = 0
|
|
|
|
for state in distributed_state._state_ctx.all_states:
|
|
if state._fsdp_param_group:
|
|
group = state._fsdp_param_group
|
|
for fsdp_param in group.fsdp_params:
|
|
total_fsdp_params += 1
|
|
if fsdp_param.sharded_state == ShardedState.UNSHARDED:
|
|
unsharded_count += 1
|
|
|
|
if expected_unsharded:
|
|
self.assertEqual(
|
|
unsharded_count,
|
|
total_fsdp_params,
|
|
f"Expected all {total_fsdp_params} FSDP parameters to be unsharded, "
|
|
f"but only {unsharded_count} are unsharded",
|
|
)
|
|
else:
|
|
self.assertEqual(
|
|
unsharded_count,
|
|
0,
|
|
f"Expected all FSDP parameters to be sharded, "
|
|
f"but {unsharded_count} out of {total_fsdp_params} are unsharded",
|
|
)
|
|
|
|
return total_fsdp_params > 0 # Return whether we found any FSDP parameters
|
|
|
|
# Test initial state - should be sharded
|
|
has_fsdp = check_fsdp_unsharded_state(stage.submod, expected_unsharded=False)
|
|
|
|
if not has_fsdp:
|
|
self.skipTest("No FSDP parameters found in the model")
|
|
|
|
def create_schedule(computation_types, microbatch_index=None):
|
|
schedule = {
|
|
0: [
|
|
_Action(
|
|
stage_index=0, # stage 0 (the only stage)
|
|
computation_type=comp_type,
|
|
microbatch_index=microbatch_index
|
|
if comp_type == _ComputationType.FORWARD
|
|
else None,
|
|
)
|
|
for comp_type in computation_types
|
|
]
|
|
}
|
|
return schedule
|
|
|
|
unshard_schedule = create_schedule(
|
|
[
|
|
_ComputationType.UNSHARD,
|
|
_ComputationType.FORWARD,
|
|
],
|
|
microbatch_index=0,
|
|
)
|
|
unshard_reshard_schedule = create_schedule(
|
|
[
|
|
_ComputationType.UNSHARD,
|
|
_ComputationType.FORWARD,
|
|
_ComputationType.RESHARD,
|
|
],
|
|
microbatch_index=0,
|
|
)
|
|
|
|
# Test 1: Run UNSHARD + RESHARD schedule
|
|
runtime = _PipelineScheduleRuntime(
|
|
[stage], n_microbatches=1, loss_fn=None, scale_grads=False
|
|
)
|
|
runtime.pipeline_order_with_comms = unshard_reshard_schedule
|
|
dummy_input = torch.randn(1, dim, device=self.device, dtype=mp_dtype)
|
|
runtime.step(dummy_input)
|
|
|
|
# Verify parameters are now sharded again
|
|
check_fsdp_unsharded_state(stage.submod, expected_unsharded=False)
|
|
|
|
# Test 2: Run UNSHARD only schedule
|
|
runtime.pipeline_order_with_comms = unshard_schedule
|
|
runtime.step(dummy_input)
|
|
|
|
# Verify parameters are still sharded
|
|
check_fsdp_unsharded_state(stage.submod, expected_unsharded=False)
|
|
|
|
|
|
instantiate_parametrized_tests(ComposabilityTest)
|
|
if __name__ == "__main__":
|
|
run_tests()
|