[KV sharing] Re-land Gemma3n model changes from #22628 (#24357)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-09-23 19:25:34 -07:00
committed by GitHub
parent 359d293006
commit 77d906995c

View File

@ -26,6 +26,7 @@ from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
GeluAndMul,
@ -44,6 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import SupportsQuant
from .utils import (AutoWeightsLoader, extract_layer_index,
@ -51,6 +53,8 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
logger = init_logger(__name__)
EPS = torch.tensor(torch.finfo().min)
class Gemma3nAltUp(nn.Module):
"""Alternating updates (Altup)
@ -532,16 +536,29 @@ class Gemma3nDecoderLayer(nn.Module):
return corrected_predictions
@support_torch_compile
class Gemma3nTextModel(nn.Module, SupportsQuant):
# This enables torch.compile if --kv-sharing-fast-prefill passed
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
kv_sharing_fast_prefill)
class Gemma3nSelfDecoder(nn.Module):
"""
Includes altup embedding and self decoder layers
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma3nDecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
quant_config = vllm_config.quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
@ -594,32 +611,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
prefix=f"{prefix}.altup_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs)
])
self.altup_unembed_projections = nn.ModuleList([
ColumnParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs)
])
# Transformer blocks.
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma3nDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
)
self.eps = torch.tensor(torch.finfo().min)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.embed_scale
def get_per_layer_input_embeddings(
self, input_ids: torch.Tensor) -> torch.Tensor:
@ -633,20 +624,11 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
return self.embed_tokens_per_layer(
per_layer_inputs_tokens) * self.embed_scale_per_layer
def forward(
def get_per_layer_inputs(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
per_layer_inputs: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None:
hidden_states_0 = inputs_embeds
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
hidden_states_0: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor],
) -> torch.Tensor:
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
per_layer_projection = per_layer_projection.reshape(
*hidden_states_0.shape[:-1],
@ -655,14 +637,18 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
)
per_layer_projection = self.per_layer_projection_norm(
per_layer_projection)
if per_layer_inputs is not None:
# Profiling run does not compute per_layer_inputs
per_layer_inputs = per_layer_projection + per_layer_inputs
per_layer_inputs *= self.per_layer_input_scale
else:
per_layer_inputs = per_layer_projection
return per_layer_inputs
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.embed_scale
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
# Altup embed.
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
@ -673,11 +659,77 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
dim=-1,
keepdim=True)**0.5
hidden_states[i] *= target_magnitude / torch.maximum(
new_magnitude, self.eps)
hidden_states = torch.stack(hidden_states, dim=0)
new_magnitude, EPS)
hidden_states = torch.stack(hidden_states, dim=-1)
return hidden_states
# Transformer blocks.
for layer_idx, layer in enumerate(self.layers):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
hidden_states_0 = inputs_embeds
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
adjusted_per_layer_inputs = self.get_per_layer_inputs(
hidden_states_0, per_layer_inputs)
hidden_states = self.altup_embed(hidden_states_0)
# [altnum_inputs, num_tokens, hidden_size]
hidden_states = hidden_states.permute(2, 0, 1)
for idx, layer in enumerate(self.decoder_layers):
layer_idx = idx + self.layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
**kwargs,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states = hidden_states.permute(1, 2, 0)
return hidden_states, adjusted_per_layer_inputs
# This enables torch.compile if --kv-sharing-fast-prefill passed
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
kv_sharing_fast_prefill)
class Gemma3nCrossDecoder(nn.Module):
"""
Cross-decoder layers
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma3nDecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# [altnum_inputs, num_tokens, hidden_size]
hidden_states = hidden_states.permute(2, 0, 1)
for idx, layer in enumerate(self.decoder_layers):
layer_idx = idx + self.layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states = layer(
positions=positions,
@ -685,22 +737,249 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
per_layer_input=per_layer_inputs[:, layer_idx, :],
**kwargs,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states = hidden_states.permute(1, 2, 0)
return hidden_states
# This disables torch.compile if --kv-sharing-fast-prefill passed
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
cache_config.kv_sharing_fast_prefill)
class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.altup_unembed_projections = nn.ModuleList([
ColumnParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs)
])
# Allocate config.num_kv_shared_layers layers for self-decoder
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma3nDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
first_kv_shared_layer_idx = (config.num_hidden_layers -
config.num_kv_shared_layers)
# NOTE(sarckk): importing this top level seems to cause issues
# during running of tests.
from vllm.compilation.backends import set_model_tag
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
with set_model_tag("self_decoder"):
self.self_decoder = Gemma3nSelfDecoder(
vllm_config=vllm_config,
prefix=f"{prefix}.self_decoder",
decoder_layers=self.layers[:first_kv_shared_layer_idx],
layer_idx_start=0,
)
# Layer idx 20-30 are cross-decoder layers in YOCO
with set_model_tag("cross_decoder"):
self.cross_decoder = Gemma3nCrossDecoder(
vllm_config=vllm_config,
prefix=f"{prefix}.cross_decoder",
decoder_layers=self.layers[first_kv_shared_layer_idx:],
layer_idx_start=first_kv_shared_layer_idx,
)
self.norm = RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
)
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
if self.fast_prefill_enabled:
# Allocate static buffers for CUDAGraph
# TODO(sarckk): Extract this functionality to interface
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
device = next(self.parameters()).device
self.positions = torch.zeros(max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(max_num_tokens, config.hidden_size,
self.config.altup_num_inputs),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
self.per_layer_inputs = torch.zeros(
(max_num_tokens, self.config.num_hidden_layers,
self.config.hidden_size_per_layer_input),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
@property
def embed_tokens(self):
return self.self_decoder.embed_tokens
def get_per_layer_input_embeddings(
self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.get_per_layer_input_embeddings(input_ids)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.get_input_embeddings(input_ids)
def fast_prefill_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
logits_indices_padded, num_logits_indices = None, None
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (self.fast_prefill_enabled and attn_metadata is not None):
assert isinstance(attn_metadata, dict)
# Last layer is a KV sharing layer
layer_attn_metadata = attn_metadata[
self.layers[-1].self_attn.attn.layer_name]
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
logits_indices_padded = (
layer_attn_metadata.logits_indices_padded)
num_logits_indices = layer_attn_metadata.num_logits_indices
# Copy inputs for cudagraph
batch_size = positions.size(0)
self.positions[:batch_size].copy_(positions)
self_decoder_hidden_states, per_layer_inputs_adjusted = \
self.self_decoder(
input_ids=input_ids,
positions=self.positions[:batch_size],
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
if logits_indices_padded is None:
logits_indices_padded = torch.arange(
positions.size(0),
dtype=positions.dtype,
device=positions.device,
)
# NOTE(sarckk): There is currently a bug caused by
# vLLM converting output of last piecewise CUDA graph
# to weakref, causing memory to be prematurely freed
# when there are multiple compilation units
# Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states = self_decoder_hidden_states.clone()
# Copy inputs for cudagraph
num_padded_logits_indices = logits_indices_padded.size(0)
self.positions[:num_padded_logits_indices].copy_(
positions[logits_indices_padded])
self.hidden_states[:num_padded_logits_indices].copy_(
self_decoder_hidden_states[logits_indices_padded])
self.per_layer_inputs[:num_padded_logits_indices].copy_(
per_layer_inputs_adjusted[logits_indices_padded])
cross_decoder_hidden_states = self.cross_decoder(
positions=self.positions[:num_padded_logits_indices],
hidden_states=self.hidden_states[:num_padded_logits_indices],
per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
**kwargs,
)
if num_logits_indices is not None:
assert num_logits_indices > 0
# Merge cross-decoder and self-decoder hidden states
hidden_states[logits_indices_padded[:num_logits_indices]] = (
cross_decoder_hidden_states[:num_logits_indices])
else:
hidden_states = cross_decoder_hidden_states
return hidden_states
def normal_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states, per_layer_inputs = self.self_decoder(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
hidden_states = self.cross_decoder(
positions=positions,
hidden_states=hidden_states,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
return hidden_states
def altup_unembed(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# Altup unembed.
target_magnitude = torch.mean(hidden_states[0]**2,
target_magnitude = torch.mean(hidden_states[..., 0]**2,
dim=-1,
keepdim=True)**0.5
for i in range(1, self.config.altup_num_inputs):
hidden_states[i] = self.altup_unembed_projections[i - 1](
hidden_states[i])
new_magnitude = torch.mean(hidden_states[i]**2,
hidden_states[..., i] = self.altup_unembed_projections[i - 1](
hidden_states[..., i])
new_magnitude = torch.mean(hidden_states[..., i]**2,
dim=-1,
keepdim=True)**0.5
hidden_states[i] *= target_magnitude / torch.maximum(
new_magnitude, self.eps)
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
hidden_states = torch.mean(hidden_states, dim=0)
hidden_states[..., i] *= target_magnitude / torch.maximum(
new_magnitude, EPS)
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
hidden_states = torch.mean(hidden_states, dim=-1)
return hidden_states
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
else:
hidden_states = self.normal_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
hidden_states = self.altup_unembed(hidden_states)
return self.norm(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
@ -716,6 +995,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# decoder layer weights, altup_unembed_projections and rmsnorm
# are initialized in text model, others are in self decoder
if (not name.startswith('layers')
and not name.startswith('altup_unembed_projections')
and not name.startswith('norm')):
name = f"self_decoder.{name}"
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for compressed-tensors quantization