mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Qwen2.5 VL SiLU-and-Mul (#22066)
Signed-off-by: kf <kuanfu.liu@embeddedllm.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: kf <kuanfu.liu@embeddedllm.com>
This commit is contained in:
@ -43,9 +43,10 @@ from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -171,16 +172,12 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.gate_proj = ColumnParallelLinear(in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_proj")
|
||||
self.up_proj = ColumnParallelLinear(in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj")
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
@ -189,10 +186,9 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
self.act_fn = act_fn
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x_gate, _ = self.gate_proj(x)
|
||||
x_gate = self.act_fn(x_gate)
|
||||
x_up, _ = self.up_proj(x)
|
||||
x_down, _ = self.down_proj(x_gate * x_up)
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x_down, _ = self.down_proj(x)
|
||||
return x_down
|
||||
|
||||
|
||||
@ -540,14 +536,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen2_5_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
Qwen2_5_VisionBlock(dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=get_act_and_mul_fn(
|
||||
vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(depth)
|
||||
])
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
@ -752,6 +748,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
("attn.qkv.", "attn.q.", "q"),
|
||||
("attn.qkv.", "attn.k.", "k"),
|
||||
("attn.qkv.", "attn.v.", "v"),
|
||||
("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
|
||||
("mlp.gate_up_proj.", "mlp.up_proj.", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
|
Reference in New Issue
Block a user