[model] fix: refactor qwen2vl patches & support no-image input for fsdp (#3496)

### What does this PR do?

This PR tries to fix #3491 

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

Tested with [latest
transformers](6e50a8afb2)

<img width="2448" height="540" alt="image"
src="https://github.com/user-attachments/assets/06d40f40-572c-4454-8e08-115857f61f21"
/>
<img width="2796" height="1394" alt="image"
src="https://github.com/user-attachments/assets/17489b9c-e376-46e3-80d8-71106d304077"
/>
<img width="2098" height="744" alt="image"
src="https://github.com/user-attachments/assets/8c7f736d-bf09-4ba9-9cf4-0d56e367c526"
/>

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

#### ⚠️ Breaking

We adopt a new format for Qwen2VL's position ids: (4, batch size, seq
len)

Assuming a vision position ids (mrope) has a shape of (3, batch size,
seq len) and a text position ids (normal rope) has a shape of (1, batch
size, seq len), we concatenate both to obtain the final position ids.

This aligns with the implementation in the Transformers >= 4.54.0 🤗

https://github.com/huggingface/transformers/blob/v4.54.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1469

#### 🎤 New

We have refactored the Qwen2VL and Qwen2.5VL patches, supporting
no-image input for FSDP by introducing fake ViT inputs. We have also
removed some redundant code for better maintainability.

#### 🚨 Changes

We move the ulysses logic into the attention function. So the position
ids will be scattered before the language model part.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
Yaowei Zheng
2025-09-18 10:10:30 +08:00
committed by GitHub
parent 214d0f0a94
commit 0d4541f397
11 changed files with 334 additions and 792 deletions

View File

@ -14,6 +14,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \
actor_rollout_ref.actor.use_kl_loss=True \

View File

@ -15,17 +15,15 @@
Apply monkey-patch function to models
"""
import importlib.metadata
import sys
from functools import lru_cache
from typing import Optional
import torch
from packaging import version
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_utils import PreTrainedModel
from verl.utils.import_utils import is_trl_available
from verl.utils.transformers_compat import is_transformers_version_in_range
from verl.utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
@ -51,6 +49,8 @@ def _ulysses_flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
query_length: int,
*args,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
@ -58,6 +58,10 @@ def _ulysses_flash_attention_forward(
"""Insert all-to-all before and after flash attention.
DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509
For transformers>=4.55, the flash attention api has changed,
we need to pass the query_length after doing ulysses all2all.
See https://github.com/huggingface/transformers/issues/40399
Args:
query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim)
key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)
@ -66,64 +70,7 @@ def _ulysses_flash_attention_forward(
Returns:
torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim)
"""
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism"
# NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,
# we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.
# For example:
# - nheads_k=4, sp=8, repeats=2
# - nheads_k=8, sp=8, repeats=1
# - nheads_k=16, sp=8, repeats=1
repeats = max(ulysses_sp_size // key_states.size(2), 1)
key_states = repeat_kv(key_states, repeats)
value_states = repeat_kv(value_states, repeats)
# (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
# TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate
# this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly.
# https://github.com/huggingface/transformers/pull/33932
# (bsz, seq_len/n) -> (bsz, seq_len)
position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]
torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.concat(position_ids_list, dim=-1)
# (bsz, seq_len, n_head/n, head_dim)
attn_output = _flash_attention_forward(
query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs
)
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
# (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
return attn_output
def _ulysses_flash_attention_forward_transformers_4_55(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
query_length: int,
*args,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
):
"""For transformers>=4.55, the flash attention api has changed,
we need to pass the query_length after doing ulysses alltoall.
See https://github.com/huggingface/transformers/issues/40399
"""
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
@ -178,6 +125,7 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
def _create_ulysses_wrapped_decoder_forward(original_forward):
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
inputs_embeds = kwargs.get("inputs_embeds")
position_ids = kwargs.get("position_ids")
call_kwargs = kwargs.copy()
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
@ -189,6 +137,7 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
)
if slice_now:
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
self._needs_initial_slice = False
try:
return original_forward(self, *args, **call_kwargs)
@ -225,12 +174,7 @@ def patch_forward_with_backends(
forward_with_torch_backend_function = model.__class__.forward
forward_with_triton_backend_function = model.__class__.forward
if model.config.model_type == "qwen2_5_vl":
from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend
forward_with_torch_backend_function = forward_with_torch_backend
forward_with_triton_backend_function = forward_with_triton_backend
elif model.config.model_type == "qwen2_vl":
if model.config.model_type in ["qwen2_5_vl", "qwen2_vl"]:
from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend
forward_with_torch_backend_function = forward_with_torch_backend
@ -296,50 +240,70 @@ def apply_monkey_patch(
# TODO: VLM models only, unify monkey patch to LLM models.
if model.config.model_type == "qwen2_5_vl":
if is_transformers_version_in_range(min_version="4.53.0"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention
if is_transformers_version_in_range(min_version="4.52.0"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLAttention,
Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLModel,
Qwen2_5_VLTextModel,
)
from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward
Qwen2_5_VLModel.forward = qwen2_vl_base_forward
Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend
else:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel
from verl.models.transformers.qwen2_vl import forward_with_normal_backend
Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend
if use_remove_padding or ulysses_sp_size > 1:
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward
Qwen2_5_VLAttention.forward = ulysses_flash_attn_forward
print("Monkey patch FlashAttention2.forward in Qwen2.5VL")
Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward
print("Monkey patch Qwen2.5VL attention layer")
if ulysses_sp_size > 1:
if is_transformers_version_in_range(min_version="4.52.0"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
else:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel)
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
elif model.config.model_type == "qwen2_vl":
if is_transformers_version_in_range(min_version="4.53.0"):
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
if is_transformers_version_in_range(min_version="4.52.0"):
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLAttention,
Qwen2VLForConditionalGeneration,
Qwen2VLModel,
Qwen2VLTextModel,
)
from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward
Qwen2VLModel.forward = qwen2_vl_base_forward
Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend
else:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel as Qwen2VLTextModel
from verl.models.transformers.qwen2_vl import forward_with_normal_backend
Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend
if use_remove_padding or ulysses_sp_size > 1:
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward
Qwen2VLAttention.forward = ulysses_flash_attn_forward
print("Monkey patch FlashAttention2.forward in Qwen2VL")
Qwen2VLAttention.forward = qwen2_vl_attn_forward
print("Monkey patch Qwen2VL attention layer")
if ulysses_sp_size > 1:
if is_transformers_version_in_range(min_version="4.52.0"):
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)
else:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
patch_vlm_for_ulysses_input_slicing(Qwen2VLModel)
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)
elif model.config.model_type == "kimi_vl":
if use_remove_padding or ulysses_sp_size > 1:
@ -357,43 +321,14 @@ def apply_monkey_patch(
return
# transformers<=4.47.1
if use_remove_padding or ulysses_sp_size > 1:
if hasattr(module, "_flash_attention_forward"):
if hasattr(module, "_flash_attention_forward"): # transformers <= 4.47.1 or legacy models
module._flash_attention_forward = _ulysses_flash_attention_forward
print(f"Monkey patch _flash_attention_forward in {model.__module__}")
else:
if is_transformers_version_in_range(min_version="4.55.0"):
from transformers.integrations import flash_attention
from transformers.integrations import flash_attention
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward_transformers_4_55
print(f"Monkey patch _flash_attention_forward in {model.__module__} for new api")
else:
# 4.48.0 <= transformers <= 4.54.1, Vision attention
from transformers.integrations import flash_attention
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward
print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward
print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")
patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend)
@lru_cache
def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:
try:
# Get the installed version of the transformers library
transformers_version_str = importlib.metadata.version("transformers")
except importlib.metadata.PackageNotFoundError as e:
raise ModuleNotFoundError("The `transformers` package is not installed.") from e
transformers_version = version.parse(transformers_version_str)
lower_bound_check = True
if min_version is not None:
lower_bound_check = version.parse(min_version) <= transformers_version
upper_bound_check = True
if max_version is not None:
upper_bound_check = transformers_version <= version.parse(max_version)
return lower_bound_check and upper_bound_check

View File

@ -1,349 +0,0 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from importlib.metadata import version
from typing import Optional
import torch
from packaging.version import Version
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLCausalLMOutputWithPast,
Qwen2_5_VLForConditionalGeneration,
)
@dataclass
class Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast):
log_probs: Optional[torch.FloatTensor] = None
entropy: Optional[torch.FloatTensor] = None
def forward_base_model_old_api(
self: Qwen2_5_VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> tuple | Qwen2_5_VLCausalLMOutputWithPast:
r"""
Copy paste Qwen2_5_VL's forward
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, "
f"features {n_image_features}"
)
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, "
f"features {n_video_features}"
)
mask = input_ids == self.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
return outputs
def forward_base_model_new_api(
self: Qwen2_5_VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> tuple | Qwen2_5_VLCausalLMOutputWithPast:
r"""
Copy paste Qwen2_5_VL's forward
https://github.com/huggingface/transformers/blob/v4.52.3/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1384
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
return outputs
def forward_with_torch_backend(
self: Qwen2_5_VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
temperature: float = 1.0,
**loss_kwargs,
) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:
from verl.utils.experimental.torch_functional import FusedLinearForPPO
if Version(version("transformers")) < Version("4.52.0"):
forward_base_model = forward_base_model_old_api
else:
forward_base_model = forward_base_model_new_api
outputs = forward_base_model(
self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
rope_deltas=rope_deltas,
cache_position=cache_position,
second_per_grid_ts=second_per_grid_ts,
)
hidden_states = outputs[0]
if not return_dict:
raise NotImplementedError("forward_with_torch_backend has to return_dict")
# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.")
fused_linear_for_ppo = FusedLinearForPPO()
log_probs, entropy = fused_linear_for_ppo.forward(
hidden_states=hidden_states,
vocab_weights=self.lm_head.weight,
input_ids=rolled_labels,
temperature=temperature,
)
return Qwen2_5_VLCausalLMOutputForPPO(
log_probs=log_probs,
entropy=entropy,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)
def forward_with_triton_backend(
self: Qwen2_5_VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
temperature: float = 1.0,
**loss_kwargs,
) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
if Version(version("transformers")) < Version("4.52.0"):
forward_base_model = forward_base_model_old_api
else:
forward_base_model = forward_base_model_new_api
outputs = forward_base_model(
self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
rope_deltas=rope_deltas,
cache_position=cache_position,
second_per_grid_ts=second_per_grid_ts,
)
hidden_states = outputs[0]
if not return_dict:
raise NotImplementedError("forward_with_triton_backend has to return_dict")
# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.")
log_probs, entropy = linear_cross_entropy(
hidden_states,
self.lm_head.weight,
rolled_labels,
temperature,
"none",
)
return Qwen2_5_VLCausalLMOutputForPPO(
log_probs=log_probs,
entropy=entropy,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)

View File

@ -19,20 +19,20 @@ from dataclasses import dataclass
from typing import Optional
import torch
from transformers.modeling_flash_attention_utils import _flash_attention_forward
import torch.distributed as dist
from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLAttention,
Qwen2VLCausalLMOutputWithPast,
Qwen2VLForConditionalGeneration,
)
from transformers.utils import is_flash_attn_greater_or_equal
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
from verl.models.transformers.monkey_patch import is_transformers_version_in_range
# Import compatibility wrapper for flash_attn_supports_top_left_mask
from verl.utils.transformers_compat import flash_attn_supports_top_left_mask
from verl.utils.transformers_compat import is_transformers_version_in_range
from verl.utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
validate_ulysses_config,
)
@ -40,22 +40,14 @@ from verl.utils.ulysses import (
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
try:
from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
except ImportError:
# Fallback: try to import from flash_attn package directly
flash_attn_func = None
_flash_supports_window_size = None
try:
from flash_attn import flash_attn_varlen_func
except ImportError:
# If flash_attn is not available, set it to None
flash_attn_varlen_func = None
logger.warning(
"flash_attn_varlen_func not available. Variable length attention will fall back to standard attention."
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
_flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
_flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
_flash_deterministic_enabled = os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
_flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def get_rope_index(
@ -69,7 +61,7 @@ def get_rope_index(
"""
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1405
"""
spatial_merge_size = processor.image_processor.merge_size
tokens_per_second = 2
@ -161,26 +153,26 @@ def get_rope_index(
def prepare_fa2_from_position_ids(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
):
query = query.view(-1, query.size(-2), query.size(-1))
key = key.view(-1, key.size(-2), key.size(-1))
value = value.view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
assert position_ids.ndim == 2 # (batch_size, seq_length)
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.view(-1)
cu_seqlens = torch.cat(
(
indices_q[position_ids == 0],
(position_ids == 0).nonzero().view(-1).to(torch.int32),
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
return (query, key, value, (cu_seqlens, cu_seqlens), (max_length, max_length))
def flash_attention_forward(
def _custom_flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
attention_mask: Optional[torch.Tensor],
query_length: int,
is_causal: bool = True,
position_ids: Optional[torch.Tensor] = None,
@ -192,74 +184,58 @@ def flash_attention_forward(
"""
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
"""
causal = is_causal if not use_top_left_mask else is_causal and query_length != 1
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = deterministic
if _flash_supports_deterministic:
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
if (
flash_attn_varlen_func is not None
and position_ids is not None
and query_length != 1
and not (torch.diff(position_ids[0], dim=-1) >= 0).all()
):
if kwargs.get("softcap") is not None:
flash_kwargs["softcap"] = kwargs.pop("softcap")
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype=torch.bfloat16
)
if position_ids is not None:
assert position_ids.ndim == 2 # (batch_size, seq_length / sp_size)
sp_size = get_ulysses_sequence_parallel_world_size()
if sp_size > 1:
# qkv: (batch_size, seq_length / sp_size, num_head, head_size)
validate_ulysses_config(query_states.size(2), sp_size)
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.cat(position_ids_lst, dim=-1) # (batch_size, seq_length)
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids[0]
) # remove channel dimension
query_states, key_states, value_states, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
flash_attn_func = flash_attn_varlen_func
common_attn_kwargs = {
"cu_seqlens_q": cu_seqlens_q,
"cu_seqlens_k": cu_seqlens_k,
"max_seqlen_q": max_seqlen_in_batch_q,
"max_seqlen_k": max_seqlen_in_batch_k,
"dropout_p": kwargs.pop("dropout", 0.0),
"softmax_scale": kwargs.pop("softmax_scale", None),
**flash_kwargs,
}
if flash_attn_func is None:
# Use transformers >= 4.54
flash_attn_func = _flash_attention_forward
specific_attn_kwargs = {
"attention_mask": attention_mask,
"position_ids": position_ids,
"query_length": query_length,
"is_causal": causal,
}
else:
specific_attn_kwargs = {"causal": causal}
attn_output = flash_attn_func(
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
**common_attn_kwargs,
**specific_attn_kwargs,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=kwargs.pop("dropout", 0.0),
softmax_scale=kwargs.pop("softmax_scale", None),
causal=is_causal,
**flash_kwargs,
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
if (
flash_attn_varlen_func is None
and position_ids is not None
and query_length != 1
and not (torch.diff(position_ids[0], dim=-1) >= 0).all()
):
logger.warning_once(
"flash_attn_varlen_func is not available; falling back to _flash_attention_forward."
"This may be suboptimal for non-monotonic position_ids in VLM mRoPE."
)
attn_output = _flash_attention_forward(
query_states,
key_states,
@ -268,16 +244,20 @@ def flash_attention_forward(
query_length,
is_causal=is_causal,
sliding_window=sliding_window,
use_top_left_mask=flash_attn_supports_top_left_mask(),
use_top_left_mask=use_top_left_mask,
deterministic=deterministic,
**kwargs,
)
) # do not pass position_ids to old flash_attention_forward
if sp_size > 1:
# (batch_size, seq_length, num_head, head_size)
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
return attn_output
def ulysses_flash_attn_forward(
self,
def qwen2_vl_attn_forward(
self: "Qwen2VLAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -295,225 +275,215 @@ def ulysses_flash_attn_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
validate_ulysses_config(self.num_heads, ulysses_sp_size)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
# (batch_size, num_head / sp_size, seq_length, head_size)
full_q_len = query_states.size(2) # full_q_len = seq_length
else:
full_q_len = q_len
# Because the input can be padded, the absolute sequence length depends on the max position id.
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
else:
sliding_window = None
attn_output = flash_attention_forward(
# This is before the transpose
q_len = query_states.shape[2]
# FA2 uses non-transposed inputs
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if position_ids.ndim == 3:
position_ids = position_ids[0]
attn_output = _custom_flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
full_q_len,
query_length=q_len,
is_causal=getattr(self, "is_causal", True),
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=flash_attn_supports_top_left_mask(),
use_top_left_mask=_flash_use_top_left_mask,
position_ids=position_ids, # important: pass position ids
) # (batch_size, seq_length, num_head / sp_size, head_size)
if ulysses_sp_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
) # (batch_size, seq_length / sp_size, num_head, head_size)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if is_transformers_version_in_range(min_version="4.53.0"):
if is_transformers_version_in_range(min_version="4.54.0"):
return attn_output, None
else:
return attn_output, None, None
def _get_input_embeds(
model: "Qwen2VLForConditionalGeneration",
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = model.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(model.visual.dtype)
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == model.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == model.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(model.visual.dtype)
video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == model.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
mask = input_ids == model.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if model.training and pixel_values is None and pixel_values_videos is None: # handle mixed text-image data
pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
inputs_embeds += 0.0 * image_embeds.mean()
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
return inputs_embeds, attention_mask
def process_position_ids(position_ids: torch.Tensor) -> torch.Tensor:
if position_ids.ndim != 3 or position_ids.size(0) != 4:
# we concat the text position ids with the 3D vision position ids by default
# see https://github.com/huggingface/transformers/pull/39447
raise ValueError("position_ids should be a 3D tensor of shape (4, batch_size, seq_length).")
if is_transformers_version_in_range(max_version="4.53.3"):
# transformers < 4.54.0 only accepts vision position ids, so we discard the text position ids here
position_ids = position_ids[1:]
return position_ids
@dataclass
class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast):
log_probs: Optional[torch.FloatTensor] = None
entropy: Optional[torch.FloatTensor] = None
def forward_base_model(
self: Qwen2VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
def qwen2_vl_base_forward(
self: "Qwen2VLForConditionalGeneration",
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> tuple | Qwen2VLCausalLMOutputWithPast:
r"""
Copy paste Qwen2VL's forward
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, "
f"features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, "
f"features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
**kwargs,
):
kwargs["inputs_embeds"], kwargs["attention_mask"] = _get_input_embeds(
self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
) # avoid lora module having multiple keyword arguments
return self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
return outputs
def qwen2_vl_forward(
self: "Qwen2VLForConditionalGeneration",
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
if is_transformers_version_in_range(min_version="4.52.0"):
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=process_position_ids(position_ids),
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
**kwargs,
)
else:
inputs_embeds, attention_mask = _get_input_embeds(
self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
)
return self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=process_position_ids(position_ids),
inputs_embeds=inputs_embeds,
**kwargs,
)
def forward_with_normal_backend(
self: Qwen2VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
labels: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
**kwargs,
) -> "Qwen2VLCausalLMOutputWithPast":
outputs = qwen2_vl_forward(self, input_ids, **kwargs)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
return Qwen2VLCausalLMOutputWithPast(
logits=logits,
hidden_states=outputs.hidden_states,
)
def forward_with_torch_backend(
self: Qwen2VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
**loss_kwargs,
**kwargs,
) -> tuple | Qwen2VLCausalLMOutputForPPO:
from verl.utils.experimental.torch_functional import FusedLinearForPPO
outputs = forward_base_model(
self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
rope_deltas=rope_deltas,
cache_position=cache_position,
)
outputs = qwen2_vl_forward(self, input_ids, **kwargs)
hidden_states = outputs[0]
if not return_dict:
raise NotImplementedError("forward_with_torch_backend has to return_dict")
# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
@ -529,64 +499,25 @@ def forward_with_torch_backend(
input_ids=rolled_labels,
temperature=temperature,
)
return Qwen2VLCausalLMOutputForPPO(
log_probs=log_probs,
entropy=entropy,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)
def forward_with_triton_backend(
self: Qwen2VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
**loss_kwargs,
**kwargs,
) -> tuple | Qwen2VLCausalLMOutputForPPO:
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
outputs = forward_base_model(
self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
rope_deltas=rope_deltas,
cache_position=cache_position,
)
outputs = qwen2_vl_forward(self, input_ids, **kwargs)
hidden_states = outputs[0]
if not return_dict:
raise NotImplementedError("forward_with_triton_backend has to return_dict")
# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
@ -602,12 +533,8 @@ def forward_with_triton_backend(
temperature,
"none",
)
return Qwen2VLCausalLMOutputForPPO(
log_probs=log_probs,
entropy=entropy,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)

View File

@ -299,17 +299,18 @@ class RLHFDataset(Dataset):
if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
from verl.models.transformers.qwen2_vl import get_rope_index
position_ids = [
get_rope_index(
self.processor,
input_ids=input_ids[0],
image_grid_thw=model_inputs.get("image_grid_thw"),
video_grid_thw=model_inputs.get("video_grid_thw"),
second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
attention_mask=attention_mask[0],
)
] # (1, 3, seq_len)
vision_position_ids = get_rope_index(
self.processor,
input_ids=input_ids[0],
image_grid_thw=model_inputs.get("image_grid_thw"),
video_grid_thw=model_inputs.get("video_grid_thw"),
second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
attention_mask=attention_mask[0],
) # (3, seq_length)
valid_mask = attention_mask[0].bool()
text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long)
text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item())
position_ids = [torch.cat((text_position_ids, vision_position_ids), dim=0)] # (1, 4, seq_length)
else:
position_ids = compute_position_id_with_mask(attention_mask)

View File

@ -16,6 +16,12 @@
Compatibility utilities for different versions of transformers library.
"""
import importlib.metadata
from functools import lru_cache
from typing import Optional
from packaging import version
# Handle version compatibility for flash_attn_supports_top_left_mask
# This function was added in newer versions of transformers
try:
@ -28,3 +34,24 @@ except ImportError:
Returns False to disable features that require this function.
"""
return False
@lru_cache
def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:
try:
# Get the installed version of the transformers library
transformers_version_str = importlib.metadata.version("transformers")
except importlib.metadata.PackageNotFoundError as e:
raise ModuleNotFoundError("The `transformers` package is not installed.") from e
transformers_version = version.parse(transformers_version_str)
lower_bound_check = True
if min_version is not None:
lower_bound_check = version.parse(min_version) <= transformers_version
upper_bound_check = True
if max_version is not None:
upper_bound_check = transformers_version <= version.parse(max_version)
return lower_bound_check and upper_bound_check

View File

@ -288,7 +288,7 @@ def ulysses_pad(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torc
if position_ids_rmpad is not None:
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)
if position_ids_rmpad.dim() == 3:
pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(3, 1, 1)
pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(position_ids_rmpad.size(0), 1, 1)
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)
return input_ids_rmpad, position_ids_rmpad, pad_size

View File

@ -114,7 +114,7 @@ class DataParallelPPOActor(BasePPOActor):
position_ids = micro_batch["position_ids"]
entropy = None
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen)
if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(
@ -128,7 +128,7 @@ class DataParallelPPOActor(BasePPOActor):
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices

View File

@ -83,7 +83,7 @@ class DataParallelPPOCritic(BasePPOCritic):
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices

View File

@ -755,8 +755,8 @@ class SGLangRollout(BaseRollout):
response_length = response.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
if position_ids.dim() == 3: # qwen2vl mrope
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
if position_ids.dim() == 3: # qwen2vl mrope (batch size, 4, seq len)
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, position_ids.size(1), -1)
# TODO(sgm): fix position_ids on right_pad
# prompt: left pad + response: right pad

View File

@ -372,8 +372,8 @@ class vLLMRollout(BaseRollout):
response_length = response.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)
if position_ids.dim() == 3: # qwen2vl mrope
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
if position_ids.dim() == 3: # qwen2vl mrope (batch size, 4, seq len)
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, position_ids.size(1), -1)
# TODO(sgm): fix position_ids on right_pad
# prompt: left pad + response: right pad