mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Model] Use merge_by_field_config
for MM models (D-F) (#26076)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -20,8 +20,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.transformers import replace_linear_class
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, MultiModalUUIDDict,
|
||||
NestedTensors)
|
||||
MultiModalKwargsItems, MultiModalUUIDDict)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@ -40,7 +39,7 @@ from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
# The image token id may be various
|
||||
@ -50,15 +49,15 @@ _IMAGE_TOKEN = "<image>"
|
||||
class DeepseekVL2ImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- bnp: Batch size * number of images * number of patches
|
||||
- p: Number of patches
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
|
||||
data: Annotated[torch.Tensor,
|
||||
TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})]
|
||||
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
||||
|
||||
|
||||
@ -228,12 +227,8 @@ class DeepseekVL2MultiModalProcessor(
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
pixel_values = processed_outputs["pixel_values"]
|
||||
# split pixel values into patches corresponding to each image
|
||||
images_spatial_crop = processed_outputs["images_spatial_crop"]
|
||||
patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop]
|
||||
pixel_values = pixel_values.split(patches_per_image)
|
||||
processed_outputs["pixel_values"] = pixel_values
|
||||
processed_outputs["num_patches"] = (
|
||||
processed_outputs["images_spatial_crop"].prod(-1) + 1)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@ -242,8 +237,11 @@ class DeepseekVL2MultiModalProcessor(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
num_patches = hf_inputs.get("num_patches", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_patches),
|
||||
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
@ -318,6 +316,7 @@ class DeepseekVL2MultiModalProcessor(
|
||||
info=DeepseekVL2ProcessingInfo,
|
||||
dummy_inputs=DeepseekVL2DummyInputsBuilder)
|
||||
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||
"language.": "language_model.",
|
||||
@ -460,37 +459,30 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
if pixel_values is not None:
|
||||
expected_h = expected_w = self.vision_config.image_size
|
||||
return DeepseekVL2ImagePixelInputs(type="pixel_values",
|
||||
data=flatten_bn(pixel_values),
|
||||
images_spatial_crop=flatten_bn(
|
||||
images_spatial_crop,
|
||||
concat=True),
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w,
|
||||
})
|
||||
return DeepseekVL2ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
images_spatial_crop=images_spatial_crop,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w,
|
||||
})
|
||||
|
||||
if image_embeds is not None:
|
||||
return DeepseekVL2VImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _pixel_values_to_embedding(
|
||||
self,
|
||||
pixel_values: NestedTensors,
|
||||
pixel_values: torch.Tensor,
|
||||
images_spatial_crop: torch.Tensor,
|
||||
) -> NestedTensors:
|
||||
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
|
||||
total_tiles = [x for x in pixel_values]
|
||||
|
||||
# [batch_all_tiles, 3, height, width]
|
||||
total_tiles = torch.cat(total_tiles, dim=0)
|
||||
|
||||
) -> list[torch.Tensor]:
|
||||
# [batch_all_tiles, vit_seq_len, c]
|
||||
images_feature = self.vision.forward_features(total_tiles)
|
||||
images_feature = self.vision.forward_features(pixel_values)
|
||||
|
||||
# [batch_all_tiles, hw, D]
|
||||
images_embeds = self.projector(images_feature)
|
||||
@ -573,7 +565,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return vision_embeddings
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor:
|
||||
self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_data = image_input["data"]
|
||||
if is_list_of(image_data, torch.Tensor):
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -42,34 +42,38 @@ from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
||||
DotsVisionConfig)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .vision import run_dp_sharded_mrope_vision_model
|
||||
|
||||
IMAGE_TOKEN = "<|imgpad|>"
|
||||
|
||||
|
||||
class DotsOCRImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values", "image_grid_thw"]
|
||||
|
||||
pixel_values: torch.Tensor
|
||||
image_grid_thw: torch.Tensor
|
||||
|
||||
|
||||
class DotsOCRImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds", "image_grid_thw"]
|
||||
image_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- List[`torch.Tensor`]: A list of tensors holding all images' features.
|
||||
Each tensor holds an image's features.
|
||||
- `torch.Tensor`: A tensor holding all images' features
|
||||
(concatenation of all images' feature tensors).
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the images.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
class DotsOCRImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: The total number of patches over each image over each prompt in
|
||||
the batch
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
|
||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
|
||||
class DotsOCRImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- nf: Number of image features
|
||||
- hs: Hidden size
|
||||
- ni: Number of images
|
||||
"""
|
||||
type: Literal["image_embeds"]
|
||||
|
||||
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
|
||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
|
||||
DotsOCRImageInputs = Union[DotsOCRImagePixelInputs,
|
||||
@ -654,6 +658,8 @@ class DotsVisionTransformer(nn.Module):
|
||||
)
|
||||
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
".attn.qkv_proj.": ".attn.qkv.",
|
||||
@ -709,22 +715,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
architectures=["Qwen2ForCausalLM"],
|
||||
)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return torch.concat(list(mm_input))
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -735,28 +725,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return DotsOCRImagePixelInputs(type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||
image_embeds, "image embeds")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return DotsOCRImageEmbeddingInputs(type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
image_grid_thw=image_grid_thw)
|
||||
|
@ -25,7 +25,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -56,6 +56,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
@ -579,38 +580,38 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
class Ernie4_5_VLImagePixelInputs(TypedDict):
|
||||
class Ernie4_5_VLImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: The total number of patches over each image over each prompt in
|
||||
the batch
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches, num_channels * patch_size * patch_size)`
|
||||
"""
|
||||
|
||||
grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
|
||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
|
||||
Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs
|
||||
|
||||
|
||||
class Ernie4_5_VLVideoPixelInputs(TypedDict):
|
||||
class Ernie4_5_VLVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: The total number of patches over each image over each prompt in
|
||||
the batch
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * temporal_patch_size * patch_size *
|
||||
patch_size
|
||||
"""
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches,
|
||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
||||
"""
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")]
|
||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
|
||||
Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs
|
||||
Ernie4_5_VLVideoInputs = Ernie4_5_VLVideoPixelInputs
|
||||
|
||||
# === Vision Processor === #
|
||||
|
||||
@ -1213,6 +1214,7 @@ class Ernie4_5_VLDummyInputsBuilder(
|
||||
dummy_inputs=Ernie4_5_VLDummyInputsBuilder)
|
||||
class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@ -1325,22 +1327,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -1350,15 +1336,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Ernie4_5_VLImagePixelInputs(type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw)
|
||||
@ -1372,11 +1349,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values_videos, "video pixel values")
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
return Ernie4_5_VLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
|
@ -59,17 +59,14 @@ class FuyuImagePatchInputs(TensorSchema):
|
||||
|
||||
type: Literal["image_patches"] = "image_patches"
|
||||
|
||||
flat_data: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bnp", "fn"),
|
||||
]
|
||||
image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
|
||||
|
||||
patches_per_image: Annotated[list[int], TensorShape("bn")]
|
||||
"""
|
||||
The number of total patches for each image in the batch.
|
||||
|
||||
This is used to split the embeddings which has the first two dimensions
|
||||
flattened just like `flat_data`.
|
||||
flattened just like `image_patches_flat`.
|
||||
"""
|
||||
|
||||
|
||||
@ -174,28 +171,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
image_patches = processed_outputs.get("image_patches")
|
||||
if image_patches is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
# Original output: (1, num_images, Pn, Px * Py * C)
|
||||
# New output: (num_images, Pn, Px * Py * C)
|
||||
# image_patches is a list with shape:
|
||||
# (1, num_images, Pn, Px * Py * C)
|
||||
# before Transformers 4.53
|
||||
if isinstance(image_patches, list):
|
||||
assert len(image_patches) == 1
|
||||
assert (isinstance(image_patches[0], torch.Tensor)
|
||||
and len(image_patches[0]) == len(images))
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
# image_patches is a tensor with shape:
|
||||
# (num_images, Pn, Px * Py * C)
|
||||
# after Transformers 4.53
|
||||
elif isinstance(image_patches, torch.Tensor):
|
||||
assert len(image_patches) == len(images)
|
||||
else:
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
image_patches = processed_outputs["image_patches"]
|
||||
processed_outputs["image_patches"] = flatten_bn(image_patches)
|
||||
processed_outputs["patches_per_image"] = torch.tensor(
|
||||
[len(p) for p in image_patches])
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@ -218,7 +197,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
||||
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
image_patches=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", patches_per_image),
|
||||
patches_per_image=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -263,6 +248,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
info=FuyuProcessingInfo,
|
||||
dummy_inputs=FuyuDummyInputsBuilder)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@ -306,29 +292,28 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
if image_patches is not None:
|
||||
image_patches_flat = flatten_bn(image_patches)
|
||||
flat_data = flatten_bn(image_patches_flat, concat=True)
|
||||
patches_per_image = kwargs.pop("patches_per_image", None)
|
||||
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
flat_data=flat_data,
|
||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
||||
resolve_bindings={"fn": self.image_feature_size},
|
||||
)
|
||||
if image_patches is None:
|
||||
return None
|
||||
|
||||
return None
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
image_patches_flat=image_patches,
|
||||
patches_per_image=patches_per_image,
|
||||
resolve_bindings={"fn": self.image_feature_size},
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
|
||||
image_patches_flat = image_input["flat_data"]
|
||||
image_patches_flat = image_input["image_patches_flat"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
vision_embeddings_flat, _ = self.vision_embed_tokens(
|
||||
image_patches_flat)
|
||||
|
||||
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||
return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
Reference in New Issue
Block a user