[MoE] More balanced expert sharding (#21497)

Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon
2025-07-24 15:56:08 -07:00
committed by GitHub
parent 07d80d7b0e
commit fe56180c7f

View File

@ -591,22 +591,20 @@ def determine_expert_map(
if ep_size == 1:
return (global_num_experts, None)
local_num_experts = global_num_experts // ep_size
# Distribute experts as evenly as possible to each rank.
base_experts = global_num_experts // ep_size
remainder = global_num_experts % ep_size
if ep_rank < remainder:
local_num_experts = base_experts + 1
else:
local_num_experts = base_experts
# Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
# Create a expert map for the local experts
if ep_rank < (ep_size - 1):
# Each non-last rank gets local_num_experts experts.
expert_map[ep_rank * local_num_experts:
(ep_rank + 1) * local_num_experts] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
local_num_experts = (global_num_experts - ep_rank * local_num_experts)
expert_map[-local_num_experts:] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
start_idx = ep_rank * base_experts + min(ep_rank, remainder)
expert_map[start_idx:start_idx + local_num_experts] = torch.arange(
0, local_num_experts, dtype=torch.int32)
return (local_num_experts, expert_map)