Files
pytorch/test/distributed/test_composability.py

383 lines
13 KiB
Python

# Owner(s): ["oncall: distributed"]
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import (
PipelineScheduleSingle,
Schedule1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS,
)
from torch.distributed.tensor import DTensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_WITH_ROCM,
)
device_type = "cuda"
# MLP Layer
class MLPModule(torch.nn.Module):
def __init__(self, d_hid: int):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)
self.init_weights()
def init_weights(self):
# ensure a proper init otherwise gradient tests will be more likely to get zero grad values
torch.nn.init.kaiming_uniform_(
self.net1.weight, mode="fan_in", nonlinearity="relu"
)
torch.nn.init.kaiming_uniform_(
self.net2.weight, mode="fan_in", nonlinearity="relu"
)
def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x
class MLPModuleEven(torch.nn.Module):
def __init__(self, d_hid: int):
super().__init__()
self.net1 = nn.Linear(d_hid, d_hid)
self.net2 = nn.Linear(d_hid, d_hid)
self.net3 = nn.Linear(d_hid, d_hid * 2)
self.init_weights()
def init_weights(self):
torch.nn.init.kaiming_uniform_(
self.net1.weight, mode="fan_in", nonlinearity="relu"
)
torch.nn.init.kaiming_uniform_(
self.net2.weight, mode="fan_in", nonlinearity="relu"
)
torch.nn.init.kaiming_uniform_(
self.net3.weight, mode="fan_in", nonlinearity="relu"
)
def forward(self, x):
x = F.relu(self.net1(x))
x = F.relu(self.net2(x))
x = F.relu(self.net3(x))
return x
def loss_fn(y, target, scale=1e-4):
# Scale the loss to simulate a small learning rate and avoid exploding grads
return torch.nn.functional.cross_entropy(y, target) * scale
class ComposabilityTest(MultiProcContinuousTest):
@classmethod
def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
def _rand_microbatches(self, dp_mesh, num_microbatches, dim, dtype=torch.float32):
full = [
torch.rand((num_microbatches, dim), device=self.device, dtype=dtype)
for _ in range(dp_mesh.size())
]
local = full[dp_mesh.get_local_rank()]
local_mb = [[local[i].reshape((1, dim))] for i in range(num_microbatches)]
return full, local, local_mb
# build a pipeline stage
def _build_pp_stage(
self, pp_group, full_model, total_layers, apply_dp, stage_idx, num_stages
):
# divide the model (e.g. 8 layers) by the number of stages
layers_per_stage = total_layers // num_stages
assert layers_per_stage * num_stages == total_layers
# return offset so validation code can match partial layer back to orig model
offset = stage_idx * layers_per_stage
partial_model = nn.Sequential(
*full_model[offset : (stage_idx + 1) * layers_per_stage]
)
partial_model.to(self.device)
dp_model = apply_dp(partial_model)
stage = PipelineStage(
dp_model,
stage_idx,
num_stages,
self.device,
group=pp_group,
)
return stage, offset
def _build_pp_schedule(
self,
ScheduleClass,
num_microbatches,
pp_group,
full_model,
total_layers,
apply_dp,
loss_fn,
):
if issubclass(ScheduleClass, PipelineScheduleSingle):
pipeline_stage, offset = self._build_pp_stage(
pp_group,
full_model,
total_layers,
apply_dp,
pp_group.rank(),
pp_group.size(),
)
partial_models = [pipeline_stage.submod]
offsets = [offset]
pipeline_schedule = ScheduleClass(
pipeline_stage,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
)
else:
n_virtual = 2
num_stages = pp_group.size() * n_virtual
stages = []
offsets = []
for i in range(n_virtual):
stage, offset = self._build_pp_stage(
pp_group,
full_model,
total_layers,
apply_dp,
pp_group.rank() + n_virtual * i,
num_stages,
)
stages.append(stage)
offsets.append(offset)
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
pipeline_schedule = ScheduleClass(
stages,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
)
return pipeline_schedule, partial_models, offsets
@requires_nccl()
@skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
@parametrize(
"ScheduleClass",
[
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
],
)
def test_pp_ddp(self, ScheduleClass):
if ScheduleClass == ScheduleInterleavedZeroBubble:
# TODO: DDP + InterleavedZeroBubble is not currently supported due to issue with DDP reducer not triggering
# https://github.com/pytorch/pytorch/issues/144530
return
torch.get_device_module(device_type).set_device(self.device)
mesh_shape = (self.world_size // 2, 2)
mesh_dim_names = ("dp", "pp")
device_mesh = init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
pp_group = device_mesh["pp"].get_group()
dp_mesh = device_mesh["dp"]
# create "entire model"
total_layers = 8
num_microbatches = 8
dim = 10
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
ref_model = nn.Sequential(*copy.deepcopy(full_model))
ref_model.to(self.device)
# Prepare inputs
inputs, input_local, _ = self._rand_microbatches(dp_mesh, num_microbatches, dim)
targets, target_local, _ = self._rand_microbatches(
dp_mesh, num_microbatches, dim
)
def apply_dp(partial_model):
return DDP(partial_model, process_group=dp_mesh.get_group())
# Build pipeline stages, apply data parallelism and attach to a schedule
pipeline_schedule, partial_models, offsets = self._build_pp_schedule(
ScheduleClass,
num_microbatches,
pp_group,
full_model,
total_layers,
apply_dp,
loss_fn,
)
# Run the pipeline
if pp_group.rank() == 0:
pipeline_schedule.step(input_local)
else:
pipeline_schedule.step(target=target_local)
# Ref model runs on 2 different inputs, accumulating grads across them.
# this ensures that we detect if the DDP all-reduce becomes a no-op.
for sim_dp_rank in range(dp_mesh.size()):
loss_fn(ref_model(inputs[sim_dp_rank]), targets[sim_dp_rank]).backward()
ref_model.to(torch.float32)
for p in ref_model.parameters():
p.grad = p.grad.to(torch.float32)
p.grad /= dp_mesh.size()
# Validate that whichever weights we have locally match that part of our local/full ref model
ref_parameters = dict(ref_model.named_parameters())
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")[
1:
] # remove the DDP module. prefix (FSDP2 doesn't have one)
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
torch.testing.assert_close(p.grad, ref_p.grad)
@requires_nccl()
@skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
@parametrize("dp_type", ["FSDP", "FSDP_MP"])
@parametrize(
"ScheduleClass",
[
Schedule1F1B,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ScheduleInterleavedZeroBubble,
],
)
def test_pp_fsdp(self, dp_type, ScheduleClass):
if TEST_WITH_ROCM:
return
torch.get_device_module(device_type).set_device(self.device)
mesh_shape = (self.world_size // 2, 2)
mesh_dim_names = ("dp", "pp")
device_mesh = init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
pp_group = device_mesh["pp"].get_group()
dp_mesh = device_mesh["dp"]
# fsdp_mixed-precision dtype
mp_dtype = torch.bfloat16 if dp_type == "FSDP_MP" else torch.float32
# create "entire model"
total_layers = 8
num_microbatches = 8
dim = 10
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
ref_model = nn.Sequential(*copy.deepcopy(full_model))
ref_model.to(self.device)
if dp_type == "FSDP_MP":
ref_model.to(dtype=mp_dtype)
# Prepare inputs
inputs, input_local, _ = self._rand_microbatches(
dp_mesh, num_microbatches, dim, dtype=mp_dtype
)
targets, target_local, _ = self._rand_microbatches(
dp_mesh, num_microbatches, dim, dtype=mp_dtype
)
# Apply FSDP to stage module
def apply_dp(partial_model):
mp_policy = MixedPrecisionPolicy(
param_dtype=mp_dtype,
reduce_dtype=torch.float32,
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer in partial_model.children():
fully_shard(
layer,
**fsdp_config,
reshard_after_forward=False,
)
return fully_shard(partial_model, **fsdp_config)
# Build pipeline stages, apply data parallelism and attach to a schedule
pipeline_schedule, partial_models, offsets = self._build_pp_schedule(
ScheduleClass,
num_microbatches,
pp_group,
full_model,
total_layers,
apply_dp,
loss_fn,
)
# Run the pipeline
if pp_group.rank() == 0:
pipeline_schedule.step(input_local)
else:
pipeline_schedule.step(target=target_local)
for m in partial_models:
for p in m.parameters():
assert p.grad is not None
# introduce a race condition for FSDP's reduce-scatter which could corrupt gradients if pipelining
# does not properly synchronize with FSDP
p.grad.div_(2.0)
p.grad.mul_(2.0)
# Ref model runs on 2 different inputs, accumulating grads across them.
# this ensures that we detect if the FSDP reduce becomes a no-op.
# (in fsdp case, we use one of these inputs on each DP rank)
for sim_dp_rank in range(dp_mesh.size()):
loss_fn(ref_model(inputs[sim_dp_rank]), targets[sim_dp_rank]).backward()
ref_model.to(torch.float32)
for p in ref_model.parameters():
p.grad = p.grad.to(torch.float32)
p.grad /= dp_mesh.size()
# Validate that whichever weights we have locally match that part of our local/full ref model
# (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler)
ref_parameters = dict(ref_model.named_parameters())
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
self.assertTrue(isinstance(p.grad, DTensor))
torch.testing.assert_close(
p.grad.full_tensor(), ref_p.grad, atol=5e-5, rtol=2e-2
)
instantiate_parametrized_tests(ComposabilityTest)
if __name__ == "__main__":
run_tests()