Files
pytorch/test/distributed/tensor/parallel/test_tp_examples.py

557 lines
20 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import itertools
from copy import deepcopy
from typing import NamedTuple, Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,
loss_parallel,
parallelize_module,
RowwiseParallel,
)
from torch.distributed.tensor.parallel.input_reshard import input_reshard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
MLPModule,
ModelArgs,
NUM_DEVICES,
skip_unless_torch_gpu,
Transformer,
with_comms,
)
c10d_functional = torch.ops.c10d_functional
reduce_scatter, all_gather, all_reduce = (
c10d_functional.reduce_scatter_tensor,
c10d_functional.all_gather_into_tensor,
c10d_functional.all_reduce,
)
class ExpCommCounts(NamedTuple):
fwd: Optional[dict] = None
bwd: Optional[dict] = None
optim: Optional[dict] = None
class DistTensorParallelExampleTest(DTensorTestBase):
def _check_module(self, m1, m2, check_grad=False):
named_parameters = dict(m1.named_parameters())
for name, param_m2 in m2.named_parameters():
self.assertTrue(name in named_parameters)
param_m1 = named_parameters[name]
if check_grad:
param_m2 = param_m2.grad
param_m1 = param_m1.grad
if isinstance(param_m2, DTensor):
replicate = [Replicate()]
param_m2 = param_m2.redistribute(
device_mesh=param_m2.device_mesh, placements=replicate
).to_local()
self.assertEqual(param_m2, param_m1)
def _test_mlp_training_e2e(self, is_seq_parallel=False, recompute_activation=False):
inp_size = [8, 10]
# Ensure all tp ranks have same input.
rng_seed = self.rank if is_seq_parallel else 0
torch.manual_seed(rng_seed)
inp = torch.rand(*inp_size, device=self.device_type)
model = MLPModule(self.device_type)
model_tp = deepcopy(model)
# Ensure model are initialized the same way.
self._check_module(model, model_tp)
# Shard module and initialize optimizer.
LR = 0.25
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
parallelize_plan = {
"net1": (
ColwiseParallel(input_layouts=Shard(0))
if is_seq_parallel
else ColwiseParallel()
),
"net2": (
RowwiseParallel(output_layouts=Shard(0))
if is_seq_parallel
else RowwiseParallel()
),
}
model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan)
if recompute_activation:
model_tp = input_reshard(
checkpoint_wrapper(
model_tp, checkpoint_impl=CheckpointImpl.NO_REENTRANT
),
device_mesh,
None if is_seq_parallel else 0,
)
optim = torch.optim.SGD(model.parameters(), lr=LR)
optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR)
output = model(inp)
output.sum().backward()
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model_tp(inp)
output_tp.sum().backward()
self.assertEqual(output, output_tp)
if is_seq_parallel:
self.assertEqual(
comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 2
)
self.assertEqual(
comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1
)
else:
self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 1)
if is_seq_parallel:
# Sum gradients from different ranks, since input
# are different across ranks for sequence parallel.
dist.all_reduce(model.net1.weight.grad)
dist.all_reduce(model.net1.bias.grad)
dist.all_reduce(model.net2.weight.grad)
dist.all_reduce(model.net2.bias.grad)
# Ensure gradients are same.
self._check_module(model, model_tp, check_grad=True)
optim.step()
optim_tp.step()
# Ensure model weights are still same after update.
# Due to the trick we use for Partial aggregation, we only check the weight when local_rank = 0.
self._check_module(model, model_tp)
inp = torch.rand(*inp_size, device=self.device_type)
output = model(inp)
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)
def _test_mlp_inference(self, device_mesh):
inp_size = [8, 10]
# Ensure all tp ranks have same input.
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
model = MLPModule(self.device_type)
model_tp = deepcopy(model)
# Ensure model are initialized the same way.
self._check_module(model, model_tp)
# Shard module and initialize optimizer.
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan)
output = model(inp)
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)
@with_comms
@parametrize("is_seq_parallel", [True, False])
# TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
# @parametrize("recompute_activation", [True, False])
@parametrize("recompute_activation", [False])
def test_mlp_training(self, is_seq_parallel, recompute_activation):
self._test_mlp_training_e2e(
is_seq_parallel=is_seq_parallel, recompute_activation=recompute_activation
)
@with_comms
def test_mlp_inference(self):
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
with torch.inference_mode():
self._test_mlp_inference(device_mesh)
def _setup_single_gpu_model(self, model_args, dtype):
return Transformer(model_args).to(device=self.device_type, dtype=dtype)
def _setup_tp_model(self, model, is_seq_parallel, dtype):
model_tp = deepcopy(model)
self._check_module(model, model_tp)
device_mesh = DeviceMesh(self.device_type, torch.arange(0, NUM_DEVICES))
local_output_for_attn = dtype is torch.float64
return Transformer.parallelize(
model_tp,
device_mesh,
is_seq_parallel,
local_output_for_attn=local_output_for_attn,
)
def _setup_optimizer(self, model, model_tp):
# Step 3: Run test by comparing outputs from single-gpu and multi-gpu models.
LR = 0.25
optim = torch.optim.Adam(model.parameters(), lr=LR)
optim_tp = torch.optim.Adam(model_tp.parameters(), lr=LR)
return optim, optim_tp
def _validate_fwd(
self, model, model_tp, inp, expected_comms_dict=None, check_comms=True
):
# Compare outputs on the same input.
output = model(inp)
with CommDebugMode() as comm_mode:
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)
if check_comms:
self.assertDictEqual(comm_mode.get_comm_counts(), expected_comms_dict or {})
return output, output_tp
def _validate_bwd(
self,
model,
model_tp,
output,
output_tp,
expected_comms_dict=None,
check_comms=True,
):
# Ensure gradients are equal.
output.sum().backward()
with CommDebugMode() as comm_mode:
output_tp.sum().backward()
self._check_module(model, model_tp, check_grad=True)
if check_comms:
self.assertDictEqual(comm_mode.get_comm_counts(), expected_comms_dict or {})
def _validate_optim_step(
self,
model,
model_tp,
optim,
optim_tp,
expected_comms_dict=None,
check_comms=True,
):
optim.step() # Ensure model weights are still the same after update.
from torch.distributed.tensor.experimental import implicit_replication
with implicit_replication():
with CommDebugMode() as comm_mode:
optim_tp.step()
self._check_module(model, model_tp)
if check_comms:
self.assertDictEqual(comm_mode.get_comm_counts(), expected_comms_dict or {})
@staticmethod
def _thaw_params(thaw_params, model, model_tp):
if not thaw_params:
return
for target_model in [model, model_tp]:
for n, p in target_model.named_parameters():
if n not in thaw_params:
p.requires_grad_(False)
@with_comms
@skip_unless_torch_gpu
@parametrize("is_seq_parallel", [True, False])
@parametrize("dtype", [torch.float64, torch.float32])
def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype):
EXP_BASE_CC = ExpCommCounts(
fwd={all_reduce: 6, all_gather: 1}, bwd={all_reduce: 9}
)
EXP_SEQ_PARALLEL_CC = ExpCommCounts(
fwd={reduce_scatter: 6, all_gather: 6},
bwd={reduce_scatter: 5, all_gather: 6},
optim={all_reduce: 30},
)
# Disable dropout in the test since we cannot reproduce the same random
# behaviors when comparing single-gpu models with multi-gpu models.
model_args = ModelArgs(dropout_p=0.0)
model = self._setup_single_gpu_model(
model_args, dtype
) # Step 1: Initialize single-gpu models.
model_tp = self._setup_tp_model(
model, is_seq_parallel, dtype
) # Step 2: Setup tp model, place onto device mesh.
optim, optim_tp = self._setup_optimizer(
model, model_tp
) # Step 3: Setup optimizers for both models
# Initialize input and make sure all ranks have the same input.
inp_size = [8, 8] # [batch_size, seq_len]
if is_seq_parallel:
assert inp_size[1] % self.world_size == 0
torch.manual_seed(0)
steps = 10 if type(model) is torch.float64 else 1
for _ in range(steps):
inp = torch.randint(
model_args.vocab_size, inp_size, device=self.device_type
)
expected_fwd_comms = (
EXP_SEQ_PARALLEL_CC.fwd if is_seq_parallel else EXP_BASE_CC.fwd
)
output, output_tp = self._validate_fwd(
model, model_tp, inp, expected_fwd_comms
)
expected_bwd_comms = (
EXP_SEQ_PARALLEL_CC.bwd if is_seq_parallel else EXP_BASE_CC.bwd
)
self._validate_bwd(model, model_tp, output, output_tp, expected_bwd_comms)
expected_optim_comms = (
EXP_SEQ_PARALLEL_CC.optim if is_seq_parallel else EXP_BASE_CC.optim
)
self._validate_optim_step(
model, model_tp, optim, optim_tp, expected_optim_comms
)
@with_comms
@skip_unless_torch_gpu
@parametrize(
"thaw_params, is_seq_parallel, dtype, exp_cnts",
[
(
None, # all require grad seq_parallel float32 baseline
True,
torch.float32,
ExpCommCounts(
bwd={reduce_scatter: 5, all_gather: 6}, optim={all_reduce: 30}
),
),
(
None, # all require grad no seq_parallel float64 baseline
False,
torch.float64,
ExpCommCounts(bwd={all_reduce: 9}),
),
# test a subset of LayerNorm bwd output_masks
(
("output.weight", "norm.weight", "norm.bias"), # [False, True, True]
True,
torch.float32,
ExpCommCounts(bwd={reduce_scatter: 1}, optim={all_reduce: 6}),
),
(
("tok_embeddings.weight", "output.weight"), # [True, False, False]
True,
torch.float32,
ExpCommCounts(bwd={reduce_scatter: 5, all_gather: 5}),
),
(
(
"tok_embeddings.weight",
"output.weight",
"norm.weight",
"norm.bias",
), # [True, True, True]
True,
torch.float32,
ExpCommCounts(
bwd={reduce_scatter: 5, all_gather: 5}, optim={all_reduce: 6}
),
),
(
(
"tok_embeddings.weight",
"output.weight",
"norm.weight",
"norm.bias",
"layers.1.ffn_norm.weight",
"layers.1.ffn_norm.bias",
), # a single transformerblock layernorm
True,
torch.float32,
ExpCommCounts(
bwd={reduce_scatter: 5, all_gather: 5}, optim={all_reduce: 12}
),
),
(
(
"tok_embeddings.weight",
"layers.0.attention.wv.weight",
"layers.0.feed_forward.w1.bias",
"layers.1.ffn_norm.bias",
"layers.1.feed_forward.w2.weight",
"output.weight",
), # varied layer/param types
True,
torch.float32,
ExpCommCounts(
bwd={reduce_scatter: 5, all_gather: 5}, optim={all_reduce: 3}
),
),
],
name_fn=lambda thaw, seq, dtype, *_: f"{'seq_parallel_' if seq else ''}"
+ f"{str(dtype).split('.')[-1]}_"
+ f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}",
)
def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts):
# Sample a subset of `requires_grad` patterns
# disabling dropout to facilitate single gpu to multi-device comparison
# disable weight-tying to enable more fine-tuning configurations
model_args = ModelArgs(dropout_p=0.0, weight_tying=False)
model = self._setup_single_gpu_model(
model_args, dtype
) # Step 1: Initialize single-gpu models.
model_tp = self._setup_tp_model(
model, is_seq_parallel, dtype
) # Step 2: Setup tp model, place onto device mesh.
optim, optim_tp = self._setup_optimizer(
model, model_tp
) # Step 3: Setup optimizers for both models
DistTensorParallelExampleTest._thaw_params(
thaw_params, model, model_tp
) # Step 4: set `requires_grad` patterns
# Initialize input and make sure all ranks have the same input.
inp_size = [8, 8] # [batch_size, seq_len]
if is_seq_parallel:
assert inp_size[1] % self.world_size == 0
torch.manual_seed(0)
inp = torch.randint(model_args.vocab_size, inp_size, device=self.device_type)
output, output_tp = self._validate_fwd(model, model_tp, inp, check_comms=False)
self._validate_bwd(
model, model_tp, output, output_tp, exp_cnts.bwd, check_comms=True
)
self._validate_optim_step(
model, model_tp, optim, optim_tp, exp_cnts.optim, check_comms=True
)
@with_comms
def test_weight_tying(self):
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# Initialize different weights for embedding and fc.
torch.manual_seed(1)
self.embedding = torch.nn.Embedding(16, 8)
torch.manual_seed(2)
self.fc = torch.nn.Linear(8, 16)
def forward(self, x):
return self.fc(self.embedding(x))
model = TestModule().to(self.device_type)
parallelize_plan = {
"embedding": ColwiseParallel(),
"fc": RowwiseParallel(),
}
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
parallelize_module(model, device_mesh, parallelize_plan)
input_size = [5]
torch.manual_seed(0)
inp = torch.randint(16, input_size, device=self.device_type)
# Without weight tying.
self.assertNotEqual(
model.embedding.weight.to_local(), model.fc.weight.to_local()
)
output = model(inp)
output.sum().backward()
self.assertNotEqual(
model.embedding.weight.grad.to_local(), model.fc.weight.grad.to_local()
)
model.zero_grad()
# With weight tying.
model.fc.weight = model.embedding.weight
self.assertEqual(model.embedding.weight, model.fc.weight)
self.assertEqual(id(model.embedding.weight), id(model.fc.weight))
output = model(inp)
output.sum().backward()
self.assertEqual(model.embedding.weight.grad, model.fc.weight.grad)
self.assertEqual(id(model.embedding.weight.grad), id(model.fc.weight.grad))
@with_comms
def test_loss_parallel(self):
device_mesh = self.build_device_mesh()
comm_mode = CommDebugMode()
channel_size, channel_dim = 16, 1
test_setup = [
(2, (8, channel_size), (8,)), # calling aten.nll_loss_forward
(3, (8, channel_size, 12), (8, 12)), # calling aten.nll_loss2d_forward
]
weight = torch.rand(channel_size, device=self.device_type)
for input_ndim, input_size, target_size in test_setup:
x = torch.rand(*input_size, device=self.device_type, requires_grad=True)
target = torch.randint(channel_size, target_size, device=self.device_type)
shard_dims = list(range(input_ndim))
reductions = ["none", "mean", "sum"]
for shard_dim, reduction in itertools.product(shard_dims, reductions):
dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
y = F.cross_entropy(x, target, weight, reduction=reduction)
with loss_parallel():
if shard_dim == channel_dim:
with comm_mode:
dist_y = F.cross_entropy(
dist_x, target, weight, reduction=reduction
)
self.assertEqual(comm_mode.get_total_counts(), 3)
self.assertEqual(
comm_mode.get_comm_counts()[c10d_functional.all_reduce],
3,
)
self.assertTrue(dist_y.placements[0].is_replicate())
self.assertEqual(dist_y.to_local(), y)
with comm_mode:
if reduction == "none":
y.sum().backward()
dist_y.sum().backward()
else:
y.backward()
dist_y.backward()
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertTrue(
dist_x.grad.placements[0].is_shard(shard_dim)
)
self.assertEqual(dist_x.grad.full_tensor(), x.grad)
x.grad.zero_()
else:
with self.assertRaisesRegex(
ValueError,
"loss_parallel",
):
dist_y = F.cross_entropy(
dist_x, target, reduction=reduction
)
instantiate_parametrized_tests(DistTensorParallelExampleTest)
if __name__ == "__main__":
run_tests()