Support expert parallel load balancing in Transformers backend (#26287)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-06 12:20:16 +01:00
committed by GitHub
parent 19a00eb210
commit 0340f45553
2 changed files with 76 additions and 27 deletions

View File

@ -40,7 +40,7 @@ from vllm.config import (
)
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
@ -506,9 +506,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config
self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
# Weights to skip in `self.load_weights`
self.skip_prefixes: list[str] = []
@ -576,7 +574,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
"""
Apply the model's pipeline parallelization plan.
"""
if self.pp_size <= 1:
if self.pp_group.world_size <= 1:
return
if not self.model.supports_pp_plan:
@ -613,7 +611,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Module list
start_layer, end_layer = get_pp_indices(
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
self.text_config.num_hidden_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
@ -638,7 +638,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
"""
tp_plan = self.model.tp_plan
if not tp_plan and self.tp_size > 1:
if not tp_plan and self.tp_group.world_size > 1:
tip = get_feature_request_tip(
self.model_config.model, self.model_config.trust_remote_code
)
@ -687,7 +687,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
self.text_config.num_hidden_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
attention_instances = {}
@ -749,7 +751,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if not get_pp_group().is_first_rank:
if not self.pp_group.is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
@ -773,7 +775,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return_dict=False,
)[0][0, ...] # we remove batch dimension for now
if not get_pp_group().is_last_rank:
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
@ -811,7 +813,7 @@ class TransformersForCausalLM(TransformersBase):
if self.text_config.tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
if get_pp_group().is_last_rank:
if self.pp_group.is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.text_config.vocab_size,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .interfaces import MixtureOfExperts
from .transformers import (
TransformersBase,
TransformersForCausalLM,
@ -116,17 +117,41 @@ direct_register_custom_op(
)
class TransformersMoEBase(TransformersBase):
class TransformersMoEBase(TransformersBase, MixtureOfExperts):
def __init__(self, *, vllm_config, prefix=""):
self.check_version("4.57.0.dev0", "MoE models support")
self.ep_group = get_ep_group()
super().__init__(vllm_config=vllm_config, prefix=prefix)
if self.parallel_config.enable_eplb:
raise NotImplementedError(
"Transformers backend does not support expert parallel load "
"balancing yet."
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
):
for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers):
mlp_layer.experts.set_eplb_state(
moe_layer_idx=moe_layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
):
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for mlp in self.mlp_layers:
mlp.n_local_physical_experts = num_local_physical_experts
mlp.n_physical_experts = num_physical_experts
mlp.n_redundant_experts = self.num_redundant_experts
mlp.experts.update_expert_map()
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
"""
Params for weights, fp8 weight scales, fp8 activation scales
@ -138,6 +163,8 @@ class TransformersMoEBase(TransformersBase):
("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style
("linear", "linear_1", "linear_v"), # Grok1 style
]
num_experts = self.model_config.get_num_experts()
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
expert_mapping = []
for gate_proj, down_proj, up_proj in ckpt_names:
expert_mapping.extend(
@ -145,8 +172,8 @@ class TransformersMoEBase(TransformersBase):
ckpt_gate_proj_name=gate_proj,
ckpt_down_proj_name=down_proj,
ckpt_up_proj_name=up_proj,
num_experts=self.model_config.get_num_experts(),
num_redundant_experts=0, # TODO: enable EPLB
num_experts=num_experts,
num_redundant_experts=num_redundant_experts,
)
)
return expert_mapping
@ -167,12 +194,15 @@ class TransformersMoEBase(TransformersBase):
# If there are shared experts, the results are
# reduced after mlp.forward() not inside FusedMoE
num_experts_shared = getattr_iter(
num_shared_experts = getattr_iter(
text_config,
["num_experts_shared", "n_shared_experts", "moe_num_shared_experts"],
[
"n_shared_experts", # DeepSeek, Docs, GLM
"moe_num_shared_experts", # Aria, Ernie
],
0,
)
reduce_results = num_experts_shared == 0
reduce_results = num_shared_experts == 0
def add_all_reduce(mlp: nn.Module):
"""Adds an all-reduce to the output of `mlp.forward()`."""
@ -207,13 +237,23 @@ class TransformersMoEBase(TransformersBase):
# Expert mapping for `AutoWeightsLoader`
expert_mapping = self.get_expert_mapping()
# Configs
parallel_config = self.parallel_config
eplb_config = parallel_config.eplb_config
# Expert parallel load balancing kwargs
enable_eplb = parallel_config.enable_eplb
num_redundant_experts = eplb_config.num_redundant_experts
enable_eplb = self.parallel_config.enable_eplb
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
# MixtureOfExperts mixin settings
ep_size = self.ep_group.world_size
self.mlp_layers = [] # Used for MixtureOfExperts methods
self.expert_weights = []
self.num_moe_layers = 0
self.num_expert_groups = 1 if num_expert_group is None else num_expert_group
self.num_logical_experts = num_experts
self.num_physical_experts = num_experts + num_redundant_experts
self.num_local_physical_experts = self.num_physical_experts // ep_size
self.num_routed_experts = num_experts
self.num_shared_experts = num_shared_experts
self.num_redundant_experts = num_redundant_experts
# Recursively fuse MoE layers
def _recursive_replace(module: nn.Module, prefix: str):
@ -235,6 +275,9 @@ class TransformersMoEBase(TransformersBase):
for mlp_param_name, _ in mlp.named_parameters():
if "shared_expert" in mlp_param_name:
reduce_results = False
# If the config does not specify num_shared_experts, but
# the model has shared experts, we assume there is one.
self.num_shared_experts = 1
break
# Replace experts module with FusedMoE
fused_experts = TransformersFusedMoE(
@ -258,6 +301,10 @@ class TransformersMoEBase(TransformersBase):
)
mlp.experts = fused_experts
log_replacement(qual_name, experts, fused_experts)
# Update MixtureOfExperts mixin state
self.mlp_layers.append(mlp)
self.expert_weights.append(fused_experts.get_expert_weights())
self.num_moe_layers += 1
# If results are not all-reduced in FusedMoE, ensure they
# are all-reduced at the end of mlp.forward() if tensor
# parallel or expert parallel is enabled