[Bugfix] EPLB load statistics problem (#22167)

Signed-off-by: ycyaw66 <497410282@qq.com>
Signed-off-by: David Chen <530634352@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
This commit is contained in:
WeiQing Chen
2025-08-07 12:07:54 +08:00
committed by GitHub
parent f6278b6243
commit 4be02a3776
2 changed files with 26 additions and 41 deletions

View File

@ -32,7 +32,7 @@ from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch.distributed import ProcessGroup, all_gather, all_reduce
from torch.distributed import ProcessGroup, all_reduce
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
@ -112,13 +112,21 @@ class EplbState:
Expert load during this forward pass.
We use the token count each expert processes as the load.
Shape: (num_moe_layers, num_local_physical_experts)
Shape: (num_moe_layers, num_physical_experts)
"""
expert_load_window: torch.Tensor
"""
A sliding window of expert load.
Shape: (window_size, num_moe_layers, num_local_physical_experts)
Shape: (window_size, num_moe_layers, num_physical_experts)
NOTE: The expert_load_view now records load for all physical experts
rather than just local experts. This ensures consistent load statistics
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
The recorded load will be multiplied by dp_size when using naive all-to-all
due to each DP rank contributing the same token set to the calculation.
See:
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
"""
expert_load_window_step: int = 0
"""
@ -232,14 +240,14 @@ class EplbState:
).contiguous()
expert_load_pass = torch.zeros(
(model.num_moe_layers, model.num_local_physical_experts),
(model.num_moe_layers, model.num_physical_experts),
dtype=torch.int32,
device=device,
)
expert_load_window_size = parallel_config.eplb_window_size
expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers,
model.num_local_physical_experts),
model.num_physical_experts),
dtype=torch.int32,
device=device,
)
@ -353,18 +361,18 @@ class EplbState:
self.expert_load_pass.zero_()
if log_stats:
# `num_tokens`: (num_moe_layers,)
num_tokens = self.expert_load_pass.sum(dim=-1)
# total_expert_load_pass: (num_moe_layers, num_physical_experts)
total_expert_load_pass = self.expert_load_pass.clone()
# Collect load metrics from all ranks
ep_group = get_ep_group().device_group
assert ep_group is not None
num_tokens_list = [
torch.empty_like(num_tokens) for _ in range(ep_group.size())
]
all_gather(num_tokens_list, num_tokens, group=ep_group)
# Stack to get (num_ranks, num_moe_layers)
num_tokens_per_rank = torch.stack(num_tokens_list).float()
all_reduce(total_expert_load_pass, group=ep_group)
# num_tokens_per_rank: (num_moe_layers, num_ranks)
num_tokens_per_rank = total_expert_load_pass.reshape(
total_expert_load_pass.shape[0], ep_group.size(),
-1).sum(dim=-1).float()
# Compute balancedness ratio:
# for each layer:
@ -426,17 +434,7 @@ class EplbState:
"(profile)" if is_profile else "")
if global_expert_load is None:
# This mapping is only used here, so we do not store it in the state
physical_expert_start = ep_rank * model.num_local_physical_experts
physical_expert_end = (physical_expert_start +
model.num_local_physical_experts)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map = self.physical_to_logical_map[
:,
physical_expert_start:physical_expert_end,
]
# Map the local physical expert load to global logical experts
# Map the physical expert load to global logical experts
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
model.num_moe_layers,
@ -446,7 +444,7 @@ class EplbState:
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
index=self.physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(),
src=self.expert_load_window,
)
@ -618,4 +616,4 @@ def _node_count_with_rank_mapping(
if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id
return next_node_id
return next_node_id

View File

@ -1430,22 +1430,9 @@ class FusedMoE(torch.nn.Module):
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_logical_experts,)
# `expert_load_view`: (num_physical_experts,)
# Mask out non-local experts
if expert_map is not None:
topk_ids_local = expert_map[topk_ids]
topk_ids_flatten = topk_ids_local.flatten()
else:
topk_ids_flatten = topk_ids.flatten()
# Should be equivalent to:
# ```
# topk_ids_masked = topk_ids_local[topk_ids_local >= 0]
# expert_load_view += topk_ids_masked.bincount(
# minlength=expert_load_view.shape[0])
# ```
# We use `scatter_add_` since `bincount` cannot be compiled
topk_ids_flatten = topk_ids.flatten()
# Performance optimization:
# `masked_fill` is significantly faster than `masked_select`