mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Hotfix][Pixtral] Fix multiple images bugs (#8415)
This commit is contained in:
committed by
GitHub
parent
b61bd98f90
commit
d31174a4e1
@ -658,8 +658,8 @@ class VllmRunner:
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _final_steps_generate_w_logprobs(
|
||||
self,
|
||||
req_outputs: List[RequestOutput],
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
||||
|
BIN
tests/models/fixtures/pixtral_chat.pickle
Normal file
BIN
tests/models/fixtures/pixtral_chat.pickle
Normal file
Binary file not shown.
BIN
tests/models/fixtures/pixtral_chat_engine.pickle
Normal file
BIN
tests/models/fixtures/pixtral_chat_engine.pickle
Normal file
Binary file not shown.
@ -2,13 +2,128 @@
|
||||
|
||||
Run `pytest tests/models/test_mistral.py`.
|
||||
"""
|
||||
import pytest
|
||||
import pickle
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
MODELS = ["mistralai/Pixtral-12B-2409"]
|
||||
IMG_URLS = [
|
||||
"https://picsum.photos/id/237/400/300",
|
||||
"https://picsum.photos/id/231/200/300",
|
||||
"https://picsum.photos/id/27/500/500",
|
||||
"https://picsum.photos/id/17/150/600",
|
||||
]
|
||||
PROMPT = "Describe each image in one short sentence."
|
||||
|
||||
|
||||
def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
|
||||
return [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": PROMPT,
|
||||
}] + [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url
|
||||
}
|
||||
} for url in urls],
|
||||
}]
|
||||
|
||||
|
||||
def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
|
||||
msg = _create_msg_format(urls)
|
||||
|
||||
tokenizer = MistralTokenizer.from_model("pixtral")
|
||||
|
||||
request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
|
||||
tokenized = tokenizer.encode_chat_completion(request)
|
||||
|
||||
engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)
|
||||
|
||||
images = []
|
||||
for chunk in request.messages[0].content:
|
||||
if isinstance(chunk, ImageURLChunk):
|
||||
images.append(image_from_chunk(chunk))
|
||||
|
||||
mm_data = MultiModalDataBuiltins(image=images)
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
return engine_inputs
|
||||
|
||||
|
||||
MSGS = [
|
||||
_create_msg_format(IMG_URLS[:1]),
|
||||
_create_msg_format(IMG_URLS[:2]),
|
||||
_create_msg_format(IMG_URLS),
|
||||
]
|
||||
ENGINE_INPUTS = [
|
||||
_create_engine_inputs(IMG_URLS[:1]),
|
||||
_create_engine_inputs(IMG_URLS[:2]),
|
||||
_create_engine_inputs(IMG_URLS),
|
||||
]
|
||||
|
||||
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
|
||||
LIMIT_MM_PER_PROMPT = dict(image=4)
|
||||
|
||||
MAX_MODEL_LEN = [8192, 65536]
|
||||
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
|
||||
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
|
||||
|
||||
|
||||
def load_logprobs(filename: str) -> Any:
|
||||
with open(filename, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on A100 locally but will OOM on CI machine."
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
def test_chat(
|
||||
vllm_runner,
|
||||
max_model_len: int,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tokenizer_mode="mistral",
|
||||
enable_chunked_prefill=False,
|
||||
max_model_len=max_model_len,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
) as vllm_model:
|
||||
outputs = []
|
||||
for msg in MSGS:
|
||||
output = vllm_model.model.chat(msg,
|
||||
sampling_params=SAMPLING_PARAMS)
|
||||
|
||||
outputs.extend(output)
|
||||
|
||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||
check_logprobs_close(outputs_0_lst=logprobs,
|
||||
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
|
||||
name_0="output",
|
||||
name_1="h100_ref")
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
@ -17,48 +132,37 @@ MODELS = ["mistralai/Pixtral-12B-2409"]
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
image_urls = [
|
||||
"https://picsum.photos/id/237/200/300",
|
||||
"https://picsum.photos/seed/picsum/200/300"
|
||||
]
|
||||
expected = [
|
||||
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
|
||||
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
|
||||
]
|
||||
prompt = "Describe the image in one short sentence."
|
||||
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
|
||||
args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer_mode="mistral",
|
||||
enable_chunked_prefill=False,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
dtype=dtype,
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(args)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
|
||||
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
|
||||
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
|
||||
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
tokenizer_mode="mistral") as vllm_model:
|
||||
outputs = []
|
||||
count = 0
|
||||
while True:
|
||||
out = engine.step()
|
||||
count += 1
|
||||
for request_output in out:
|
||||
if request_output.finished:
|
||||
outputs.append(request_output)
|
||||
|
||||
for i, image_url in enumerate(image_urls):
|
||||
messages = [
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}]
|
||||
},
|
||||
]
|
||||
if count == 2:
|
||||
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
|
||||
SAMPLING_PARAMS)
|
||||
if not engine.has_unfinished_requests():
|
||||
break
|
||||
|
||||
outputs = vllm_model.model.chat(messages,
|
||||
sampling_params=sampling_params)
|
||||
assert outputs[0].outputs[0].text == expected[i]
|
||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||
check_logprobs_close(outputs_0_lst=logprobs,
|
||||
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
|
||||
name_0="output",
|
||||
name_1="h100_ref")
|
||||
|
@ -1,4 +1,3 @@
|
||||
import math
|
||||
from array import array
|
||||
from dataclasses import dataclass, fields
|
||||
from itertools import tee
|
||||
@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||
mm_encoder = tokenizer.instruct.mm_encoder
|
||||
|
||||
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
||||
patch_size = mm_encoder.mm_config.image_patch_size
|
||||
image_token_id = mm_encoder.special_ids.img
|
||||
|
||||
mm_config = ctx.model_config.multimodal_config
|
||||
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
|
||||
|
||||
# approximate image size
|
||||
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
|
||||
num_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
|
||||
# dummy size
|
||||
size = 256
|
||||
image = Image.new("RGB", (size, size), color=0)
|
||||
img_chunk = ImageChunk(image=image)
|
||||
|
||||
tokens = mm_encoder(img_chunk).tokens
|
||||
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
tokens)
|
||||
image_feature_size = (size**2) // (patch_size**2)
|
||||
|
||||
num_image_tokens = image_feature_size * num_images
|
||||
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * num_image_tokens
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - num_image_tokens)
|
||||
|
||||
seq_data = SequenceData(token_ids)
|
||||
mm_data = {"image": max_num_images_per_request * [image]}
|
||||
mm_data = {"image": num_images * [image]}
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
|
||||
return MultiModalInputs({"images": images})
|
||||
|
||||
|
||||
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: Optional[List[torch.Tensor]],
|
||||
image_id: int) -> torch.Tensor:
|
||||
text_locations = input_ids != image_id
|
||||
image_locations = input_ids == image_id
|
||||
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is not None and "image" in multi_modal_data:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||
|
||||
seq_len = input_ids.shape[0]
|
||||
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
||||
image_token_id = mm_encoder.special_ids.img
|
||||
|
||||
N_txt = text_locations.sum().item()
|
||||
_, D_txt = inputs_embeds.shape
|
||||
N_img, D_img = image_features.shape
|
||||
if image_token_id not in llm_inputs['prompt_token_ids']:
|
||||
raise ValueError(
|
||||
(f"You've passed {llm_inputs=} without {image_token_id=}"
|
||||
" Make sure to process your input via mistral_common's"
|
||||
" tokenizer or pass a chat completion request. For more"
|
||||
" For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411."))
|
||||
|
||||
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
|
||||
"to image features dim {D_img}")
|
||||
assert (seq_len == N_txt +
|
||||
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
|
||||
f"{(N_txt, N_img, image_locations.sum().item())}")
|
||||
|
||||
inputs_embeds[image_locations, :] = image_features
|
||||
return inputs_embeds
|
||||
return llm_inputs
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
|
||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
def __init__(self,
|
||||
@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
return None
|
||||
|
||||
if isinstance(images, torch.Tensor):
|
||||
# always take last images
|
||||
images = [images[-1][i] for i in range(images.size(1))]
|
||||
# if passed as batch take all images
|
||||
N, B, C, W, H = images.shape
|
||||
images = images.reshape(N * B, C, W, H)
|
||||
images = [images[i] for i in range(images.size(0))]
|
||||
elif isinstance(images, list):
|
||||
# always take last images
|
||||
images = [images[-1][i] for i in range(len(images[0]))]
|
||||
# if passed as list flatten lists of tensors
|
||||
flatten_images = []
|
||||
for imgs_per_req in images:
|
||||
imgs_per_req = [
|
||||
imgs_per_req[i] for i in range(imgs_per_req.size(0))
|
||||
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
|
||||
|
||||
flatten_images.extend(imgs_per_req)
|
||||
|
||||
images = flatten_images
|
||||
|
||||
return images
|
||||
|
||||
|
Reference in New Issue
Block a user