mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Gemma3n (Text-only) (#20134)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: Roger Wang <hey@rogerw.me> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
@ -336,6 +336,7 @@ Specified using `--task generate`.
|
||||
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -392,6 +393,9 @@ Specified using `--task generate`.
|
||||
!!! note
|
||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||
|
||||
!!! note
|
||||
Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
|
||||
|
||||
### Pooling Models
|
||||
|
||||
See [this page](./pooling_models.md) for more information on how to use pooling models.
|
||||
|
@ -164,6 +164,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
|
||||
min_transformers_version="4.53"),
|
||||
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
|
||||
"Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"),
|
||||
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",
|
||||
|
@ -135,6 +135,57 @@ class MulAndSilu(CustomOp):
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@CustomOp.register("gelu_and_mul_sparse")
|
||||
class GeluAndMulSparse(CustomOp):
|
||||
"""An activation function for GeluAndMulSparse.
|
||||
This activation function is used in Gemma3n. It computes:
|
||||
up_proj = self.up_proj(x)
|
||||
gate_proj = self.gate_proj(x)
|
||||
gate_proj = self._gaussian_topk(gate_proj) # sparsity
|
||||
activations = self.act_fn(gate_proj) # gelu
|
||||
down_proj = self.down_proj(activations * up_proj)
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self, activation_sparsity: float, approximate: str = "none"):
|
||||
super().__init__()
|
||||
# Gelu.
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
|
||||
# Sparsity.
|
||||
if activation_sparsity == 0.0:
|
||||
raise ValueError(
|
||||
"activation_sparsity is 0.0. Please use GeluAndMul.")
|
||||
target_sparsity_tensor = torch.tensor(activation_sparsity,
|
||||
dtype=torch.float32)
|
||||
normal_dist = torch.distributions.normal.Normal(0, 1)
|
||||
self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)
|
||||
|
||||
def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Get % sparse percentile of the Gaussian distribution."""
|
||||
# NOTE(rob): for TP>1, we could all-gather to get the means/std.
|
||||
# But we do not do this because in expectation they are the same
|
||||
# and in practice the eval scores are good without gathering.
|
||||
mean = torch.mean(x, dim=-1, keepdim=True)
|
||||
std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
|
||||
cutoff_x = mean + std * self.std_multiplier
|
||||
return nn.functional.relu(x - cutoff_x)
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
out = self._gaussian_topk(x[..., :d])
|
||||
out = F.gelu(out, approximate=self.approximate)
|
||||
return out * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
@CustomOp.register("gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
"""An activation function for GeGLU.
|
||||
|
811
vllm/model_executor/models/gemma3n.py
Normal file
811
vllm/model_executor/models/gemma3n.py
Normal file
@ -0,0 +1,811 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
||||
GeluAndMul,
|
||||
GeluAndMulSparse)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
is_pp_missing_parameter, make_layers, maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Gemma3nAltUp(nn.Module):
|
||||
"""Alternating updates (Altup)
|
||||
The AltUp module wraps transformer layers. The `predict` step modifies the
|
||||
input to the transformer layer, and the `correct` step propagates the output
|
||||
of the transformer layer to the sparsely updated dimensions.
|
||||
See more in the research paper:
|
||||
https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
rms_norm_eps: float,
|
||||
altup_num_inputs: int,
|
||||
altup_coef_clip: float,
|
||||
altup_active_idx: int,
|
||||
prefix: str,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.altup_num_inputs = altup_num_inputs
|
||||
self.altup_active_idx = altup_active_idx
|
||||
self.altup_coef_clip = altup_coef_clip
|
||||
|
||||
self.correction_coefs = ReplicatedLinear(
|
||||
altup_num_inputs,
|
||||
altup_num_inputs,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.correction_coefs",
|
||||
return_bias=False,
|
||||
)
|
||||
self.prediction_coefs = ReplicatedLinear(
|
||||
altup_num_inputs,
|
||||
altup_num_inputs**2,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.prediction_coefs",
|
||||
return_bias=False,
|
||||
)
|
||||
self.modality_router = ReplicatedLinear(
|
||||
hidden_size,
|
||||
altup_num_inputs,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.modality_router",
|
||||
return_bias=False,
|
||||
)
|
||||
self.router_norm = RMSNorm(
|
||||
hidden_size=hidden_size,
|
||||
eps=rms_norm_eps,
|
||||
)
|
||||
self.router_input_scale = torch.tensor(
|
||||
hidden_size**-1.0, dtype=self.modality_router.weight.dtype)
|
||||
self.correct_output_scale = nn.Parameter(
|
||||
torch.zeros(hidden_size, dtype=torch.float32))
|
||||
|
||||
def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
|
||||
router_inputs = self.router_norm(x) * self.router_input_scale
|
||||
routed = self.modality_router(router_inputs)
|
||||
return torch.tanh(routed.float()).type_as(x)
|
||||
|
||||
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
||||
return (corrected.type_as(self.correct_output_scale) *
|
||||
self.correct_output_scale).type_as(corrected)
|
||||
|
||||
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# hidden: [altup_num_inputs, num_tokens, hidden_size]
|
||||
# modalities: [num_tokens, num_altup_inputs]
|
||||
# all_coefs: [num_tokens, num_altup_inputs ** 2]
|
||||
modalities = self._compute_router_modalities(
|
||||
hidden_states[self.altup_active_idx])
|
||||
all_coefs = self.prediction_coefs(modalities)
|
||||
|
||||
# Reshape and transpose the 2D matrix for the matmul.
|
||||
# all_coefs_T: [num_tokens, num_altup_inputs, num_altup_inputs]
|
||||
all_coefs_T = all_coefs.reshape(
|
||||
-1,
|
||||
self.altup_num_inputs,
|
||||
self.altup_num_inputs,
|
||||
).permute(0, 2, 1)
|
||||
|
||||
# hidden_states to [num_tokens, hidden_size, altup_num_inputs]
|
||||
predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs_T)
|
||||
# [altup_num_inputs, num_tokens, hidden_size]
|
||||
predictions = predictions.permute(2, 0, 1)
|
||||
predictions += hidden_states
|
||||
return predictions.contiguous()
|
||||
|
||||
def correct(self, predictions: torch.Tensor,
|
||||
activated: torch.Tensor) -> torch.Tensor:
|
||||
# predictions: [altup_num_inputs, num_tokens, hidden_size]
|
||||
# activated: [num_tokens, hidden_size]
|
||||
# modalities: [num_tokens, altup_num_inputs]
|
||||
modalities = self._compute_router_modalities(activated)
|
||||
# innovation: [num_tokens, altup_num_inputs]
|
||||
innovation = activated - predictions[self.altup_active_idx]
|
||||
# innovation: [altup_num_inputs, num_tokens, hidden_size]
|
||||
innovation = innovation.repeat(self.altup_num_inputs, 1, 1)
|
||||
|
||||
# Permute to [altup_num_inputs, num_tokens] as the last dim
|
||||
# is a scalar applied to each altup input and expand on
|
||||
# num_tokens dim for broadcastability over hidden_size.
|
||||
# all_coefs: [num_tokens, altup_num_inputs]
|
||||
all_coefs = self.correction_coefs(modalities) + 1.0
|
||||
# all_coefs: [altup_num_inputs, num_tokens, 1]
|
||||
all_coefs = all_coefs.T.unsqueeze(-1)
|
||||
|
||||
# Elementwise (broadcast over hidden_size).
|
||||
corrected = torch.mul(innovation, all_coefs)
|
||||
corrected += predictions
|
||||
|
||||
return corrected.contiguous()
|
||||
|
||||
|
||||
class Gemma3nLaurelBlock(nn.Module):
|
||||
"""Learned Augmented Residual Layer"""
|
||||
|
||||
def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float,
|
||||
prefix: str):
|
||||
super().__init__()
|
||||
|
||||
self.linear_left = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
laurel_rank,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.linear_left",
|
||||
return_bias=False,
|
||||
)
|
||||
self.linear_right = RowParallelLinear(laurel_rank,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.linear_right",
|
||||
return_bias=False)
|
||||
self.post_laurel_norm = RMSNorm(
|
||||
hidden_size=hidden_size,
|
||||
eps=rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
laurel_x = self.linear_left(x)
|
||||
laurel_x = self.linear_right(laurel_x)
|
||||
normed_laurel_x = self.post_laurel_norm(laurel_x)
|
||||
return x + normed_laurel_x
|
||||
|
||||
|
||||
class Gemma3nMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_activation: str,
|
||||
activation_sparsity: float = 0.0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_activation != "gelu_pytorch_tanh":
|
||||
raise ValueError(
|
||||
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
|
||||
"function. Please set `hidden_act` and `hidden_activation` to "
|
||||
"`gelu_pytorch_tanh`.")
|
||||
|
||||
self.act_fn = GeluAndMulSparse(
|
||||
activation_sparsity=activation_sparsity,
|
||||
approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul(
|
||||
approximate="tanh")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Gemma3nAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Gemma3nTextConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.q_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=config.rms_norm_eps)
|
||||
self.k_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=config.rms_norm_eps)
|
||||
self.v_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=config.rms_norm_eps,
|
||||
has_weight=False)
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
if config.layer_types[layer_idx] == "sliding_attention":
|
||||
self.sliding_window = config.sliding_window
|
||||
rope_theta = config.rope_local_base_freq
|
||||
rope_scaling = {"rope_type": "default"}
|
||||
else:
|
||||
self.sliding_window = None
|
||||
rope_theta = config.rope_theta
|
||||
rope_scaling = config.rope_scaling
|
||||
|
||||
first_kv_shared_layer_idx = (config.num_hidden_layers -
|
||||
config.num_kv_shared_layers)
|
||||
self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx
|
||||
|
||||
if self.is_kv_shared:
|
||||
# Last full attention layer is 1 before sharing
|
||||
# Last sliding attention layer is 2 before sharing
|
||||
offset = 2 if self.sliding_window is not None else 1
|
||||
kv_shared_layer_index = first_kv_shared_layer_idx - offset
|
||||
kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
|
||||
else:
|
||||
kv_sharing_target_layer_name = None
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
is_neox_style=True,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=1.0,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=self.sliding_window,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
q = q.unflatten(-1, (self.num_heads, self.head_dim))
|
||||
q = self.q_norm(q)
|
||||
q = q.flatten(-2, -1)
|
||||
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
|
||||
k = self.k_norm(k)
|
||||
k = k.flatten(-2, -1)
|
||||
v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
|
||||
v = self.v_norm(v)
|
||||
v = v.flatten(-2, -1)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Gemma3nDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Gemma3nTextConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.altup_active_idx = config.altup_active_idx
|
||||
assert config.altup_correct_scale
|
||||
|
||||
self.altup = Gemma3nAltUp(
|
||||
hidden_size=config.hidden_size,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
altup_num_inputs=config.altup_num_inputs,
|
||||
altup_coef_clip=config.altup_coef_clip,
|
||||
altup_active_idx=config.altup_active_idx,
|
||||
prefix=f"{prefix}.altup",
|
||||
)
|
||||
self.self_attn = Gemma3nAttention(
|
||||
config=config,
|
||||
hidden_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = Gemma3nMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
# NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501
|
||||
intermediate_size=config.intermediate_size[extract_layer_index(
|
||||
prefix)],
|
||||
hidden_activation=config.hidden_activation,
|
||||
quant_config=quant_config,
|
||||
activation_sparsity=config.activation_sparsity_pattern[
|
||||
extract_layer_index(prefix)],
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.laurel = Gemma3nLaurelBlock(
|
||||
hidden_size=config.hidden_size,
|
||||
laurel_rank=config.laurel_rank,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
prefix=f"{prefix}.laurel",
|
||||
)
|
||||
|
||||
# NOTE(rob): should be ColumnParallelLinear and RowParallelLinear
|
||||
# But, we need to add per_layer_input_gate(x) to per_layer_input.
|
||||
# per_layer_input cannot be sharded, so we replicate for now.
|
||||
self.per_layer_input_gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size_per_layer_input,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.per_layer_input_gate",
|
||||
return_bias=False,
|
||||
)
|
||||
self.per_layer_projection = ReplicatedLinear(
|
||||
config.hidden_size_per_layer_input,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.per_layer_projection",
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
# LayerNorms.
|
||||
self.input_layernorm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.pre_feedforward_layernorm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.post_feedforward_layernorm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.post_per_layer_input_norm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.act_fn = _ACTIVATION_REGISTRY[config.hidden_activation]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
per_layer_input: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# ActUp (predict).
|
||||
predictions = self.altup.predict(hidden_states)
|
||||
active_prediction = predictions[self.altup_active_idx]
|
||||
active_prediction_normed = self.input_layernorm(active_prediction)
|
||||
laurel_output = self.laurel(active_prediction_normed)
|
||||
|
||||
# Attention.
|
||||
attn = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=active_prediction_normed,
|
||||
**kwargs,
|
||||
)
|
||||
attn = self.post_attention_layernorm(attn)
|
||||
attn_gated = attn + active_prediction
|
||||
attn_laurel = (attn_gated + laurel_output) / torch.sqrt(
|
||||
torch.tensor(2.0))
|
||||
|
||||
# MLP.
|
||||
attn_norm = self.pre_feedforward_layernorm(attn_laurel)
|
||||
attn_ffw = self.mlp(attn_norm)
|
||||
attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
|
||||
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
|
||||
|
||||
# ActUp (connect).
|
||||
corrected_predictions = self.altup.correct(predictions,
|
||||
attn_ffw_laurel_gated)
|
||||
first_prediction = corrected_predictions[self.altup_active_idx]
|
||||
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
||||
|
||||
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
|
||||
first_prediction = self.per_layer_input_gate(first_prediction)
|
||||
first_prediction = self.act_fn(first_prediction)
|
||||
first_prediction = torch.mul(first_prediction, per_layer_input)
|
||||
|
||||
# per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
||||
first_prediction = self.per_layer_projection(first_prediction)
|
||||
first_prediction = self.post_per_layer_input_norm(first_prediction)
|
||||
corrected_predictions[1:] += first_prediction
|
||||
|
||||
return corrected_predictions
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Gemma3nTextModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config.text_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
self.embed_scale = torch.tensor(
|
||||
config.hidden_size**0.5,
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
)
|
||||
self.embed_tokens_per_layer = VocabParallelEmbedding(
|
||||
config.vocab_size_per_layer_input,
|
||||
config.num_hidden_layers * config.hidden_size_per_layer_input,
|
||||
prefix=f"{prefix}.per_layer_embed_tokens",
|
||||
)
|
||||
self.embed_scale_per_layer = torch.tensor(
|
||||
config.hidden_size_per_layer_input**0.5,
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
)
|
||||
self.per_layer_model_projection = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.num_hidden_layers * config.hidden_size_per_layer_input,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.per_layer_model_projection",
|
||||
)
|
||||
self.per_layer_projection_norm = RMSNorm(
|
||||
hidden_size=config.hidden_size_per_layer_input,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to(
|
||||
self.embed_tokens.weight.dtype)
|
||||
self.per_layer_projection_scale = torch.tensor(
|
||||
config.hidden_size**0.5,
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
)
|
||||
self.altup_projections = nn.ModuleList([
|
||||
ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.{idx-1}.altup_projections",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
self.altup_unembed_projections = nn.ModuleList([
|
||||
ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
|
||||
# Transformer blocks.
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Gemma3nDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.eps = torch.tensor(torch.finfo().min)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
def get_per_layer_input_embeddings(
|
||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
# Deal with the fact that vocab_size_per_layer_input < vocab_size
|
||||
# which causes us to have some out of vocab tokens by setting
|
||||
# those token ids to 0. This matches the HF implementation.
|
||||
per_layer_inputs_mask = torch.logical_and(
|
||||
input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
|
||||
per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
|
||||
torch.zeros_like(input_ids))
|
||||
return self.embed_tokens_per_layer(
|
||||
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states_0 = inputs_embeds
|
||||
else:
|
||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||
|
||||
# Per layer inputs.
|
||||
if input_ids is None:
|
||||
raise ValueError("Passing None for input ids is not supported.")
|
||||
per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
|
||||
per_layer_inputs = per_layer_inputs.reshape(
|
||||
-1, self.config.num_hidden_layers,
|
||||
self.config.hidden_size_per_layer_input)
|
||||
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*hidden_states_0.shape[:-1],
|
||||
self.config.num_hidden_layers,
|
||||
self.config.hidden_size_per_layer_input,
|
||||
)
|
||||
per_layer_projection = self.per_layer_projection_norm(
|
||||
per_layer_projection)
|
||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||
per_layer_inputs *= self.per_layer_input_scale
|
||||
|
||||
# Altup embed.
|
||||
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
||||
keepdim=True)**0.5
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
|
||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
# Transformer blocks.
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
# [altup_num_inputs, num_tokens, hidden_size]
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Altup unembed.
|
||||
target_magnitude = torch.mean(hidden_states[0]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
hidden_states[i] = self.altup_unembed_projections[i - 1](
|
||||
hidden_states[i])
|
||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
|
||||
hidden_states = torch.mean(hidden_states, dim=0)
|
||||
|
||||
return self.norm(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
# Avoid spurious match with ".up_proj".
|
||||
if "altup_projections" in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Gemma3nModel(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.language_model = Gemma3nTextModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "language_model"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.language_model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs)
|
||||
|
||||
|
||||
class Gemma3nForConditionalGeneration(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
lora_config = vllm_config.lora_config
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Gemma3nModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.text_config.vocab_size,
|
||||
soft_cap=config.text_config.final_logit_softcapping)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: Optional[SamplingMetadata],
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.model.language_model.embed_tokens,
|
||||
hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self,
|
||||
skip_substrs=([
|
||||
"embed_audio.", "embed_vision.",
|
||||
"audio_tower.", "vision_tower."
|
||||
]))
|
||||
return loader.load_weights(weights)
|
@ -58,6 +58,8 @@ _TEXT_GENERATION_MODELS = {
|
||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||
#TODO(ywang96): Support multimodal gemma3n
|
||||
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||
|
Reference in New Issue
Block a user