[Model] Gemma3n MM (#20495)

Signed-off-by: ShriKode <shrikode@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Roger Wang <hey@rogerw.me>
Co-authored-by: ShriKode <shrikode@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
Nicolò Lucchesi
2025-08-09 18:56:25 +02:00
committed by GitHub
parent 2d18256e47
commit 5a16fa614c
11 changed files with 864 additions and 55 deletions

View File

@ -349,7 +349,7 @@ th {
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
@ -412,9 +412,6 @@ th {
!!! note
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
!!! note
Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
### Pooling Models
See [this page](./pooling_models.md) for more information on how to use pooling models.
@ -608,6 +605,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
@ -677,6 +675,15 @@ Some models are supported only via the [Transformers backend](#transformers). Th
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
!!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
MobileNet-v5 vision backbone.
Performance is not yet fully optimized mainly due to:
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
!!! note
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.

View File

@ -96,6 +96,25 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
)
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
model_name = "google/gemma-3n-E2B-it"
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_batched_tokens=2048,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
"<end_of_turn>\n<start_of_turn>model\n"
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is
@ -331,6 +350,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"voxtral": run_voxtral,
"gemma3n": run_gemma3n,
"granite_speech": run_granite_speech,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,

View File

@ -211,7 +211,33 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Gemma3N
def run_gemma3n(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "google/gemma-3n-E2B-it"
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
enforce_eager=True,
)
prompts = [
(
"<start_of_turn>user\n"
f"<image_soft_token>{question}<end_of_turn>\n"
"<start_of_turn>model\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
@ -1395,6 +1421,7 @@ model_example_map = {
"florence2": run_florence2,
"fuyu": run_fuyu,
"gemma3": run_gemma3,
"gemma3n": run_gemma3n,
"glm4v": run_glm4v,
"glm4_1v": run_glm4_1v,
"h2ovl_chat": run_h2ovl,

View File

@ -21,7 +21,7 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli
sentence-transformers # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests
timm # required for internvl test
timm >=1.0.17 # required for internvl and gemma3n-mm test
torch==2.7.1
torchaudio==2.7.1
torchvision==0.22.1

View File

@ -1051,7 +1051,7 @@ tiktoken==0.7.0
# via
# lm-eval
# mistral-common
timm==1.0.15
timm==1.0.17
# via
# -r requirements/test.in
# open-clip-torch

View File

@ -271,6 +271,7 @@ def _test_processing_correctness_one(
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"google/gemma-3n-E2B-it",
"zai-org/glm-4v-9b",
"zai-org/GLM-4.1V-9B-Thinking",
"ibm-granite/granite-speech-3.3-2b",
@ -315,7 +316,7 @@ def _test_processing_correctness_one(
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
"omni-research/Tarsier-7b",
"omni-research/Tarsier2-Recap-7b"
"omni-research/Tarsier2-Recap-7b",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@ -327,6 +328,8 @@ def test_processing_correctness(
num_batches: int,
simplify_rate: float,
):
if model_id == "google/gemma-3n-E2B-it":
pytest.skip("Skipping gemma-3n-E2B-it due to transformers #39911 bug.")
_test_processing_correctness(
model_id,
hit_rate=hit_rate,

View File

@ -186,7 +186,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it",
min_transformers_version="4.53"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
@ -391,6 +391,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
min_transformers_version="4.53"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b",
trust_remote_code=True,

61
tests/test_test.py Normal file
View File

@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import LLM, envs
from vllm.sampling_params import SamplingParams
if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
# TODO TPU will appear busy if we fan-out test params here
@pytest.mark.parametrize("n_prompts", [1])
def test_logprobs(model_name: str, n_prompts: int):
"""
Request top logprobs with different sampling settings and check
that results contains the requested number, ordered ascendingly.
"""
def check_num_logprobs(logprobs, expected_num: int):
for step in logprobs:
prev_logp = 1.0
# order by rank
sorted_step = dict(
sorted(step.items(), key=lambda item: item[1].rank))
if len(step) != expected_num:
print("watch out", sorted_step)
# check results are ordered by prob value
# assert len(step) == expected_num
for rankno, (tid, logp) in enumerate(sorted_step.items()):
assert logp.logprob <= prev_logp
prev_logp = logp.logprob
assert logp.rank == rankno + 1
llm = LLM(model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=128,
max_num_batched_tokens=128)
prompts = [
"Write a short story about a robot that dreams for the first time."
] * n_prompts
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
logprobs=4)
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4)
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4, top_k=12, top_p=0.5)
for sp in [greedy_sampling_params, regular_sampling_params, \
topkp_sampling_params]:
output = llm.generate(prompts, sp)
for o in output:
check_num_logprobs(o.outputs[0].logprobs, 4)

View File

@ -331,14 +331,15 @@ class Gemma3nAttention(nn.Module):
config.num_kv_shared_layers)
self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx
kv_sharing_target_layer_name = None
if self.is_kv_shared:
# Last full attention layer is 1 before sharing
# Last sliding attention layer is 2 before sharing
offset = 2 if self.sliding_window is not None else 1
kv_shared_layer_index = first_kv_shared_layer_idx - offset
kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
else:
kv_sharing_target_layer_name = None
if kv_shared_layer_index >= 0:
# Only the greater layer is required to specify sharing.
kv_sharing_target_layer_name = f"language_model.model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
self.rotary_emb = get_rope(
self.head_dim,
@ -396,6 +397,7 @@ class Gemma3nDecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
assert isinstance(config, Gemma3nTextConfig)
self.altup_active_idx = config.altup_active_idx
assert config.altup_correct_scale
@ -537,7 +539,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
@ -553,6 +555,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
config.hidden_size**0.5,
dtype=self.embed_tokens.weight.dtype,
)
# Additional per-layer embeddings (PLE)
self.embed_tokens_per_layer = VocabParallelEmbedding(
config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input,
@ -636,6 +639,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
per_layer_inputs: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
@ -644,13 +649,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
# Per layer inputs.
if input_ids is None:
raise ValueError("Passing None for input ids is not supported.")
per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
per_layer_inputs = per_layer_inputs.reshape(
-1, self.config.num_hidden_layers,
self.config.hidden_size_per_layer_input)
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
per_layer_projection = per_layer_projection.reshape(
*hidden_states_0.shape[:-1],
@ -659,8 +657,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
)
per_layer_projection = self.per_layer_projection_norm(
per_layer_projection)
if per_layer_inputs is not None:
# Profiling run does not compute per_layer_inputs
per_layer_inputs = per_layer_projection + per_layer_inputs
per_layer_inputs *= self.per_layer_input_scale
else:
per_layer_inputs = per_layer_projection
# Altup embed.
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
@ -760,29 +763,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
return loaded_params
class Gemma3nModel(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.language_model = Gemma3nTextModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "language_model"))
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.language_model(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
**kwargs)
class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
class Gemma3nForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -802,25 +783,33 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
super().__init__()
self.config = config
self.cache_config = vllm_config.cache_config
self.model = Gemma3nModel(vllm_config=vllm_config,
self.model = Gemma3nTextModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(
config.text_config.vocab_size,
soft_cap=config.text_config.final_logit_softcapping)
config.vocab_size, soft_cap=config.final_logit_softcapping)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.language_model.get_input_embeddings(input_ids)
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
*,
per_layer_inputs: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, **kwargs)
hidden_states = self.model(
input_ids,
positions,
per_layer_inputs=per_layer_inputs,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def compute_logits(
@ -828,8 +817,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata],
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.language_model.embed_tokens,
hidden_states, sampling_metadata)
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,

View File

@ -0,0 +1,700 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, TypedDict, Union, cast
import torch
from torch import nn
from transformers import AutoModel, BatchFeature
from transformers.models.gemma3n import (Gemma3nAudioConfig,
Gemma3nAudioFeatureExtractor,
Gemma3nConfig, Gemma3nProcessor,
Gemma3nTextConfig,
Gemma3nVisionConfig)
from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
MultiModalDataParser)
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
find_mm_placeholders,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
# This should be based on model config but we hardcode them for now.
TOKENS_PER_IMAGE = 256
TOKENS_PER_AUDIO = 188
class Gemma3nImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3nAudioInputs(TypedDict):
input_features: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
Gemma3nImageInputs = Gemma3nImagePixelInputs
class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma3nConfig)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "audio": None}
def get_max_tokens_per_item(
self, seq_len: int,
mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]:
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3nProcessor],
) -> str:
"""
Get the replacement text for image tokens.
For Gemma3n, this should return the full_image_sequence which includes
BOI token, repeated image tokens, and EOI token.
"""
if processor is None:
processor = self.get_hf_processor()
return PromptUpdateDetails.select_token_id(
processor.full_image_sequence, processor.image_token_id)
def get_audio_repl(
self,
*,
processor: Optional[Gemma3nProcessor],
) -> str:
"""
Get the replacement text for audio tokens.
For Gemma3n, this should return the full_audio_sequence which includes
BOA token, repeated audio tokens, and EOA token.
"""
if processor is None:
processor = self.get_hf_processor()
# Return the full audio sequence as defined by the processor
return PromptUpdateDetails.select_token_id(
processor.full_audio_sequence, processor.audio_token_id)
class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
audio_token = processor.audio_token
return image_token * num_images + audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
processor = self.info.get_hf_processor()
audio_feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501
audio_len = audio_feature_extractor.fft_length
image_processor: SiglipImageProcessorFast = processor.image_processor
img_width = image_processor.size.get("width", 224)
img_height = image_processor.size.get("height", 224)
return {
"image":
self._get_dummy_images(width=img_width,
height=img_height,
num_images=num_images),
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_hf_processor().feature_extractor
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# HF Transformers audio processor no longer accepts `audios` key.
# We pop `audios` and replace it with `audio` key to surpress
# the warning.
if 'audios' in mm_data:
mm_data['audio'] = mm_data.pop('audios')
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
mm_kwargs,
tok_kwargs,
)
if 'input_features' in processed_outputs:
# Avoid padding since we need the output of each item to be
# independent of other items for the cache to work correctly
unpadded_features = [
f[mask] for f, mask in zip(
processed_outputs["input_features"],
processed_outputs["input_features_mask"],
)
]
processed_outputs["input_features"] = unpadded_features
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
input_features=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
prompt_updates = []
# Handle image tokens
if "image" in mm_items:
image_token = hf_processor.image_token
def get_replacement_image(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
prompt_updates.append(
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_image,
))
# Handle audio tokens
if "audio" in mm_items:
audio_token = hf_processor.audio_token
def get_replacement_audio(item_idx: int):
return self.info.get_audio_repl(processor=hf_processor, )
prompt_updates.append(
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_audio,
))
return prompt_updates
def _apply_token_matches(
self,
prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
token_ids = super()._apply_token_matches(
prompt,
mm_matches,
mm_item_counts,
)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
# tokens, we have to combine them to be consistent with
# the output of the tokenizer
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
token_ids = replace_token_matches(
token_ids,
[newline_1, newline_2],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_1],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_2],
[newline_4],
)
return token_ids
def _find_mm_placeholders(
self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
def get_repl_toks(tok: int) -> list[int]:
if tok == newline_3:
return [newline_1, newline_2]
if tok == newline_4:
return [newline_2, newline_2]
return [tok]
repl_token_ids = list[int]()
repl_orig_idxs = list[int]()
for orig_idx, orig_tok in enumerate(new_token_ids):
repl_toks = get_repl_toks(orig_tok)
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
mm_item_counts)
return {
modality: [
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens,
is_embed=p.is_embed,
) for p in placeholders
]
for modality, placeholders in repls.items()
}
class Gemma3nMultimodalEmbedder(nn.Module):
"""Embeds token ids or soft tokens for multimodal content into language
model space."""
def __init__(
self,
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
text_config: Gemma3nTextConfig,
):
super().__init__()
self.multimodal_hidden_size = multimodal_config.hidden_size
self.eps = multimodal_config.rms_norm_eps
self.vocab_offset = multimodal_config.vocab_offset
self.vocab_size = multimodal_config.vocab_size
self.text_hidden_size = text_config.hidden_size
self.embedding = VocabParallelEmbedding(
self.vocab_size,
self.multimodal_hidden_size,
)
self.hard_embedding_norm = RMSNorm(
self.multimodal_hidden_size,
eps=self.eps,
)
self.soft_embedding_norm = RMSNorm(
self.multimodal_hidden_size,
eps=self.eps,
)
self.embedding_projection = RowParallelLinear(
self.multimodal_hidden_size,
self.text_hidden_size,
bias=False,
)
self.embedding_post_projection_norm = RMSNorm(
self.text_hidden_size,
eps=self.eps,
has_weight=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Embeds token ids or soft tokens for multimodal content into language model space.
Args:
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
`[vocab_offset, vocab_offset + vocab_size)`.
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
Returns:
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
""" # noqa: E501
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is not None:
emb_norm = self.soft_embedding_norm(inputs_embeds)
else:
hard_emb = self.embedding(input_ids - self.vocab_offset)
emb_norm = self.hard_embedding_norm(hard_emb)
emb_norm_proj, _ = self.embedding_projection(emb_norm)
return self.embedding_post_projection_norm(emb_norm_proj)
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
info=Gemma3nProcessingInfo,
dummy_inputs=Gemma3nDummyInputsBuilder)
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.embed_audio.": "embed_audio.",
"model.embed_vision.": "embed_vision.",
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.audio_tower.": "audio_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
"model": "language_model.model",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vocab_size = config.text_config.vocab_size
self.sliding_window = getattr(config.text_config,
"interleaved_sliding_window", None)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.audio_tower = AutoModel.from_config(config=config.audio_config)
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
config.text_config)
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
config.text_config)
self.language_model: nn.Module = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Gemma3nForCausalLM"],
)
self.language_model = cast(Gemma3nForCausalLM, self.language_model)
# NOTE (NickLucche) In order to be compatible with cudagraph, the
# buffer needs to be consistent, so we pre-allocate here.
self.per_layer_embeddings = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
self.config.text_config.num_hidden_layers,
self.config.text_config.hidden_size_per_layer_input,
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype)
@property
def dtype(self):
return next(self.parameters()).dtype
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
# TODO check if there are any
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
# TODO is this the case?
assert image_embeds is None, "Gemma3n does not support image_embeds."
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True)
pixel_values = pixel_values.contiguous()
return Gemma3nImagePixelInputs(
pixel_values=self._validate_pixel_values(pixel_values), )
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
input_features = kwargs.pop("input_features", None)
if input_features is None:
return None
input_features_mask = kwargs.pop("input_features_mask", None)
if input_features_mask is None:
return None
return Gemma3nAudioInputs(
input_features=input_features,
input_features_mask=input_features_mask,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values", "image_embeds"
) and "image" not in mm_input_by_modality:
mm_input_by_modality[
"image"] = self._parse_and_validate_image_input(**kwargs)
if input_key == "input_features" \
and "audio" not in mm_input_by_modality:
mm_input_by_modality[
"audio"] = self._parse_and_validate_audio_input(**kwargs)
return mm_input_by_modality
def _process_image_input(
self,
image_input: Gemma3nImageInputs,
) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
vision_outputs = self.vision_tower(pixel_values=pixel_values,
do_pooling=False,
return_dict=True).last_hidden_state
# TODO try to avoid copy here
# (batch, channels, height, width) to (batch, height * width, channels)
vision_outputs = vision_outputs.reshape(
vision_outputs.shape[0],
self.config.vision_config.hidden_size,
self.config.vision_soft_tokens_per_image,
).permute(0, 2, 1).contiguous()
# Normalize and embed the soft tokens into language model space.
vision_outputs *= self.config.vision_config.hidden_size**0.5
# Return a list of embeddings instead of a batched tensor
return self.embed_vision(inputs_embeds=vision_outputs).unbind(0)
def _process_audio_input(
self,
audio_input: Gemma3nAudioInputs,
) -> list[torch.Tensor]:
assert self.audio_tower is not None
input_features = audio_input["input_features"].squeeze(1)
input_features_mask = audio_input["input_features_mask"].squeeze(1)
audio_outputs, audio_mask = self.audio_tower(input_features,
~input_features_mask)
audio_features = self.embed_audio(inputs_embeds=audio_outputs)
# ruff: noqa
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
# text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
# the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
# TODO precompute and cache padding
audio_padding_toks = torch.tensor([[self.vocab_size - 1]],
dtype=torch.long,
device=audio_features.device)
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
audio_features = torch.where(audio_mask.unsqueeze(-1),
audio_padding_embs, audio_features)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim)
audio_features = torch.cat((audio_features, extra_padding_features),
dim=1)
# Return a list of embeddings instead of a batched tensor
return audio_features.unbind(0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs)
if mm_input_by_modality is None:
return []
multimodal_embeddings: list[torch.Tensor] = []
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings.extend(vision_embeddings)
if modality == "audio":
audio_embeddings = self._process_audio_input(multimodal_input)
multimodal_embeddings.extend(audio_embeddings)
return multimodal_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
if input_ids is not None:
per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
input_ids)
per_layer_inputs = per_layer_inputs.reshape(
-1, self.config.text_config.num_hidden_layers,
self.config.text_config.hidden_size_per_layer_input)
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
per_layer_inputs)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
# NOTE: this order of processing mm items is important
[self.config.image_token_id, self.config.audio_token_id])
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE (NickLucche) During profiling, `get_input_embeddings` is not
# called, hence we don't have input_ids to compute PLEs. We simply
# select a chunk of pre-allocated PLEs. During normal execution,
# `get_input_embeddings` is called before forward, hence this slice
# will contain PLEs computed from the actual input_ids.
per_layer_inputs = self.per_layer_embeddings[:inputs_embeds.shape[0]]
hidden_states = self.language_model.model(
input_ids,
positions,
per_layer_inputs=per_layer_inputs,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower")
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality == "image":
return "<image_soft_token>"
elif modality == "audio":
return "<audio_soft_token>"
else:
raise ValueError(f"Unsupported modality: {modality}")

View File

@ -69,8 +69,7 @@ _TEXT_GENERATION_MODELS = {
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
#TODO(ywang96): Support multimodal gemma3n
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
@ -205,6 +204,7 @@ _MULTIMODAL_MODELS = {
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501