[Model] Enable DP for ViT in Qwen2-VL (#25445)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-23 13:17:10 +08:00
committed by GitHub
parent 5774b0a1da
commit c98be0a232

View File

@ -66,6 +66,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
@ -217,17 +218,20 @@ class Qwen2VisionMLP(nn.Module):
act_layer: type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.fc1 = ColumnParallelLinear(in_features,
hidden_features,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel)
self.act = act_layer()
self.fc2 = RowParallelLinear(hidden_features,
in_features,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
@ -293,25 +297,28 @@ class Qwen2VisionAttention(nn.Module):
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_size = world_size
self.tp_size = (1 if use_data_parallel else
parallel_state.get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size)
num_heads, self.tp_size)
self.qkv = ColumnParallelLinear(input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
@ -453,6 +460,7 @@ class Qwen2VisionBlock(nn.Module):
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
@ -465,12 +473,14 @@ class Qwen2VisionBlock(nn.Module):
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
self.mlp = Qwen2VisionMLP(dim,
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
def forward(
self,
@ -531,6 +541,7 @@ class Qwen2VisionPatchMerger(nn.Module):
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
@ -542,13 +553,15 @@ class Qwen2VisionPatchMerger(nn.Module):
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.0"),
prefix=f"{prefix}.mlp.0",
disable_tp=use_data_parallel),
nn.GELU(),
RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.2"),
prefix=f"{prefix}.mlp.2",
disable_tp=use_data_parallel),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -600,6 +613,7 @@ class Qwen2VisionTransformer(nn.Module):
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
@ -613,6 +627,9 @@ class Qwen2VisionTransformer(nn.Module):
num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio
self.use_data_parallel = use_data_parallel
self.out_hidden_size = vision_config.hidden_size
self.spatial_merge_size = spatial_merge_size
self.num_heads = num_heads
self.embed_dim = embed_dim
@ -634,7 +651,8 @@ class Qwen2VisionTransformer(nn.Module):
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(depth)
])
self.merger = Qwen2VisionPatchMerger(
@ -643,6 +661,7 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
@ -659,8 +678,9 @@ class Qwen2VisionTransformer(nn.Module):
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
pos_ids = []
max_grid_size = 0
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
@ -678,8 +698,8 @@ class Qwen2VisionTransformer(nn.Module):
).permute(0, 2, 1, 3).flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
max_grid_size = max(max_grid_size, h, w)
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
@ -698,7 +718,7 @@ class Qwen2VisionTransformer(nn.Module):
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
@ -708,8 +728,9 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
grid_thw[:, 0]).cumsum(
grid_thw_ = torch.tensor(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
grid_thw_[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
@ -1112,6 +1133,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.": "language_model.model.",
})
supports_encoder_tp_data = True
def get_mrope_input_positions(
self,
input_tokens: list[int],
@ -1239,6 +1262,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.config = config
self.multimodal_config = multimodal_config
@ -1249,6 +1273,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
else:
self.visual = None
@ -1357,7 +1382,15 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = image_input["image_embeds"]
else:
pixel_values = image_input["pixel_values"]
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values,
grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
@ -1377,7 +1410,14 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
video_embeds = video_input["video_embeds"]
else:
pixel_values_videos = video_input["pixel_values_videos"]
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d")
else:
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw_list)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size