mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix precision error in LLaVA-NeXT (#11735)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -15,10 +15,9 @@ def processor_for_llava_next():
|
||||
return LlavaNextMultiModalProcessor
|
||||
|
||||
|
||||
# FIXME: image_size [(198, 176), (176, 198)]
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
|
||||
(488, 183)])
|
||||
(488, 183), (198, 176), (176, 198)])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_prompt_replacements(
|
||||
processor_for_llava_next,
|
||||
|
@ -2,6 +2,7 @@ from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
|
||||
@ -139,16 +140,21 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
# NOTE: HF resizes based on float32
|
||||
original_aspect_ratio = np.array(original_width / original_height,
|
||||
dtype=np.float32)
|
||||
current_aspect_ratio = np.array(current_width / current_height,
|
||||
dtype=np.float32)
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
scale_factor = np.array(current_width / original_width,
|
||||
dtype=np.float32)
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= 2 * padding
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
scale_factor = np.array(current_height / original_height,
|
||||
dtype=np.float32)
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= 2 * padding
|
||||
|
@ -3,6 +3,7 @@ from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, LlavaOnevisionConfig,
|
||||
@ -127,18 +128,24 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
# NOTE: HF resizes based on float32
|
||||
original_aspect_ratio = np.array(original_width / original_height,
|
||||
dtype=np.float32)
|
||||
current_aspect_ratio = np.array(current_width / current_height,
|
||||
dtype=np.float32)
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
new_height = int(original_height *
|
||||
(current_width / original_width))
|
||||
scale_factor = np.array(current_width / original_width,
|
||||
dtype=np.float32)
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= padding * 2
|
||||
current_height -= 2 * padding
|
||||
else:
|
||||
new_width = int(original_width *
|
||||
(current_height / original_height))
|
||||
scale_factor = np.array(current_height / original_height,
|
||||
dtype=np.float32)
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= padding * 2
|
||||
current_width -= 2 * padding
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
|
Reference in New Issue
Block a user