[Model] [Quantization] Support quantization for Gemma3n (#21974)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2025-08-01 01:45:15 -04:00
committed by GitHub
parent e1a7fe4af5
commit 0f46a780d4

View File

@ -46,6 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter, make_layers, maybe_prefix)
@ -68,6 +69,7 @@ class Gemma3nAltUp(nn.Module):
altup_num_inputs: int,
altup_coef_clip: float,
altup_active_idx: int,
quant_config: QuantizationConfig,
prefix: str,
):
super().__init__()
@ -80,6 +82,7 @@ class Gemma3nAltUp(nn.Module):
altup_num_inputs,
altup_num_inputs,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.correction_coefs",
return_bias=False,
)
@ -87,6 +90,7 @@ class Gemma3nAltUp(nn.Module):
altup_num_inputs,
altup_num_inputs**2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.prediction_coefs",
return_bias=False,
)
@ -94,6 +98,7 @@ class Gemma3nAltUp(nn.Module):
hidden_size,
altup_num_inputs,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.modality_router",
return_bias=False,
)
@ -400,6 +405,7 @@ class Gemma3nDecoderLayer(nn.Module):
altup_num_inputs=config.altup_num_inputs,
altup_coef_clip=config.altup_coef_clip,
altup_active_idx=config.altup_active_idx,
quant_config=quant_config,
prefix=f"{prefix}.altup",
)
self.self_attn = Gemma3nAttention(
@ -527,7 +533,7 @@ class Gemma3nDecoderLayer(nn.Module):
@support_torch_compile
class Gemma3nTextModel(nn.Module):
class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -540,6 +546,7 @@ class Gemma3nTextModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.embed_scale = torch.tensor(
@ -549,6 +556,7 @@ class Gemma3nTextModel(nn.Module):
self.embed_tokens_per_layer = VocabParallelEmbedding(
config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_embed_tokens",
)
self.embed_scale_per_layer = torch.tensor(
@ -582,7 +590,7 @@ class Gemma3nTextModel(nn.Module):
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_projections",
prefix=f"{prefix}.altup_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs)
])
self.altup_unembed_projections = nn.ModuleList([
@ -593,7 +601,7 @@ class Gemma3nTextModel(nn.Module):
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs)
])
@ -774,7 +782,7 @@ class Gemma3nModel(nn.Module):
**kwargs)
class Gemma3nForConditionalGeneration(nn.Module):
class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",