mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core][VLM] Stack multimodal tensors to represent multiple images within each prompt (#7902)
This commit is contained in:
@ -45,8 +45,6 @@ Base Classes
|
||||
|
||||
.. autodata:: vllm.multimodal.NestedTensors
|
||||
|
||||
.. autodata:: vllm.multimodal.BatchedTensors
|
||||
|
||||
.. autodata:: vllm.multimodal.BatchedTensorInputs
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
|
||||
|
83
tests/multimodal/test_base.py
Normal file
83
tests/multimodal/test_base.py
Normal file
@ -0,0 +1,83 @@
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.base import MultiModalInputs, NestedTensors
|
||||
|
||||
|
||||
def assert_nested_tensors_equal(expected: NestedTensors,
|
||||
actual: NestedTensors):
|
||||
assert type(expected) == type(actual)
|
||||
if isinstance(expected, torch.Tensor):
|
||||
assert torch.equal(expected, actual)
|
||||
else:
|
||||
for expected_item, actual_item in zip(expected, actual):
|
||||
assert_nested_tensors_equal(expected_item, actual_item)
|
||||
|
||||
|
||||
def assert_multimodal_inputs_equal(expected: MultiModalInputs,
|
||||
actual: MultiModalInputs):
|
||||
assert set(expected.keys()) == set(actual.keys())
|
||||
for key in expected:
|
||||
assert_nested_tensors_equal(expected[key], actual[key])
|
||||
|
||||
|
||||
def test_multimodal_input_batch_single_tensor():
|
||||
t = torch.rand([1, 2])
|
||||
result = MultiModalInputs.batch([{"image": t}])
|
||||
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
|
||||
|
||||
|
||||
def test_multimodal_input_batch_multiple_tensors():
|
||||
a = torch.rand([1, 1, 2])
|
||||
b = torch.rand([1, 1, 2])
|
||||
c = torch.rand([1, 1, 2])
|
||||
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
|
||||
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
|
||||
|
||||
|
||||
def test_multimodal_input_batch_multiple_heterogeneous_tensors():
|
||||
a = torch.rand([1, 2, 2])
|
||||
b = torch.rand([1, 3, 2])
|
||||
c = torch.rand([1, 4, 2])
|
||||
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
|
||||
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
|
||||
|
||||
|
||||
def test_multimodal_input_batch_nested_tensors():
|
||||
a = torch.rand([2, 3])
|
||||
b = torch.rand([2, 3])
|
||||
c = torch.rand([2, 3])
|
||||
result = MultiModalInputs.batch([{
|
||||
"image": [a]
|
||||
}, {
|
||||
"image": [b]
|
||||
}, {
|
||||
"image": [c]
|
||||
}])
|
||||
assert_multimodal_inputs_equal(result, {
|
||||
"image":
|
||||
torch.stack([a.unsqueeze(0),
|
||||
b.unsqueeze(0),
|
||||
c.unsqueeze(0)])
|
||||
})
|
||||
|
||||
|
||||
def test_multimodal_input_batch_heterogeneous_lists():
|
||||
a = torch.rand([1, 2, 3])
|
||||
b = torch.rand([1, 2, 3])
|
||||
c = torch.rand([1, 2, 3])
|
||||
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
|
||||
assert_multimodal_inputs_equal(
|
||||
result,
|
||||
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})
|
||||
|
||||
|
||||
def test_multimodal_input_batch_multiple_batchable_lists():
|
||||
a = torch.rand([1, 2, 3])
|
||||
b = torch.rand([1, 2, 3])
|
||||
c = torch.rand([1, 2, 3])
|
||||
d = torch.rand([1, 2, 3])
|
||||
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
|
||||
assert_multimodal_inputs_equal(
|
||||
result,
|
||||
{"image": torch.stack([torch.stack([a, b]),
|
||||
torch.stack([c, d])])})
|
@ -555,6 +555,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
|
||||
return Blip2ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
@ -564,6 +567,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_embeds = image_embeds.squeeze(1)
|
||||
|
||||
return Blip2ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
|
@ -946,6 +946,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
|
||||
return ChameleonImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
|
@ -249,6 +249,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
|
||||
if isinstance(image_patches, torch.Tensor):
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_patches = image_patches.squeeze(1)
|
||||
|
||||
expected_feature_size = self.image_feature_size
|
||||
if image_patches.size(-1) != expected_feature_size:
|
||||
raise ValueError(
|
||||
|
@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
||||
min_num,
|
||||
max_num,
|
||||
use_thumbnail=use_thumbnail)
|
||||
# Add an N dimension for number of images per prompt (currently 1).
|
||||
data = data.unsqueeze(0)
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
@ -410,6 +412,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
# Flatten the B and N dimensions
|
||||
image_embeds = image_embeds.flatten(0, 2)
|
||||
|
||||
return InternVLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
@ -422,6 +428,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Flatten the B and N dimensions
|
||||
pixel_values = pixel_values.flatten(0, 2)
|
||||
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
|
@ -232,6 +232,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
@ -241,6 +245,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_embeds = image_embeds.squeeze(1)
|
||||
|
||||
return LlavaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
|
@ -361,6 +361,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
else:
|
||||
pixel_values = [t.squeeze(0) for t in pixel_values]
|
||||
|
||||
image_sizes = image_sizes.squeeze(1)
|
||||
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
@ -372,6 +380,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of image embeds. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_embeds = image_embeds.squeeze(1)
|
||||
|
||||
return LlavaNextImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
|
@ -594,9 +594,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
|
||||
pixel_values_flat: List[torch.Tensor] = []
|
||||
tgt_sizes_flat: List[torch.Tensor] = []
|
||||
for b in range(len(pixel_values)):
|
||||
pixel_values_flat += pixel_values[b]
|
||||
tgt_sizes_flat += tgt_sizes[b]
|
||||
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
||||
if len(pixel_b) != len(tgt_b):
|
||||
raise ValueError("Inconsistent N lengths, found: "
|
||||
f"{len(pixel_b)} vs {len(tgt_b)}")
|
||||
|
||||
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||
pixel_values_flat += pixel_n
|
||||
tgt_sizes_flat += tgt_n
|
||||
|
||||
# NOTE: Input IDs does not contain image tokens during memory profiling,
|
||||
# so we allow it to be empty
|
||||
|
@ -185,6 +185,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
|
||||
return PaliGemmaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
@ -194,6 +198,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_embeds = image_embeds.squeeze(1)
|
||||
|
||||
return PaliGemmaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
|
@ -560,6 +560,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
# Merge the B and N dimensions.
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
pixel_values = pixel_values.flatten(0, 1)
|
||||
else:
|
||||
pixel_values = torch.cat(pixel_values)
|
||||
|
||||
image_sizes = image_sizes.flatten(0, 1)
|
||||
|
||||
return Phi3VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
|
@ -333,6 +333,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of audio features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
# Remove the N dimension until multiple audios are supported.
|
||||
if isinstance(audio_features, torch.Tensor):
|
||||
audio_features = audio_features.squeeze(1)
|
||||
else:
|
||||
audio_features = [t.squeeze(0) for t in audio_features]
|
||||
|
||||
return UltravoxAudioFeatureInputs(type="audio_features",
|
||||
data=audio_features)
|
||||
|
||||
@ -341,6 +347,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("Incorrect type of audio embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple audios are supported.
|
||||
audio_embeds = audio_embeds.squeeze(1)
|
||||
|
||||
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
|
||||
data=audio_embeds)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.func import functional_call
|
||||
@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.loader import build_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.multimodal import BatchedTensors
|
||||
from vllm.multimodal.base import NestedTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -54,9 +55,34 @@ def init_vllm_registered_model(
|
||||
)
|
||||
|
||||
|
||||
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
|
||||
"""
|
||||
Recursively concatenates NestedTensors along any heterogeneously sized
|
||||
dimensions.
|
||||
"""
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
return embeddings
|
||||
|
||||
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
|
||||
|
||||
|
||||
def _embedding_count_expression(embeddings: NestedTensors) -> str:
|
||||
"""
|
||||
Constructs a debugging representation of the number of embeddings in the
|
||||
NestedTensors.
|
||||
"""
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
|
||||
|
||||
return " + ".join(
|
||||
_embedding_count_expression(inner) for inner in embeddings)
|
||||
|
||||
|
||||
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: BatchedTensors,
|
||||
multimodal_embeddings: NestedTensors,
|
||||
placeholder_token_id: int) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
mask = (input_ids == placeholder_token_id)
|
||||
num_expected_tokens = mask.sum()
|
||||
|
||||
if isinstance(multimodal_embeddings, torch.Tensor):
|
||||
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
|
||||
total_tokens = batch_size * batch_tokens
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = f"{batch_size} x {batch_tokens}"
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = multimodal_embeddings.view(
|
||||
total_tokens, embed_dim)
|
||||
else:
|
||||
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
|
||||
total_tokens = sum(size_per_batch)
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = ' + '.join(map(str, size_per_batch))
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
*dims, embed_dim = flattened.shape
|
||||
num_multimodal_embeddings = np.prod(dims)
|
||||
if num_multimodal_embeddings != num_expected_tokens:
|
||||
expr = _embedding_count_expression(multimodal_embeddings)
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {num_multimodal_embeddings} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins,
|
||||
from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
|
||||
MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
|
||||
NestedTensors)
|
||||
from .registry import MultiModalRegistry
|
||||
@ -14,7 +14,6 @@ See also:
|
||||
|
||||
__all__ = [
|
||||
"BatchedTensorInputs",
|
||||
"BatchedTensors",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalInputs",
|
||||
|
@ -1,9 +1,8 @@
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from typing import Callable, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final
|
||||
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
|
||||
TypedDict, TypeVar, Union, cast, final)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -15,23 +14,16 @@ from typing_extensions import TypeAlias
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import JSONTree, json_map_leaves
|
||||
from vllm.utils import json_map_leaves
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor]
|
||||
NestedTensors = Union[List["NestedTensors"], torch.Tensor]
|
||||
"""
|
||||
Use a list instead of a tensor if the dimensions of each element do not match.
|
||||
Currently only supports up to singly nested list of tensors.
|
||||
Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||
"""
|
||||
|
||||
BatchedTensors: TypeAlias = JSONTree[torch.Tensor]
|
||||
"""
|
||||
A nested JSON structure of tensors which have been batched via
|
||||
:meth:`MultiModalInputs.batch`.
|
||||
"""
|
||||
|
||||
BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]]
|
||||
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
|
||||
"""
|
||||
A dictionary containing nested tensors which have been batched via
|
||||
:meth:`MultiModalInputs.batch`.
|
||||
@ -54,26 +46,23 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
|
||||
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
||||
"""
|
||||
If each input tensor in the batch has the same shape, return a single
|
||||
batched tensor; otherwise, return a list of :class:`NestedTensors` with
|
||||
one element per item in the batch.
|
||||
Recursively stacks lists of tensors when they all have the same shape.
|
||||
"""
|
||||
# may be list rather than tensors
|
||||
if isinstance(tensors[0], list):
|
||||
return [[t for t in tensor[0]]
|
||||
for tensor in cast(List[List[torch.Tensor]], tensors)]
|
||||
if isinstance(nested_tensors, torch.Tensor):
|
||||
return nested_tensors
|
||||
|
||||
tensors_ = cast(List[torch.Tensor], tensors)
|
||||
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
|
||||
if any(isinstance(t, list) for t in stacked):
|
||||
return stacked
|
||||
|
||||
unbatched_shape = tensors_[0].shape[1:]
|
||||
tensors_ = cast(List[torch.Tensor], stacked)
|
||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
|
||||
for tensor in tensors_:
|
||||
if tensor.shape[1:] != unbatched_shape:
|
||||
return [tensor.squeeze(0) for tensor in tensors_]
|
||||
|
||||
return torch.cat(tensors_, dim=0)
|
||||
return torch.stack(tensors_)
|
||||
|
||||
@staticmethod
|
||||
def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
|
||||
@ -102,7 +91,7 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
item_lists[k].append(v)
|
||||
|
||||
return {
|
||||
k: MultiModalInputs._try_concat(item_list)
|
||||
k: MultiModalInputs._try_stack(item_list)
|
||||
for k, item_list in item_lists.items()
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user