632 lines
23 KiB
Python
632 lines
23 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
|
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
|
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Inference-only NemotronH model."""
|
|
from collections.abc import Iterable
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm import envs
|
|
from vllm.attention.layer import Attention
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.distributed.parallel_state import get_pp_group
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
Mamba2Metadata, prepare_mamba2_metadata)
|
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
|
SupportsLoRA, SupportsPP,
|
|
SupportsQuant)
|
|
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
MambaCacheParams)
|
|
from vllm.model_executor.models.utils import (
|
|
AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers,
|
|
maybe_prefix)
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.transformers_utils.configs import NemotronHConfig
|
|
from vllm.utils import LayerBlockType
|
|
|
|
|
|
class NemotronHMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: NemotronHConfig,
|
|
layer_idx: int,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
bias: bool = False,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
hybrid_override_pattern = config.hybrid_override_pattern
|
|
mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1
|
|
if isinstance(config.intermediate_size, list):
|
|
if len(config.intermediate_size) == 1:
|
|
intermediate_size = config.intermediate_size[0]
|
|
else:
|
|
intermediate_size = config.intermediate_size[mlp_index]
|
|
else:
|
|
intermediate_size = config.intermediate_size
|
|
|
|
self.up_proj = ColumnParallelLinear(
|
|
input_size=config.hidden_size,
|
|
output_size=intermediate_size,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.up_proj",
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
input_size=intermediate_size,
|
|
output_size=config.hidden_size,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
self.act_fn = ReLUSquaredActivation()
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x, _ = self.up_proj(x)
|
|
x = self.act_fn(x)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class NemotronHMLPDecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: NemotronHConfig,
|
|
layer_idx: int,
|
|
model_config: Optional[ModelConfig] = None,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.mixer = NemotronHMLP(
|
|
config,
|
|
quant_config=quant_config,
|
|
bias=config.mlp_bias,
|
|
prefix=f"{prefix}.mixer",
|
|
layer_idx=layer_idx,
|
|
)
|
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
**kwargs,
|
|
):
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.norm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.norm(hidden_states, residual)
|
|
|
|
hidden_states = self.mixer(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
class NemotronHMambaDecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: NemotronHConfig,
|
|
layer_idx: int,
|
|
model_config: Optional[ModelConfig] = None,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.mixer = MambaMixer2(
|
|
hidden_size=config.hidden_size,
|
|
ssm_state_size=config.ssm_state_size,
|
|
conv_kernel_size=config.conv_kernel,
|
|
intermediate_size=config.mamba_num_heads * config.mamba_head_dim,
|
|
use_conv_bias=config.use_conv_bias,
|
|
use_bias=config.use_bias,
|
|
n_groups=config.n_groups,
|
|
num_heads=config.mamba_num_heads,
|
|
head_dim=config.mamba_head_dim,
|
|
rms_norm_eps=config.rms_norm_eps,
|
|
activation=config.mamba_hidden_act,
|
|
model_config=model_config,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mixer",
|
|
)
|
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
mamba_cache_params: MambaCacheParams,
|
|
mamba2_metadata: Mamba2Metadata,
|
|
**kwargs,
|
|
):
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.norm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.norm(hidden_states, residual)
|
|
|
|
output = torch.empty_like(hidden_states)
|
|
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
|
return output, residual
|
|
|
|
|
|
class NemotronHAttention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: NemotronHConfig,
|
|
layer_idx: int,
|
|
model_config: Optional[ModelConfig] = None,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = config.num_attention_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = config.num_key_value_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
if hasattr(config, "head_dim") and config.head_dim is not None:
|
|
self.head_dim = config.head_dim
|
|
else:
|
|
self.head_dim = config.hidden_size // self.total_num_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
config.hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class NemotronHAttentionDecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: NemotronHConfig,
|
|
layer_idx: int,
|
|
model_config: Optional[ModelConfig] = None,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.mixer = NemotronHAttention(
|
|
config,
|
|
layer_idx,
|
|
model_config,
|
|
cache_config,
|
|
quant_config,
|
|
prefix=f"{prefix}.mixer",
|
|
)
|
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
**kwargs,
|
|
):
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.norm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.norm(hidden_states, residual)
|
|
|
|
hidden_states = self.mixer(hidden_states=hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
ALL_DECODER_LAYER_TYPES = {
|
|
"M": NemotronHMambaDecoderLayer,
|
|
"-": NemotronHMLPDecoderLayer,
|
|
"*": NemotronHAttentionDecoderLayer,
|
|
}
|
|
|
|
|
|
@support_torch_compile
|
|
class NemotronHModel(nn.Module):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config: NemotronHConfig = vllm_config.model_config.hf_config
|
|
model_config = vllm_config.model_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
lora_config = vllm_config.lora_config
|
|
|
|
self.config = config
|
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
|
self.vocab_size = config.vocab_size + lora_vocab
|
|
self.org_vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
self.vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
)
|
|
|
|
def get_layer(prefix: str):
|
|
layer_idx = int(prefix.rsplit(".", 1)[1])
|
|
layer_class = ALL_DECODER_LAYER_TYPES[
|
|
config.hybrid_override_pattern[layer_idx]]
|
|
return layer_class(
|
|
config,
|
|
layer_idx,
|
|
model_config,
|
|
cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
)
|
|
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
len(config.hybrid_override_pattern),
|
|
get_layer,
|
|
prefix=f"{prefix}.layers")
|
|
self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory(
|
|
["hidden_states", "residual"], config.hidden_size)
|
|
|
|
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
mamba_cache_params: MambaCacheParams,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
|
|
if not envs.VLLM_USE_V1:
|
|
mamba2_metadata = prepare_mamba2_metadata(
|
|
chunk_size=self.config.chunk_size,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
else:
|
|
# v1 get mamba2_metadata from forward_context
|
|
mamba2_metadata = None
|
|
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
residual = None
|
|
num_non_mamba_layers = 0
|
|
for i, layer in enumerate(self.layers):
|
|
layer_mamba_cache_params = None
|
|
if isinstance(layer,
|
|
NemotronHMambaDecoderLayer) and mamba_cache_params:
|
|
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
|
i - num_non_mamba_layers)
|
|
else:
|
|
num_non_mamba_layers += 1
|
|
|
|
hidden_states, residual = layer(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
mamba_cache_params=layer_mamba_cache_params,
|
|
mamba2_metadata=mamba2_metadata,
|
|
)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states": hidden_states,
|
|
"residual": residual
|
|
})
|
|
hidden_states, _ = self.norm_f(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
attb_params_mapping = {
|
|
"q_proj": "q",
|
|
"k_proj": "k",
|
|
"v_proj": "v",
|
|
}
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if "embeddings" in name:
|
|
name = name.replace("embeddings", "embed_tokens")
|
|
|
|
if "A_log" in name:
|
|
name = name.replace("A_log", "A")
|
|
loaded_weight = loaded_weight.to(torch.float32)
|
|
|
|
if "D" in name:
|
|
loaded_weight = loaded_weight.to(torch.float32)
|
|
|
|
if "dt_bias" in name:
|
|
loaded_weight = loaded_weight.to(torch.float32)
|
|
|
|
# load attn params
|
|
if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]):
|
|
weight_name = next(proj
|
|
for proj in ["q_proj", "k_proj", "v_proj"]
|
|
if proj in name)
|
|
name = name.replace(weight_name, "qkv_proj")
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight,
|
|
attb_params_mapping[weight_name])
|
|
# load other params
|
|
else:
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|
IsHybrid, SupportsQuant):
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
}
|
|
|
|
# LoRA specific attributes
|
|
embedding_modules = {
|
|
"embed_tokens": "input_embeddings",
|
|
"lm_head": "output_embeddings",
|
|
}
|
|
embedding_padding_modules = ["lm_head"]
|
|
|
|
@classmethod
|
|
def get_mamba_state_dtype_from_config(
|
|
cls,
|
|
vllm_config: "VllmConfig",
|
|
) -> tuple[torch.dtype, torch.dtype]:
|
|
|
|
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
|
vllm_config.model_config.dtype,
|
|
vllm_config.cache_config.mamba_cache_dtype,
|
|
vllm_config.cache_config.mamba_ssm_cache_dtype,
|
|
)
|
|
|
|
@classmethod
|
|
def get_mamba_state_shape_from_config(
|
|
cls,
|
|
vllm_config: "VllmConfig",
|
|
use_v1: bool = True,
|
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
|
|
|
Args:
|
|
vllm_config: vLLM config
|
|
use_v1: Get shapes for V1 (or V0)
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- conv_state_shape: Shape for convolutional state cache
|
|
- temporal_state_shape: Shape for state space model cache
|
|
"""
|
|
parallel_config = vllm_config.parallel_config
|
|
hf_config = vllm_config.model_config.hf_config
|
|
intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim
|
|
|
|
return MambaStateShapeCalculator.mamba2_state_shape(
|
|
intermediate_size=intermediate_size,
|
|
tp_world_size=parallel_config.tensor_parallel_size,
|
|
n_groups=hf_config.n_groups,
|
|
num_heads=hf_config.mamba_num_heads,
|
|
head_dim=hf_config.mamba_head_dim,
|
|
state_size=hf_config.ssm_state_size,
|
|
conv_kernel=hf_config.conv_kernel,
|
|
use_v1=use_v1,
|
|
)
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
config = vllm_config.model_config.hf_config
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
cache_config = vllm_config.cache_config
|
|
lora_config = vllm_config.lora_config
|
|
scheduler_config = vllm_config.scheduler_config
|
|
assert not cache_config.enable_prefix_caching, \
|
|
"NemotronH currently does not support prefix caching"
|
|
|
|
self.quant_config = vllm_config.quant_config
|
|
|
|
super().__init__()
|
|
self.config = config
|
|
self.scheduler_config = scheduler_config
|
|
self.model = NemotronHModel(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
self.unpadded_vocab_size = config.vocab_size
|
|
if lora_config:
|
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
|
self.lm_head = ParallelLMHead(
|
|
self.unpadded_vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
|
# We need bigger padding if using lora for kernel
|
|
# compatibility
|
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
|
)
|
|
# Used to track and store by the Mamba cache between steps.
|
|
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
|
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
config.vocab_size)
|
|
|
|
self.make_empty_intmd_tensors = (self.model.make_empty_intmd_tensors)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.get_input_embeddings(input_ids)
|
|
|
|
def forward(self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs):
|
|
|
|
mamba_cache_params = None
|
|
if not envs.VLLM_USE_V1:
|
|
if self.mamba_cache is None:
|
|
|
|
num_mamba_layers = \
|
|
self.model_config.get_num_layers_by_block_type(
|
|
self.vllm_config.parallel_config,
|
|
LayerBlockType.mamba
|
|
)
|
|
mamba_state_shape = \
|
|
self.get_mamba_state_shape_from_config(
|
|
self.vllm_config, use_v1=False)
|
|
mamba_state_dtype = \
|
|
self.get_mamba_state_dtype_from_config(
|
|
self.vllm_config)
|
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
|
num_mamba_layers,
|
|
*mamba_state_shape,
|
|
*mamba_state_dtype)
|
|
|
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
|
|
|
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
|
intermediate_tensors, inputs_embeds)
|
|
|
|
return hidden_states
|
|
|
|
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
input_buffers, **kwargs)
|
|
|
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
# update name in weights before passing to loader
|
|
updated_weights = []
|
|
for name, loaded_weight in weights:
|
|
name = name.replace("backbone", "model")
|
|
updated_weights.append((name, loaded_weight))
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(updated_weights)
|