mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Support expert parallel load balancing in Transformers backend (#26287)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -40,7 +40,7 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group, get_tp_group
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -506,9 +506,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config
|
||||
|
||||
self.pp_group = get_pp_group()
|
||||
self.pp_size = self.pp_group.world_size
|
||||
self.pp_rank = self.pp_group.rank_in_group
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# Weights to skip in `self.load_weights`
|
||||
self.skip_prefixes: list[str] = []
|
||||
@ -576,7 +574,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
"""
|
||||
Apply the model's pipeline parallelization plan.
|
||||
"""
|
||||
if self.pp_size <= 1:
|
||||
if self.pp_group.world_size <= 1:
|
||||
return
|
||||
|
||||
if not self.model.supports_pp_plan:
|
||||
@ -613,7 +611,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
|
||||
# Module list
|
||||
start_layer, end_layer = get_pp_indices(
|
||||
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
|
||||
self.text_config.num_hidden_layers,
|
||||
self.pp_group.rank_in_group,
|
||||
self.pp_group.world_size,
|
||||
)
|
||||
layers_name = pp_plan[module_list_idx]
|
||||
layers = getattr(self.model, layers_name)
|
||||
@ -638,7 +638,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
"""
|
||||
tp_plan = self.model.tp_plan
|
||||
|
||||
if not tp_plan and self.tp_size > 1:
|
||||
if not tp_plan and self.tp_group.world_size > 1:
|
||||
tip = get_feature_request_tip(
|
||||
self.model_config.model, self.model_config.trust_remote_code
|
||||
)
|
||||
@ -687,7 +687,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
head_size = self.model_config.get_head_size()
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
start, end = get_pp_indices(
|
||||
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
|
||||
self.text_config.num_hidden_layers,
|
||||
self.pp_group.rank_in_group,
|
||||
self.pp_group.world_size,
|
||||
)
|
||||
|
||||
attention_instances = {}
|
||||
@ -749,7 +751,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if not get_pp_group().is_first_rank:
|
||||
if not self.pp_group.is_first_rank:
|
||||
assert intermediate_tensors is not None
|
||||
input_ids = None
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
@ -773,7 +775,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
return_dict=False,
|
||||
)[0][0, ...] # we remove batch dimension for now
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
if not self.pp_group.is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
|
||||
return hidden_states
|
||||
@ -811,7 +813,7 @@ class TransformersForCausalLM(TransformersBase):
|
||||
if self.text_config.tie_word_embeddings:
|
||||
self.skip_prefixes.append("lm_head.")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.pp_group.is_last_rank:
|
||||
self.unpadded_vocab_size = self.text_config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.text_config.vocab_size,
|
||||
|
@ -30,6 +30,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .interfaces import MixtureOfExperts
|
||||
from .transformers import (
|
||||
TransformersBase,
|
||||
TransformersForCausalLM,
|
||||
@ -116,17 +117,41 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
class TransformersMoEBase(TransformersBase):
|
||||
class TransformersMoEBase(TransformersBase, MixtureOfExperts):
|
||||
def __init__(self, *, vllm_config, prefix=""):
|
||||
self.check_version("4.57.0.dev0", "MoE models support")
|
||||
self.ep_group = get_ep_group()
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
if self.parallel_config.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"Transformers backend does not support expert parallel load "
|
||||
"balancing yet."
|
||||
def set_eplb_state(
|
||||
self,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
):
|
||||
for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers):
|
||||
mlp_layer.experts.set_eplb_state(
|
||||
moe_layer_idx=moe_layer_idx,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
):
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||
for mlp in self.mlp_layers:
|
||||
mlp.n_local_physical_experts = num_local_physical_experts
|
||||
mlp.n_physical_experts = num_physical_experts
|
||||
mlp.n_redundant_experts = self.num_redundant_experts
|
||||
mlp.experts.update_expert_map()
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
"""
|
||||
Params for weights, fp8 weight scales, fp8 activation scales
|
||||
@ -138,6 +163,8 @@ class TransformersMoEBase(TransformersBase):
|
||||
("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style
|
||||
("linear", "linear_1", "linear_v"), # Grok1 style
|
||||
]
|
||||
num_experts = self.model_config.get_num_experts()
|
||||
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
|
||||
expert_mapping = []
|
||||
for gate_proj, down_proj, up_proj in ckpt_names:
|
||||
expert_mapping.extend(
|
||||
@ -145,8 +172,8 @@ class TransformersMoEBase(TransformersBase):
|
||||
ckpt_gate_proj_name=gate_proj,
|
||||
ckpt_down_proj_name=down_proj,
|
||||
ckpt_up_proj_name=up_proj,
|
||||
num_experts=self.model_config.get_num_experts(),
|
||||
num_redundant_experts=0, # TODO: enable EPLB
|
||||
num_experts=num_experts,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
)
|
||||
)
|
||||
return expert_mapping
|
||||
@ -167,12 +194,15 @@ class TransformersMoEBase(TransformersBase):
|
||||
|
||||
# If there are shared experts, the results are
|
||||
# reduced after mlp.forward() not inside FusedMoE
|
||||
num_experts_shared = getattr_iter(
|
||||
num_shared_experts = getattr_iter(
|
||||
text_config,
|
||||
["num_experts_shared", "n_shared_experts", "moe_num_shared_experts"],
|
||||
[
|
||||
"n_shared_experts", # DeepSeek, Docs, GLM
|
||||
"moe_num_shared_experts", # Aria, Ernie
|
||||
],
|
||||
0,
|
||||
)
|
||||
reduce_results = num_experts_shared == 0
|
||||
reduce_results = num_shared_experts == 0
|
||||
|
||||
def add_all_reduce(mlp: nn.Module):
|
||||
"""Adds an all-reduce to the output of `mlp.forward()`."""
|
||||
@ -207,13 +237,23 @@ class TransformersMoEBase(TransformersBase):
|
||||
# Expert mapping for `AutoWeightsLoader`
|
||||
expert_mapping = self.get_expert_mapping()
|
||||
|
||||
# Configs
|
||||
parallel_config = self.parallel_config
|
||||
eplb_config = parallel_config.eplb_config
|
||||
|
||||
# Expert parallel load balancing kwargs
|
||||
enable_eplb = parallel_config.enable_eplb
|
||||
num_redundant_experts = eplb_config.num_redundant_experts
|
||||
enable_eplb = self.parallel_config.enable_eplb
|
||||
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
|
||||
|
||||
# MixtureOfExperts mixin settings
|
||||
ep_size = self.ep_group.world_size
|
||||
|
||||
self.mlp_layers = [] # Used for MixtureOfExperts methods
|
||||
self.expert_weights = []
|
||||
self.num_moe_layers = 0
|
||||
self.num_expert_groups = 1 if num_expert_group is None else num_expert_group
|
||||
self.num_logical_experts = num_experts
|
||||
self.num_physical_experts = num_experts + num_redundant_experts
|
||||
self.num_local_physical_experts = self.num_physical_experts // ep_size
|
||||
self.num_routed_experts = num_experts
|
||||
self.num_shared_experts = num_shared_experts
|
||||
self.num_redundant_experts = num_redundant_experts
|
||||
|
||||
# Recursively fuse MoE layers
|
||||
def _recursive_replace(module: nn.Module, prefix: str):
|
||||
@ -235,6 +275,9 @@ class TransformersMoEBase(TransformersBase):
|
||||
for mlp_param_name, _ in mlp.named_parameters():
|
||||
if "shared_expert" in mlp_param_name:
|
||||
reduce_results = False
|
||||
# If the config does not specify num_shared_experts, but
|
||||
# the model has shared experts, we assume there is one.
|
||||
self.num_shared_experts = 1
|
||||
break
|
||||
# Replace experts module with FusedMoE
|
||||
fused_experts = TransformersFusedMoE(
|
||||
@ -258,6 +301,10 @@ class TransformersMoEBase(TransformersBase):
|
||||
)
|
||||
mlp.experts = fused_experts
|
||||
log_replacement(qual_name, experts, fused_experts)
|
||||
# Update MixtureOfExperts mixin state
|
||||
self.mlp_layers.append(mlp)
|
||||
self.expert_weights.append(fused_experts.get_expert_weights())
|
||||
self.num_moe_layers += 1
|
||||
# If results are not all-reduced in FusedMoE, ensure they
|
||||
# are all-reduced at the end of mlp.forward() if tensor
|
||||
# parallel or expert parallel is enabled
|
||||
|
Reference in New Issue
Block a user