mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:46:02 +08:00
reduce all-to-all communication volume when both expert and non-expert are tensor-parallel (#5626)
Example: E + M + D parallel world_size = 8 model_degree = 2 expert_degree = 4 mp_group = [0, 1], [2,3], [4,5],[6,7] expert_parallel_group = [0,2,4,6], [1,3,5,7] The original execution method was that before executing Expert, there was no drop operation, and two EPs did all-to-all separately. In the end, they both obtained complete data, but 0 and 1 obtained exactly the same data. Similarly, 2, 3, and so on all obtained the same data. Therefore, we can drop the data before executing all-to-all, and then execute allgather after all-to-all to obtain the complete data. After executing Expert, the data on 0 and 1 is exactly the same, so we can drop it and then execute all-to-all , and then execute allgather to obtain the complete data. 1. non-expert use TP, expert not use TP: drop -> alltoall -> exe MOE -> alltoall -> allgather 2. both non-expert and expert all use TP: - the original execution order: alltoall -> exe MOE-> allreduce -> alltoall - optimized execution order: drop -> alltoall -> allgather -> exe MOE -> drop ->alltoall -> allgather Signed-off-by: --local <zhiwei.tao@enflame-tech.com> Co-authored-by: --local <zhiwei.tao@enflame-tech.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
@ -32,15 +32,23 @@ def _gather_tokens(input_, dim=0):
|
|||||||
mpu = deepspeed.utils.groups.mpu
|
mpu = deepspeed.utils.groups.mpu
|
||||||
|
|
||||||
input_ = input_.contiguous()
|
input_ = input_.contiguous()
|
||||||
# Size and dimension.
|
world_size = bwc_tensor_model_parallel_world_size(mpu)
|
||||||
rank = bwc_tensor_model_parallel_rank(mpu)
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))]
|
gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
|
||||||
tensor_list[rank] = input_
|
deepspeed.comm.all_gather_into_tensor(gather_buffer, input_, group=bwc_tensor_model_parallel_group(mpu))
|
||||||
deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu))
|
if dim == 0:
|
||||||
|
shape = list(input_.size())
|
||||||
# Note: torch.cat already creates a contiguous tensor.
|
shape[0] = shape[0] * world_size
|
||||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
output = gather_buffer.view(shape)
|
||||||
|
else:
|
||||||
|
tensor_list = [
|
||||||
|
gather_buffer.narrow(0,
|
||||||
|
input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
|
||||||
|
]
|
||||||
|
# Note: torch.cat already creates a contiguous tensor.
|
||||||
|
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -50,6 +58,8 @@ def _drop_tokens(input_, dim=0):
|
|||||||
mpu = deepspeed.utils.groups.mpu
|
mpu = deepspeed.utils.groups.mpu
|
||||||
|
|
||||||
total_chunks = bwc_tensor_model_parallel_world_size(mpu)
|
total_chunks = bwc_tensor_model_parallel_world_size(mpu)
|
||||||
|
if total_chunks == 1:
|
||||||
|
return input_
|
||||||
this_chunk = bwc_tensor_model_parallel_rank(mpu)
|
this_chunk = bwc_tensor_model_parallel_rank(mpu)
|
||||||
assert input_.shape[
|
assert input_.shape[
|
||||||
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
|
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
|
||||||
|
@ -533,13 +533,18 @@ class MOELayer(Base):
|
|||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
self.timers(FIRST_ALLTOALL_TIMER).start()
|
self.timers(FIRST_ALLTOALL_TIMER).start()
|
||||||
|
|
||||||
if groups._get_expert_model_parallel_world_size() == 1:
|
tensor_model_world_size = bwc_tensor_model_parallel_world_size(groups.mpu)
|
||||||
# If the non-expert is tensor-parallel, it will create
|
if tensor_model_world_size > 1:
|
||||||
|
# If the non-expert is tensor-parallel,
|
||||||
|
# Whether expert is tensor-parallel or not , it will create
|
||||||
# duplicate tokens on the tensor-parallel ranks.
|
# duplicate tokens on the tensor-parallel ranks.
|
||||||
# Since our experts are not tensor-parallel, these duplicates
|
# drop duplicate tokens also doubles up as a communication
|
||||||
# need to be dropped to ensure correctness.
|
# optimization as we are reducing the all-to-all communication volume.
|
||||||
# this also doubles up as a communication optimization as we are
|
# 1: for not tensor-parallel expert,drop duplicate tokens to ensure
|
||||||
# reducing the all-to-all communication volume.
|
# both correctness and reduce all-to-all communication.
|
||||||
|
# 2: for tensor-parallel expert,drop duplicate tokens to reduce all-to-all
|
||||||
|
# communication volume,before expert execution, it is necessary to perform
|
||||||
|
# an allgather to ensure correctness,
|
||||||
dispatched_input = drop_tokens(dispatched_input, dim=1)
|
dispatched_input = drop_tokens(dispatched_input, dim=1)
|
||||||
|
|
||||||
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
|
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
|
||||||
@ -548,10 +553,22 @@ class MOELayer(Base):
|
|||||||
self.timers(FIRST_ALLTOALL_TIMER).stop()
|
self.timers(FIRST_ALLTOALL_TIMER).stop()
|
||||||
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
|
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
|
||||||
|
|
||||||
|
if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1:
|
||||||
|
# if both expert and non-expert are tensor-parallel
|
||||||
|
# the dropped duplicate tokens need to be gathered on each
|
||||||
|
# tensor parallel rank again to ensure correctness
|
||||||
|
dispatched_input = gather_tokens(dispatched_input, dim=1)
|
||||||
|
|
||||||
# Re-shape after all-to-all: ecm -> gecm
|
# Re-shape after all-to-all: ecm -> gecm
|
||||||
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
||||||
|
|
||||||
expert_output = self.experts(dispatched_input)
|
expert_output = self.experts(dispatched_input)
|
||||||
|
# Re-shape before drop_tokens: gecm -> ecm
|
||||||
|
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
||||||
|
if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1:
|
||||||
|
# if both expert and non-expert are tensor-parallel
|
||||||
|
# drop duplicate tokens to ensure both correctness
|
||||||
|
# and reduce all-to-all communication.
|
||||||
|
expert_output = drop_tokens(expert_output, dim=1)
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
self.timers(SECOND_ALLTOALL_TIMER).start()
|
self.timers(SECOND_ALLTOALL_TIMER).start()
|
||||||
@ -562,10 +579,7 @@ class MOELayer(Base):
|
|||||||
self.timers(SECOND_ALLTOALL_TIMER).stop()
|
self.timers(SECOND_ALLTOALL_TIMER).stop()
|
||||||
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
|
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
|
||||||
|
|
||||||
# Re-shape back: gecm -> ecm
|
if tensor_model_world_size > 1:
|
||||||
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
|
||||||
|
|
||||||
if groups._get_expert_model_parallel_world_size() == 1:
|
|
||||||
# the dropped duplicate tokens need to be gathered on each
|
# the dropped duplicate tokens need to be gathered on each
|
||||||
# tensor parallel rank again for the tensor-parallel
|
# tensor parallel rank again for the tensor-parallel
|
||||||
# non-expert of the next layer.
|
# non-expert of the next layer.
|
||||||
|
@ -327,7 +327,7 @@ class DeepSpeedMoEInference(nn.Module):
|
|||||||
|
|
||||||
if self.expert_mp_group is not None:
|
if self.expert_mp_group is not None:
|
||||||
world_size = dist.get_world_size(group=self.expert_mp_group)
|
world_size = dist.get_world_size(group=self.expert_mp_group)
|
||||||
gather_buffer = torch.zeros(world_size * attention_output.numel(),
|
gather_buffer = torch.empty(world_size * attention_output.numel(),
|
||||||
dtype=attention_output.dtype,
|
dtype=attention_output.dtype,
|
||||||
device=attention_output.device)
|
device=attention_output.device)
|
||||||
dist.all_gather_into_tensor(gather_buffer, attention_output, group=self.expert_mp_group)
|
dist.all_gather_into_tensor(gather_buffer, attention_output, group=self.expert_mp_group)
|
||||||
|
@ -2237,7 +2237,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
|||||||
return grad_dict
|
return grad_dict
|
||||||
|
|
||||||
def _fp32_state_allgather(self, param, fp32_state_partition):
|
def _fp32_state_allgather(self, param, fp32_state_partition):
|
||||||
reduce_buffer = torch.zeros(self.partition_count * fp32_state_partition.numel(),
|
reduce_buffer = torch.empty(self.partition_count * fp32_state_partition.numel(),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=param.device)
|
device=param.device)
|
||||||
my_rank = dist.get_rank(group=self.dp_process_group)
|
my_rank = dist.get_rank(group=self.dp_process_group)
|
||||||
|
Reference in New Issue
Block a user