mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[VLM] Report multi_modal_placeholders in output (#10407)
Signed-off-by: Linkun Chen <lkchen+anyscale@github.com>
This commit is contained in:
@ -8,13 +8,17 @@ from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
from mistral_common.multimodal import download_image
|
||||
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
||||
from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams,
|
||||
TextPrompt, TokensPrompt)
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sequence import Logprob, SampleLogprobs
|
||||
|
||||
from ....utils import VLLM_PATH, large_gpu_test
|
||||
@ -49,6 +53,20 @@ def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
|
||||
}]
|
||||
|
||||
|
||||
def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]:
|
||||
return [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"content": PROMPT,
|
||||
}, *({
|
||||
"type": "image",
|
||||
"image": download_image(url)
|
||||
} for url in urls)],
|
||||
}]
|
||||
|
||||
|
||||
def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
|
||||
msg = _create_msg_format(urls)
|
||||
|
||||
@ -70,6 +88,23 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
|
||||
return engine_inputs
|
||||
|
||||
|
||||
def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt:
|
||||
msg = _create_msg_format_hf(urls)
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b")
|
||||
prompt = tokenizer.apply_chat_template(msg)
|
||||
|
||||
images = []
|
||||
for chunk in msg[0]["content"]:
|
||||
if chunk["type"] == "image":
|
||||
images.append(chunk["image"])
|
||||
|
||||
mm_data = MultiModalDataBuiltins(image=images)
|
||||
engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data)
|
||||
|
||||
return engine_inputs
|
||||
|
||||
|
||||
MSGS = [
|
||||
_create_msg_format(IMG_URLS[:1]),
|
||||
_create_msg_format(IMG_URLS[:2]),
|
||||
@ -191,3 +226,45 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||
outputs_1_lst=logprobs,
|
||||
name_0="h100_ref",
|
||||
name_1="output")
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=24)
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,expected_ranges",
|
||||
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
|
||||
"offset": 10,
|
||||
"length": 494
|
||||
}]),
|
||||
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
|
||||
"offset": 10,
|
||||
"length": 266
|
||||
}, {
|
||||
"offset": 276,
|
||||
"length": 1056
|
||||
}, {
|
||||
"offset": 1332,
|
||||
"length": 418
|
||||
}])])
|
||||
def test_multi_modal_placeholders(
|
||||
vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None:
|
||||
with vllm_runner(
|
||||
"mistral-community/pixtral-12b",
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.model.generate(prompt)
|
||||
|
||||
assert len(outputs) == 1, f"{len(outputs)=}"
|
||||
output: RequestOutput = outputs[0]
|
||||
assert hasattr(output,
|
||||
"multi_modal_placeholders"), f"{output.__dict__=}"
|
||||
assert "image" in output.multi_modal_placeholders, \
|
||||
f"{output.multi_modal_placeholders.keys()=}"
|
||||
image_placeholder_ranges: list[
|
||||
PlaceholderRange] = output.multi_modal_placeholders["image"]
|
||||
assert len(image_placeholder_ranges) == len(
|
||||
expected_ranges), f"{image_placeholder_ranges=}"
|
||||
for real_range, expected_range in zip(image_placeholder_ranges,
|
||||
expected_ranges):
|
||||
assert real_range == expected_range, \
|
||||
f"{real_range=} {expected_range=}"
|
||||
|
@ -30,6 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
consecutive_placeholder_ranges)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
@ -773,15 +774,28 @@ def input_processor_for_pixtral_hf(
|
||||
replace_tokens[-1] = image_end_id
|
||||
replace_tokens_list.append(replace_tokens)
|
||||
|
||||
reverse_offsets: List[int] = []
|
||||
# Backward iteration for replacement without affecting known indices
|
||||
for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
|
||||
reversed(replace_tokens_list)):
|
||||
reverse_offsets.append(
|
||||
len(new_token_ids) - placeholder_idx + len(replace_tokens))
|
||||
new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens
|
||||
|
||||
placeholder_ranges: List[PlaceholderRange] = []
|
||||
for reverse_offset, replace_tokens in zip(reversed(reverse_offsets),
|
||||
replace_tokens_list):
|
||||
placeholder_ranges.append(
|
||||
PlaceholderRange(
|
||||
offset=len(new_token_ids) - reverse_offset,
|
||||
length=len(replace_tokens),
|
||||
))
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"image": placeholder_ranges})
|
||||
|
||||
|
||||
class PixtralHFMLP(nn.Module):
|
||||
|
@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||
SequenceGroup, SequenceGroupBase, SequenceStatus)
|
||||
@ -103,10 +104,13 @@ class RequestOutput:
|
||||
encoder_prompt: Optional[str] = None,
|
||||
encoder_prompt_token_ids: Optional[List[int]] = None,
|
||||
num_cached_tokens: Optional[int] = None,
|
||||
*,
|
||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.multi_modal_placeholders = multi_modal_placeholders or {}
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.outputs = outputs
|
||||
self.finished = finished
|
||||
@ -275,17 +279,26 @@ class RequestOutput:
|
||||
finished_time = time.time() if finished else None
|
||||
seq_group.set_finished_time(finished_time)
|
||||
|
||||
init_args = (seq_group.request_id, prompt, prompt_token_ids,
|
||||
prompt_logprobs, outputs, finished, seq_group.metrics,
|
||||
seq_group.lora_request, encoder_prompt,
|
||||
encoder_prompt_token_ids, num_cached_tokens)
|
||||
init_kwargs = {
|
||||
"request_id": seq_group.request_id,
|
||||
"prompt": prompt,
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"prompt_logprobs": prompt_logprobs,
|
||||
"outputs": outputs,
|
||||
"finished": finished,
|
||||
"metrics": seq_group.metrics,
|
||||
"lora_request": seq_group.lora_request,
|
||||
"encoder_prompt": encoder_prompt,
|
||||
"encoder_prompt_token_ids": encoder_prompt_token_ids,
|
||||
"num_cached_tokens": num_cached_tokens,
|
||||
"multi_modal_placeholders": seq_group.multi_modal_placeholders
|
||||
}
|
||||
|
||||
if use_cache:
|
||||
request_output = seq_group.cached_request_output
|
||||
request_output.__init__(*init_args) # type: ignore
|
||||
|
||||
request_output.__init__(**init_kwargs) # type: ignore
|
||||
else:
|
||||
request_output = cls(*init_args)
|
||||
request_output = cls(**init_kwargs) # type: ignore
|
||||
|
||||
return request_output
|
||||
|
||||
@ -300,7 +313,8 @@ class RequestOutput:
|
||||
f"finished={self.finished}, "
|
||||
f"metrics={self.metrics}, "
|
||||
f"lora_request={self.lora_request}, "
|
||||
f"num_cached_tokens={self.num_cached_tokens})")
|
||||
f"num_cached_tokens={self.num_cached_tokens}, "
|
||||
f"multi_modal_placeholders={self.multi_modal_placeholders})")
|
||||
|
||||
|
||||
class EmbeddingRequestOutput:
|
||||
|
Reference in New Issue
Block a user