mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[feat]: Create interface for model-specific M-RoPE (#24194)
Signed-off-by: AzizCode92 <azizbenothman76@gmail.com> Signed-off-by: Aziz <azizbenothman76@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
||||
SupportsPP, SupportsTranscription, SupportsV0Only,
|
||||
has_inner_state, supports_lora, supports_multimodal,
|
||||
supports_pp, supports_transcription, supports_v0_only)
|
||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMRoPE,
|
||||
SupportsMultiModal, SupportsPP, SupportsTranscription,
|
||||
SupportsV0Only, has_inner_state, supports_lora,
|
||||
supports_mrope, supports_multimodal, supports_pp,
|
||||
supports_transcription, supports_v0_only)
|
||||
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
|
||||
is_pooling_model, is_text_generation_model)
|
||||
from .registry import ModelRegistry
|
||||
@ -21,6 +22,8 @@ __all__ = [
|
||||
"supports_lora",
|
||||
"SupportsMultiModal",
|
||||
"supports_multimodal",
|
||||
"SupportsMRoPE",
|
||||
"supports_mrope",
|
||||
"SupportsPP",
|
||||
"supports_pp",
|
||||
"SupportsTranscription",
|
||||
|
@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.models.whisper.tokenization_whisper import LANGUAGES
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
@ -852,3 +853,70 @@ def supports_eagle3(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
|
||||
return isinstance(model, SupportsEagle3)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsMRoPE(Protocol):
|
||||
"""The interface required for all models that support M-RoPE."""
|
||||
|
||||
supports_mrope: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports M-RoPE.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Get M-RoPE input positions and delta value for this specific model.
|
||||
|
||||
This method should be implemented by each model that supports M-RoPE
|
||||
to provide model-specific logic for computing input positions.
|
||||
|
||||
Args:
|
||||
input_tokens: List of input token IDs
|
||||
hf_config: HuggingFace model configuration
|
||||
image_grid_thw: Image grid dimensions (t, h, w)
|
||||
video_grid_thw: Video grid dimensions (t, h, w)
|
||||
second_per_grid_ts: Seconds per grid timestep for videos
|
||||
context_len: Context length
|
||||
seq_len: Sequence length
|
||||
audio_feature_lengths: Audio feature lengths for multimodal models
|
||||
use_audio_in_video: Whether to use audio in video for interleaving
|
||||
|
||||
Returns:
|
||||
Tuple of (llm_positions, mrope_position_delta)
|
||||
- llm_positions: Tensor of shape [3, num_tokens]
|
||||
with T/H/W positions
|
||||
- mrope_position_delta: Delta for position calculations
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]:
|
||||
...
|
||||
|
||||
|
||||
def supports_mrope(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[SupportsMRoPE]], TypeIs[SupportsMRoPE]]:
|
||||
return isinstance(model, SupportsMRoPE)
|
||||
|
@ -32,7 +32,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers import AutoConfig, BatchFeature
|
||||
from transformers import AutoConfig, BatchFeature, PretrainedConfig
|
||||
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
|
||||
Qwen2VLProcessor)
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||
@ -73,7 +73,7 @@ from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
@ -1096,7 +1096,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
info=Qwen2VLProcessingInfo,
|
||||
dummy_inputs=Qwen2VLDummyInputsBuilder)
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP):
|
||||
SupportsLoRA, SupportsPP, SupportsMRoPE):
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
@ -1109,6 +1109,118 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get M-RoPE input positions for Qwen2-VL model."""
|
||||
if image_grid_thw is None:
|
||||
image_grid_thw = []
|
||||
if video_grid_thw is None:
|
||||
video_grid_thw = []
|
||||
if second_per_grid_ts is None:
|
||||
second_per_grid_ts = []
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(hf_config.vision_config,
|
||||
"tokens_per_second", 1.0)
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
video_second_per_grid_t = 0.0
|
||||
if remain_images > 0:
|
||||
try:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
except ValueError:
|
||||
ed_image = len(input_tokens) + 1
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if remain_videos > 0:
|
||||
try:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
except ValueError:
|
||||
ed_video = len(input_tokens) + 1
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_second_per_grid_t = 1.0
|
||||
if second_per_grid_ts:
|
||||
video_second_per_grid_t = second_per_grid_ts[video_index]
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_merge_size, w // spatial_merge_size
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
|
||||
tokens_per_second).long().flatten()
|
||||
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 -
|
||||
len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
|
@ -42,6 +42,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||
supports_eagle3,
|
||||
supports_mrope,
|
||||
supports_transcription)
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||
@ -730,16 +731,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
if supports_mrope(self.model):
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
self.model.get_mrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
else:
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
|
||||
def _extract_mm_kwargs(
|
||||
self,
|
||||
|
@ -41,7 +41,8 @@ from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
|
||||
get_sampler)
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.model_executor.models import (supports_lora, supports_mrope,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap,
|
||||
@ -670,18 +671,33 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
inter_data.seq_ids[seq_idx]]
|
||||
token_ids = seq_data.get_token_ids()
|
||||
|
||||
mrope_input_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=inter_data.context_lens[seq_idx],
|
||||
seq_len=inter_data.seq_lens[seq_idx],
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
if supports_mrope(self.runner.model):
|
||||
mrope_input_positions, mrope_position_delta = \
|
||||
self.runner.model.get_mrope_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=inter_data.context_lens[seq_idx],
|
||||
seq_len=inter_data.seq_lens[seq_idx],
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
mrope_input_positions = mrope_input_positions.tolist()
|
||||
else:
|
||||
mrope_input_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=inter_data.context_lens[seq_idx],
|
||||
seq_len=inter_data.seq_lens[seq_idx],
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
|
||||
seq_data.mrope_position_delta = mrope_position_delta
|
||||
inter_data.mrope_input_positions[
|
||||
|
Reference in New Issue
Block a user