mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[model] fix: qwen3vl models shape mismatch error with SP (#3735)
This commit is contained in:
@ -127,6 +127,8 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
|
||||
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
|
||||
inputs_embeds = kwargs.get("inputs_embeds")
|
||||
position_ids = kwargs.get("position_ids")
|
||||
visual_pos_masks = kwargs.get("visual_pos_masks")
|
||||
deepstack_visual_embeds = kwargs.get("deepstack_visual_embeds")
|
||||
call_kwargs = kwargs.copy()
|
||||
|
||||
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
@ -139,6 +141,43 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
|
||||
if slice_now:
|
||||
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
|
||||
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
|
||||
# Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models
|
||||
if visual_pos_masks is not None:
|
||||
original_visual_mask = visual_pos_masks
|
||||
sliced_visual_mask = slice_input_tensor(visual_pos_masks, dim=1, padding=False)
|
||||
call_kwargs["visual_pos_masks"] = sliced_visual_mask
|
||||
|
||||
if deepstack_visual_embeds is not None:
|
||||
sliced_embeds = []
|
||||
|
||||
num_visual_before = original_visual_mask.sum().item()
|
||||
num_visual_in_shard = sliced_visual_mask.sum().item()
|
||||
|
||||
if num_visual_in_shard > 0 and num_visual_before > 0:
|
||||
# Calculate which visual embeddings belong to this shard
|
||||
# We need to find the offset of visual tokens in this shard
|
||||
from verl.utils.ulysses import get_ulysses_sequence_parallel_rank
|
||||
|
||||
rank = get_ulysses_sequence_parallel_rank()
|
||||
seq_len = original_visual_mask.shape[1]
|
||||
local_seq_len = seq_len // current_ulysses_sp_size
|
||||
start_idx = rank * local_seq_len
|
||||
end_idx = start_idx + local_seq_len
|
||||
|
||||
# Get total visual tokens before and up to the end of the shard's sequence slice
|
||||
# This correctly handles batches by summing across all samples
|
||||
visual_start = original_visual_mask[:, :start_idx].sum().item() if start_idx > 0 else 0
|
||||
visual_end = original_visual_mask[:, :end_idx].sum().item()
|
||||
|
||||
# Slice each tensor in deepstack_visual_embeds
|
||||
for embed in deepstack_visual_embeds:
|
||||
sliced_embeds.append(embed[visual_start:visual_end])
|
||||
else:
|
||||
# No visual tokens in this shard, create empty tensors to maintain gradient flow
|
||||
for embed in deepstack_visual_embeds:
|
||||
sliced_embeds.append(embed[:0])
|
||||
call_kwargs["deepstack_visual_embeds"] = sliced_embeds
|
||||
|
||||
self._needs_initial_slice = False
|
||||
try:
|
||||
return original_forward(self, *args, **call_kwargs)
|
||||
@ -290,9 +329,7 @@ def apply_monkey_patch(
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
|
||||
)
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2VLFlashAttention2 as Qwen2VLAttention,
|
||||
)
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
|
||||
|
||||
if use_remove_padding or ulysses_sp_size > 1:
|
||||
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward
|
||||
|
Reference in New Issue
Block a user