mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@ -26,6 +26,7 @@ from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
||||
GeluAndMul,
|
||||
@ -44,6 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
|
||||
|
||||
from .interfaces import SupportsQuant
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
@ -51,6 +53,8 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
EPS = torch.tensor(torch.finfo().min)
|
||||
|
||||
|
||||
class Gemma3nAltUp(nn.Module):
|
||||
"""Alternating updates (Altup)
|
||||
@ -532,16 +536,29 @@ class Gemma3nDecoderLayer(nn.Module):
|
||||
return corrected_predictions
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
class Gemma3nSelfDecoder(nn.Module):
|
||||
"""
|
||||
Includes altup embedding and self decoder layers
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layers: list[Gemma3nDecoderLayer],
|
||||
layer_idx_start: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layers = decoder_layers
|
||||
self.layer_idx_start = layer_idx_start
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
@ -594,32 +611,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
prefix=f"{prefix}.altup_projections.{idx-1}",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
self.altup_unembed_projections = nn.ModuleList([
|
||||
ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
|
||||
# Transformer blocks.
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Gemma3nDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.eps = torch.tensor(torch.finfo().min)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
def get_per_layer_input_embeddings(
|
||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@ -633,20 +624,11 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
return self.embed_tokens_per_layer(
|
||||
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
||||
|
||||
def forward(
|
||||
def get_per_layer_inputs(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states_0 = inputs_embeds
|
||||
else:
|
||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||
|
||||
hidden_states_0: torch.Tensor,
|
||||
per_layer_inputs: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*hidden_states_0.shape[:-1],
|
||||
@ -655,14 +637,18 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
)
|
||||
per_layer_projection = self.per_layer_projection_norm(
|
||||
per_layer_projection)
|
||||
|
||||
if per_layer_inputs is not None:
|
||||
# Profiling run does not compute per_layer_inputs
|
||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||
per_layer_inputs *= self.per_layer_input_scale
|
||||
else:
|
||||
per_layer_inputs = per_layer_projection
|
||||
return per_layer_inputs
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
|
||||
# Altup embed.
|
||||
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
||||
@ -673,11 +659,77 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
new_magnitude, EPS)
|
||||
hidden_states = torch.stack(hidden_states, dim=-1)
|
||||
return hidden_states
|
||||
|
||||
# Transformer blocks.
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states_0 = inputs_embeds
|
||||
else:
|
||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||
|
||||
adjusted_per_layer_inputs = self.get_per_layer_inputs(
|
||||
hidden_states_0, per_layer_inputs)
|
||||
hidden_states = self.altup_embed(hidden_states_0)
|
||||
|
||||
# [altnum_inputs, num_tokens, hidden_size]
|
||||
hidden_states = hidden_states.permute(2, 0, 1)
|
||||
|
||||
for idx, layer in enumerate(self.decoder_layers):
|
||||
layer_idx = idx + self.layer_idx_start
|
||||
# [altup_num_inputs, num_tokens, hidden_size]
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# [num_tokens, hidden_size, altnum_inputs]
|
||||
hidden_states = hidden_states.permute(1, 2, 0)
|
||||
|
||||
return hidden_states, adjusted_per_layer_inputs
|
||||
|
||||
|
||||
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
class Gemma3nCrossDecoder(nn.Module):
|
||||
"""
|
||||
Cross-decoder layers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layers: list[Gemma3nDecoderLayer],
|
||||
layer_idx_start: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layers = decoder_layers
|
||||
self.layer_idx_start = layer_idx_start
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# [altnum_inputs, num_tokens, hidden_size]
|
||||
hidden_states = hidden_states.permute(2, 0, 1)
|
||||
for idx, layer in enumerate(self.decoder_layers):
|
||||
layer_idx = idx + self.layer_idx_start
|
||||
# [altup_num_inputs, num_tokens, hidden_size]
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
@ -685,22 +737,249 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
||||
**kwargs,
|
||||
)
|
||||
# [num_tokens, hidden_size, altnum_inputs]
|
||||
hidden_states = hidden_states.permute(1, 2, 0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# This disables torch.compile if --kv-sharing-fast-prefill passed
|
||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||
cache_config.kv_sharing_fast_prefill)
|
||||
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.altup_unembed_projections = nn.ModuleList([
|
||||
ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
|
||||
# Allocate config.num_kv_shared_layers layers for self-decoder
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Gemma3nDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
first_kv_shared_layer_idx = (config.num_hidden_layers -
|
||||
config.num_kv_shared_layers)
|
||||
|
||||
# NOTE(sarckk): importing this top level seems to cause issues
|
||||
# during running of tests.
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
|
||||
with set_model_tag("self_decoder"):
|
||||
self.self_decoder = Gemma3nSelfDecoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.self_decoder",
|
||||
decoder_layers=self.layers[:first_kv_shared_layer_idx],
|
||||
layer_idx_start=0,
|
||||
)
|
||||
# Layer idx 20-30 are cross-decoder layers in YOCO
|
||||
with set_model_tag("cross_decoder"):
|
||||
self.cross_decoder = Gemma3nCrossDecoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.cross_decoder",
|
||||
decoder_layers=self.layers[first_kv_shared_layer_idx:],
|
||||
layer_idx_start=first_kv_shared_layer_idx,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
|
||||
|
||||
if self.fast_prefill_enabled:
|
||||
# Allocate static buffers for CUDAGraph
|
||||
# TODO(sarckk): Extract this functionality to interface
|
||||
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
device = next(self.parameters()).device
|
||||
self.positions = torch.zeros(max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
self.hidden_states = torch.zeros(
|
||||
(max_num_tokens, config.hidden_size,
|
||||
self.config.altup_num_inputs),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.per_layer_inputs = torch.zeros(
|
||||
(max_num_tokens, self.config.num_hidden_layers,
|
||||
self.config.hidden_size_per_layer_input),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def embed_tokens(self):
|
||||
return self.self_decoder.embed_tokens
|
||||
|
||||
def get_per_layer_input_embeddings(
|
||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.self_decoder.get_per_layer_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.self_decoder.get_input_embeddings(input_ids)
|
||||
|
||||
def fast_prefill_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
logits_indices_padded, num_logits_indices = None, None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
# attn_metadata is None during dummy runs
|
||||
if (self.fast_prefill_enabled and attn_metadata is not None):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
# Last layer is a KV sharing layer
|
||||
layer_attn_metadata = attn_metadata[
|
||||
self.layers[-1].self_attn.attn.layer_name]
|
||||
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
|
||||
logits_indices_padded = (
|
||||
layer_attn_metadata.logits_indices_padded)
|
||||
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||
|
||||
# Copy inputs for cudagraph
|
||||
batch_size = positions.size(0)
|
||||
self.positions[:batch_size].copy_(positions)
|
||||
self_decoder_hidden_states, per_layer_inputs_adjusted = \
|
||||
self.self_decoder(
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if logits_indices_padded is None:
|
||||
logits_indices_padded = torch.arange(
|
||||
positions.size(0),
|
||||
dtype=positions.dtype,
|
||||
device=positions.device,
|
||||
)
|
||||
|
||||
# NOTE(sarckk): There is currently a bug caused by
|
||||
# vLLM converting output of last piecewise CUDA graph
|
||||
# to weakref, causing memory to be prematurely freed
|
||||
# when there are multiple compilation units
|
||||
# Keep .clone() until fix in
|
||||
# https://github.com/vllm-project/vllm/pull/22282
|
||||
hidden_states = self_decoder_hidden_states.clone()
|
||||
|
||||
# Copy inputs for cudagraph
|
||||
num_padded_logits_indices = logits_indices_padded.size(0)
|
||||
self.positions[:num_padded_logits_indices].copy_(
|
||||
positions[logits_indices_padded])
|
||||
self.hidden_states[:num_padded_logits_indices].copy_(
|
||||
self_decoder_hidden_states[logits_indices_padded])
|
||||
self.per_layer_inputs[:num_padded_logits_indices].copy_(
|
||||
per_layer_inputs_adjusted[logits_indices_padded])
|
||||
cross_decoder_hidden_states = self.cross_decoder(
|
||||
positions=self.positions[:num_padded_logits_indices],
|
||||
hidden_states=self.hidden_states[:num_padded_logits_indices],
|
||||
per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if num_logits_indices is not None:
|
||||
assert num_logits_indices > 0
|
||||
# Merge cross-decoder and self-decoder hidden states
|
||||
hidden_states[logits_indices_padded[:num_logits_indices]] = (
|
||||
cross_decoder_hidden_states[:num_logits_indices])
|
||||
else:
|
||||
hidden_states = cross_decoder_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def normal_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, per_layer_inputs = self.self_decoder(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.cross_decoder(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def altup_unembed(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Altup unembed.
|
||||
target_magnitude = torch.mean(hidden_states[0]**2,
|
||||
target_magnitude = torch.mean(hidden_states[..., 0]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
hidden_states[i] = self.altup_unembed_projections[i - 1](
|
||||
hidden_states[i])
|
||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
||||
hidden_states[..., i] = self.altup_unembed_projections[i - 1](
|
||||
hidden_states[..., i])
|
||||
new_magnitude = torch.mean(hidden_states[..., i]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
|
||||
hidden_states = torch.mean(hidden_states, dim=0)
|
||||
hidden_states[..., i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, EPS)
|
||||
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
|
||||
hidden_states = torch.mean(hidden_states, dim=-1)
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if self.fast_prefill_enabled:
|
||||
hidden_states = self.fast_prefill_forward(
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.normal_forward(
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.altup_unembed(hidden_states)
|
||||
return self.norm(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
@ -716,6 +995,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
# decoder layer weights, altup_unembed_projections and rmsnorm
|
||||
# are initialized in text model, others are in self decoder
|
||||
if (not name.startswith('layers')
|
||||
and not name.startswith('altup_unembed_projections')
|
||||
and not name.startswith('norm')):
|
||||
name = f"self_decoder.{name}"
|
||||
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
|
Reference in New Issue
Block a user