[V0 Deprecation] Remove V0 logic from get_input_embeddings interface (#25242)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-19 19:10:52 +08:00
committed by GitHub
parent a3d087adec
commit 5089fd749c
4 changed files with 21 additions and 83 deletions

View File

@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
EOT = "<|endofturn|>"
@ -740,33 +741,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
**kwargs,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if (kwargs.get("pixel_values_images") is not None
or kwargs.get("pixel_values_videos")
is not None): # v0 compatibility
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
if multimodal_embeddings is not None:
multimodal_embeddings = torch.cat(multimodal_embeddings, dim=0)
_mask_image = input_ids == self.config.image_token_id
_mask_video = input_ids == self.config.video_token_id
assert _mask_image.sum() + _mask_video.sum() == len(
multimodal_embeddings)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
placeholder_token_id=[
self.config.image_token_id,
self.config.video_token_id,
],
)
if multimodal_embeddings.dtype != inputs_embeds.dtype:
multimodal_embeddings = multimodal_embeddings.to(
dtype=inputs_embeds.dtype)
if multimodal_embeddings.device != inputs_embeds.device:
multimodal_embeddings = multimodal_embeddings.to(
device=inputs_embeds.device)
if _mask_image.sum() > 0:
inputs_embeds[
_mask_image] = multimodal_embeddings[:sum(_mask_image)]
if _mask_video.sum() > 0:
inputs_embeds[_mask_video] = multimodal_embeddings[
-sum(_mask_video):]
return inputs_embeds
def forward(
@ -783,8 +771,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids=input_ids,
**kwargs)
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,

View File

@ -23,7 +23,6 @@ from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
@ -97,33 +96,10 @@ class SupportsMultiModal(Protocol):
"""
...
# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> Tensor:
...
# TODO: Remove this overload once v0 is deprecated
@overload
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> Tensor:
...
def get_input_embeddings(
self,
input_ids: Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
# Only necessary so that the v0 overload is valid
# TODO: Remove attn_metadata once v0 is deprecated
attn_metadata: Optional["AttentionMetadata"] = None,
) -> Tensor:
"""
Returns the input embeddings merged from the text embeddings from

View File

@ -13,9 +13,7 @@ from transformers import BatchFeature, ProcessorMixin
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm import envs
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader import DefaultModelLoader
@ -37,8 +35,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings,
merge_multimodal_embeddings_from_map)
merge_multimodal_embeddings)
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
_MAX_ENCODER_BATCH_SIZE = 16
@ -568,17 +565,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
safe_input_ids)
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
# TODO(ywang96): remove this block after v0 is deprecated.
if not envs.VLLM_USE_V1:
attn_metadata = get_forward_context().attn_metadata
merge_multimodal_embeddings_from_map(
inputs_embeds, multimodal_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
else:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
return inputs_embeds
def forward(self,

View File

@ -15,7 +15,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available)
@ -389,22 +389,6 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings_from_map(
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
placeholder map .
Note:
This updates ``inputs_embeds`` in place.
"""
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
placeholder_map.src].to(dtype=inputs_embeds.dtype)
return inputs_embeds
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,