mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Support MoE for pipeline models (#5338)
This PR enhances DeepSpeed to support MoE for pipeline models (e.g. GPTModelPipe from Megatron-DeepSpeed). Main changes: - Enhance expert groups creation for pipeline (enhance both flavors: DP/PP/EP and DP/TP/PP/EP) - Fix MoE save/load checkpoint for PipelineModule based models. - Display MoE loss for PipelineModule based models. - Support gradients reduce for BF16_Optimizer for PipelineModule.<br>Note that same commit also fixes gradients reduction error when using Megatron-DeepSpeed GPTModelPipe with BF16_Optimizer also for a dense (no MOE) model. - When using no-drop tokens, all-reduce the capacity (op=max) using expert parallel group instead of world group --------- Signed-off-by: Moshe Island <misland@habana.ai> Co-authored-by: Moshe Island <misland@habana.ai>
This commit is contained in:
@ -71,7 +71,7 @@ class MoE(nn.Module):
|
||||
|
||||
experts = Experts(expert, self.num_local_experts, self.expert_group_name)
|
||||
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
|
||||
min_capacity, noisy_gate_policy, drop_tokens, use_rts,
|
||||
min_capacity, noisy_gate_policy, drop_tokens, use_rts, None,
|
||||
top2_2nd_expert_sampling),
|
||||
experts,
|
||||
self.expert_group_name,
|
||||
|
@ -23,6 +23,8 @@
|
||||
|
||||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_world_size, bwc_tensor_model_parallel_rank,
|
||||
bwc_tensor_model_parallel_group)
|
||||
|
||||
|
||||
def _gather_tokens(input_, dim=0):
|
||||
@ -31,11 +33,11 @@ def _gather_tokens(input_, dim=0):
|
||||
|
||||
input_ = input_.contiguous()
|
||||
# Size and dimension.
|
||||
rank = mpu.get_tensor_model_parallel_rank()
|
||||
rank = bwc_tensor_model_parallel_rank(mpu)
|
||||
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())]
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))]
|
||||
tensor_list[rank] = input_
|
||||
deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group())
|
||||
deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu))
|
||||
|
||||
# Note: torch.cat already creates a contiguous tensor.
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
@ -47,8 +49,8 @@ def _drop_tokens(input_, dim=0):
|
||||
"""Divide a tensor among the tensor parallel ranks"""
|
||||
mpu = deepspeed.utils.groups.mpu
|
||||
|
||||
total_chunks = mpu.get_tensor_model_parallel_world_size()
|
||||
this_chunk = mpu.get_tensor_model_parallel_rank()
|
||||
total_chunks = bwc_tensor_model_parallel_world_size(mpu)
|
||||
this_chunk = bwc_tensor_model_parallel_rank(mpu)
|
||||
assert input_.shape[
|
||||
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
|
||||
chunk_size = input_.shape[dim] // total_chunks
|
||||
@ -92,7 +94,7 @@ class _DropTokens(torch.autograd.Function):
|
||||
|
||||
def gather_tokens(input_, dim=0):
|
||||
mpu = deepspeed.utils.groups.mpu
|
||||
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
|
||||
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
|
||||
# no tensor parallelism for non-experts
|
||||
return input_
|
||||
return _GatherTokens.apply(input_, dim)
|
||||
@ -100,7 +102,7 @@ def gather_tokens(input_, dim=0):
|
||||
|
||||
def drop_tokens(input_, dim=0):
|
||||
mpu = deepspeed.utils.groups.mpu
|
||||
if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
|
||||
if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
|
||||
# no tensor parallelism for non-experts
|
||||
return input_
|
||||
return _DropTokens.apply(input_, dim)
|
||||
|
@ -17,7 +17,8 @@ The file has been adapted from two fairscale files:
|
||||
|
||||
from deepspeed.utils.timer import SynchronizedWallClockTimer
|
||||
from deepspeed.utils import logger
|
||||
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple
|
||||
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size
|
||||
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -184,6 +185,7 @@ def top1gating(logits: Tensor,
|
||||
noisy_gate_policy: Optional[str] = None,
|
||||
drop_tokens: bool = True,
|
||||
use_rts: bool = True,
|
||||
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
|
||||
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Implements Top1Gating on logits."""
|
||||
if noisy_gate_policy == 'RSample':
|
||||
@ -209,12 +211,13 @@ def top1gating(logits: Tensor,
|
||||
# if we don't want to drop any tokens
|
||||
if not drop_tokens:
|
||||
new_capacity = torch.max(exp_counts).to(logits.device)
|
||||
# Communicate across all processes to pick the maximum capacity.
|
||||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
|
||||
# Communicate across expert processes to pick the maximum capacity.
|
||||
if ep_group is not None:
|
||||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
if groups._get_expert_model_parallel_world_size() == 1:
|
||||
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
|
||||
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
|
||||
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
|
||||
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
|
||||
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
|
||||
# Make sure the capacity value does not exceed the number of tokens.
|
||||
capacity = min(new_capacity, torch.tensor(mask1.size(0)))
|
||||
@ -286,6 +289,7 @@ def top2gating(logits: Tensor,
|
||||
capacity_factor: float,
|
||||
min_capacity: int,
|
||||
drop_tokens: bool = True,
|
||||
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
|
||||
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Implements Top2Gating on logits."""
|
||||
# everything is in fp32 in this function
|
||||
@ -328,11 +332,12 @@ def top2gating(logits: Tensor,
|
||||
else:
|
||||
# Do not drop tokens - set capacity according to current expert assignments
|
||||
new_capacity = torch.max(exp_counts)
|
||||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
|
||||
if ep_group is not None:
|
||||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
if groups._get_expert_model_parallel_world_size() == 1:
|
||||
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
|
||||
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
|
||||
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
|
||||
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
|
||||
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
|
||||
capacity = new_capacity
|
||||
|
||||
@ -376,7 +381,7 @@ class TopKGate(Module):
|
||||
Args:
|
||||
model_dim (int):
|
||||
size of model embedding dimension
|
||||
num_experts (ints):
|
||||
num_experts (int):
|
||||
number of experts in model
|
||||
"""
|
||||
|
||||
@ -392,6 +397,7 @@ class TopKGate(Module):
|
||||
noisy_gate_policy: Optional[str] = None,
|
||||
drop_tokens: bool = True,
|
||||
use_rts: bool = True,
|
||||
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
|
||||
top2_2nd_expert_sampling: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -399,6 +405,7 @@ class TopKGate(Module):
|
||||
if k != 1 and k != 2:
|
||||
raise ValueError('Only top-1 and top-2 gatings are supported.')
|
||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
|
||||
self.ep_group = ep_group
|
||||
self.k = k
|
||||
self.capacity_factor = capacity_factor
|
||||
self.eval_capacity_factor = eval_capacity_factor
|
||||
@ -411,6 +418,10 @@ class TopKGate(Module):
|
||||
self.use_rts = use_rts
|
||||
self.top2_2nd_expert_sampling = top2_2nd_expert_sampling
|
||||
|
||||
def _set_ep_group(self, ep_group):
|
||||
assert self.ep_group is None, f'Attempting to override an existing ep_group'
|
||||
self.ep_group = ep_group
|
||||
|
||||
def forward(self,
|
||||
input: torch.Tensor,
|
||||
used_token: torch.Tensor = None,
|
||||
@ -428,11 +439,11 @@ class TopKGate(Module):
|
||||
if self.k == 1:
|
||||
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
|
||||
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
|
||||
self.drop_tokens, self.use_rts, use_tutel)
|
||||
self.drop_tokens, self.use_rts, self.ep_group, use_tutel)
|
||||
|
||||
else:
|
||||
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
|
||||
self.min_capacity, self.drop_tokens, self.top2_2nd_expert_sampling)
|
||||
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
self.timers(TOPK_GATE_TIMER).stop()
|
||||
@ -492,6 +503,7 @@ class MOELayer(Base):
|
||||
|
||||
def _set_ep_group(self, ep_group):
|
||||
self.ep_group = ep_group
|
||||
self.gate._set_ep_group(ep_group)
|
||||
|
||||
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
|
||||
|
||||
|
@ -226,7 +226,7 @@ class DeepSpeedMoEInference(nn.Module):
|
||||
self.moe_gate = TopKGate(self.config.hidden_size, self.config.global_experts, self.config.k,
|
||||
self.config.capacity_factor, self.config.eval_capacity_factor,
|
||||
self.config.min_capacity, self.config.noisy_gate_policy, self.config.drop_tokens,
|
||||
self.config.use_rts)
|
||||
self.config.use_rts, self.ep_group)
|
||||
|
||||
self.ep_group = ep_group
|
||||
self.mp_group = mp_group
|
||||
|
@ -25,8 +25,9 @@ from torch import _C
|
||||
|
||||
from deepspeed.runtime.config import DeepSpeedConfig
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
|
||||
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
|
||||
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
|
||||
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
# DeepSpeed Checkpointing Enabled or Disabled
|
||||
|
@ -13,11 +13,11 @@ from deepspeed.runtime.base_optimizer import ZeROOptimizer
|
||||
from packaging import version as pkg_version
|
||||
from deepspeed.git_version_info import version
|
||||
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
|
||||
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
|
||||
is_model_parallel_parameter, see_memory_usage, graph_process,
|
||||
get_norm_with_moe_layers)
|
||||
align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter,
|
||||
see_memory_usage, graph_process, get_norm_with_moe_layers)
|
||||
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
|
||||
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
|
||||
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
|
||||
from deepspeed.checkpoint import enable_universal_checkpoint
|
||||
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
|
||||
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
|
||||
|
@ -3228,9 +3228,12 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
# Load flow uses below saved file for model parameters, RNG and more
|
||||
if groups._get_data_parallel_rank() == 0:
|
||||
# get non-moe parameters
|
||||
# Get non-moe parameters
|
||||
# Classes DeepSpeedEngine and PipelineEngine have different behavior for method module_state_dict.
|
||||
# DeepSpeedEngine returns the state dict, where PipelineEngine saves the state dict and returns None.
|
||||
# We need to get the state dict, therefore, call to DeepSpeedEngine (base class for PipelineEngine)
|
||||
model_state_dict = self._get_non_moe_state_dict(
|
||||
self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))
|
||||
DeepSpeedEngine.module_state_dict(self, exclude_frozen_parameters=exclude_frozen_parameters))
|
||||
|
||||
# TODO: update num experts info,.. in checkpoint
|
||||
state = {
|
||||
|
@ -4,6 +4,7 @@
|
||||
# DeepSpeed Team
|
||||
|
||||
from types import MethodType
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from deepspeed import comm as dist
|
||||
@ -194,9 +195,15 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
|
||||
#stores the loss for the entire batch
|
||||
self.total_loss = None
|
||||
self.total_additional_losses = None
|
||||
self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
|
||||
self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
|
||||
|
||||
# stores aggregated-DP train final loss and aggregated-DP additional losses, if any
|
||||
# additional losses are stored as dict: {loss-name: agg-loss}
|
||||
self.agg_train_loss = None
|
||||
self.agg_additional_losses = None
|
||||
|
||||
if self._config.pipeline['activation_checkpoint_interval'] > 0:
|
||||
self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval']
|
||||
# set use_reentrant default to True.
|
||||
@ -284,10 +291,7 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
self._force_grad_boundary = False
|
||||
|
||||
def _bf16_reduce_grads(self):
|
||||
# Make our own list of gradients from the optimizer's FP32 grads
|
||||
grads = []
|
||||
self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(),
|
||||
elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
|
||||
self.buffered_allreduce_fallback(grads=None, elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
|
||||
|
||||
def _reserve_pipe_buffers(self, num_buffers):
|
||||
"""Ensure that each pipeline buffer has at least ``num_buffers`` slots.
|
||||
@ -363,6 +367,7 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
|
||||
self.module.train()
|
||||
self.total_loss = None
|
||||
self.total_additional_losses = None
|
||||
self._compute_loss = True
|
||||
|
||||
# Do the work
|
||||
@ -371,7 +376,9 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
stages=self.num_stages,
|
||||
stage_id=self.stage_id)
|
||||
self._exec_schedule(sched)
|
||||
self.agg_train_loss = self._aggregate_total_loss()
|
||||
|
||||
with torch.no_grad():
|
||||
self.agg_train_loss = self._aggregate_total_loss()
|
||||
|
||||
self.timers(TRAIN_BATCH_TIMER).stop()
|
||||
|
||||
@ -380,10 +387,12 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0
|
||||
iter_time = elapsed / self.steps_per_print()
|
||||
tput = self.train_batch_size() / iter_time
|
||||
print(f'steps: {self.global_steps} '
|
||||
f'loss: {self.agg_train_loss:0.4f} '
|
||||
f'iter time (s): {iter_time:0.3f} '
|
||||
f'samples/sec: {tput:0.3f}')
|
||||
log_str = f'steps: {self.global_steps} loss: {self.agg_train_loss:0.4f} '
|
||||
if self.agg_additional_losses is not None:
|
||||
for loss_name, loss_value in self.agg_additional_losses.items():
|
||||
log_str += f'{loss_name}: {loss_value.item():0.4f} '
|
||||
log_str += f'iter time (s): {iter_time:0.3f} samples/sec: {tput:0.3f}'
|
||||
print(log_str)
|
||||
else:
|
||||
self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True)
|
||||
|
||||
@ -565,29 +574,66 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
def _aggregate_total_loss(self):
|
||||
# Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
|
||||
if self.is_last_stage():
|
||||
# Scale loss and additional losses, if any
|
||||
loss = self._scale_loss_by_gas(self.total_loss)
|
||||
self.dp_group_loss = loss.clone().detach()
|
||||
self.agg_additional_losses = self.total_additional_losses
|
||||
if self.agg_additional_losses is not None:
|
||||
self.agg_additional_losses = OrderedDict({
|
||||
loss_name: self._scale_loss_by_gas(_loss.clone().detach())
|
||||
for loss_name, _loss in self.agg_additional_losses.items()
|
||||
})
|
||||
|
||||
## Average loss across all data-parallel groups
|
||||
self.dp_group_loss = loss.clone().detach()
|
||||
agg_loss = self.dp_group_loss.clone().detach()
|
||||
#print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)
|
||||
|
||||
# Average loss across all data-parallel groups
|
||||
if self.is_data_parallel:
|
||||
dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
|
||||
agg_loss /= self.dp_world_size
|
||||
if self.agg_additional_losses is None:
|
||||
dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
|
||||
agg_loss /= self.dp_world_size
|
||||
else:
|
||||
# use a single reduce op for agg_loss and additional losses, if any
|
||||
assert '__train_loss__' not in self.agg_additional_losses.keys()
|
||||
tensors = OrderedDict({'__train_loss__': agg_loss})
|
||||
tensors.update(self.agg_additional_losses.items())
|
||||
flat_tensor = torch.cat([t.clone().reshape(-1).detach() for t in tensors.values()])
|
||||
dist.all_reduce(flat_tensor, group=self.mpu.get_data_parallel_group())
|
||||
flat_tensor /= self.dp_world_size
|
||||
offset = 0
|
||||
reduced_tensor = {}
|
||||
for name, t in tensors.items():
|
||||
n_elem = t.numel()
|
||||
reduced_tensor[name] = flat_tensor[offset:offset + n_elem].clone().detach().reshape(t.shape)
|
||||
offset += n_elem
|
||||
agg_loss = reduced_tensor['__train_loss__']
|
||||
self.agg_additional_losses = OrderedDict(
|
||||
{name: reduced_tensor[name]
|
||||
for name in self.agg_additional_losses.keys()})
|
||||
|
||||
assert self.global_rank in self.grid.pp_group
|
||||
losses = torch.stack([self.dp_group_loss, agg_loss]).float()
|
||||
losses = [self.dp_group_loss, agg_loss]
|
||||
if self.agg_additional_losses is not None:
|
||||
losses += list(self.agg_additional_losses.values())
|
||||
losses = torch.stack(losses).float()
|
||||
if self.is_pipe_parallel:
|
||||
dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
|
||||
else:
|
||||
# Get loss from last stage
|
||||
src_rank = self.grid.stage_to_global(self.num_stages - 1)
|
||||
assert src_rank in self.grid.pp_group
|
||||
losses = torch.Tensor([0., 0.]).to(self.device)
|
||||
# losses to reduce are: dp_group_loss, agg_loss, model additional losses
|
||||
# therefore: 2 + n_additional_losses
|
||||
additional_losses = self.module.get_additional_losses()
|
||||
n_additional_losses = 0 if additional_losses is None else len(additional_losses)
|
||||
losses = torch.Tensor([0.] * (2 + n_additional_losses)).to(self.device)
|
||||
dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group())
|
||||
self.dp_group_loss = losses[0].clone().detach()
|
||||
agg_loss = losses[1].clone().detach()
|
||||
|
||||
if additional_losses is not None:
|
||||
self.agg_additional_losses = OrderedDict(
|
||||
{name: losses[2 + i].clone().detach()
|
||||
for i, name in enumerate(additional_losses.keys())})
|
||||
return agg_loss
|
||||
|
||||
def set_dataloader(self, loader):
|
||||
@ -715,19 +761,34 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
self.loss = outputs
|
||||
if self.eval_return_logits:
|
||||
self.outputs = outputs
|
||||
|
||||
if isinstance(self.loss, torch.Tensor):
|
||||
self.fwd_outputs.append(self.loss.detach())
|
||||
|
||||
if self.total_loss is None:
|
||||
self.total_loss = torch.zeros_like(self.loss)
|
||||
self.total_loss += self.loss.detach()
|
||||
else:
|
||||
self.fwd_outputs.append([l.detach() for l in self.loss])
|
||||
|
||||
if self.total_loss is None:
|
||||
self.total_loss = [torch.zeros_like(l) for l in self.loss]
|
||||
for idx, l in enumerate(self.loss):
|
||||
self.total_loss[idx] += l.detach()
|
||||
def add_to_total_loss(_total_loss, _loss):
|
||||
if isinstance(_loss, torch.Tensor):
|
||||
if _total_loss is None:
|
||||
_total_loss = torch.zeros_like(_loss)
|
||||
_total_loss += _loss.detach()
|
||||
else:
|
||||
if _total_loss is None:
|
||||
_total_loss = [torch.zeros_like(_l) for _l in _loss]
|
||||
for _idx, _l in enumerate(_loss):
|
||||
_total_loss[_idx] += _l.detach()
|
||||
return _total_loss
|
||||
|
||||
self.total_loss = add_to_total_loss(self.total_loss, self.loss)
|
||||
|
||||
# aggregate additional losses across gradient accumulation steps
|
||||
additional_losses = self.module.get_additional_losses()
|
||||
if additional_losses is not None:
|
||||
if self.total_additional_losses is None:
|
||||
self.total_additional_losses = OrderedDict()
|
||||
for name, loss in additional_losses.items():
|
||||
total = self.total_additional_losses[name] if name in self.total_additional_losses else None
|
||||
self.total_additional_losses[name] = add_to_total_loss(total, loss)
|
||||
|
||||
def _exec_backward_pass(self, buffer_id):
|
||||
assert self.optimizer is not None, "must provide optimizer during " \
|
||||
@ -1332,7 +1393,7 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
strict (bool, optional): Strict state loading. Defaults to True.
|
||||
"""
|
||||
assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism"
|
||||
state_dict = checkpoint['module']
|
||||
state_dict = checkpoint if self.has_moe_layers else checkpoint['module']
|
||||
if (state_dict is not None) and (not isinstance(state_dict, str)):
|
||||
super().load_module_state_dict(state_dict, strict)
|
||||
return
|
||||
@ -1371,3 +1432,6 @@ class PipelineEngine(DeepSpeedEngine):
|
||||
# Equivalent to: self._exec_forward_pass(buffer_id=0)
|
||||
self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
|
||||
self._exec_instr(**cmd.kwargs)
|
||||
|
||||
def get_additional_losses(self):
|
||||
return self.agg_additional_losses
|
||||
|
@ -634,3 +634,10 @@ class PipelineModule(nn.Module):
|
||||
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)
|
||||
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
|
||||
return any(len(list(p)) > 0 for p in params)
|
||||
|
||||
def get_additional_losses(self):
|
||||
""" Returns model specific additional losses for reporting
|
||||
|
||||
Return a dictionary of {"loss name": loss_value} or None if no additional losses.
|
||||
"""
|
||||
return None
|
||||
|
@ -25,6 +25,8 @@ except ModuleNotFoundError:
|
||||
from torch import inf
|
||||
|
||||
from deepspeed.utils import groups, logger
|
||||
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size,
|
||||
bwc_pipeline_parallel_group)
|
||||
from deepspeed.runtime.constants import PIPE_REPLICATED
|
||||
from numpy import prod
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
@ -117,44 +119,6 @@ def is_model_parallel_parameter(p) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def bwc_tensor_model_parallel_rank(mpu=None):
|
||||
"""Backwards-compatible way of querying the tensor model parallel rank from
|
||||
an ``mpu`` object.
|
||||
|
||||
*Tensor* model parallelism means that tensors are physically split across
|
||||
processes. This contrasts with *pipeline* model parallelism, in which the
|
||||
layers are partitioned but tensors left intact.
|
||||
|
||||
The API for tensor model parallelism has changed across versions and this
|
||||
helper provides a best-effort implementation across versions of ``mpu``
|
||||
objects. The preferred mechanism is
|
||||
``mpu.get_tensor_model_parallel_rank()``.
|
||||
|
||||
This should "just work" with both Megatron-LM and DeepSpeed's pipeline
|
||||
parallelism.
|
||||
|
||||
Args:
|
||||
mpu (model parallel unit, optional): The tensor model parallel rank.
|
||||
If ``mpu=None``, returns 0. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
int: the rank
|
||||
"""
|
||||
if mpu is None:
|
||||
# No model parallelism in easy :)
|
||||
return 0
|
||||
|
||||
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
|
||||
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
|
||||
return mpu.get_tensor_model_parallel_rank()
|
||||
elif hasattr(mpu, 'get_slice_parallel_rank'):
|
||||
# Some DeepSpeed + pipeline parallelism versions
|
||||
return mpu.get_slice_parallel_rank()
|
||||
else:
|
||||
# Deprecated Megatron and DeepSpeed convention
|
||||
return mpu.get_model_parallel_rank()
|
||||
|
||||
|
||||
def copy_to_device(item, device, criterion_func):
|
||||
"""
|
||||
Return a copy of tensor on specified device.
|
||||
@ -894,8 +858,16 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
|
||||
all_norms.append(t.data.abs().max().float())
|
||||
total_norm = torch.stack(all_norms).max()
|
||||
device_total_norm = total_norm.to(get_accelerator().current_device_name())
|
||||
# Max across model parallel
|
||||
if mpu is not None:
|
||||
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
|
||||
# For MoE grads, max over model parallel only if MoE-TP is enabled
|
||||
if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
|
||||
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
|
||||
# If MoE grads and MoE-TP disabled, max over pipeline parallel
|
||||
elif bwc_pipeline_parallel_world_size(mpu) > 1:
|
||||
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=bwc_pipeline_parallel_group(mpu))
|
||||
|
||||
# MoE grads: max across expert parallel group
|
||||
if moe_ep_group is not None:
|
||||
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=moe_ep_group)
|
||||
total_norm = device_total_norm.to(input_tensors[0].device)
|
||||
@ -922,8 +894,16 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
|
||||
|
||||
device_total_norm = compute_buffer[0].float().detach()
|
||||
|
||||
# Sum across model parallel
|
||||
if mpu is not None:
|
||||
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
|
||||
# For MoE grads, sum over model parallel only if MoE-TP is enabled
|
||||
if moe_ep_group is None or groups._get_expert_model_parallel_world_size() > 1:
|
||||
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
|
||||
# If MoE grads and MoE-TP disabled, sum over pipeline parallel
|
||||
elif bwc_pipeline_parallel_world_size(mpu) > 1:
|
||||
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=bwc_pipeline_parallel_group(mpu))
|
||||
|
||||
# MoE grads: sum across expert parallel group
|
||||
if moe_ep_group is not None:
|
||||
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group)
|
||||
total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type)
|
||||
|
@ -11,13 +11,13 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from deepspeed.runtime.base_optimizer import ZeROOptimizer
|
||||
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
|
||||
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, empty_cache, see_memory_usage, inf,
|
||||
is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)
|
||||
|
||||
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
|
||||
align_dense_tensors, all_gather_dp_groups)
|
||||
from deepspeed.runtime.zero.config import ZeroStageEnum
|
||||
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
|
||||
from deepspeed.moe.utils import is_moe_param
|
||||
from deepspeed.git_version_info import version
|
||||
|
||||
|
104
deepspeed/utils/bwc.py
Normal file
104
deepspeed/utils/bwc.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
|
||||
def bwc_tensor_model_parallel_rank(mpu=None):
|
||||
"""Backwards-compatible way of querying the tensor model parallel rank from
|
||||
an ``mpu`` object.
|
||||
|
||||
*Tensor* model parallelism means that tensors are physically split across
|
||||
processes. This contrasts with *pipeline* model parallelism, in which the
|
||||
layers are partitioned but tensors left intact.
|
||||
|
||||
The API for tensor model parallelism has changed across versions and this
|
||||
helper provides a best-effort implementation across versions of ``mpu``
|
||||
objects. The preferred mechanism is
|
||||
``mpu.get_tensor_model_parallel_rank()``.
|
||||
|
||||
This should "just work" with both Megatron-LM and DeepSpeed's pipeline
|
||||
parallelism.
|
||||
|
||||
Args:
|
||||
mpu (model parallel unit, optional): The tensor model parallel rank.
|
||||
If ``mpu=None``, returns 0. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
int: the rank
|
||||
"""
|
||||
if mpu is None:
|
||||
# No model parallelism in easy :)
|
||||
return 0
|
||||
|
||||
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
|
||||
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
|
||||
return mpu.get_tensor_model_parallel_rank()
|
||||
elif hasattr(mpu, 'get_slice_parallel_rank'):
|
||||
# Some DeepSpeed + pipeline parallelism versions
|
||||
return mpu.get_slice_parallel_rank()
|
||||
else:
|
||||
# Deprecated Megatron and DeepSpeed convention
|
||||
return mpu.get_model_parallel_rank()
|
||||
|
||||
|
||||
def bwc_tensor_model_parallel_world_size(mpu=None):
|
||||
"""Backwards-compatible way of querying the tensor model parallel world size.
|
||||
Similar to bwc_tensor_model_parallel_rank.
|
||||
"""
|
||||
if mpu is None:
|
||||
return 1
|
||||
|
||||
if hasattr(mpu, 'get_tensor_model_parallel_world_size'):
|
||||
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
|
||||
return mpu.get_tensor_model_parallel_world_size()
|
||||
elif hasattr(mpu, 'get_slice_parallel_world_size'):
|
||||
# Some DeepSpeed + pipeline parallelism versions
|
||||
return mpu.get_slice_parallel_world_size()
|
||||
else:
|
||||
# Deprecated Megatron and DeepSpeed convention
|
||||
return mpu.get_model_parallel_world_size()
|
||||
|
||||
|
||||
def bwc_tensor_model_parallel_group(mpu=None):
|
||||
"""Backwards-compatible way of querying the tensor model parallel group.
|
||||
Similar to bwc_tensor_model_parallel_rank.
|
||||
"""
|
||||
if mpu is None:
|
||||
return None
|
||||
|
||||
if hasattr(mpu, 'get_tensor_model_parallel_group'):
|
||||
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
|
||||
return mpu.get_tensor_model_parallel_group()
|
||||
elif hasattr(mpu, 'get_slice_parallel_group'):
|
||||
# Some DeepSpeed + pipeline parallelism versions
|
||||
return mpu.get_slice_parallel_group()
|
||||
else:
|
||||
# Deprecated Megatron and DeepSpeed convention
|
||||
return mpu.get_model_parallel_group()
|
||||
|
||||
|
||||
def bwc_pipeline_parallel_world_size(mpu=None):
|
||||
"""Backwards-compatible way of querying the pipeline parallel world size."""
|
||||
world_size = 1
|
||||
if mpu is not None:
|
||||
if hasattr(mpu, 'get_pipeline_model_parallel_world_size'):
|
||||
# New Megatron and DeepSpeed convention (post pipeline-parallelism release)
|
||||
world_size = mpu.get_pipeline_model_parallel_world_size()
|
||||
elif hasattr(mpu, 'get_pipe_parallel_world_size'):
|
||||
# DeepSpeed Topology
|
||||
world_size = mpu.get_pipe_parallel_world_size()
|
||||
return world_size
|
||||
|
||||
|
||||
def bwc_pipeline_parallel_group(mpu=None):
|
||||
"""Backwards-compatible way of querying the pipeline parallel group."""
|
||||
if mpu is None:
|
||||
return None
|
||||
if hasattr(mpu, 'get_pipeline_model_parallel_group'):
|
||||
# Megatron
|
||||
return mpu.get_pipeline_model_parallel_group()
|
||||
elif hasattr(mpu, 'get_pipe_parallel_group'):
|
||||
# DeepSpeed Topology
|
||||
return mpu.get_pipe_parallel_group()
|
||||
assert False, 'mpu does not support pipeline parallel group'
|
@ -27,6 +27,7 @@
|
||||
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.utils import log_dist
|
||||
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size
|
||||
from deepspeed.utils.exceptions import DeprecatedException
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
# Expert parallel group that the current rank belongs to.
|
||||
@ -128,31 +129,32 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe
|
||||
|
||||
log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0])
|
||||
world_size = dist.get_world_size()
|
||||
pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
|
||||
rank = dist.get_rank()
|
||||
|
||||
_ensure_divisibility(world_size, expert_parallel_size_)
|
||||
pp_stride = world_size // pp_world_size
|
||||
_ensure_divisibility(pp_stride, expert_parallel_size_)
|
||||
|
||||
group_name = f"ep_size_{expert_parallel_size_}"
|
||||
|
||||
# Build the expert data parallel groups.
|
||||
global _EXPERT_DATA_PARALLEL_GROUP
|
||||
|
||||
ep_stride = world_size // expert_parallel_size_
|
||||
ep_stride = pp_stride // expert_parallel_size_
|
||||
|
||||
# Only create group if it does not already exist
|
||||
if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
|
||||
for i in range(expert_parallel_size_):
|
||||
if use_data_before_expert_parallel_:
|
||||
ranks = range(i * ep_stride, (i + 1) * ep_stride)
|
||||
else:
|
||||
ranks = range(i, world_size, expert_parallel_size_)
|
||||
group = dist.new_group(ranks)
|
||||
log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0])
|
||||
if use_data_before_expert_parallel_:
|
||||
if i == (rank // ep_stride):
|
||||
_EXPERT_DATA_PARALLEL_GROUP[group_name] = group
|
||||
else:
|
||||
if i == (rank % expert_parallel_size_):
|
||||
for pp_stage_start in range(0, world_size, pp_stride):
|
||||
for i in range(expert_parallel_size_):
|
||||
if use_data_before_expert_parallel_:
|
||||
ranks = range(pp_stage_start + i * ep_stride, pp_stage_start + (i + 1) * ep_stride)
|
||||
else:
|
||||
ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, expert_parallel_size_)
|
||||
group = dist.new_group(ranks)
|
||||
log_dist(
|
||||
f'Creating expert data parallel process group named {group_name} '
|
||||
f'with ranks: {list(ranks)}', [0])
|
||||
if rank in ranks:
|
||||
_EXPERT_DATA_PARALLEL_GROUP[group_name] = group
|
||||
|
||||
# Build the expert parallel groups.
|
||||
@ -161,24 +163,29 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe
|
||||
# Only create group if it does not already exist
|
||||
if group_name not in _EXPERT_PARALLEL_GROUP:
|
||||
if use_data_before_expert_parallel_:
|
||||
for i in range(ep_stride):
|
||||
ranks = range(i, world_size, ep_stride)
|
||||
group = dist.new_group(ranks)
|
||||
log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0])
|
||||
if i == (rank % ep_stride):
|
||||
_EXPERT_PARALLEL_GROUP[group_name] = group
|
||||
for pp_stage_start in range(0, world_size, pp_stride):
|
||||
for i in range(ep_stride):
|
||||
ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, ep_stride)
|
||||
group = dist.new_group(ranks)
|
||||
log_dist(
|
||||
f'creating expert parallel process group named {group_name} '
|
||||
f'with ranks: {list(ranks)}', [0])
|
||||
if rank in ranks:
|
||||
_EXPERT_PARALLEL_GROUP[group_name] = group
|
||||
else:
|
||||
for i in range(world_size // expert_parallel_size_):
|
||||
ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
|
||||
group = dist.new_group(ranks)
|
||||
log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0])
|
||||
if i == (rank // expert_parallel_size_):
|
||||
log_dist(f'creating expert parallel process group named {group_name} '
|
||||
f'with ranks: {list(ranks)}', [0])
|
||||
if rank in ranks:
|
||||
_EXPERT_PARALLEL_GROUP[group_name] = group
|
||||
|
||||
|
||||
def _get_expert_parallel_ranks(world_size,
|
||||
model_parallel_size_,
|
||||
tensor_parallel_size_,
|
||||
expert_parallel_size_,
|
||||
pipeline_parallel_size_=1,
|
||||
use_data_before_expert_parallel_=False):
|
||||
"""Generate expert parallel and expert data parallel group ranks list.
|
||||
|
||||
@ -193,32 +200,40 @@ def _get_expert_parallel_ranks(world_size,
|
||||
|
||||
Args:
|
||||
world_size (int): Distributed world size.
|
||||
model_parallel_size_ (int): Model parallel group size.
|
||||
tensor_parallel_size_ (int): Tensor parallel group size.
|
||||
expert_parallel_size_ (int): Expert parallel group size.
|
||||
pipeline_parallel_size_ (int): Pipeline parallel group size
|
||||
use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology
|
||||
Returns:
|
||||
Expert parallel group ranks and Expert data parallel group ranks list.
|
||||
"""
|
||||
_ensure_divisibility(world_size, model_parallel_size_)
|
||||
dp_world_size = world_size // model_parallel_size_
|
||||
_ensure_divisibility(world_size, tensor_parallel_size_ * pipeline_parallel_size_)
|
||||
dp_world_size = world_size // (tensor_parallel_size_ * pipeline_parallel_size_)
|
||||
_ensure_divisibility(dp_world_size, expert_parallel_size_)
|
||||
|
||||
# Generate data parallel groups
|
||||
data_parallel_groups = []
|
||||
dp_group_size = model_parallel_size_
|
||||
dp_group_size = tensor_parallel_size_
|
||||
pp_stride = world_size // pipeline_parallel_size_
|
||||
|
||||
if use_data_before_expert_parallel_:
|
||||
dp_stride = world_size // expert_parallel_size_ // model_parallel_size_
|
||||
for i in range(dp_group_size):
|
||||
data_parallel_groups.append(list())
|
||||
for ds in range(dp_stride):
|
||||
# [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30]
|
||||
# [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31]
|
||||
data_parallel_groups[-1].extend(
|
||||
list(range(i + ds * model_parallel_size_, world_size, dp_stride * model_parallel_size_)))
|
||||
dp_stride = world_size // expert_parallel_size_ // tensor_parallel_size_ // pipeline_parallel_size_
|
||||
for pp_stage_start in range(0, world_size, pp_stride):
|
||||
pp_stage_next = pp_stage_start + pp_stride
|
||||
for i in range(dp_group_size):
|
||||
data_parallel_groups.append(list())
|
||||
for ds in range(dp_stride):
|
||||
# [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30]
|
||||
# [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31]
|
||||
data_parallel_groups[-1].extend(
|
||||
list(
|
||||
range(pp_stage_start + i + ds * tensor_parallel_size_, pp_stage_next,
|
||||
dp_stride * tensor_parallel_size_)))
|
||||
else:
|
||||
for i in range(dp_group_size):
|
||||
data_parallel_groups.append(list(range(i, world_size, dp_group_size)))
|
||||
for pp_stage_start in range(0, world_size, pp_stride):
|
||||
pp_stage_next = pp_stage_start + pp_stride
|
||||
for i in range(dp_group_size):
|
||||
data_parallel_groups.append(list(range(pp_stage_start + i, pp_stage_next, dp_group_size)))
|
||||
|
||||
expert_parallel_groups = []
|
||||
expert_data_parallel_groups = []
|
||||
@ -252,36 +267,33 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_
|
||||
expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
|
||||
"""
|
||||
assert dist.is_initialized(), "dist is not initialized"
|
||||
model_parallel_size_ = mpu.get_model_parallel_world_size()
|
||||
tensor_parallel_size_ = bwc_tensor_model_parallel_world_size(mpu)
|
||||
|
||||
global expert_tensor_parallel_world_size
|
||||
expert_tensor_parallel_world_size = model_parallel_size_
|
||||
expert_tensor_parallel_world_size = tensor_parallel_size_
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dp_world_size = mpu.get_data_parallel_world_size()
|
||||
dp_rank = mpu.get_data_parallel_rank()
|
||||
pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
|
||||
|
||||
_ensure_divisibility(world_size, model_parallel_size_)
|
||||
_ensure_divisibility(world_size, tensor_parallel_size_)
|
||||
_ensure_divisibility(dp_world_size, expert_parallel_size_)
|
||||
|
||||
log_dist(
|
||||
f"Creating deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}",
|
||||
[0])
|
||||
f"Creating deepspeed groups with model parallel size {tensor_parallel_size_}, "
|
||||
f"pipeline parallel size {pp_world_size}, expert parallel size {expert_parallel_size_}, "
|
||||
f"world size {world_size}, dp world size {dp_world_size}", [0])
|
||||
|
||||
global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP
|
||||
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
_DATA_PARALLEL_GROUP = mpu.get_data_parallel_group()
|
||||
_MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group()
|
||||
|
||||
group_name = f"ep_size_{expert_parallel_size_}"
|
||||
|
||||
# Only create groups if they don't already exist
|
||||
# Need to check conditions outside the group creation loop because of the way torch.dist group creation works
|
||||
if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP:
|
||||
expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(
|
||||
world_size, model_parallel_size_, expert_parallel_size_, use_data_before_expert_parallel_)
|
||||
world_size, tensor_parallel_size_, expert_parallel_size_, pp_world_size, use_data_before_expert_parallel_)
|
||||
for ranks in expert_parallel_groups:
|
||||
group = dist.new_group(ranks)
|
||||
if rank in list(ranks):
|
||||
|
@ -18,7 +18,7 @@ def test_get_expert_parallel_ranks():
|
||||
expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
|
||||
"""
|
||||
expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(world_size=16,
|
||||
model_parallel_size_=2,
|
||||
tensor_parallel_size_=2,
|
||||
expert_parallel_size_=4)
|
||||
assert expert_parallel_groups == [
|
||||
[0, 2, 4, 6],
|
||||
|
Reference in New Issue
Block a user