[Model] Gemma3n MM (#20495)
Signed-off-by: ShriKode <shrikode@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Roger Wang <hey@rogerw.me> Co-authored-by: ShriKode <shrikode@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
@ -349,7 +349,7 @@ th {
|
||||
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -412,9 +412,6 @@ th {
|
||||
!!! note
|
||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||
|
||||
!!! note
|
||||
Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
|
||||
|
||||
### Pooling Models
|
||||
|
||||
See [this page](./pooling_models.md) for more information on how to use pooling models.
|
||||
@ -608,6 +605,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -677,6 +675,15 @@ Some models are supported only via the [Transformers backend](#transformers). Th
|
||||
|
||||
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
|
||||
|
||||
!!! note
|
||||
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
|
||||
MobileNet-v5 vision backbone.
|
||||
|
||||
Performance is not yet fully optimized mainly due to:
|
||||
|
||||
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
|
||||
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
|
||||
|
||||
!!! note
|
||||
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
|
||||
|
||||
|
@ -96,6 +96,25 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Gemma3N
|
||||
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "google/gemma-3n-E2B-it"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_batched_tokens=2048,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
enforce_eager=True,
|
||||
)
|
||||
prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Granite Speech
|
||||
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||
# NOTE - the setting in this example are somehat different than what is
|
||||
@ -331,6 +350,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
model_example_map = {
|
||||
"voxtral": run_voxtral,
|
||||
"gemma3n": run_gemma3n,
|
||||
"granite_speech": run_granite_speech,
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
|
@ -211,7 +211,33 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Gemma3N
|
||||
def run_gemma3n(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "google/gemma-3n-E2B-it"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
(
|
||||
"<start_of_turn>user\n"
|
||||
f"<image_soft_token>{question}<end_of_turn>\n"
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
@ -1395,6 +1421,7 @@ model_example_map = {
|
||||
"florence2": run_florence2,
|
||||
"fuyu": run_fuyu,
|
||||
"gemma3": run_gemma3,
|
||||
"gemma3n": run_gemma3n,
|
||||
"glm4v": run_glm4v,
|
||||
"glm4_1v": run_glm4_1v,
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
|
@ -21,7 +21,7 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli
|
||||
sentence-transformers # required for embedding tests
|
||||
soundfile # required for audio tests
|
||||
jiwer # required for audio tests
|
||||
timm # required for internvl test
|
||||
timm >=1.0.17 # required for internvl and gemma3n-mm test
|
||||
torch==2.7.1
|
||||
torchaudio==2.7.1
|
||||
torchvision==0.22.1
|
||||
|
@ -1051,7 +1051,7 @@ tiktoken==0.7.0
|
||||
# via
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
timm==1.0.15
|
||||
timm==1.0.17
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# open-clip-torch
|
||||
|
@ -271,6 +271,7 @@ def _test_processing_correctness_one(
|
||||
"microsoft/Florence-2-base",
|
||||
"adept/fuyu-8b",
|
||||
"google/gemma-3-4b-it",
|
||||
"google/gemma-3n-E2B-it",
|
||||
"zai-org/glm-4v-9b",
|
||||
"zai-org/GLM-4.1V-9B-Thinking",
|
||||
"ibm-granite/granite-speech-3.3-2b",
|
||||
@ -315,7 +316,7 @@ def _test_processing_correctness_one(
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
"openai/whisper-large-v3",
|
||||
"omni-research/Tarsier-7b",
|
||||
"omni-research/Tarsier2-Recap-7b"
|
||||
"omni-research/Tarsier2-Recap-7b",
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@ -327,6 +328,8 @@ def test_processing_correctness(
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
if model_id == "google/gemma-3n-E2B-it":
|
||||
pytest.skip("Skipping gemma-3n-E2B-it due to transformers #39911 bug.")
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
hit_rate=hit_rate,
|
||||
|
@ -186,7 +186,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
|
||||
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it",
|
||||
min_transformers_version="4.53"),
|
||||
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
|
||||
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
|
||||
@ -391,6 +391,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
|
||||
min_transformers_version="4.53"),
|
||||
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
|
||||
"GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b",
|
||||
trust_remote_code=True,
|
||||
|
61
tests/test_test.py
Normal file
61
tests/test_test.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, envs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
pytest.skip(
|
||||
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
||||
# TODO TPU will appear busy if we fan-out test params here
|
||||
@pytest.mark.parametrize("n_prompts", [1])
|
||||
def test_logprobs(model_name: str, n_prompts: int):
|
||||
"""
|
||||
Request top logprobs with different sampling settings and check
|
||||
that results contains the requested number, ordered ascendingly.
|
||||
"""
|
||||
|
||||
def check_num_logprobs(logprobs, expected_num: int):
|
||||
for step in logprobs:
|
||||
prev_logp = 1.0
|
||||
# order by rank
|
||||
sorted_step = dict(
|
||||
sorted(step.items(), key=lambda item: item[1].rank))
|
||||
|
||||
if len(step) != expected_num:
|
||||
print("watch out", sorted_step)
|
||||
|
||||
# check results are ordered by prob value
|
||||
# assert len(step) == expected_num
|
||||
for rankno, (tid, logp) in enumerate(sorted_step.items()):
|
||||
assert logp.logprob <= prev_logp
|
||||
prev_logp = logp.logprob
|
||||
assert logp.rank == rankno + 1
|
||||
|
||||
llm = LLM(model_name,
|
||||
enforce_eager=False,
|
||||
max_num_seqs=1,
|
||||
max_model_len=128,
|
||||
max_num_batched_tokens=128)
|
||||
prompts = [
|
||||
"Write a short story about a robot that dreams for the first time."
|
||||
] * n_prompts
|
||||
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
|
||||
logprobs=4)
|
||||
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
|
||||
logprobs=4)
|
||||
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
|
||||
logprobs=4, top_k=12, top_p=0.5)
|
||||
|
||||
for sp in [greedy_sampling_params, regular_sampling_params, \
|
||||
topkp_sampling_params]:
|
||||
output = llm.generate(prompts, sp)
|
||||
for o in output:
|
||||
check_num_logprobs(o.outputs[0].logprobs, 4)
|
@ -331,14 +331,15 @@ class Gemma3nAttention(nn.Module):
|
||||
config.num_kv_shared_layers)
|
||||
self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx
|
||||
|
||||
kv_sharing_target_layer_name = None
|
||||
if self.is_kv_shared:
|
||||
# Last full attention layer is 1 before sharing
|
||||
# Last sliding attention layer is 2 before sharing
|
||||
offset = 2 if self.sliding_window is not None else 1
|
||||
kv_shared_layer_index = first_kv_shared_layer_idx - offset
|
||||
kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
|
||||
else:
|
||||
kv_sharing_target_layer_name = None
|
||||
if kv_shared_layer_index >= 0:
|
||||
# Only the greater layer is required to specify sharing.
|
||||
kv_sharing_target_layer_name = f"language_model.model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -396,6 +397,7 @@ class Gemma3nDecoderLayer(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert isinstance(config, Gemma3nTextConfig)
|
||||
self.altup_active_idx = config.altup_active_idx
|
||||
assert config.altup_correct_scale
|
||||
|
||||
@ -537,7 +539,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config.text_config
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
@ -553,6 +555,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
config.hidden_size**0.5,
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
)
|
||||
# Additional per-layer embeddings (PLE)
|
||||
self.embed_tokens_per_layer = VocabParallelEmbedding(
|
||||
config.vocab_size_per_layer_input,
|
||||
config.num_hidden_layers * config.hidden_size_per_layer_input,
|
||||
@ -636,6 +639,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -644,13 +649,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
else:
|
||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||
|
||||
# Per layer inputs.
|
||||
if input_ids is None:
|
||||
raise ValueError("Passing None for input ids is not supported.")
|
||||
per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
|
||||
per_layer_inputs = per_layer_inputs.reshape(
|
||||
-1, self.config.num_hidden_layers,
|
||||
self.config.hidden_size_per_layer_input)
|
||||
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*hidden_states_0.shape[:-1],
|
||||
@ -659,8 +657,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
)
|
||||
per_layer_projection = self.per_layer_projection_norm(
|
||||
per_layer_projection)
|
||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||
per_layer_inputs *= self.per_layer_input_scale
|
||||
|
||||
if per_layer_inputs is not None:
|
||||
# Profiling run does not compute per_layer_inputs
|
||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||
per_layer_inputs *= self.per_layer_input_scale
|
||||
else:
|
||||
per_layer_inputs = per_layer_projection
|
||||
|
||||
# Altup embed.
|
||||
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
||||
@ -760,29 +763,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Gemma3nModel(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.language_model = Gemma3nTextModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "language_model"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.language_model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs)
|
||||
|
||||
|
||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
|
||||
class Gemma3nForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -802,25 +783,33 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.model = Gemma3nModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = Gemma3nTextModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.text_config.vocab_size,
|
||||
soft_cap=config.text_config.final_logit_softcapping)
|
||||
config.vocab_size, soft_cap=config.final_logit_softcapping)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.language_model.get_input_embeddings(input_ids)
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
*,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, **kwargs)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -828,8 +817,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: Optional[SamplingMetadata],
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.model.language_model.embed_tokens,
|
||||
hidden_states, sampling_metadata)
|
||||
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
700
vllm/model_executor/models/gemma3n_mm.py
Normal file
700
vllm/model_executor/models/gemma3n_mm.py
Normal file
@ -0,0 +1,700 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Optional, TypedDict, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BatchFeature
|
||||
from transformers.models.gemma3n import (Gemma3nAudioConfig,
|
||||
Gemma3nAudioFeatureExtractor,
|
||||
Gemma3nConfig, Gemma3nProcessor,
|
||||
Gemma3nTextConfig,
|
||||
Gemma3nVisionConfig)
|
||||
from transformers.models.siglip import SiglipImageProcessorFast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, BoundPromptUpdate,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptTargetMatch,
|
||||
PromptUpdate, PromptUpdateDetails,
|
||||
find_mm_placeholders,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# This should be based on model config but we hardcode them for now.
|
||||
TOKENS_PER_IMAGE = 256
|
||||
TOKENS_PER_AUDIO = 188
|
||||
|
||||
|
||||
class Gemma3nImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class Gemma3nAudioInputs(TypedDict):
|
||||
input_features: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
||||
input_features_mask: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length)`"""
|
||||
|
||||
|
||||
Gemma3nImageInputs = Gemma3nImagePixelInputs
|
||||
|
||||
|
||||
class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Gemma3nConfig)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "audio": None}
|
||||
|
||||
def get_max_tokens_per_item(
|
||||
self, seq_len: int,
|
||||
mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]:
|
||||
|
||||
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Gemma3nProcessor],
|
||||
) -> str:
|
||||
"""
|
||||
Get the replacement text for image tokens.
|
||||
|
||||
For Gemma3n, this should return the full_image_sequence which includes
|
||||
BOI token, repeated image tokens, and EOI token.
|
||||
"""
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
processor.full_image_sequence, processor.image_token_id)
|
||||
|
||||
def get_audio_repl(
|
||||
self,
|
||||
*,
|
||||
processor: Optional[Gemma3nProcessor],
|
||||
) -> str:
|
||||
"""
|
||||
Get the replacement text for audio tokens.
|
||||
|
||||
For Gemma3n, this should return the full_audio_sequence which includes
|
||||
BOA token, repeated audio tokens, and EOA token.
|
||||
"""
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
# Return the full audio sequence as defined by the processor
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
processor.full_audio_sequence, processor.audio_token_id)
|
||||
|
||||
|
||||
class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
processor = self.info.get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
audio_token = processor.audio_token
|
||||
|
||||
return image_token * num_images + audio_token * num_audios
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
processor = self.info.get_hf_processor()
|
||||
audio_feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501
|
||||
audio_len = audio_feature_extractor.fft_length
|
||||
image_processor: SiglipImageProcessorFast = processor.image_processor
|
||||
img_width = image_processor.size.get("width", 224)
|
||||
img_height = image_processor.size.get("height", 224)
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=img_width,
|
||||
height=img_height,
|
||||
num_images=num_images),
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
|
||||
class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_hf_processor().feature_extractor
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
|
||||
# HF Transformers audio processor no longer accepts `audios` key.
|
||||
# We pop `audios` and replace it with `audio` key to surpress
|
||||
# the warning.
|
||||
if 'audios' in mm_data:
|
||||
mm_data['audio'] = mm_data.pop('audios')
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt,
|
||||
mm_data,
|
||||
mm_kwargs,
|
||||
tok_kwargs,
|
||||
)
|
||||
if 'input_features' in processed_outputs:
|
||||
# Avoid padding since we need the output of each item to be
|
||||
# independent of other items for the cache to work correctly
|
||||
unpadded_features = [
|
||||
f[mask] for f, mask in zip(
|
||||
processed_outputs["input_features"],
|
||||
processed_outputs["input_features_mask"],
|
||||
)
|
||||
]
|
||||
processed_outputs["input_features"] = unpadded_features
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_mask=MultiModalFieldConfig.batched("audio"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
prompt_updates = []
|
||||
|
||||
# Handle image tokens
|
||||
if "image" in mm_items:
|
||||
image_token = hf_processor.image_token
|
||||
|
||||
def get_replacement_image(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
return self.info.get_image_repl(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
prompt_updates.append(
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=image_token,
|
||||
replacement=get_replacement_image,
|
||||
))
|
||||
|
||||
# Handle audio tokens
|
||||
if "audio" in mm_items:
|
||||
audio_token = hf_processor.audio_token
|
||||
|
||||
def get_replacement_audio(item_idx: int):
|
||||
return self.info.get_audio_repl(processor=hf_processor, )
|
||||
|
||||
prompt_updates.append(
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=audio_token,
|
||||
replacement=get_replacement_audio,
|
||||
))
|
||||
|
||||
return prompt_updates
|
||||
|
||||
def _apply_token_matches(
|
||||
self,
|
||||
prompt: list[int],
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[int]:
|
||||
token_ids = super()._apply_token_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
# "\n\n\n" and "\n\n\n\n" are single tokens
|
||||
# Since our replacement can insert "\n\n" next to "\n"
|
||||
# tokens, we have to combine them to be consistent with
|
||||
# the output of the tokenizer
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
newline_1 = vocab["\n"]
|
||||
newline_2 = vocab["\n\n"]
|
||||
newline_3 = vocab["\n\n\n"]
|
||||
newline_4 = vocab["\n\n\n\n"]
|
||||
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
[newline_1, newline_2],
|
||||
[newline_3],
|
||||
)
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
[newline_2, newline_1],
|
||||
[newline_3],
|
||||
)
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
[newline_2, newline_2],
|
||||
[newline_4],
|
||||
)
|
||||
|
||||
return token_ids
|
||||
|
||||
def _find_mm_placeholders(
|
||||
self,
|
||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||
new_token_ids: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
newline_1 = vocab["\n"]
|
||||
newline_2 = vocab["\n\n"]
|
||||
newline_3 = vocab["\n\n\n"]
|
||||
newline_4 = vocab["\n\n\n\n"]
|
||||
|
||||
def get_repl_toks(tok: int) -> list[int]:
|
||||
if tok == newline_3:
|
||||
return [newline_1, newline_2]
|
||||
if tok == newline_4:
|
||||
return [newline_2, newline_2]
|
||||
|
||||
return [tok]
|
||||
|
||||
repl_token_ids = list[int]()
|
||||
repl_orig_idxs = list[int]()
|
||||
for orig_idx, orig_tok in enumerate(new_token_ids):
|
||||
repl_toks = get_repl_toks(orig_tok)
|
||||
repl_token_ids.extend(repl_toks)
|
||||
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||
|
||||
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
|
||||
mm_item_counts)
|
||||
|
||||
return {
|
||||
modality: [
|
||||
PlaceholderFeaturesInfo(
|
||||
modality=p.modality,
|
||||
item_idx=p.item_idx,
|
||||
start_idx=repl_orig_idxs[p.start_idx],
|
||||
tokens=p.tokens,
|
||||
is_embed=p.is_embed,
|
||||
) for p in placeholders
|
||||
]
|
||||
for modality, placeholders in repls.items()
|
||||
}
|
||||
|
||||
|
||||
class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
"""Embeds token ids or soft tokens for multimodal content into language
|
||||
model space."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
|
||||
text_config: Gemma3nTextConfig,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.multimodal_hidden_size = multimodal_config.hidden_size
|
||||
self.eps = multimodal_config.rms_norm_eps
|
||||
self.vocab_offset = multimodal_config.vocab_offset
|
||||
self.vocab_size = multimodal_config.vocab_size
|
||||
self.text_hidden_size = text_config.hidden_size
|
||||
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
self.multimodal_hidden_size,
|
||||
)
|
||||
|
||||
self.hard_embedding_norm = RMSNorm(
|
||||
self.multimodal_hidden_size,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
self.soft_embedding_norm = RMSNorm(
|
||||
self.multimodal_hidden_size,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
self.embedding_projection = RowParallelLinear(
|
||||
self.multimodal_hidden_size,
|
||||
self.text_hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.embedding_post_projection_norm = RMSNorm(
|
||||
self.text_hidden_size,
|
||||
eps=self.eps,
|
||||
has_weight=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds token ids or soft tokens for multimodal content into language model space.
|
||||
|
||||
Args:
|
||||
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
|
||||
`[vocab_offset, vocab_offset + vocab_size)`.
|
||||
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
|
||||
|
||||
Returns:
|
||||
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
|
||||
""" # noqa: E501
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is not None:
|
||||
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
||||
else:
|
||||
hard_emb = self.embedding(input_ids - self.vocab_offset)
|
||||
emb_norm = self.hard_embedding_norm(hard_emb)
|
||||
|
||||
emb_norm_proj, _ = self.embedding_projection(emb_norm)
|
||||
return self.embedding_post_projection_norm(emb_norm_proj)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
|
||||
info=Gemma3nProcessingInfo,
|
||||
dummy_inputs=Gemma3nDummyInputsBuilder)
|
||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
"model.embed_audio.": "embed_audio.",
|
||||
"model.embed_vision.": "embed_vision.",
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.audio_tower.": "audio_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model": "language_model.model",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
self.sliding_window = getattr(config.text_config,
|
||||
"interleaved_sliding_window", None)
|
||||
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.audio_tower = AutoModel.from_config(config=config.audio_config)
|
||||
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
|
||||
config.text_config)
|
||||
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
|
||||
config.text_config)
|
||||
|
||||
self.language_model: nn.Module = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
architectures=["Gemma3nForCausalLM"],
|
||||
)
|
||||
self.language_model = cast(Gemma3nForCausalLM, self.language_model)
|
||||
# NOTE (NickLucche) In order to be compatible with cudagraph, the
|
||||
# buffer needs to be consistent, so we pre-allocate here.
|
||||
self.per_layer_embeddings = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
self.config.text_config.num_hidden_layers,
|
||||
self.config.text_config.hidden_size_per_layer_input,
|
||||
device=self.language_model.model.embed_tokens.weight.device,
|
||||
dtype=self.language_model.model.embed_tokens.weight.dtype)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
# TODO check if there are any
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
# TODO is this the case?
|
||||
assert image_embeds is None, "Gemma3n does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
pixel_values = pixel_values.contiguous()
|
||||
|
||||
return Gemma3nImagePixelInputs(
|
||||
pixel_values=self._validate_pixel_values(pixel_values), )
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
if input_features is None:
|
||||
return None
|
||||
|
||||
input_features_mask = kwargs.pop("input_features_mask", None)
|
||||
if input_features_mask is None:
|
||||
return None
|
||||
|
||||
return Gemma3nAudioInputs(
|
||||
input_features=input_features,
|
||||
input_features_mask=input_features_mask,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
mm_input_by_modality = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values", "image_embeds"
|
||||
) and "image" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key == "input_features" \
|
||||
and "audio" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"audio"] = self._parse_and_validate_audio_input(**kwargs)
|
||||
return mm_input_by_modality
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Gemma3nImageInputs,
|
||||
) -> list[torch.Tensor]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = image_input["pixel_values"]
|
||||
vision_outputs = self.vision_tower(pixel_values=pixel_values,
|
||||
do_pooling=False,
|
||||
return_dict=True).last_hidden_state
|
||||
# TODO try to avoid copy here
|
||||
# (batch, channels, height, width) to (batch, height * width, channels)
|
||||
vision_outputs = vision_outputs.reshape(
|
||||
vision_outputs.shape[0],
|
||||
self.config.vision_config.hidden_size,
|
||||
self.config.vision_soft_tokens_per_image,
|
||||
).permute(0, 2, 1).contiguous()
|
||||
# Normalize and embed the soft tokens into language model space.
|
||||
vision_outputs *= self.config.vision_config.hidden_size**0.5
|
||||
# Return a list of embeddings instead of a batched tensor
|
||||
return self.embed_vision(inputs_embeds=vision_outputs).unbind(0)
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: Gemma3nAudioInputs,
|
||||
) -> list[torch.Tensor]:
|
||||
assert self.audio_tower is not None
|
||||
input_features = audio_input["input_features"].squeeze(1)
|
||||
input_features_mask = audio_input["input_features_mask"].squeeze(1)
|
||||
audio_outputs, audio_mask = self.audio_tower(input_features,
|
||||
~input_features_mask)
|
||||
audio_features = self.embed_audio(inputs_embeds=audio_outputs)
|
||||
|
||||
# ruff: noqa
|
||||
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
|
||||
# text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
|
||||
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
|
||||
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
|
||||
# the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
|
||||
# TODO precompute and cache padding
|
||||
audio_padding_toks = torch.tensor([[self.vocab_size - 1]],
|
||||
dtype=torch.long,
|
||||
device=audio_features.device)
|
||||
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
|
||||
audio_features = torch.where(audio_mask.unsqueeze(-1),
|
||||
audio_padding_embs, audio_features)
|
||||
|
||||
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
|
||||
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501
|
||||
extra_padding_features = audio_padding_embs.expand(
|
||||
audio_batch_size, extra_padding_tokens, audio_embed_dim)
|
||||
|
||||
audio_features = torch.cat((audio_features, extra_padding_features),
|
||||
dim=1)
|
||||
# Return a list of embeddings instead of a batched tensor
|
||||
return audio_features.unbind(0)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
||||
**kwargs)
|
||||
if mm_input_by_modality is None:
|
||||
return []
|
||||
|
||||
multimodal_embeddings: list[torch.Tensor] = []
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in mm_input_by_modality:
|
||||
multimodal_input = mm_input_by_modality[modality]
|
||||
if modality == "image":
|
||||
vision_embeddings = self._process_image_input(multimodal_input)
|
||||
multimodal_embeddings.extend(vision_embeddings)
|
||||
if modality == "audio":
|
||||
audio_embeddings = self._process_audio_input(multimodal_input)
|
||||
multimodal_embeddings.extend(audio_embeddings)
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
||||
# them here, as the model forward has only access to the input_embeds.
|
||||
if input_ids is not None:
|
||||
per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
|
||||
input_ids)
|
||||
per_layer_inputs = per_layer_inputs.reshape(
|
||||
-1, self.config.text_config.num_hidden_layers,
|
||||
self.config.text_config.hidden_size_per_layer_input)
|
||||
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
|
||||
per_layer_inputs)
|
||||
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
# NOTE: this order of processing mm items is important
|
||||
[self.config.image_token_id, self.config.audio_token_id])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE (NickLucche) During profiling, `get_input_embeddings` is not
|
||||
# called, hence we don't have input_ids to compute PLEs. We simply
|
||||
# select a chunk of pre-allocated PLEs. During normal execution,
|
||||
# `get_input_embeddings` is called before forward, hence this slice
|
||||
# will contain PLEs computed from the actual input_ids.
|
||||
per_layer_inputs = self.per_layer_embeddings[:inputs_embeds.shape[0]]
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids,
|
||||
positions,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector",
|
||||
tower_model="vision_tower")
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality == "image":
|
||||
return "<image_soft_token>"
|
||||
elif modality == "audio":
|
||||
return "<audio_soft_token>"
|
||||
else:
|
||||
raise ValueError(f"Unsupported modality: {modality}")
|
@ -69,8 +69,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||
#TODO(ywang96): Support multimodal gemma3n
|
||||
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
|
||||
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
|
||||
@ -205,6 +204,7 @@ _MULTIMODAL_MODELS = {
|
||||
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
|
||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
|
||||
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501
|
||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
|
||||
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
|
||||
|
Reference in New Issue
Block a user