diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 38c9d5abb5..68915d60ef 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -325,7 +325,7 @@ class SupportsLoRA(Protocol): # are empty by default. embedding_modules: ClassVar[dict[str, str]] = {} embedding_padding_modules: ClassVar[list[str]] = [] - packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} + packed_modules_mapping: dict[str, list[str]] = {} # We can't use runtime_checkable with ClassVar for issubclass checks diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index c57299a2d3..7251e7b2ee 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -534,11 +534,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): "q_proj", "k_proj", "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + ] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -547,6 +543,18 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + # Only perform the following mapping when Qwen2MoeMLP exists + if ( + getattr(config, "mlp_only_layers", []) + or config.shared_expert_intermediate_size > 0 + ): + self.packed_modules_mapping["gate_up_proj"] = ( + [ + "gate_proj", + "up_proj", + ], + ) + self.model = Qwen2MoeModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 825272535a..0769378933 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -634,11 +634,7 @@ class Qwen3MoeForCausalLM( "q_proj", "k_proj", "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + ] } fall_back_to_pt_during_load = False @@ -649,6 +645,14 @@ class Qwen3MoeForCausalLM( quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + # Only perform the following mapping when Qwen3MoeMLP exists + if getattr(config, "mlp_only_layers", []): + self.packed_modules_mapping["gate_up_proj"] = ( + [ + "gate_proj", + "up_proj", + ], + ) self.model = Qwen3MoeModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") )