Support MoE for pipeline models (#5338)

This PR enhances DeepSpeed to support MoE for pipeline models (e.g.
GPTModelPipe from Megatron-DeepSpeed).
Main changes:

- Enhance expert groups creation for pipeline (enhance both flavors:
DP/PP/EP and DP/TP/PP/EP)
- Fix MoE save/load checkpoint for PipelineModule based models.
- Display MoE loss for PipelineModule based models.
- Support gradients reduce for BF16_Optimizer for
PipelineModule.<br>Note that same commit also fixes gradients reduction
error when using Megatron-DeepSpeed GPTModelPipe with BF16_Optimizer
also for a dense (no MOE) model.
- When using no-drop tokens, all-reduce the capacity (op=max) using
expert parallel group instead of world group

---------

Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>
This commit is contained in:
Moshe Island
2024-04-08 18:35:53 +03:00
committed by GitHub
parent 42a8eaa705
commit 08e0733e4a
14 changed files with 326 additions and 141 deletions

View File

@ -71,7 +71,7 @@ class MoE(nn.Module):
experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts,
min_capacity, noisy_gate_policy, drop_tokens, use_rts, None,
top2_2nd_expert_sampling),
experts,
self.expert_group_name,

View File

@ -23,6 +23,8 @@
import torch
import deepspeed
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_world_size, bwc_tensor_model_parallel_rank,
bwc_tensor_model_parallel_group)
def _gather_tokens(input_, dim=0):
@ -31,11 +33,11 @@ def _gather_tokens(input_, dim=0):
input_ = input_.contiguous()
# Size and dimension.
rank = mpu.get_tensor_model_parallel_rank()
rank = bwc_tensor_model_parallel_rank(mpu)
tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())]
tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))]
tensor_list[rank] = input_
deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group())
deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu))
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
@ -47,8 +49,8 @@ def _drop_tokens(input_, dim=0):
"""Divide a tensor among the tensor parallel ranks"""
mpu = deepspeed.utils.groups.mpu
total_chunks = mpu.get_tensor_model_parallel_world_size()
this_chunk = mpu.get_tensor_model_parallel_rank()
total_chunks = bwc_tensor_model_parallel_world_size(mpu)
this_chunk = bwc_tensor_model_parallel_rank(mpu)
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
chunk_size = input_.shape[dim] // total_chunks
@ -92,7 +94,7 @@ class _DropTokens(torch.autograd.Function):
def gather_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
# no tensor parallelism for non-experts
return input_
return _GatherTokens.apply(input_, dim)
@ -100,7 +102,7 @@ def gather_tokens(input_, dim=0):
def drop_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
# no tensor parallelism for non-experts
return input_
return _DropTokens.apply(input_, dim)

View File

@ -17,7 +17,8 @@ The file has been adapted from two fairscale files:
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.utils import logger
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union
import torch
from torch import Tensor
@ -184,6 +185,7 @@ def top1gating(logits: Tensor,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
@ -209,12 +211,13 @@ def top1gating(logits: Tensor,
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
# Communicate across all processes to pick the maximum capacity.
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
# Communicate across expert processes to pick the maximum capacity.
if ep_group is not None:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
# Make sure the capacity value does not exceed the number of tokens.
capacity = min(new_capacity, torch.tensor(mask1.size(0)))
@ -286,6 +289,7 @@ def top2gating(logits: Tensor,
capacity_factor: float,
min_capacity: int,
drop_tokens: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
@ -328,11 +332,12 @@ def top2gating(logits: Tensor,
else:
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if ep_group is not None:
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity
@ -376,7 +381,7 @@ class TopKGate(Module):
Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
num_experts (int):
number of experts in model
"""
@ -392,6 +397,7 @@ class TopKGate(Module):
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
top2_2nd_expert_sampling: bool = True) -> None:
super().__init__()
@ -399,6 +405,7 @@ class TopKGate(Module):
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.ep_group = ep_group
self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
@ -411,6 +418,10 @@ class TopKGate(Module):
self.use_rts = use_rts
self.top2_2nd_expert_sampling = top2_2nd_expert_sampling
def _set_ep_group(self, ep_group):
assert self.ep_group is None, f'Attempting to override an existing ep_group'
self.ep_group = ep_group
def forward(self,
input: torch.Tensor,
used_token: torch.Tensor = None,
@ -428,11 +439,11 @@ class TopKGate(Module):
if self.k == 1:
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, use_tutel)
self.drop_tokens, self.use_rts, self.ep_group, use_tutel)
else:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.top2_2nd_expert_sampling)
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).stop()
@ -492,6 +503,7 @@ class MOELayer(Base):
def _set_ep_group(self, ep_group):
self.ep_group = ep_group
self.gate._set_ep_group(ep_group)
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:

View File

@ -226,7 +226,7 @@ class DeepSpeedMoEInference(nn.Module):
self.moe_gate = TopKGate(self.config.hidden_size, self.config.global_experts, self.config.k,
self.config.capacity_factor, self.config.eval_capacity_factor,
self.config.min_capacity, self.config.noisy_gate_policy, self.config.drop_tokens,
self.config.use_rts)
self.config.use_rts, self.ep_group)
self.ep_group = ep_group
self.mp_group = mp_group

View File

@ -25,8 +25,9 @@ from torch import _C
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.accelerator import get_accelerator
# DeepSpeed Checkpointing Enabled or Disabled

View File

@ -13,11 +13,11 @@ from deepspeed.runtime.base_optimizer import ZeROOptimizer
from packaging import version as pkg_version
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process,
get_norm_with_moe_layers)
align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter,
see_memory_usage, graph_process, get_norm_with_moe_layers)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,

View File

@ -3228,9 +3228,12 @@ class DeepSpeedEngine(Module):
# Load flow uses below saved file for model parameters, RNG and more
if groups._get_data_parallel_rank() == 0:
# get non-moe parameters
# Get non-moe parameters
# Classes DeepSpeedEngine and PipelineEngine have different behavior for method module_state_dict.
# DeepSpeedEngine returns the state dict, where PipelineEngine saves the state dict and returns None.
# We need to get the state dict, therefore, call to DeepSpeedEngine (base class for PipelineEngine)
model_state_dict = self._get_non_moe_state_dict(
self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))
DeepSpeedEngine.module_state_dict(self, exclude_frozen_parameters=exclude_frozen_parameters))
# TODO: update num experts info,.. in checkpoint
state = {

View File

@ -4,6 +4,7 @@
# DeepSpeed Team
from types import MethodType
from collections import OrderedDict
import torch
from deepspeed import comm as dist
@ -194,9 +195,15 @@ class PipelineEngine(DeepSpeedEngine):
#stores the loss for the entire batch
self.total_loss = None
self.total_additional_losses = None
self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
# stores aggregated-DP train final loss and aggregated-DP additional losses, if any
# additional losses are stored as dict: {loss-name: agg-loss}
self.agg_train_loss = None
self.agg_additional_losses = None
if self._config.pipeline['activation_checkpoint_interval'] > 0:
self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval']
# set use_reentrant default to True.
@ -284,10 +291,7 @@ class PipelineEngine(DeepSpeedEngine):
self._force_grad_boundary = False
def _bf16_reduce_grads(self):
# Make our own list of gradients from the optimizer's FP32 grads
grads = []
self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(),
elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
self.buffered_allreduce_fallback(grads=None, elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
def _reserve_pipe_buffers(self, num_buffers):
"""Ensure that each pipeline buffer has at least ``num_buffers`` slots.
@ -363,6 +367,7 @@ class PipelineEngine(DeepSpeedEngine):
self.module.train()
self.total_loss = None
self.total_additional_losses = None
self._compute_loss = True
# Do the work
@ -371,7 +376,9 @@ class PipelineEngine(DeepSpeedEngine):
stages=self.num_stages,
stage_id=self.stage_id)
self._exec_schedule(sched)
self.agg_train_loss = self._aggregate_total_loss()
with torch.no_grad():
self.agg_train_loss = self._aggregate_total_loss()
self.timers(TRAIN_BATCH_TIMER).stop()
@ -380,10 +387,12 @@ class PipelineEngine(DeepSpeedEngine):
elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0
iter_time = elapsed / self.steps_per_print()
tput = self.train_batch_size() / iter_time
print(f'steps: {self.global_steps} '
f'loss: {self.agg_train_loss:0.4f} '
f'iter time (s): {iter_time:0.3f} '
f'samples/sec: {tput:0.3f}')
log_str = f'steps: {self.global_steps} loss: {self.agg_train_loss:0.4f} '
if self.agg_additional_losses is not None:
for loss_name, loss_value in self.agg_additional_losses.items():
log_str += f'{loss_name}: {loss_value.item():0.4f} '
log_str += f'iter time (s): {iter_time:0.3f} samples/sec: {tput:0.3f}'
print(log_str)
else:
self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True)
@ -565,29 +574,66 @@ class PipelineEngine(DeepSpeedEngine):
def _aggregate_total_loss(self):
# Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
if self.is_last_stage():
# Scale loss and additional losses, if any
loss = self._scale_loss_by_gas(self.total_loss)
self.dp_group_loss = loss.clone().detach()
self.agg_additional_losses = self.total_additional_losses
if self.agg_additional_losses is not None:
self.agg_additional_losses = OrderedDict({
loss_name: self._scale_loss_by_gas(_loss.clone().detach())
for loss_name, _loss in self.agg_additional_losses.items()
})
## Average loss across all data-parallel groups
self.dp_group_loss = loss.clone().detach()
agg_loss = self.dp_group_loss.clone().detach()
#print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)
# Average loss across all data-parallel groups
if self.is_data_parallel:
dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
agg_loss /= self.dp_world_size
if self.agg_additional_losses is None:
dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
agg_loss /= self.dp_world_size
else:
# use a single reduce op for agg_loss and additional losses, if any
assert '__train_loss__' not in self.agg_additional_losses.keys()
tensors = OrderedDict({'__train_loss__': agg_loss})
tensors.update(self.agg_additional_losses.items())
flat_tensor = torch.cat([t.clone().reshape(-1).detach() for t in tensors.values()])
dist.all_reduce(flat_tensor, group=self.mpu.get_data_parallel_group())
flat_tensor /= self.dp_world_size
offset = 0
reduced_tensor = {}
for name, t in tensors.items():
n_elem = t.numel()
reduced_tensor[name] = flat_tensor[offset:offset + n_elem].clone().detach().reshape(t.shape)
offset += n_elem
agg_loss = reduced_tensor['__train_loss__']
self.agg_additional_losses = OrderedDict(
{name: reduced_tensor[name]
for name in self.agg_additional_losses.keys()})
assert self.global_rank in self.grid.pp_group
losses = torch.stack([self.dp_group_loss, agg_loss]).float()
losses = [self.dp_group_loss, agg_loss]
if self.agg_additional_losses is not None:
losses += list(self.agg_additional_losses.values())
losses = torch.stack(losses).float()
if self.is_pipe_parallel:
dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
else:
# Get loss from last stage
src_rank = self.grid.stage_to_global(self.num_stages - 1)
assert src_rank in self.grid.pp_group
losses = torch.Tensor([0., 0.]).to(self.device)
# losses to reduce are: dp_group_loss, agg_loss, model additional losses
# therefore: 2 + n_additional_losses
additional_losses = self.module.get_additional_losses()
n_additional_losses = 0 if additional_losses is None else len(additional_losses)
losses = torch.Tensor([0.] * (2 + n_additional_losses)).to(self.device)
dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group())
self.dp_group_loss = losses[0].clone().detach()
agg_loss = losses[1].clone().detach()
if additional_losses is not None:
self.agg_additional_losses = OrderedDict(
{name: losses[2 + i].clone().detach()
for i, name in enumerate(additional_losses.keys())})
return agg_loss
def set_dataloader(self, loader):
@ -715,19 +761,34 @@ class PipelineEngine(DeepSpeedEngine):
self.loss = outputs
if self.eval_return_logits:
self.outputs = outputs
if isinstance(self.loss, torch.Tensor):
self.fwd_outputs.append(self.loss.detach())
if self.total_loss is None:
self.total_loss = torch.zeros_like(self.loss)
self.total_loss += self.loss.detach()
else:
self.fwd_outputs.append([l.detach() for l in self.loss])
if self.total_loss is None:
self.total_loss = [torch.zeros_like(l) for l in self.loss]
for idx, l in enumerate(self.loss):
self.total_loss[idx] += l.detach()
def add_to_total_loss(_total_loss, _loss):
if isinstance(_loss, torch.Tensor):
if _total_loss is None:
_total_loss = torch.zeros_like(_loss)
_total_loss += _loss.detach()
else:
if _total_loss is None:
_total_loss = [torch.zeros_like(_l) for _l in _loss]
for _idx, _l in enumerate(_loss):
_total_loss[_idx] += _l.detach()
return _total_loss
self.total_loss = add_to_total_loss(self.total_loss, self.loss)
# aggregate additional losses across gradient accumulation steps
additional_losses = self.module.get_additional_losses()
if additional_losses is not None:
if self.total_additional_losses is None:
self.total_additional_losses = OrderedDict()
for name, loss in additional_losses.items():
total = self.total_additional_losses[name] if name in self.total_additional_losses else None
self.total_additional_losses[name] = add_to_total_loss(total, loss)
def _exec_backward_pass(self, buffer_id):
assert self.optimizer is not None, "must provide optimizer during " \
@ -1332,7 +1393,7 @@ class PipelineEngine(DeepSpeedEngine):
strict (bool, optional): Strict state loading. Defaults to True.
"""
assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism"
state_dict = checkpoint['module']
state_dict = checkpoint if self.has_moe_layers else checkpoint['module']
if (state_dict is not None) and (not isinstance(state_dict, str)):
super().load_module_state_dict(state_dict, strict)
return
@ -1371,3 +1432,6 @@ class PipelineEngine(DeepSpeedEngine):
# Equivalent to: self._exec_forward_pass(buffer_id=0)
self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
self._exec_instr(**cmd.kwargs)
def get_additional_losses(self):
return self.agg_additional_losses

View File

@ -634,3 +634,10 @@ class PipelineModule(nn.Module):
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)
def get_additional_losses(self):
""" Returns model specific additional losses for reporting
Return a dictionary of {"loss name": loss_value} or None if no additional losses.
"""
return None

View File

@ -25,6 +25,8 @@ except ModuleNotFoundError:
from torch import inf
from deepspeed.utils import groups, logger
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size,
bwc_pipeline_parallel_group)
from deepspeed.runtime.constants import PIPE_REPLICATED
from numpy import prod
from deepspeed.accelerator import get_accelerator
@ -117,44 +119,6 @@ def is_model_parallel_parameter(p) -> bool:
return False
def bwc_tensor_model_parallel_rank(mpu=None):
"""Backwards-compatible way of querying the tensor model parallel rank from
an ``mpu`` object.
*Tensor* model parallelism means that tensors are physically split across
processes. This contrasts with *pipeline* model parallelism, in which the
layers are partitioned but tensors left intact.
The API for tensor model parallelism has changed across versions and this
helper provides a best-effort implementation across versions of ``mpu``
objects. The preferred mechanism is
``mpu.get_tensor_model_parallel_rank()``.
This should "just work" with both Megatron-LM and DeepSpeed's pipeline
parallelism.
Args:
mpu (model parallel unit, optional): The tensor model parallel rank.
If ``mpu=None``, returns 0. Defaults to ``None``.
Returns:
int: the rank
"""
if mpu is None:
# No model parallelism in easy :)
return 0
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
return mpu.get_tensor_model_parallel_rank()
elif hasattr(mpu, 'get_slice_parallel_rank'):
# Some DeepSpeed + pipeline parallelism versions
return mpu.get_slice_parallel_rank()
else:
# Deprecated Megatron and DeepSpeed convention
return mpu.get_model_parallel_rank()
def copy_to_device(item, device, criterion_func):
"""
Return a copy of tensor on specified device.
@ -894,8 +858,16 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
all_norms.append(t.data.abs().max().float())
total_norm = torch.stack(all_norms).max()
device_total_norm = total_norm.to(get_accelerator().current_device_name())
# Max across model parallel
if mpu is not None:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
# For MoE grads, max over model parallel only if MoE-TP is enabled
if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
# If MoE grads and MoE-TP disabled, max over pipeline parallel
elif bwc_pipeline_parallel_world_size(mpu) > 1:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=bwc_pipeline_parallel_group(mpu))
# MoE grads: max across expert parallel group
if moe_ep_group is not None:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=moe_ep_group)
total_norm = device_total_norm.to(input_tensors[0].device)
@ -922,8 +894,16 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
device_total_norm = compute_buffer[0].float().detach()
# Sum across model parallel
if mpu is not None:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
# For MoE grads, sum over model parallel only if MoE-TP is enabled
if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
# If MoE grads and MoE-TP disabled, sum over pipeline parallel
elif bwc_pipeline_parallel_world_size(mpu) > 1:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=bwc_pipeline_parallel_group(mpu))
# MoE grads: sum across expert parallel group
if moe_ep_group is not None:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group)
total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type)

View File

@ -11,13 +11,13 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, empty_cache, see_memory_usage, inf,
is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
align_dense_tensors, all_gather_dp_groups)
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.utils import logger
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.moe.utils import is_moe_param
from deepspeed.git_version_info import version

104
deepspeed/utils/bwc.py Normal file
View File

@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
def bwc_tensor_model_parallel_rank(mpu=None):
"""Backwards-compatible way of querying the tensor model parallel rank from
an ``mpu`` object.
*Tensor* model parallelism means that tensors are physically split across
processes. This contrasts with *pipeline* model parallelism, in which the
layers are partitioned but tensors left intact.
The API for tensor model parallelism has changed across versions and this
helper provides a best-effort implementation across versions of ``mpu``
objects. The preferred mechanism is
``mpu.get_tensor_model_parallel_rank()``.
This should "just work" with both Megatron-LM and DeepSpeed's pipeline
parallelism.
Args:
mpu (model parallel unit, optional): The tensor model parallel rank.
If ``mpu=None``, returns 0. Defaults to ``None``.
Returns:
int: the rank
"""
if mpu is None:
# No model parallelism in easy :)
return 0
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
return mpu.get_tensor_model_parallel_rank()
elif hasattr(mpu, 'get_slice_parallel_rank'):
# Some DeepSpeed + pipeline parallelism versions
return mpu.get_slice_parallel_rank()
else:
# Deprecated Megatron and DeepSpeed convention
return mpu.get_model_parallel_rank()
def bwc_tensor_model_parallel_world_size(mpu=None):
"""Backwards-compatible way of querying the tensor model parallel world size.
Similar to bwc_tensor_model_parallel_rank.
"""
if mpu is None:
return 1
if hasattr(mpu, 'get_tensor_model_parallel_world_size'):
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
return mpu.get_tensor_model_parallel_world_size()
elif hasattr(mpu, 'get_slice_parallel_world_size'):
# Some DeepSpeed + pipeline parallelism versions
return mpu.get_slice_parallel_world_size()
else:
# Deprecated Megatron and DeepSpeed convention
return mpu.get_model_parallel_world_size()
def bwc_tensor_model_parallel_group(mpu=None):
"""Backwards-compatible way of querying the tensor model parallel group.
Similar to bwc_tensor_model_parallel_rank.
"""
if mpu is None:
return None
if hasattr(mpu, 'get_tensor_model_parallel_group'):
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
return mpu.get_tensor_model_parallel_group()
elif hasattr(mpu, 'get_slice_parallel_group'):
# Some DeepSpeed + pipeline parallelism versions
return mpu.get_slice_parallel_group()
else:
# Deprecated Megatron and DeepSpeed convention
return mpu.get_model_parallel_group()
def bwc_pipeline_parallel_world_size(mpu=None):
"""Backwards-compatible way of querying the pipeline parallel world size."""
world_size = 1
if mpu is not None:
if hasattr(mpu, 'get_pipeline_model_parallel_world_size'):
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
world_size = mpu.get_pipeline_model_parallel_world_size()
elif hasattr(mpu, 'get_pipe_parallel_world_size'):
# DeepSpeed Topology
world_size = mpu.get_pipe_parallel_world_size()
return world_size
def bwc_pipeline_parallel_group(mpu=None):
"""Backwards-compatible way of querying the pipeline parallel group."""
if mpu is None:
return None
if hasattr(mpu, 'get_pipeline_model_parallel_group'):
# Megatron
return mpu.get_pipeline_model_parallel_group()
elif hasattr(mpu, 'get_pipe_parallel_group'):
# DeepSpeed Topology
return mpu.get_pipe_parallel_group()
assert False, 'mpu does not support pipeline parallel group'

View File

@ -27,6 +27,7 @@
from deepspeed import comm as dist
from deepspeed.utils import log_dist
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size
from deepspeed.utils.exceptions import DeprecatedException
from deepspeed.accelerator import get_accelerator
# Expert parallel group that the current rank belongs to.
@ -128,31 +129,32 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe
log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0])
world_size = dist.get_world_size()
pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
rank = dist.get_rank()
_ensure_divisibility(world_size, expert_parallel_size_)
pp_stride = world_size // pp_world_size
_ensure_divisibility(pp_stride, expert_parallel_size_)
group_name = f"ep_size_{expert_parallel_size_}"
# Build the expert data parallel groups.
global _EXPERT_DATA_PARALLEL_GROUP
ep_stride = world_size // expert_parallel_size_
ep_stride = pp_stride // expert_parallel_size_
# Only create group if it does not already exist
if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
for i in range(expert_parallel_size_):
if use_data_before_expert_parallel_:
ranks = range(i * ep_stride, (i + 1) * ep_stride)
else:
ranks = range(i, world_size, expert_parallel_size_)
group = dist.new_group(ranks)
log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0])
if use_data_before_expert_parallel_:
if i == (rank // ep_stride):
_EXPERT_DATA_PARALLEL_GROUP[group_name] = group
else:
if i == (rank % expert_parallel_size_):
for pp_stage_start in range(0, world_size, pp_stride):
for i in range(expert_parallel_size_):
if use_data_before_expert_parallel_:
ranks = range(pp_stage_start + i * ep_stride, pp_stage_start + (i + 1) * ep_stride)
else:
ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, expert_parallel_size_)
group = dist.new_group(ranks)
log_dist(
f'Creating expert data parallel process group named {group_name} '
f'with ranks: {list(ranks)}', [0])
if rank in ranks:
_EXPERT_DATA_PARALLEL_GROUP[group_name] = group
# Build the expert parallel groups.
@ -161,24 +163,29 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe
# Only create group if it does not already exist
if group_name not in _EXPERT_PARALLEL_GROUP:
if use_data_before_expert_parallel_:
for i in range(ep_stride):
ranks = range(i, world_size, ep_stride)
group = dist.new_group(ranks)
log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0])
if i == (rank % ep_stride):
_EXPERT_PARALLEL_GROUP[group_name] = group
for pp_stage_start in range(0, world_size, pp_stride):
for i in range(ep_stride):
ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, ep_stride)
group = dist.new_group(ranks)
log_dist(
f'creating expert parallel process group named {group_name} '
f'with ranks: {list(ranks)}', [0])
if rank in ranks:
_EXPERT_PARALLEL_GROUP[group_name] = group
else:
for i in range(world_size // expert_parallel_size_):
ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
group = dist.new_group(ranks)
log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0])
if i == (rank // expert_parallel_size_):
log_dist(f'creating expert parallel process group named {group_name} '
f'with ranks: {list(ranks)}', [0])
if rank in ranks:
_EXPERT_PARALLEL_GROUP[group_name] = group
def _get_expert_parallel_ranks(world_size,
model_parallel_size_,
tensor_parallel_size_,
expert_parallel_size_,
pipeline_parallel_size_=1,
use_data_before_expert_parallel_=False):
"""Generate expert parallel and expert data parallel group ranks list.
@ -193,32 +200,40 @@ def _get_expert_parallel_ranks(world_size,
Args:
world_size (int): Distributed world size.
model_parallel_size_ (int): Model parallel group size.
tensor_parallel_size_ (int): Tensor parallel group size.
expert_parallel_size_ (int): Expert parallel group size.
pipeline_parallel_size_ (int): Pipeline parallel group size
use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology
Returns:
Expert parallel group ranks and Expert data parallel group ranks list.
"""
_ensure_divisibility(world_size, model_parallel_size_)
dp_world_size = world_size // model_parallel_size_
_ensure_divisibility(world_size, tensor_parallel_size_ * pipeline_parallel_size_)
dp_world_size = world_size // (tensor_parallel_size_ * pipeline_parallel_size_)
_ensure_divisibility(dp_world_size, expert_parallel_size_)
# Generate data parallel groups
data_parallel_groups = []
dp_group_size = model_parallel_size_
dp_group_size = tensor_parallel_size_
pp_stride = world_size // pipeline_parallel_size_
if use_data_before_expert_parallel_:
dp_stride = world_size // expert_parallel_size_ // model_parallel_size_
for i in range(dp_group_size):
data_parallel_groups.append(list())
for ds in range(dp_stride):
# [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30]
# [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31]
data_parallel_groups[-1].extend(
list(range(i + ds * model_parallel_size_, world_size, dp_stride * model_parallel_size_)))
dp_stride = world_size // expert_parallel_size_ // tensor_parallel_size_ // pipeline_parallel_size_
for pp_stage_start in range(0, world_size, pp_stride):
pp_stage_next = pp_stage_start + pp_stride
for i in range(dp_group_size):
data_parallel_groups.append(list())
for ds in range(dp_stride):
# [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30]
# [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31]
data_parallel_groups[-1].extend(
list(
range(pp_stage_start + i + ds * tensor_parallel_size_, pp_stage_next,
dp_stride * tensor_parallel_size_)))
else:
for i in range(dp_group_size):
data_parallel_groups.append(list(range(i, world_size, dp_group_size)))
for pp_stage_start in range(0, world_size, pp_stride):
pp_stage_next = pp_stage_start + pp_stride
for i in range(dp_group_size):
data_parallel_groups.append(list(range(pp_stage_start + i, pp_stage_next, dp_group_size)))
expert_parallel_groups = []
expert_data_parallel_groups = []
@ -252,36 +267,33 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_
expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
"""
assert dist.is_initialized(), "dist is not initialized"
model_parallel_size_ = mpu.get_model_parallel_world_size()
tensor_parallel_size_ = bwc_tensor_model_parallel_world_size(mpu)
global expert_tensor_parallel_world_size
expert_tensor_parallel_world_size = model_parallel_size_
expert_tensor_parallel_world_size = tensor_parallel_size_
world_size = dist.get_world_size()
rank = dist.get_rank()
dp_world_size = mpu.get_data_parallel_world_size()
dp_rank = mpu.get_data_parallel_rank()
pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
_ensure_divisibility(world_size, model_parallel_size_)
_ensure_divisibility(world_size, tensor_parallel_size_)
_ensure_divisibility(dp_world_size, expert_parallel_size_)
log_dist(
f"Creating deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}",
[0])
f"Creating deepspeed groups with model parallel size {tensor_parallel_size_}, "
f"pipeline parallel size {pp_world_size}, expert parallel size {expert_parallel_size_}, "
f"world size {world_size}, dp world size {dp_world_size}", [0])
global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP
# Get world size and rank. Ensure some consistencies.
_DATA_PARALLEL_GROUP = mpu.get_data_parallel_group()
_MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group()
group_name = f"ep_size_{expert_parallel_size_}"
# Only create groups if they don't already exist
# Need to check conditions outside the group creation loop because of the way torch.dist group creation works
if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP:
expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(
world_size, model_parallel_size_, expert_parallel_size_, use_data_before_expert_parallel_)
world_size, tensor_parallel_size_, expert_parallel_size_, pp_world_size, use_data_before_expert_parallel_)
for ranks in expert_parallel_groups:
group = dist.new_group(ranks)
if rank in list(ranks):

View File

@ -18,7 +18,7 @@ def test_get_expert_parallel_ranks():
expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
"""
expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(world_size=16,
model_parallel_size_=2,
tensor_parallel_size_=2,
expert_parallel_size_=4)
assert expert_parallel_groups == [
[0, 2, 4, 6],