mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[VLM][Core] Fix exceptions on ragged NestedTensors (#7974)
This commit is contained in:
@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists():
|
||||
result,
|
||||
{"image": torch.stack([torch.stack([a, b]),
|
||||
torch.stack([c, d])])})
|
||||
|
||||
|
||||
def test_multimodal_input_batch_mixed_stacking_depths():
|
||||
a = torch.rand([1, 2, 3])
|
||||
b = torch.rand([1, 3, 3])
|
||||
c = torch.rand([1, 4, 3])
|
||||
|
||||
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
|
||||
assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})
|
||||
|
||||
result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}])
|
||||
assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
|
||||
Union, overload)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.func import functional_call
|
||||
@ -96,12 +95,13 @@ def flatten_bn(
|
||||
|
||||
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
|
||||
"""
|
||||
Recursively concatenates NestedTensors along any heterogeneously sized
|
||||
dimensions.
|
||||
Recursively flattens and concatenates NestedTensors on all but the last
|
||||
dimension.
|
||||
"""
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
return embeddings
|
||||
# Flatten all but the last dimension.
|
||||
return embeddings.flatten(0, -2)
|
||||
|
||||
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
|
||||
|
||||
@ -136,15 +136,13 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
assert isinstance(num_expected_tokens, int)
|
||||
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
*dims, embed_dim = flattened.shape
|
||||
num_multimodal_embeddings = np.prod(dims)
|
||||
if num_multimodal_embeddings != num_expected_tokens:
|
||||
if flattened.shape[0] != num_expected_tokens:
|
||||
expr = _embedding_count_expression(multimodal_embeddings)
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {num_multimodal_embeddings} "
|
||||
f"Attempted to assign {expr} = {flattened.shape[0]} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim)
|
||||
inputs_embeds[mask] = flattened
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
|
@ -54,8 +54,8 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
return nested_tensors
|
||||
|
||||
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
|
||||
if is_list_of(stacked, list):
|
||||
# Do not stack nested lists
|
||||
if not is_list_of(stacked, torch.Tensor, check="all"):
|
||||
# Only tensors (not lists) can be stacked.
|
||||
return stacked
|
||||
|
||||
tensors_ = cast(List[torch.Tensor], stacked)
|
||||
|
Reference in New Issue
Block a user