[VLM] Report multi_modal_placeholders in output (#10407)

Signed-off-by: Linkun Chen <lkchen+anyscale@github.com>
This commit is contained in:
lkchen
2024-11-18 00:06:16 -08:00
committed by GitHub
parent 51bb12d17b
commit c7dec926f6
3 changed files with 115 additions and 10 deletions

View File

@ -8,13 +8,17 @@ from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import pytest
from mistral_common.multimodal import download_image
from mistral_common.protocol.instruct.messages import ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from transformers import AutoProcessor
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams,
TextPrompt, TokensPrompt)
from vllm.multimodal import MultiModalDataBuiltins
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sequence import Logprob, SampleLogprobs
from ....utils import VLLM_PATH, large_gpu_test
@ -49,6 +53,20 @@ def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
}]
def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]:
return [{
"role":
"user",
"content": [{
"type": "text",
"content": PROMPT,
}, *({
"type": "image",
"image": download_image(url)
} for url in urls)],
}]
def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
msg = _create_msg_format(urls)
@ -70,6 +88,23 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
return engine_inputs
def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt:
msg = _create_msg_format_hf(urls)
tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b")
prompt = tokenizer.apply_chat_template(msg)
images = []
for chunk in msg[0]["content"]:
if chunk["type"] == "image":
images.append(chunk["image"])
mm_data = MultiModalDataBuiltins(image=images)
engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data)
return engine_inputs
MSGS = [
_create_msg_format(IMG_URLS[:1]),
_create_msg_format(IMG_URLS[:2]),
@ -191,3 +226,45 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")
@large_gpu_test(min_gb=24)
@pytest.mark.parametrize(
"prompt,expected_ranges",
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
"offset": 10,
"length": 494
}]),
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
"offset": 10,
"length": 266
}, {
"offset": 276,
"length": 1056
}, {
"offset": 1332,
"length": 418
}])])
def test_multi_modal_placeholders(
vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None:
with vllm_runner(
"mistral-community/pixtral-12b",
max_model_len=8192,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model:
outputs = vllm_model.model.generate(prompt)
assert len(outputs) == 1, f"{len(outputs)=}"
output: RequestOutput = outputs[0]
assert hasattr(output,
"multi_modal_placeholders"), f"{output.__dict__=}"
assert "image" in output.multi_modal_placeholders, \
f"{output.multi_modal_placeholders.keys()=}"
image_placeholder_ranges: list[
PlaceholderRange] = output.multi_modal_placeholders["image"]
assert len(image_placeholder_ranges) == len(
expected_ranges), f"{image_placeholder_ranges=}"
for real_range, expected_range in zip(image_placeholder_ranges,
expected_ranges):
assert real_range == expected_range, \
f"{real_range=} {expected_range=}"

View File

@ -30,6 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import IntermediateTensors, SequenceData
@ -773,15 +774,28 @@ def input_processor_for_pixtral_hf(
replace_tokens[-1] = image_end_id
replace_tokens_list.append(replace_tokens)
reverse_offsets: List[int] = []
# Backward iteration for replacement without affecting known indices
for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
reversed(replace_tokens_list)):
reverse_offsets.append(
len(new_token_ids) - placeholder_idx + len(replace_tokens))
new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens
placeholder_ranges: List[PlaceholderRange] = []
for reverse_offset, replace_tokens in zip(reversed(reverse_offsets),
replace_tokens_list):
placeholder_ranges.append(
PlaceholderRange(
offset=len(new_token_ids) - reverse_offset,
length=len(replace_tokens),
))
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
class PixtralHFMLP(nn.Module):

View File

@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Union
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceGroupBase, SequenceStatus)
@ -103,10 +104,13 @@ class RequestOutput:
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None,
num_cached_tokens: Optional[int] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.multi_modal_placeholders = multi_modal_placeholders or {}
self.prompt_logprobs = prompt_logprobs
self.outputs = outputs
self.finished = finished
@ -275,17 +279,26 @@ class RequestOutput:
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished, seq_group.metrics,
seq_group.lora_request, encoder_prompt,
encoder_prompt_token_ids, num_cached_tokens)
init_kwargs = {
"request_id": seq_group.request_id,
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"prompt_logprobs": prompt_logprobs,
"outputs": outputs,
"finished": finished,
"metrics": seq_group.metrics,
"lora_request": seq_group.lora_request,
"encoder_prompt": encoder_prompt,
"encoder_prompt_token_ids": encoder_prompt_token_ids,
"num_cached_tokens": num_cached_tokens,
"multi_modal_placeholders": seq_group.multi_modal_placeholders
}
if use_cache:
request_output = seq_group.cached_request_output
request_output.__init__(*init_args) # type: ignore
request_output.__init__(**init_kwargs) # type: ignore
else:
request_output = cls(*init_args)
request_output = cls(**init_kwargs) # type: ignore
return request_output
@ -300,7 +313,8 @@ class RequestOutput:
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request}, "
f"num_cached_tokens={self.num_cached_tokens})")
f"num_cached_tokens={self.num_cached_tokens}, "
f"multi_modal_placeholders={self.multi_modal_placeholders})")
class EmbeddingRequestOutput: