mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[EPLB] Support ernie4.5-moe (#22100)
Signed-off-by: Haisheng Chen <langzs335@outlook.com> Signed-off-by: Haisheng Chen <60504847+HsChen-sys@users.noreply.github.com> Signed-off-by: Haisheng Chen <hac048@ucsd.edu> Co-authored-by: Haisheng Chen <langzs335@outlook.com>
This commit is contained in:
@ -33,8 +33,12 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
@ -58,7 +62,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@ -118,12 +122,34 @@ class Ernie4_5_MoeMoE(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.layer_idx = layer_idx
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None)
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts: int = config.moe_num_experts
|
||||
self.n_shared_experts: int = self.moe_num_shared_experts
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_redundant_experts = parallel_config.num_redundant_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
|
||||
self.physical_expert_end = (
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
self.has_shared_experts = getattr(config, "moe_num_shared_experts", 0) > 0
|
||||
|
||||
if self.tp_size > config.moe_num_experts:
|
||||
@ -171,6 +197,8 @@ class Ernie4_5_MoeMoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -298,6 +326,7 @@ class Ernie4_5_MoeDecoderLayer(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -338,7 +367,10 @@ class Ernie4_5_MoeDecoderLayer(nn.Module):
|
||||
and layer_idx <= moe_layer_end_index
|
||||
):
|
||||
self.mlp = Ernie4_5_MoeMoE(
|
||||
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = Ernie4_5_MoeMLP(
|
||||
@ -393,6 +425,9 @@ class Ernie4_5_MoeModel(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
enable_eplb = parallel_config.enable_eplb
|
||||
self.num_redundant_experts = parallel_config.num_redundant_experts
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@ -411,6 +446,7 @@ class Ernie4_5_MoeModel(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
@ -465,6 +501,7 @@ class Ernie4_5_MoeModel(nn.Module):
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.moe_num_experts,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
@ -513,15 +550,22 @@ class Ernie4_5_MoeModel(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
@ -541,6 +585,12 @@ class Ernie4_5_MoeModel(nn.Module):
|
||||
)
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
@ -563,7 +613,7 @@ class Ernie4_5_MoeModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -605,6 +655,81 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
self.expert_weights = []
|
||||
|
||||
# Set MoE hyperparameters
|
||||
moe_layers_indices = [
|
||||
i
|
||||
for i in range(config.num_hidden_layers)
|
||||
if (
|
||||
i >= config.moe_layer_start_index
|
||||
and i <= config.moe_layer_end_index
|
||||
and (i + 1) % config.moe_layer_interval == 0
|
||||
)
|
||||
]
|
||||
self.num_moe_layers = len(moe_layers_indices)
|
||||
self.num_expert_groups = 1
|
||||
|
||||
self.moe_layers: list[SharedFusedMoE] = []
|
||||
example_moe = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Ernie4_5_MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Ernie4_5_MoeMoE):
|
||||
example_moe = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_moe is None:
|
||||
logger.warning("No Ernie4_5_MoeMoE layer found in model.layers.")
|
||||
self.num_logical_experts = 0
|
||||
self.num_physical_experts = 0
|
||||
self.num_local_physical_experts = 0
|
||||
self.num_routed_experts = 0
|
||||
self.num_shared_experts = 0
|
||||
self.num_redundant_experts = 0
|
||||
else:
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_shared_experts = example_moe.n_shared_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> None:
|
||||
for layer_idx, layer in enumerate(self.moe_layers):
|
||||
# Register the expert weights.
|
||||
self.expert_weights.append(layer.get_expert_weights())
|
||||
layer.set_eplb_state(
|
||||
moe_layer_idx=layer_idx,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer.mlp, Ernie4_5_MoeMoE):
|
||||
moe = layer.mlp
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
|
Reference in New Issue
Block a user