Re-submit: Fix: Proper RGBA -> RGB conversion for PIL images. (#18569)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
Chenheli Hua
2025-05-22 18:59:18 -07:00
committed by GitHub
parent 46791e1b4b
commit 04eb88dc80
15 changed files with 89 additions and 20 deletions

View File

@ -35,6 +35,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
logger = logging.getLogger(__name__)
@ -257,7 +258,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image):
image = image.convert("RGB")
image = convert_image_mode(image, "RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")

View File

@ -11,6 +11,7 @@ from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
from vllm.utils import FlexibleArgumentParser
@ -45,7 +46,8 @@ def get_mixed_modalities_query() -> QueryResult:
"audio":
AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image":
ImageAsset("cherry_blossom").pil_image.convert("RGB"),
convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB"),
"video":
VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},

View File

@ -19,6 +19,7 @@ from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import convert_image_mode
from vllm.utils import FlexibleArgumentParser
@ -1096,8 +1097,8 @@ def get_multi_modal_input(args):
"""
if args.modality == "image":
# Input image and question
image = ImageAsset("cherry_blossom") \
.pil_image.convert("RGB")
image = convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB")
img_questions = [
"What is the content of this image?",
"Describe the content of this image in detail.",

View File

@ -4,6 +4,7 @@ import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]
@ -26,8 +27,9 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None:
give the same result.
"""
image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image_stop = ImageAsset("stop_sign").pil_image.convert("RGB")
image_cherry = convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB")
image_stop = convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB")
images = [image_cherry, image_stop]
video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays

View File

@ -12,7 +12,7 @@ from transformers import AutoTokenizer
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.image import convert_image_mode, rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
@ -267,7 +267,7 @@ def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
# use the example speech question so that the model outputs are reasonable
audio = librosa.load(speech_question, sr=None)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
inputs_vision_speech = [
(

View File

@ -4,6 +4,7 @@ import pytest
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.multimodal.image import convert_image_mode
from ..utils import create_new_process_for_each_test
@ -58,7 +59,7 @@ def test_oot_registration_embedding(
assert all(v == 0 for v in output.outputs.embedding)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
@create_new_process_for_each_test()

Binary file not shown.

After

Width:  |  Height:  |  Size: 219 KiB

View File

@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import numpy as np
from PIL import Image, ImageChops
from vllm.multimodal.image import convert_image_mode
ASSETS_DIR = Path(__file__).parent / "assets"
assert ASSETS_DIR.exists()
def test_rgb_to_rgb():
# Start with an RGB image.
original_image = Image.open(ASSETS_DIR / "image1.png").convert("RGB")
converted_image = convert_image_mode(original_image, "RGB")
# RGB to RGB should be a no-op.
diff = ImageChops.difference(original_image, converted_image)
assert diff.getbbox() is None
def test_rgba_to_rgb():
original_image = Image.open(ASSETS_DIR / "rgba.png")
original_image_numpy = np.array(original_image)
converted_image = convert_image_mode(original_image, "RGB")
converted_image_numpy = np.array(converted_image)
for i in range(original_image_numpy.shape[0]):
for j in range(original_image_numpy.shape[1]):
# Verify that all transparent pixels are converted to white.
if original_image_numpy[i][j][3] == 0:
assert converted_image_numpy[i][j][0] == 255
assert converted_image_numpy[i][j][1] == 255
assert converted_image_numpy[i][j][2] == 255

View File

@ -10,6 +10,7 @@ import numpy as np
import pytest
from PIL import Image, ImageChops
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata)
@ -53,7 +54,7 @@ def get_supported_suffixes() -> tuple[str, ...]:
def _image_equals(a: Image.Image, b: Image.Image) -> bool:
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all()
@pytest.mark.asyncio

View File

@ -13,7 +13,6 @@ generation. Supported dataset types include:
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
SampleRequest instances, similar to the approach used in ShareGPT.
"""
import base64
import io
import json
@ -33,6 +32,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
logger = logging.getLogger(__name__)
@ -259,7 +259,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, dict) and 'bytes' in image:
image = Image.open(BytesIO(image['bytes']))
if isinstance(image, Image.Image):
image = image.convert("RGB")
image = convert_image_mode(image, "RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(

View File

@ -23,6 +23,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
@ -77,7 +78,7 @@ InternVLImageInputs = Union[InternVLImagePixelInputs,
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),

View File

@ -24,6 +24,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
@ -78,7 +79,7 @@ SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs,
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),

View File

@ -10,6 +10,7 @@ from blake3 import blake3
from PIL import Image
from vllm.logger import init_logger
from vllm.multimodal.image import convert_image_mode
if TYPE_CHECKING:
from vllm.inputs import TokensPrompt
@ -35,7 +36,8 @@ class MultiModalHasher:
return np.array(obj).tobytes()
if isinstance(obj, Image.Image):
return cls.item_to_bytes("image", np.array(obj.convert("RGBA")))
return cls.item_to_bytes("image",
np.array(convert_image_mode(obj, "RGBA")))
if isinstance(obj, torch.Tensor):
return cls.item_to_bytes("tensor", obj.numpy())
if isinstance(obj, np.ndarray):

View File

@ -22,6 +22,25 @@ def rescale_image_size(image: Image.Image,
return image
# TODO: Support customizable background color to fill in.
def rgba_to_rgb(
image: Image.Image, background_color=(255, 255, 255)) -> Image.Image:
"""Convert an RGBA image to RGB with filled background color."""
assert image.mode == "RGBA"
converted = Image.new("RGB", image.size, background_color)
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return converted
def convert_image_mode(image: Image.Image, to_mode: str):
if image.mode == to_mode:
return image
elif image.mode == "RGBA" and to_mode == "RGB":
return rgba_to_rgb(image)
else:
return image.convert(to_mode)
class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, *, image_mode: str = "RGB") -> None:
@ -32,7 +51,7 @@ class ImageMediaIO(MediaIO[Image.Image]):
def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data))
image.load()
return image.convert(self.image_mode)
return convert_image_mode(image, self.image_mode)
def load_base64(self, media_type: str, data: str) -> Image.Image:
return self.load_bytes(base64.b64decode(data))
@ -40,7 +59,7 @@ class ImageMediaIO(MediaIO[Image.Image]):
def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath)
image.load()
return image.convert(self.image_mode)
return convert_image_mode(image, self.image_mode)
def encode_base64(
self,
@ -51,7 +70,7 @@ class ImageMediaIO(MediaIO[Image.Image]):
image = media
with BytesIO() as buffer:
image = image.convert(self.image_mode)
image = convert_image_mode(image, self.image_mode)
image.save(buffer, image_format)
data = buffer.getvalue()

View File

@ -33,6 +33,8 @@ from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
Unpack)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.multimodal.image import convert_image_mode
__all__ = ['OvisProcessor']
IGNORE_ID = -100
@ -361,8 +363,8 @@ class OvisProcessor(ProcessorMixin):
# pick the partition with maximum covering_ratio and break the tie using #sub_images
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
if convert_to_rgb and image.mode != 'RGB':
image = image.convert('RGB')
if convert_to_rgb:
image = convert_image_mode(image, 'RGB')
sides = self.get_image_size()