mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V0 Deprecation] Remove V0 logic from get_input_embeddings
interface (#25242)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
EOT = "<|endofturn|>"
|
||||
@ -740,33 +741,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if (kwargs.get("pixel_values_images") is not None
|
||||
or kwargs.get("pixel_values_videos")
|
||||
is not None): # v0 compatibility
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
if multimodal_embeddings is not None:
|
||||
multimodal_embeddings = torch.cat(multimodal_embeddings, dim=0)
|
||||
_mask_image = input_ids == self.config.image_token_id
|
||||
_mask_video = input_ids == self.config.video_token_id
|
||||
assert _mask_image.sum() + _mask_video.sum() == len(
|
||||
multimodal_embeddings)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.image_token_id,
|
||||
self.config.video_token_id,
|
||||
],
|
||||
)
|
||||
|
||||
if multimodal_embeddings.dtype != inputs_embeds.dtype:
|
||||
multimodal_embeddings = multimodal_embeddings.to(
|
||||
dtype=inputs_embeds.dtype)
|
||||
if multimodal_embeddings.device != inputs_embeds.device:
|
||||
multimodal_embeddings = multimodal_embeddings.to(
|
||||
device=inputs_embeds.device)
|
||||
|
||||
if _mask_image.sum() > 0:
|
||||
inputs_embeds[
|
||||
_mask_image] = multimodal_embeddings[:sum(_mask_image)]
|
||||
if _mask_video.sum() > 0:
|
||||
inputs_embeds[_mask_video] = multimodal_embeddings[
|
||||
-sum(_mask_video):]
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
@ -783,8 +771,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids=input_ids,
|
||||
**kwargs)
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
input_ids = None
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
@ -23,7 +23,6 @@ from vllm.utils import supports_kw
|
||||
from .interfaces_base import is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -97,33 +96,10 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
# Only for models that support v0 chunked prefill
|
||||
# TODO(ywang96): Remove this overload once v0 is deprecated
|
||||
@overload
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
attn_metadata: Optional["AttentionMetadata"] = None,
|
||||
) -> Tensor:
|
||||
...
|
||||
|
||||
# TODO: Remove this overload once v0 is deprecated
|
||||
@overload
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> Tensor:
|
||||
...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
# Only necessary so that the v0 overload is valid
|
||||
# TODO: Remove attn_metadata once v0 is deprecated
|
||||
attn_metadata: Optional["AttentionMetadata"] = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Returns the input embeddings merged from the text embeddings from
|
||||
|
@ -13,9 +13,7 @@ from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.model_loader import DefaultModelLoader
|
||||
@ -37,8 +35,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings,
|
||||
merge_multimodal_embeddings_from_map)
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
|
||||
_MAX_ENCODER_BATCH_SIZE = 16
|
||||
@ -568,17 +565,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
safe_input_ids)
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
|
||||
# TODO(ywang96): remove this block after v0 is deprecated.
|
||||
if not envs.VLLM_USE_V1:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
merge_multimodal_embeddings_from_map(
|
||||
inputs_embeds, multimodal_embeddings,
|
||||
attn_metadata.multi_modal_placeholder_index_maps["audio"])
|
||||
else:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.audio_token_index)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.audio_token_index)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
|
@ -15,7 +15,7 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
|
||||
from vllm.multimodal import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||
is_uva_available)
|
||||
@ -389,22 +389,6 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
|
||||
_embedding_count_expression(inner) for inner in embeddings)
|
||||
|
||||
|
||||
def merge_multimodal_embeddings_from_map(
|
||||
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
|
||||
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
|
||||
placeholder map .
|
||||
|
||||
Note:
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
|
||||
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
|
||||
placeholder_map.src].to(dtype=inputs_embeds.dtype)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def _merge_multimodal_embeddings(
|
||||
inputs_embeds: torch.Tensor,
|
||||
is_multimodal: torch.Tensor,
|
||||
|
Reference in New Issue
Block a user