mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] EPLB load statistics problem (#22167)
Signed-off-by: ycyaw66 <497410282@qq.com> Signed-off-by: David Chen <530634352@qq.com> Co-authored-by: ycyaw66 <497410282@qq.com>
This commit is contained in:
@ -32,7 +32,7 @@ from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, all_gather, all_reduce
|
||||
from torch.distributed import ProcessGroup, all_reduce
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
|
||||
@ -112,13 +112,21 @@ class EplbState:
|
||||
Expert load during this forward pass.
|
||||
We use the token count each expert processes as the load.
|
||||
|
||||
Shape: (num_moe_layers, num_local_physical_experts)
|
||||
Shape: (num_moe_layers, num_physical_experts)
|
||||
"""
|
||||
expert_load_window: torch.Tensor
|
||||
"""
|
||||
A sliding window of expert load.
|
||||
|
||||
Shape: (window_size, num_moe_layers, num_local_physical_experts)
|
||||
Shape: (window_size, num_moe_layers, num_physical_experts)
|
||||
|
||||
NOTE: The expert_load_view now records load for all physical experts
|
||||
rather than just local experts. This ensures consistent load statistics
|
||||
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
|
||||
The recorded load will be multiplied by dp_size when using naive all-to-all
|
||||
due to each DP rank contributing the same token set to the calculation.
|
||||
See:
|
||||
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
|
||||
"""
|
||||
expert_load_window_step: int = 0
|
||||
"""
|
||||
@ -232,14 +240,14 @@ class EplbState:
|
||||
).contiguous()
|
||||
|
||||
expert_load_pass = torch.zeros(
|
||||
(model.num_moe_layers, model.num_local_physical_experts),
|
||||
(model.num_moe_layers, model.num_physical_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
expert_load_window_size = parallel_config.eplb_window_size
|
||||
expert_load_window = torch.zeros(
|
||||
(expert_load_window_size, model.num_moe_layers,
|
||||
model.num_local_physical_experts),
|
||||
model.num_physical_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
@ -353,18 +361,18 @@ class EplbState:
|
||||
self.expert_load_pass.zero_()
|
||||
|
||||
if log_stats:
|
||||
# `num_tokens`: (num_moe_layers,)
|
||||
num_tokens = self.expert_load_pass.sum(dim=-1)
|
||||
# total_expert_load_pass: (num_moe_layers, num_physical_experts)
|
||||
total_expert_load_pass = self.expert_load_pass.clone()
|
||||
|
||||
# Collect load metrics from all ranks
|
||||
ep_group = get_ep_group().device_group
|
||||
assert ep_group is not None
|
||||
num_tokens_list = [
|
||||
torch.empty_like(num_tokens) for _ in range(ep_group.size())
|
||||
]
|
||||
all_gather(num_tokens_list, num_tokens, group=ep_group)
|
||||
# Stack to get (num_ranks, num_moe_layers)
|
||||
num_tokens_per_rank = torch.stack(num_tokens_list).float()
|
||||
all_reduce(total_expert_load_pass, group=ep_group)
|
||||
|
||||
# num_tokens_per_rank: (num_moe_layers, num_ranks)
|
||||
num_tokens_per_rank = total_expert_load_pass.reshape(
|
||||
total_expert_load_pass.shape[0], ep_group.size(),
|
||||
-1).sum(dim=-1).float()
|
||||
|
||||
# Compute balancedness ratio:
|
||||
# for each layer:
|
||||
@ -426,17 +434,7 @@ class EplbState:
|
||||
"(profile)" if is_profile else "")
|
||||
|
||||
if global_expert_load is None:
|
||||
# This mapping is only used here, so we do not store it in the state
|
||||
physical_expert_start = ep_rank * model.num_local_physical_experts
|
||||
physical_expert_end = (physical_expert_start +
|
||||
model.num_local_physical_experts)
|
||||
# (num_moe_layers, num_local_physical_experts)
|
||||
local_physical_to_logical_map = self.physical_to_logical_map[
|
||||
:,
|
||||
physical_expert_start:physical_expert_end,
|
||||
]
|
||||
|
||||
# Map the local physical expert load to global logical experts
|
||||
# Map the physical expert load to global logical experts
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
model.num_moe_layers,
|
||||
@ -446,7 +444,7 @@ class EplbState:
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
|
||||
index=self.physical_to_logical_map.unsqueeze(0).expand_as(
|
||||
self.expert_load_window).long(),
|
||||
src=self.expert_load_window,
|
||||
)
|
||||
@ -618,4 +616,4 @@ def _node_count_with_rank_mapping(
|
||||
if is_same_node and node_assignment[other_rank] == 0:
|
||||
node_assignment[other_rank] = next_node_id
|
||||
|
||||
return next_node_id
|
||||
return next_node_id
|
@ -1430,22 +1430,9 @@ class FusedMoE(torch.nn.Module):
|
||||
# to the modular kernel, we can move this logic there
|
||||
# to achieve better efficiency.
|
||||
|
||||
# `expert_load_view`: (num_logical_experts,)
|
||||
# `expert_load_view`: (num_physical_experts,)
|
||||
|
||||
# Mask out non-local experts
|
||||
if expert_map is not None:
|
||||
topk_ids_local = expert_map[topk_ids]
|
||||
topk_ids_flatten = topk_ids_local.flatten()
|
||||
else:
|
||||
topk_ids_flatten = topk_ids.flatten()
|
||||
|
||||
# Should be equivalent to:
|
||||
# ```
|
||||
# topk_ids_masked = topk_ids_local[topk_ids_local >= 0]
|
||||
# expert_load_view += topk_ids_masked.bincount(
|
||||
# minlength=expert_load_view.shape[0])
|
||||
# ```
|
||||
# We use `scatter_add_` since `bincount` cannot be compiled
|
||||
topk_ids_flatten = topk_ids.flatten()
|
||||
|
||||
# Performance optimization:
|
||||
# `masked_fill` is significantly faster than `masked_select`
|
||||
|
Reference in New Issue
Block a user