mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
[model] fix: refactor qwen2vl patches & support no-image input for fsdp (#3496)
### What does this PR do?
This PR tries to fix #3491
### Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
Tested with [latest
transformers](6e50a8afb2
)
<img width="2448" height="540" alt="image"
src="https://github.com/user-attachments/assets/06d40f40-572c-4454-8e08-115857f61f21"
/>
<img width="2796" height="1394" alt="image"
src="https://github.com/user-attachments/assets/17489b9c-e376-46e3-80d8-71106d304077"
/>
<img width="2098" height="744" alt="image"
src="https://github.com/user-attachments/assets/8c7f736d-bf09-4ba9-9cf4-0d56e367c526"
/>
### API and Usage Example
> Demonstrate how the API changes if any, and provide usage example(s)
if possible.
```python
# Add code snippet or script demonstrating how to use this
```
### Design & Code Changes
#### ⚠️ Breaking
We adopt a new format for Qwen2VL's position ids: (4, batch size, seq
len)
Assuming a vision position ids (mrope) has a shape of (3, batch size,
seq len) and a text position ids (normal rope) has a shape of (1, batch
size, seq len), we concatenate both to obtain the final position ids.
This aligns with the implementation in the Transformers >= 4.54.0 🤗
https://github.com/huggingface/transformers/blob/v4.54.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1469
#### 🎤 New
We have refactored the Qwen2VL and Qwen2.5VL patches, supporting
no-image input for FSDP by introducing fake ViT inputs. We have also
removed some redundant code for better maintainability.
#### 🚨 Changes
We move the ulysses logic into the attention function. So the position
ids will be scattered before the language model part.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
@ -14,6 +14,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.model.use_fused_kernels=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
|
@ -15,17 +15,15 @@
|
||||
Apply monkey-patch function to models
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from verl.utils.import_utils import is_trl_available
|
||||
from verl.utils.transformers_compat import is_transformers_version_in_range
|
||||
from verl.utils.ulysses import (
|
||||
gather_heads_scatter_seq,
|
||||
gather_seq_scatter_heads,
|
||||
@ -51,6 +49,8 @@ def _ulysses_flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
query_length: int,
|
||||
*args,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
@ -58,6 +58,10 @@ def _ulysses_flash_attention_forward(
|
||||
"""Insert all-to-all before and after flash attention.
|
||||
DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509
|
||||
|
||||
For transformers>=4.55, the flash attention api has changed,
|
||||
we need to pass the query_length after doing ulysses all2all.
|
||||
See https://github.com/huggingface/transformers/issues/40399
|
||||
|
||||
Args:
|
||||
query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim)
|
||||
key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)
|
||||
@ -66,64 +70,7 @@ def _ulysses_flash_attention_forward(
|
||||
|
||||
Returns:
|
||||
torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim)
|
||||
"""
|
||||
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
|
||||
########## AlltoAll for Ulysses ##########
|
||||
if ulysses_sp_size > 1:
|
||||
assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism"
|
||||
|
||||
# NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,
|
||||
# we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.
|
||||
# For example:
|
||||
# - nheads_k=4, sp=8, repeats=2
|
||||
# - nheads_k=8, sp=8, repeats=1
|
||||
# - nheads_k=16, sp=8, repeats=1
|
||||
repeats = max(ulysses_sp_size // key_states.size(2), 1)
|
||||
key_states = repeat_kv(key_states, repeats)
|
||||
value_states = repeat_kv(value_states, repeats)
|
||||
|
||||
# (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)
|
||||
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
|
||||
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
|
||||
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
|
||||
|
||||
# TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate
|
||||
# this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly.
|
||||
# https://github.com/huggingface/transformers/pull/33932
|
||||
|
||||
# (bsz, seq_len/n) -> (bsz, seq_len)
|
||||
position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]
|
||||
torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
|
||||
position_ids = torch.concat(position_ids_list, dim=-1)
|
||||
|
||||
# (bsz, seq_len, n_head/n, head_dim)
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs
|
||||
)
|
||||
|
||||
########## AlltoAll for Ulysses ##########
|
||||
if ulysses_sp_size > 1:
|
||||
# (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)
|
||||
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def _ulysses_flash_attention_forward_transformers_4_55(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
query_length: int,
|
||||
*args,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""For transformers>=4.55, the flash attention api has changed,
|
||||
we need to pass the query_length after doing ulysses alltoall.
|
||||
|
||||
See https://github.com/huggingface/transformers/issues/40399
|
||||
"""
|
||||
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
|
||||
@ -178,6 +125,7 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
|
||||
def _create_ulysses_wrapped_decoder_forward(original_forward):
|
||||
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
|
||||
inputs_embeds = kwargs.get("inputs_embeds")
|
||||
position_ids = kwargs.get("position_ids")
|
||||
call_kwargs = kwargs.copy()
|
||||
|
||||
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
@ -189,6 +137,7 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
|
||||
)
|
||||
if slice_now:
|
||||
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
|
||||
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
|
||||
self._needs_initial_slice = False
|
||||
try:
|
||||
return original_forward(self, *args, **call_kwargs)
|
||||
@ -225,12 +174,7 @@ def patch_forward_with_backends(
|
||||
|
||||
forward_with_torch_backend_function = model.__class__.forward
|
||||
forward_with_triton_backend_function = model.__class__.forward
|
||||
if model.config.model_type == "qwen2_5_vl":
|
||||
from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend
|
||||
|
||||
forward_with_torch_backend_function = forward_with_torch_backend
|
||||
forward_with_triton_backend_function = forward_with_triton_backend
|
||||
elif model.config.model_type == "qwen2_vl":
|
||||
if model.config.model_type in ["qwen2_5_vl", "qwen2_vl"]:
|
||||
from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend
|
||||
|
||||
forward_with_torch_backend_function = forward_with_torch_backend
|
||||
@ -296,50 +240,70 @@ def apply_monkey_patch(
|
||||
|
||||
# TODO: VLM models only, unify monkey patch to LLM models.
|
||||
if model.config.model_type == "qwen2_5_vl":
|
||||
if is_transformers_version_in_range(min_version="4.53.0"):
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention
|
||||
if is_transformers_version_in_range(min_version="4.52.0"):
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLAttention,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2_5_VLModel,
|
||||
Qwen2_5_VLTextModel,
|
||||
)
|
||||
|
||||
from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward
|
||||
|
||||
Qwen2_5_VLModel.forward = qwen2_vl_base_forward
|
||||
Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend
|
||||
else:
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
|
||||
)
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel
|
||||
|
||||
from verl.models.transformers.qwen2_vl import forward_with_normal_backend
|
||||
|
||||
Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend
|
||||
|
||||
if use_remove_padding or ulysses_sp_size > 1:
|
||||
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
|
||||
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward
|
||||
|
||||
Qwen2_5_VLAttention.forward = ulysses_flash_attn_forward
|
||||
print("Monkey patch FlashAttention2.forward in Qwen2.5VL")
|
||||
Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward
|
||||
print("Monkey patch Qwen2.5VL attention layer")
|
||||
|
||||
if ulysses_sp_size > 1:
|
||||
if is_transformers_version_in_range(min_version="4.52.0"):
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
|
||||
|
||||
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
|
||||
else:
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
||||
|
||||
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel)
|
||||
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
|
||||
|
||||
elif model.config.model_type == "qwen2_vl":
|
||||
if is_transformers_version_in_range(min_version="4.53.0"):
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
|
||||
if is_transformers_version_in_range(min_version="4.52.0"):
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2VLAttention,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLModel,
|
||||
Qwen2VLTextModel,
|
||||
)
|
||||
|
||||
from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward
|
||||
|
||||
Qwen2VLModel.forward = qwen2_vl_base_forward
|
||||
Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend
|
||||
else:
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel as Qwen2VLTextModel
|
||||
|
||||
from verl.models.transformers.qwen2_vl import forward_with_normal_backend
|
||||
|
||||
Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend
|
||||
|
||||
if use_remove_padding or ulysses_sp_size > 1:
|
||||
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
|
||||
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward
|
||||
|
||||
Qwen2VLAttention.forward = ulysses_flash_attn_forward
|
||||
print("Monkey patch FlashAttention2.forward in Qwen2VL")
|
||||
Qwen2VLAttention.forward = qwen2_vl_attn_forward
|
||||
print("Monkey patch Qwen2VL attention layer")
|
||||
|
||||
if ulysses_sp_size > 1:
|
||||
if is_transformers_version_in_range(min_version="4.52.0"):
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
|
||||
|
||||
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)
|
||||
else:
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
||||
|
||||
patch_vlm_for_ulysses_input_slicing(Qwen2VLModel)
|
||||
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)
|
||||
|
||||
elif model.config.model_type == "kimi_vl":
|
||||
if use_remove_padding or ulysses_sp_size > 1:
|
||||
@ -357,43 +321,14 @@ def apply_monkey_patch(
|
||||
|
||||
return
|
||||
|
||||
# transformers<=4.47.1
|
||||
if use_remove_padding or ulysses_sp_size > 1:
|
||||
if hasattr(module, "_flash_attention_forward"):
|
||||
if hasattr(module, "_flash_attention_forward"): # transformers <= 4.47.1 or legacy models
|
||||
module._flash_attention_forward = _ulysses_flash_attention_forward
|
||||
print(f"Monkey patch _flash_attention_forward in {model.__module__}")
|
||||
else:
|
||||
if is_transformers_version_in_range(min_version="4.55.0"):
|
||||
from transformers.integrations import flash_attention
|
||||
from transformers.integrations import flash_attention
|
||||
|
||||
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward_transformers_4_55
|
||||
print(f"Monkey patch _flash_attention_forward in {model.__module__} for new api")
|
||||
else:
|
||||
# 4.48.0 <= transformers <= 4.54.1, Vision attention
|
||||
from transformers.integrations import flash_attention
|
||||
|
||||
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward
|
||||
print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")
|
||||
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward
|
||||
print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")
|
||||
|
||||
patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:
|
||||
try:
|
||||
# Get the installed version of the transformers library
|
||||
transformers_version_str = importlib.metadata.version("transformers")
|
||||
except importlib.metadata.PackageNotFoundError as e:
|
||||
raise ModuleNotFoundError("The `transformers` package is not installed.") from e
|
||||
|
||||
transformers_version = version.parse(transformers_version_str)
|
||||
|
||||
lower_bound_check = True
|
||||
if min_version is not None:
|
||||
lower_bound_check = version.parse(min_version) <= transformers_version
|
||||
|
||||
upper_bound_check = True
|
||||
if max_version is not None:
|
||||
upper_bound_check = transformers_version <= version.parse(max_version)
|
||||
|
||||
return lower_bound_check and upper_bound_check
|
||||
|
@ -1,349 +0,0 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from importlib.metadata import version
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLCausalLMOutputWithPast,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast):
|
||||
log_probs: Optional[torch.FloatTensor] = None
|
||||
entropy: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
def forward_base_model_old_api(
|
||||
self: Qwen2_5_VLForConditionalGeneration,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
) -> tuple | Qwen2_5_VLCausalLMOutputWithPast:
|
||||
r"""
|
||||
Copy paste Qwen2_5_VL's forward
|
||||
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, "
|
||||
f"features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, "
|
||||
f"features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.video_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts,
|
||||
attention_mask,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def forward_base_model_new_api(
|
||||
self: Qwen2_5_VLForConditionalGeneration,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
) -> tuple | Qwen2_5_VLCausalLMOutputWithPast:
|
||||
r"""
|
||||
Copy paste Qwen2_5_VL's forward
|
||||
https://github.com/huggingface/transformers/blob/v4.52.3/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1384
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def forward_with_torch_backend(
|
||||
self: Qwen2_5_VLForConditionalGeneration,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:
|
||||
from verl.utils.experimental.torch_functional import FusedLinearForPPO
|
||||
|
||||
if Version(version("transformers")) < Version("4.52.0"):
|
||||
forward_base_model = forward_base_model_old_api
|
||||
else:
|
||||
forward_base_model = forward_base_model_new_api
|
||||
outputs = forward_base_model(
|
||||
self,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
rope_deltas=rope_deltas,
|
||||
cache_position=cache_position,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
raise NotImplementedError("forward_with_torch_backend has to return_dict")
|
||||
|
||||
# Loss calculations
|
||||
if labels is not None:
|
||||
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
|
||||
elif input_ids is not None:
|
||||
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
|
||||
else:
|
||||
raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.")
|
||||
|
||||
fused_linear_for_ppo = FusedLinearForPPO()
|
||||
log_probs, entropy = fused_linear_for_ppo.forward(
|
||||
hidden_states=hidden_states,
|
||||
vocab_weights=self.lm_head.weight,
|
||||
input_ids=rolled_labels,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
return Qwen2_5_VLCausalLMOutputForPPO(
|
||||
log_probs=log_probs,
|
||||
entropy=entropy,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=rope_deltas,
|
||||
)
|
||||
|
||||
|
||||
def forward_with_triton_backend(
|
||||
self: Qwen2_5_VLForConditionalGeneration,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:
|
||||
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
||||
|
||||
if Version(version("transformers")) < Version("4.52.0"):
|
||||
forward_base_model = forward_base_model_old_api
|
||||
else:
|
||||
forward_base_model = forward_base_model_new_api
|
||||
outputs = forward_base_model(
|
||||
self,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
rope_deltas=rope_deltas,
|
||||
cache_position=cache_position,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
raise NotImplementedError("forward_with_triton_backend has to return_dict")
|
||||
|
||||
# Loss calculations
|
||||
if labels is not None:
|
||||
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
|
||||
elif input_ids is not None:
|
||||
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
|
||||
else:
|
||||
raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.")
|
||||
|
||||
log_probs, entropy = linear_cross_entropy(
|
||||
hidden_states,
|
||||
self.lm_head.weight,
|
||||
rolled_labels,
|
||||
temperature,
|
||||
"none",
|
||||
)
|
||||
|
||||
return Qwen2_5_VLCausalLMOutputForPPO(
|
||||
log_probs=log_probs,
|
||||
entropy=entropy,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=rope_deltas,
|
||||
)
|
@ -19,20 +19,20 @@ from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
import torch.distributed as dist
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2VLAttention,
|
||||
Qwen2VLCausalLMOutputWithPast,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
from transformers.utils import is_flash_attn_greater_or_equal
|
||||
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
|
||||
|
||||
from verl.models.transformers.monkey_patch import is_transformers_version_in_range
|
||||
|
||||
# Import compatibility wrapper for flash_attn_supports_top_left_mask
|
||||
from verl.utils.transformers_compat import flash_attn_supports_top_left_mask
|
||||
from verl.utils.transformers_compat import is_transformers_version_in_range
|
||||
from verl.utils.ulysses import (
|
||||
gather_heads_scatter_seq,
|
||||
gather_seq_scatter_heads,
|
||||
get_ulysses_sequence_parallel_group,
|
||||
get_ulysses_sequence_parallel_world_size,
|
||||
validate_ulysses_config,
|
||||
)
|
||||
@ -40,22 +40,14 @@ from verl.utils.ulysses import (
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func
|
||||
|
||||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||
except ImportError:
|
||||
# Fallback: try to import from flash_attn package directly
|
||||
flash_attn_func = None
|
||||
_flash_supports_window_size = None
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
# If flash_attn is not available, set it to None
|
||||
flash_attn_varlen_func = None
|
||||
logger.warning(
|
||||
"flash_attn_varlen_func not available. Variable length attention will fall back to standard attention."
|
||||
)
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
|
||||
_flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
|
||||
_flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
|
||||
_flash_deterministic_enabled = os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
_flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
||||
def get_rope_index(
|
||||
@ -69,7 +61,7 @@ def get_rope_index(
|
||||
"""
|
||||
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
|
||||
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
|
||||
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
|
||||
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1405
|
||||
"""
|
||||
spatial_merge_size = processor.image_processor.merge_size
|
||||
tokens_per_second = 2
|
||||
@ -161,26 +153,26 @@ def get_rope_index(
|
||||
def prepare_fa2_from_position_ids(
|
||||
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
|
||||
):
|
||||
query = query.view(-1, query.size(-2), query.size(-1))
|
||||
key = key.view(-1, key.size(-2), key.size(-1))
|
||||
value = value.view(-1, value.size(-2), value.size(-1))
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||
assert position_ids.ndim == 2 # (batch_size, seq_length)
|
||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||
position_ids = position_ids.view(-1)
|
||||
cu_seqlens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
(position_ids == 0).nonzero().view(-1).to(torch.int32),
|
||||
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
||||
)
|
||||
)
|
||||
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
|
||||
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
|
||||
return (query, key, value, (cu_seqlens, cu_seqlens), (max_length, max_length))
|
||||
|
||||
|
||||
def flash_attention_forward(
|
||||
def _custom_flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
query_length: int,
|
||||
is_causal: bool = True,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
@ -192,74 +184,58 @@ def flash_attention_forward(
|
||||
"""
|
||||
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
|
||||
"""
|
||||
causal = is_causal if not use_top_left_mask else is_causal and query_length != 1
|
||||
|
||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
|
||||
)
|
||||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
||||
|
||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||
if deterministic is None:
|
||||
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
flash_kwargs["deterministic"] = deterministic
|
||||
if _flash_supports_deterministic:
|
||||
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
|
||||
|
||||
if (
|
||||
flash_attn_varlen_func is not None
|
||||
and position_ids is not None
|
||||
and query_length != 1
|
||||
and not (torch.diff(position_ids[0], dim=-1) >= 0).all()
|
||||
):
|
||||
if kwargs.get("softcap") is not None:
|
||||
flash_kwargs["softcap"] = kwargs.pop("softcap")
|
||||
|
||||
query_states, key_states, value_states = fa_peft_integration_check(
|
||||
query_states, key_states, value_states, target_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
if position_ids is not None:
|
||||
assert position_ids.ndim == 2 # (batch_size, seq_length / sp_size)
|
||||
|
||||
sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
if sp_size > 1:
|
||||
# qkv: (batch_size, seq_length / sp_size, num_head, head_size)
|
||||
validate_ulysses_config(query_states.size(2), sp_size)
|
||||
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
|
||||
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
|
||||
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
|
||||
position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
|
||||
position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
|
||||
position_ids = torch.cat(position_ids_lst, dim=-1) # (batch_size, seq_length)
|
||||
|
||||
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
|
||||
batch_size = query_states.size(0)
|
||||
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||
query_states, key_states, value_states, position_ids[0]
|
||||
) # remove channel dimension
|
||||
query_states, key_states, value_states, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||
query_states, key_states, value_states, position_ids
|
||||
)
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
flash_attn_func = flash_attn_varlen_func
|
||||
common_attn_kwargs = {
|
||||
"cu_seqlens_q": cu_seqlens_q,
|
||||
"cu_seqlens_k": cu_seqlens_k,
|
||||
"max_seqlen_q": max_seqlen_in_batch_q,
|
||||
"max_seqlen_k": max_seqlen_in_batch_k,
|
||||
"dropout_p": kwargs.pop("dropout", 0.0),
|
||||
"softmax_scale": kwargs.pop("softmax_scale", None),
|
||||
**flash_kwargs,
|
||||
}
|
||||
|
||||
if flash_attn_func is None:
|
||||
# Use transformers >= 4.54
|
||||
flash_attn_func = _flash_attention_forward
|
||||
specific_attn_kwargs = {
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"query_length": query_length,
|
||||
"is_causal": causal,
|
||||
}
|
||||
else:
|
||||
specific_attn_kwargs = {"causal": causal}
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
**common_attn_kwargs,
|
||||
**specific_attn_kwargs,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=kwargs.pop("dropout", 0.0),
|
||||
softmax_scale=kwargs.pop("softmax_scale", None),
|
||||
causal=is_causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
||||
else:
|
||||
if (
|
||||
flash_attn_varlen_func is None
|
||||
and position_ids is not None
|
||||
and query_length != 1
|
||||
and not (torch.diff(position_ids[0], dim=-1) >= 0).all()
|
||||
):
|
||||
logger.warning_once(
|
||||
"flash_attn_varlen_func is not available; falling back to _flash_attention_forward."
|
||||
"This may be suboptimal for non-monotonic position_ids in VLM mRoPE."
|
||||
)
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
@ -268,16 +244,20 @@ def flash_attention_forward(
|
||||
query_length,
|
||||
is_causal=is_causal,
|
||||
sliding_window=sliding_window,
|
||||
use_top_left_mask=flash_attn_supports_top_left_mask(),
|
||||
use_top_left_mask=use_top_left_mask,
|
||||
deterministic=deterministic,
|
||||
**kwargs,
|
||||
)
|
||||
) # do not pass position_ids to old flash_attention_forward
|
||||
|
||||
if sp_size > 1:
|
||||
# (batch_size, seq_length, num_head, head_size)
|
||||
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def ulysses_flash_attn_forward(
|
||||
self,
|
||||
def qwen2_vl_attn_forward(
|
||||
self: "Qwen2VLAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@ -295,225 +275,215 @@ def ulysses_flash_attn_forward(
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
|
||||
if ulysses_sp_size > 1:
|
||||
validate_ulysses_config(self.num_heads, ulysses_sp_size)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
||||
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
||||
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
||||
# (batch_size, num_head / sp_size, seq_length, head_size)
|
||||
full_q_len = query_states.size(2) # full_q_len = seq_length
|
||||
else:
|
||||
full_q_len = q_len
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
if position_embeddings is None:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
sliding_window = None
|
||||
if (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
sliding_window = self.config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
attn_output = flash_attention_forward(
|
||||
# This is before the transpose
|
||||
q_len = query_states.shape[2]
|
||||
|
||||
# FA2 uses non-transposed inputs
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if position_ids.ndim == 3:
|
||||
position_ids = position_ids[0]
|
||||
|
||||
attn_output = _custom_flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
full_q_len,
|
||||
query_length=q_len,
|
||||
is_causal=getattr(self, "is_causal", True),
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=flash_attn_supports_top_left_mask(),
|
||||
use_top_left_mask=_flash_use_top_left_mask,
|
||||
position_ids=position_ids, # important: pass position ids
|
||||
) # (batch_size, seq_length, num_head / sp_size, head_size)
|
||||
if ulysses_sp_size > 1:
|
||||
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
|
||||
|
||||
) # (batch_size, seq_length / sp_size, num_head, head_size)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
if is_transformers_version_in_range(min_version="4.53.0"):
|
||||
if is_transformers_version_in_range(min_version="4.54.0"):
|
||||
return attn_output, None
|
||||
else:
|
||||
return attn_output, None, None
|
||||
|
||||
|
||||
def _get_input_embeds(
|
||||
model: "Qwen2VLForConditionalGeneration",
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(model.visual.dtype)
|
||||
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == model.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == model.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(model.visual.dtype)
|
||||
video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == model.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == model.config.video_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if model.training and pixel_values is None and pixel_values_videos is None: # handle mixed text-image data
|
||||
pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
||||
image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
|
||||
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
inputs_embeds += 0.0 * image_embeds.mean()
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
return inputs_embeds, attention_mask
|
||||
|
||||
|
||||
def process_position_ids(position_ids: torch.Tensor) -> torch.Tensor:
|
||||
if position_ids.ndim != 3 or position_ids.size(0) != 4:
|
||||
# we concat the text position ids with the 3D vision position ids by default
|
||||
# see https://github.com/huggingface/transformers/pull/39447
|
||||
raise ValueError("position_ids should be a 3D tensor of shape (4, batch_size, seq_length).")
|
||||
|
||||
if is_transformers_version_in_range(max_version="4.53.3"):
|
||||
# transformers < 4.54.0 only accepts vision position ids, so we discard the text position ids here
|
||||
position_ids = position_ids[1:]
|
||||
|
||||
return position_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast):
|
||||
log_probs: Optional[torch.FloatTensor] = None
|
||||
entropy: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
def forward_base_model(
|
||||
self: Qwen2VLForConditionalGeneration,
|
||||
input_ids: torch.LongTensor = None,
|
||||
def qwen2_vl_base_forward(
|
||||
self: "Qwen2VLForConditionalGeneration",
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> tuple | Qwen2VLCausalLMOutputWithPast:
|
||||
r"""
|
||||
Copy paste Qwen2VL's forward
|
||||
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, "
|
||||
f"features {n_image_features}"
|
||||
)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, "
|
||||
f"features {n_video_features}"
|
||||
)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
||||
position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
|
||||
outputs = self.model(
|
||||
**kwargs,
|
||||
):
|
||||
kwargs["inputs_embeds"], kwargs["attention_mask"] = _get_input_embeds(
|
||||
self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
|
||||
) # avoid lora module having multiple keyword arguments
|
||||
return self.language_model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def qwen2_vl_forward(
|
||||
self: "Qwen2VLForConditionalGeneration",
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if is_transformers_version_in_range(min_version="4.52.0"):
|
||||
return self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=process_position_ids(position_ids),
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
inputs_embeds, attention_mask = _get_input_embeds(
|
||||
self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
|
||||
)
|
||||
return self.model(
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=process_position_ids(position_ids),
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def forward_with_normal_backend(
|
||||
self: Qwen2VLForConditionalGeneration,
|
||||
input_ids: torch.LongTensor = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
temperature: float = 1.0,
|
||||
**kwargs,
|
||||
) -> "Qwen2VLCausalLMOutputWithPast":
|
||||
outputs = qwen2_vl_forward(self, input_ids, **kwargs)
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return Qwen2VLCausalLMOutputWithPast(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
)
|
||||
|
||||
|
||||
def forward_with_torch_backend(
|
||||
self: Qwen2VLForConditionalGeneration,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> tuple | Qwen2VLCausalLMOutputForPPO:
|
||||
from verl.utils.experimental.torch_functional import FusedLinearForPPO
|
||||
|
||||
outputs = forward_base_model(
|
||||
self,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
rope_deltas=rope_deltas,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
outputs = qwen2_vl_forward(self, input_ids, **kwargs)
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
raise NotImplementedError("forward_with_torch_backend has to return_dict")
|
||||
|
||||
# Loss calculations
|
||||
if labels is not None:
|
||||
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
|
||||
@ -529,64 +499,25 @@ def forward_with_torch_backend(
|
||||
input_ids=rolled_labels,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
return Qwen2VLCausalLMOutputForPPO(
|
||||
log_probs=log_probs,
|
||||
entropy=entropy,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=rope_deltas,
|
||||
)
|
||||
|
||||
|
||||
def forward_with_triton_backend(
|
||||
self: Qwen2VLForConditionalGeneration,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> tuple | Qwen2VLCausalLMOutputForPPO:
|
||||
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
||||
|
||||
outputs = forward_base_model(
|
||||
self,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
rope_deltas=rope_deltas,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
outputs = qwen2_vl_forward(self, input_ids, **kwargs)
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
raise NotImplementedError("forward_with_triton_backend has to return_dict")
|
||||
|
||||
# Loss calculations
|
||||
if labels is not None:
|
||||
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
|
||||
@ -602,12 +533,8 @@ def forward_with_triton_backend(
|
||||
temperature,
|
||||
"none",
|
||||
)
|
||||
|
||||
return Qwen2VLCausalLMOutputForPPO(
|
||||
log_probs=log_probs,
|
||||
entropy=entropy,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=rope_deltas,
|
||||
)
|
||||
|
@ -299,17 +299,18 @@ class RLHFDataset(Dataset):
|
||||
if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
|
||||
from verl.models.transformers.qwen2_vl import get_rope_index
|
||||
|
||||
position_ids = [
|
||||
get_rope_index(
|
||||
self.processor,
|
||||
input_ids=input_ids[0],
|
||||
image_grid_thw=model_inputs.get("image_grid_thw"),
|
||||
video_grid_thw=model_inputs.get("video_grid_thw"),
|
||||
second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
|
||||
attention_mask=attention_mask[0],
|
||||
)
|
||||
] # (1, 3, seq_len)
|
||||
|
||||
vision_position_ids = get_rope_index(
|
||||
self.processor,
|
||||
input_ids=input_ids[0],
|
||||
image_grid_thw=model_inputs.get("image_grid_thw"),
|
||||
video_grid_thw=model_inputs.get("video_grid_thw"),
|
||||
second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
|
||||
attention_mask=attention_mask[0],
|
||||
) # (3, seq_length)
|
||||
valid_mask = attention_mask[0].bool()
|
||||
text_position_ids = torch.ones((1, len(input_ids[0])), dtype=torch.long)
|
||||
text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item())
|
||||
position_ids = [torch.cat((text_position_ids, vision_position_ids), dim=0)] # (1, 4, seq_length)
|
||||
else:
|
||||
position_ids = compute_position_id_with_mask(attention_mask)
|
||||
|
||||
|
@ -16,6 +16,12 @@
|
||||
Compatibility utilities for different versions of transformers library.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from packaging import version
|
||||
|
||||
# Handle version compatibility for flash_attn_supports_top_left_mask
|
||||
# This function was added in newer versions of transformers
|
||||
try:
|
||||
@ -28,3 +34,24 @@ except ImportError:
|
||||
Returns False to disable features that require this function.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:
|
||||
try:
|
||||
# Get the installed version of the transformers library
|
||||
transformers_version_str = importlib.metadata.version("transformers")
|
||||
except importlib.metadata.PackageNotFoundError as e:
|
||||
raise ModuleNotFoundError("The `transformers` package is not installed.") from e
|
||||
|
||||
transformers_version = version.parse(transformers_version_str)
|
||||
|
||||
lower_bound_check = True
|
||||
if min_version is not None:
|
||||
lower_bound_check = version.parse(min_version) <= transformers_version
|
||||
|
||||
upper_bound_check = True
|
||||
if max_version is not None:
|
||||
upper_bound_check = transformers_version <= version.parse(max_version)
|
||||
|
||||
return lower_bound_check and upper_bound_check
|
||||
|
@ -288,7 +288,7 @@ def ulysses_pad(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torc
|
||||
if position_ids_rmpad is not None:
|
||||
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)
|
||||
if position_ids_rmpad.dim() == 3:
|
||||
pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(3, 1, 1)
|
||||
pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(position_ids_rmpad.size(0), 1, 1)
|
||||
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)
|
||||
return input_ids_rmpad, position_ids_rmpad, pad_size
|
||||
|
||||
|
@ -114,7 +114,7 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
position_ids = micro_batch["position_ids"]
|
||||
entropy = None
|
||||
if position_ids.dim() == 3: # qwen2vl mrope
|
||||
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
|
||||
position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen)
|
||||
|
||||
if self.use_remove_padding:
|
||||
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(
|
||||
@ -128,7 +128,7 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
|
||||
.transpose(0, 1)
|
||||
.unsqueeze(1)
|
||||
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
|
||||
) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen)
|
||||
else:
|
||||
position_ids_rmpad = index_first_axis(
|
||||
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
|
||||
|
@ -83,7 +83,7 @@ class DataParallelPPOCritic(BasePPOCritic):
|
||||
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
|
||||
.transpose(0, 1)
|
||||
.unsqueeze(1)
|
||||
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
|
||||
) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen)
|
||||
else:
|
||||
position_ids_rmpad = index_first_axis(
|
||||
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
|
||||
|
@ -755,8 +755,8 @@ class SGLangRollout(BaseRollout):
|
||||
response_length = response.size(1)
|
||||
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
|
||||
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)
|
||||
if position_ids.dim() == 3: # qwen2vl mrope
|
||||
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
|
||||
if position_ids.dim() == 3: # qwen2vl mrope (batch size, 4, seq len)
|
||||
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, position_ids.size(1), -1)
|
||||
|
||||
# TODO(sgm): fix position_ids on right_pad
|
||||
# prompt: left pad + response: right pad
|
||||
|
@ -372,8 +372,8 @@ class vLLMRollout(BaseRollout):
|
||||
response_length = response.size(1)
|
||||
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
|
||||
delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)
|
||||
if position_ids.dim() == 3: # qwen2vl mrope
|
||||
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
|
||||
if position_ids.dim() == 3: # qwen2vl mrope (batch size, 4, seq len)
|
||||
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, position_ids.size(1), -1)
|
||||
|
||||
# TODO(sgm): fix position_ids on right_pad
|
||||
# prompt: left pad + response: right pad
|
||||
|
Reference in New Issue
Block a user