[Core][VLM] Stack multimodal tensors to represent multiple images within each prompt (#7902)

This commit is contained in:
Peter Salas
2024-08-27 18:53:56 -07:00
committed by GitHub
parent 9c71c97ae2
commit fab5f53e2d
15 changed files with 214 additions and 60 deletions

View File

@ -45,8 +45,6 @@ Base Classes
.. autodata:: vllm.multimodal.NestedTensors
.. autodata:: vllm.multimodal.BatchedTensors
.. autodata:: vllm.multimodal.BatchedTensorInputs
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins

View File

@ -0,0 +1,83 @@
import torch
from vllm.multimodal.base import MultiModalInputs, NestedTensors
def assert_nested_tensors_equal(expected: NestedTensors,
actual: NestedTensors):
assert type(expected) == type(actual)
if isinstance(expected, torch.Tensor):
assert torch.equal(expected, actual)
else:
for expected_item, actual_item in zip(expected, actual):
assert_nested_tensors_equal(expected_item, actual_item)
def assert_multimodal_inputs_equal(expected: MultiModalInputs,
actual: MultiModalInputs):
assert set(expected.keys()) == set(actual.keys())
for key in expected:
assert_nested_tensors_equal(expected[key], actual[key])
def test_multimodal_input_batch_single_tensor():
t = torch.rand([1, 2])
result = MultiModalInputs.batch([{"image": t}])
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
def test_multimodal_input_batch_multiple_tensors():
a = torch.rand([1, 1, 2])
b = torch.rand([1, 1, 2])
c = torch.rand([1, 1, 2])
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
def test_multimodal_input_batch_multiple_heterogeneous_tensors():
a = torch.rand([1, 2, 2])
b = torch.rand([1, 3, 2])
c = torch.rand([1, 4, 2])
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
def test_multimodal_input_batch_nested_tensors():
a = torch.rand([2, 3])
b = torch.rand([2, 3])
c = torch.rand([2, 3])
result = MultiModalInputs.batch([{
"image": [a]
}, {
"image": [b]
}, {
"image": [c]
}])
assert_multimodal_inputs_equal(result, {
"image":
torch.stack([a.unsqueeze(0),
b.unsqueeze(0),
c.unsqueeze(0)])
})
def test_multimodal_input_batch_heterogeneous_lists():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(
result,
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})
def test_multimodal_input_batch_multiple_batchable_lists():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
d = torch.rand([1, 2, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
assert_multimodal_inputs_equal(
result,
{"image": torch.stack([torch.stack([a, b]),
torch.stack([c, d])])})

View File

@ -555,6 +555,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
@ -564,6 +567,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,

View File

@ -946,6 +946,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return ChameleonImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),

View File

@ -249,6 +249,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
image_patches = kwargs.pop("image_patches", None)
if isinstance(image_patches, torch.Tensor):
# Remove the N dimension until multiple images are supported.
image_patches = image_patches.squeeze(1)
expected_feature_size = self.image_feature_size
if image_patches.size(-1) != expected_feature_size:
raise ValueError(

View File

@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
min_num,
max_num,
use_thumbnail=use_thumbnail)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
@ -410,6 +412,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Flatten the B and N dimensions
image_embeds = image_embeds.flatten(0, 2)
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
@ -422,6 +428,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Flatten the B and N dimensions
pixel_values = pixel_values.flatten(0, 2)
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),

View File

@ -232,6 +232,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
@ -241,6 +245,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,

View File

@ -361,6 +361,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
# Remove the N dimension until multiple images are supported.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.squeeze(1)
else:
pixel_values = [t.squeeze(0) for t in pixel_values]
image_sizes = image_sizes.squeeze(1)
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
@ -372,6 +380,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,

View File

@ -594,9 +594,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for b in range(len(pixel_values)):
pixel_values_flat += pixel_values[b]
tgt_sizes_flat += tgt_sizes[b]
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
if len(pixel_b) != len(tgt_b):
raise ValueError("Inconsistent N lengths, found: "
f"{len(pixel_b)} vs {len(tgt_b)}")
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += pixel_n
tgt_sizes_flat += tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty

View File

@ -185,6 +185,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
@ -194,6 +198,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return PaliGemmaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,

View File

@ -560,6 +560,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
# Merge the B and N dimensions.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.flatten(0, 1)
else:
pixel_values = torch.cat(pixel_values)
image_sizes = image_sizes.flatten(0, 1)
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),

View File

@ -333,6 +333,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
# Remove the N dimension until multiple audios are supported.
if isinstance(audio_features, torch.Tensor):
audio_features = audio_features.squeeze(1)
else:
audio_features = [t.squeeze(0) for t in audio_features]
return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features)
@ -341,6 +347,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
# Remove the N dimension until multiple audios are supported.
audio_embeds = audio_embeds.squeeze(1)
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)

View File

@ -1,5 +1,6 @@
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.func import functional_call
@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal import BatchedTensors
from vllm.multimodal.base import NestedTensors
from vllm.utils import is_pin_memory_available
@ -54,9 +55,34 @@ def init_vllm_registered_model(
)
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively concatenates NestedTensors along any heterogeneously sized
dimensions.
"""
if isinstance(embeddings, torch.Tensor):
return embeddings
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings: NestedTensors) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(
_embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: BatchedTensors,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum()
if isinstance(multimodal_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
total_tokens = batch_size * batch_tokens
if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}"
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = multimodal_embeddings.view(
total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
flattened = _flatten_embeddings(multimodal_embeddings)
*dims, embed_dim = flattened.shape
num_multimodal_embeddings = np.prod(dims)
if num_multimodal_embeddings != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {num_multimodal_embeddings} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim)
return inputs_embeds

View File

@ -1,4 +1,4 @@
from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins,
from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
NestedTensors)
from .registry import MultiModalRegistry
@ -14,7 +14,6 @@ See also:
__all__ = [
"BatchedTensorInputs",
"BatchedTensors",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalInputs",

View File

@ -1,9 +1,8 @@
import sys
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
TypedDict, TypeVar, Union, cast, final)
import numpy as np
import torch
@ -15,23 +14,16 @@ from typing_extensions import TypeAlias
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import JSONTree, json_map_leaves
from vllm.utils import json_map_leaves
logger = init_logger(__name__)
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor]
NestedTensors = Union[List["NestedTensors"], torch.Tensor]
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
BatchedTensors: TypeAlias = JSONTree[torch.Tensor]
"""
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]]
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
@ -54,26 +46,23 @@ class MultiModalInputs(_MultiModalInputsBase):
"""
@staticmethod
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
If each input tensor in the batch has the same shape, return a single
batched tensor; otherwise, return a list of :class:`NestedTensors` with
one element per item in the batch.
Recursively stacks lists of tensors when they all have the same shape.
"""
# may be list rather than tensors
if isinstance(tensors[0], list):
return [[t for t in tensor[0]]
for tensor in cast(List[List[torch.Tensor]], tensors)]
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
tensors_ = cast(List[torch.Tensor], tensors)
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if any(isinstance(t, list) for t in stacked):
return stacked
unbatched_shape = tensors_[0].shape[1:]
tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
for tensor in tensors_:
if tensor.shape[1:] != unbatched_shape:
return [tensor.squeeze(0) for tensor in tensors_]
return torch.cat(tensors_, dim=0)
return torch.stack(tensors_)
@staticmethod
def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
@ -102,7 +91,7 @@ class MultiModalInputs(_MultiModalInputsBase):
item_lists[k].append(v)
return {
k: MultiModalInputs._try_concat(item_list)
k: MultiModalInputs._try_stack(item_list)
for k, item_list in item_lists.items()
}