mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160892 Approved by: https://github.com/fduwjj
383 lines
13 KiB
Python
383 lines
13 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
import copy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
|
from torch.distributed.pipelining import PipelineStage
|
|
from torch.distributed.pipelining.schedules import (
|
|
PipelineScheduleSingle,
|
|
Schedule1F1B,
|
|
ScheduleGPipe,
|
|
ScheduleInterleaved1F1B,
|
|
ScheduleInterleavedZeroBubble,
|
|
ScheduleLoopedBFS,
|
|
)
|
|
from torch.distributed.tensor import DTensor
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcContinuousTest,
|
|
requires_nccl,
|
|
skip_if_lt_x_gpu,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
skip_but_pass_in_sandcastle_if,
|
|
TEST_WITH_ROCM,
|
|
)
|
|
|
|
|
|
device_type = "cuda"
|
|
|
|
|
|
# MLP Layer
|
|
class MLPModule(torch.nn.Module):
|
|
def __init__(self, d_hid: int):
|
|
super().__init__()
|
|
self.net1 = torch.nn.Linear(d_hid, d_hid)
|
|
self.relu = torch.nn.ReLU()
|
|
self.net2 = torch.nn.Linear(d_hid, d_hid)
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
# ensure a proper init otherwise gradient tests will be more likely to get zero grad values
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net1.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net2.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.net1(x)
|
|
x = self.relu(x)
|
|
x = self.net2(x)
|
|
return x
|
|
|
|
|
|
class MLPModuleEven(torch.nn.Module):
|
|
def __init__(self, d_hid: int):
|
|
super().__init__()
|
|
self.net1 = nn.Linear(d_hid, d_hid)
|
|
self.net2 = nn.Linear(d_hid, d_hid)
|
|
self.net3 = nn.Linear(d_hid, d_hid * 2)
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net1.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net2.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
torch.nn.init.kaiming_uniform_(
|
|
self.net3.weight, mode="fan_in", nonlinearity="relu"
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.net1(x))
|
|
x = F.relu(self.net2(x))
|
|
x = F.relu(self.net3(x))
|
|
return x
|
|
|
|
|
|
def loss_fn(y, target, scale=1e-4):
|
|
# Scale the loss to simulate a small learning rate and avoid exploding grads
|
|
return torch.nn.functional.cross_entropy(y, target) * scale
|
|
|
|
|
|
class ComposabilityTest(MultiProcContinuousTest):
|
|
@classmethod
|
|
def backend_str(cls) -> str:
|
|
# Testing with NCCL backend
|
|
return "nccl"
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
def _rand_microbatches(self, dp_mesh, num_microbatches, dim, dtype=torch.float32):
|
|
full = [
|
|
torch.rand((num_microbatches, dim), device=self.device, dtype=dtype)
|
|
for _ in range(dp_mesh.size())
|
|
]
|
|
local = full[dp_mesh.get_local_rank()]
|
|
local_mb = [[local[i].reshape((1, dim))] for i in range(num_microbatches)]
|
|
return full, local, local_mb
|
|
|
|
# build a pipeline stage
|
|
def _build_pp_stage(
|
|
self, pp_group, full_model, total_layers, apply_dp, stage_idx, num_stages
|
|
):
|
|
# divide the model (e.g. 8 layers) by the number of stages
|
|
layers_per_stage = total_layers // num_stages
|
|
assert layers_per_stage * num_stages == total_layers
|
|
# return offset so validation code can match partial layer back to orig model
|
|
offset = stage_idx * layers_per_stage
|
|
partial_model = nn.Sequential(
|
|
*full_model[offset : (stage_idx + 1) * layers_per_stage]
|
|
)
|
|
partial_model.to(self.device)
|
|
dp_model = apply_dp(partial_model)
|
|
stage = PipelineStage(
|
|
dp_model,
|
|
stage_idx,
|
|
num_stages,
|
|
self.device,
|
|
group=pp_group,
|
|
)
|
|
return stage, offset
|
|
|
|
def _build_pp_schedule(
|
|
self,
|
|
ScheduleClass,
|
|
num_microbatches,
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
loss_fn,
|
|
):
|
|
if issubclass(ScheduleClass, PipelineScheduleSingle):
|
|
pipeline_stage, offset = self._build_pp_stage(
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
pp_group.rank(),
|
|
pp_group.size(),
|
|
)
|
|
|
|
partial_models = [pipeline_stage.submod]
|
|
offsets = [offset]
|
|
pipeline_schedule = ScheduleClass(
|
|
pipeline_stage,
|
|
n_microbatches=num_microbatches,
|
|
loss_fn=loss_fn,
|
|
)
|
|
else:
|
|
n_virtual = 2
|
|
num_stages = pp_group.size() * n_virtual
|
|
stages = []
|
|
offsets = []
|
|
for i in range(n_virtual):
|
|
stage, offset = self._build_pp_stage(
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
pp_group.rank() + n_virtual * i,
|
|
num_stages,
|
|
)
|
|
stages.append(stage)
|
|
offsets.append(offset)
|
|
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
|
|
pipeline_schedule = ScheduleClass(
|
|
stages,
|
|
n_microbatches=num_microbatches,
|
|
loss_fn=loss_fn,
|
|
)
|
|
return pipeline_schedule, partial_models, offsets
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(4)
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
|
|
@parametrize(
|
|
"ScheduleClass",
|
|
[
|
|
ScheduleGPipe,
|
|
ScheduleInterleaved1F1B,
|
|
ScheduleInterleavedZeroBubble,
|
|
],
|
|
)
|
|
def test_pp_ddp(self, ScheduleClass):
|
|
if ScheduleClass == ScheduleInterleavedZeroBubble:
|
|
# TODO: DDP + InterleavedZeroBubble is not currently supported due to issue with DDP reducer not triggering
|
|
# https://github.com/pytorch/pytorch/issues/144530
|
|
return
|
|
|
|
torch.get_device_module(device_type).set_device(self.device)
|
|
mesh_shape = (self.world_size // 2, 2)
|
|
mesh_dim_names = ("dp", "pp")
|
|
device_mesh = init_device_mesh(
|
|
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
pp_group = device_mesh["pp"].get_group()
|
|
dp_mesh = device_mesh["dp"]
|
|
|
|
# create "entire model"
|
|
total_layers = 8
|
|
num_microbatches = 8
|
|
dim = 10
|
|
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
|
|
ref_model = nn.Sequential(*copy.deepcopy(full_model))
|
|
ref_model.to(self.device)
|
|
|
|
# Prepare inputs
|
|
inputs, input_local, _ = self._rand_microbatches(dp_mesh, num_microbatches, dim)
|
|
targets, target_local, _ = self._rand_microbatches(
|
|
dp_mesh, num_microbatches, dim
|
|
)
|
|
|
|
def apply_dp(partial_model):
|
|
return DDP(partial_model, process_group=dp_mesh.get_group())
|
|
|
|
# Build pipeline stages, apply data parallelism and attach to a schedule
|
|
pipeline_schedule, partial_models, offsets = self._build_pp_schedule(
|
|
ScheduleClass,
|
|
num_microbatches,
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
loss_fn,
|
|
)
|
|
|
|
# Run the pipeline
|
|
if pp_group.rank() == 0:
|
|
pipeline_schedule.step(input_local)
|
|
else:
|
|
pipeline_schedule.step(target=target_local)
|
|
|
|
# Ref model runs on 2 different inputs, accumulating grads across them.
|
|
# this ensures that we detect if the DDP all-reduce becomes a no-op.
|
|
for sim_dp_rank in range(dp_mesh.size()):
|
|
loss_fn(ref_model(inputs[sim_dp_rank]), targets[sim_dp_rank]).backward()
|
|
ref_model.to(torch.float32)
|
|
for p in ref_model.parameters():
|
|
p.grad = p.grad.to(torch.float32)
|
|
p.grad /= dp_mesh.size()
|
|
|
|
# Validate that whichever weights we have locally match that part of our local/full ref model
|
|
ref_parameters = dict(ref_model.named_parameters())
|
|
for partial_model, offset in zip(partial_models, offsets):
|
|
for name, p in partial_model.named_parameters():
|
|
parts = name.split(".")[
|
|
1:
|
|
] # remove the DDP module. prefix (FSDP2 doesn't have one)
|
|
parts[0] = str(int(parts[0]) + offset)
|
|
name = ".".join(parts)
|
|
ref_p = ref_parameters[name]
|
|
torch.testing.assert_close(p.grad, ref_p.grad)
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(4)
|
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
|
|
@parametrize("dp_type", ["FSDP", "FSDP_MP"])
|
|
@parametrize(
|
|
"ScheduleClass",
|
|
[
|
|
Schedule1F1B,
|
|
ScheduleInterleaved1F1B,
|
|
ScheduleLoopedBFS,
|
|
ScheduleInterleavedZeroBubble,
|
|
],
|
|
)
|
|
def test_pp_fsdp(self, dp_type, ScheduleClass):
|
|
if TEST_WITH_ROCM:
|
|
return
|
|
|
|
torch.get_device_module(device_type).set_device(self.device)
|
|
mesh_shape = (self.world_size // 2, 2)
|
|
mesh_dim_names = ("dp", "pp")
|
|
device_mesh = init_device_mesh(
|
|
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
pp_group = device_mesh["pp"].get_group()
|
|
dp_mesh = device_mesh["dp"]
|
|
|
|
# fsdp_mixed-precision dtype
|
|
mp_dtype = torch.bfloat16 if dp_type == "FSDP_MP" else torch.float32
|
|
|
|
# create "entire model"
|
|
total_layers = 8
|
|
num_microbatches = 8
|
|
dim = 10
|
|
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
|
|
ref_model = nn.Sequential(*copy.deepcopy(full_model))
|
|
ref_model.to(self.device)
|
|
if dp_type == "FSDP_MP":
|
|
ref_model.to(dtype=mp_dtype)
|
|
|
|
# Prepare inputs
|
|
inputs, input_local, _ = self._rand_microbatches(
|
|
dp_mesh, num_microbatches, dim, dtype=mp_dtype
|
|
)
|
|
targets, target_local, _ = self._rand_microbatches(
|
|
dp_mesh, num_microbatches, dim, dtype=mp_dtype
|
|
)
|
|
|
|
# Apply FSDP to stage module
|
|
def apply_dp(partial_model):
|
|
mp_policy = MixedPrecisionPolicy(
|
|
param_dtype=mp_dtype,
|
|
reduce_dtype=torch.float32,
|
|
)
|
|
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
|
for layer in partial_model.children():
|
|
fully_shard(
|
|
layer,
|
|
**fsdp_config,
|
|
reshard_after_forward=False,
|
|
)
|
|
return fully_shard(partial_model, **fsdp_config)
|
|
|
|
# Build pipeline stages, apply data parallelism and attach to a schedule
|
|
pipeline_schedule, partial_models, offsets = self._build_pp_schedule(
|
|
ScheduleClass,
|
|
num_microbatches,
|
|
pp_group,
|
|
full_model,
|
|
total_layers,
|
|
apply_dp,
|
|
loss_fn,
|
|
)
|
|
|
|
# Run the pipeline
|
|
if pp_group.rank() == 0:
|
|
pipeline_schedule.step(input_local)
|
|
else:
|
|
pipeline_schedule.step(target=target_local)
|
|
for m in partial_models:
|
|
for p in m.parameters():
|
|
assert p.grad is not None
|
|
# introduce a race condition for FSDP's reduce-scatter which could corrupt gradients if pipelining
|
|
# does not properly synchronize with FSDP
|
|
p.grad.div_(2.0)
|
|
p.grad.mul_(2.0)
|
|
|
|
# Ref model runs on 2 different inputs, accumulating grads across them.
|
|
# this ensures that we detect if the FSDP reduce becomes a no-op.
|
|
# (in fsdp case, we use one of these inputs on each DP rank)
|
|
for sim_dp_rank in range(dp_mesh.size()):
|
|
loss_fn(ref_model(inputs[sim_dp_rank]), targets[sim_dp_rank]).backward()
|
|
ref_model.to(torch.float32)
|
|
for p in ref_model.parameters():
|
|
p.grad = p.grad.to(torch.float32)
|
|
p.grad /= dp_mesh.size()
|
|
|
|
# Validate that whichever weights we have locally match that part of our local/full ref model
|
|
# (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler)
|
|
ref_parameters = dict(ref_model.named_parameters())
|
|
for partial_model, offset in zip(partial_models, offsets):
|
|
for name, p in partial_model.named_parameters():
|
|
parts = name.split(".")
|
|
parts[0] = str(int(parts[0]) + offset)
|
|
name = ".".join(parts)
|
|
ref_p = ref_parameters[name]
|
|
self.assertTrue(isinstance(p.grad, DTensor))
|
|
torch.testing.assert_close(
|
|
p.grad.full_tensor(), ref_p.grad, atol=5e-5, rtol=2e-2
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(ComposabilityTest)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|