[model] fix: qwen3vl models shape mismatch error with SP (#3735)

This commit is contained in:
ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟
2025-10-13 08:09:10 +03:00
committed by GitHub
parent 9d4554b931
commit e9ee6b39c6

View File

@ -127,6 +127,8 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
inputs_embeds = kwargs.get("inputs_embeds")
position_ids = kwargs.get("position_ids")
visual_pos_masks = kwargs.get("visual_pos_masks")
deepstack_visual_embeds = kwargs.get("deepstack_visual_embeds")
call_kwargs = kwargs.copy()
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
@ -139,6 +141,43 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
if slice_now:
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
# Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models
if visual_pos_masks is not None:
original_visual_mask = visual_pos_masks
sliced_visual_mask = slice_input_tensor(visual_pos_masks, dim=1, padding=False)
call_kwargs["visual_pos_masks"] = sliced_visual_mask
if deepstack_visual_embeds is not None:
sliced_embeds = []
num_visual_before = original_visual_mask.sum().item()
num_visual_in_shard = sliced_visual_mask.sum().item()
if num_visual_in_shard > 0 and num_visual_before > 0:
# Calculate which visual embeddings belong to this shard
# We need to find the offset of visual tokens in this shard
from verl.utils.ulysses import get_ulysses_sequence_parallel_rank
rank = get_ulysses_sequence_parallel_rank()
seq_len = original_visual_mask.shape[1]
local_seq_len = seq_len // current_ulysses_sp_size
start_idx = rank * local_seq_len
end_idx = start_idx + local_seq_len
# Get total visual tokens before and up to the end of the shard's sequence slice
# This correctly handles batches by summing across all samples
visual_start = original_visual_mask[:, :start_idx].sum().item() if start_idx > 0 else 0
visual_end = original_visual_mask[:, :end_idx].sum().item()
# Slice each tensor in deepstack_visual_embeds
for embed in deepstack_visual_embeds:
sliced_embeds.append(embed[visual_start:visual_end])
else:
# No visual tokens in this shard, create empty tensors to maintain gradient flow
for embed in deepstack_visual_embeds:
sliced_embeds.append(embed[:0])
call_kwargs["deepstack_visual_embeds"] = sliced_embeds
self._needs_initial_slice = False
try:
return original_forward(self, *args, **call_kwargs)
@ -290,9 +329,7 @@ def apply_monkey_patch(
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLFlashAttention2 as Qwen2VLAttention,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
if use_remove_padding or ulysses_sp_size > 1:
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward