mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] PP support for Mamba-like models (#10992)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
@ -128,7 +128,7 @@ Text Generation
|
||||
- FalconMamba
|
||||
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
|
||||
- ✅︎
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`GemmaForCausalLM`
|
||||
- Gemma
|
||||
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
|
||||
@ -193,7 +193,7 @@ Text Generation
|
||||
- Jamba
|
||||
- :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc.
|
||||
- ✅︎
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`LlamaForCausalLM`
|
||||
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
|
||||
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
|
||||
@ -203,7 +203,7 @@ Text Generation
|
||||
- Mamba
|
||||
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
|
||||
-
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`MiniCPMForCausalLM`
|
||||
- MiniCPM
|
||||
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.
|
||||
|
@ -156,13 +156,13 @@ TEXT_GENERATION_MODELS = {
|
||||
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
|
||||
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
|
||||
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
||||
# TODO: Implement PP
|
||||
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
|
||||
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
||||
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
||||
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
|
||||
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
|
||||
# Uses Llama
|
||||
# "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
|
||||
"state-spaces/mamba-130m-hf": PPTestSettings.fast(),
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
|
||||
"mosaicml/mpt-7b": PPTestSettings.fast(),
|
||||
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
|
||||
@ -234,6 +234,8 @@ TEST_MODELS = [
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"microsoft/Phi-3-vision-128k-instruct",
|
||||
"fixie-ai/ultravox-v0_3",
|
||||
# [LANGUAGE GENERATION - HYBRID ARCH]
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
]
|
||||
|
||||
|
||||
|
@ -27,8 +27,8 @@ from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
print_warning_once, random_uuid,
|
||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, print_warning_once, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -284,6 +284,7 @@ class ModelConfig:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
self.is_attention_free = self._init_attention_free()
|
||||
self.is_hybrid = self._init_is_hybrid()
|
||||
self.has_inner_state = self._init_has_inner_state()
|
||||
|
||||
if current_platform.is_neuron():
|
||||
@ -340,6 +341,10 @@ class ModelConfig:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_attention_free_model(architectures)
|
||||
|
||||
def _init_is_hybrid(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_hybrid_model(architectures)
|
||||
|
||||
def _init_has_inner_state(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.model_has_inner_state(architectures)
|
||||
@ -669,26 +674,51 @@ class ModelConfig:
|
||||
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
|
||||
return num_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
def get_layers_start_end_indices(
|
||||
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
|
||||
pp_size = parallel_config.pipeline_parallel_size
|
||||
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
||||
return start, end
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
start, end = self.get_layers_start_end_indices(parallel_config)
|
||||
return end - start
|
||||
|
||||
def get_num_attention_layers(self,
|
||||
parallel_config: "ParallelConfig") -> int:
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
def get_num_layers_by_block_type(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
block_type: LayerBlockType = LayerBlockType.attention,
|
||||
) -> int:
|
||||
# This function relies on 'layers_block_type' in hf_config,
|
||||
# for w/o this attribute, we will need to have workarounds like so
|
||||
attn_block_type = block_type == LayerBlockType.attention
|
||||
is_transformer = not self.is_hybrid and not self.is_attention_free
|
||||
start, end = self.get_layers_start_end_indices(parallel_config)
|
||||
|
||||
num_layers = self.get_num_layers(parallel_config)
|
||||
if is_transformer:
|
||||
# Handle the basic case first
|
||||
return end - start if attn_block_type else 0
|
||||
elif self.is_attention_free:
|
||||
# Attention free
|
||||
# Note that this code assumes there
|
||||
# is only one type of attention-free block type.
|
||||
return 0 if attn_block_type else end - start
|
||||
else:
|
||||
# Hybrid model
|
||||
layers_block_type_value = getattr(self.hf_config,
|
||||
"layers_block_type", None)
|
||||
if layers_block_type_value is None:
|
||||
raise ValueError("The model is an hybrid without a"
|
||||
"layers_block_type in the hf_config,"
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
|
||||
# Transformers supports layers_block_type @property
|
||||
layers = getattr(self.hf_config, "layers_block_type",
|
||||
["attention"] * num_layers)
|
||||
return len([t for t in layers if t == "attention"])
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
|
||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||
"""
|
||||
|
@ -363,6 +363,43 @@ def is_attention_free(
|
||||
return isinstance(model, IsAttentionFree)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IsHybrid(Protocol):
|
||||
"""The interface required for all models like Jamba that have both
|
||||
attention and mamba blocks, indicates that
|
||||
hf_config has 'layers_block_type'"""
|
||||
|
||||
is_hybrid: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model has both mamba and attention blocks
|
||||
, also indicates that the model's hf_config has
|
||||
'layers_block_type' """
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _IsHybridType(Protocol):
|
||||
is_hybrid: ClassVar[Literal[True]]
|
||||
|
||||
|
||||
@overload
|
||||
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]:
|
||||
...
|
||||
|
||||
|
||||
def is_hybrid(
|
||||
model: Union[Type[object], object]
|
||||
) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _IsHybridType)
|
||||
|
||||
return isinstance(model, IsHybrid)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsCrossEncoding(Protocol):
|
||||
"""The interface required for all models that support cross encoding."""
|
||||
|
@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -25,9 +26,12 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import HasInnerState, SupportsLoRA
|
||||
from .utils import maybe_prefix
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -281,16 +285,24 @@ class JambaModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
decoder_layers = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
|
||||
decoder_layers.append(
|
||||
layer_class(config,
|
||||
layer_idx=i,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{i}"))
|
||||
self.layers = nn.ModuleList(decoder_layers)
|
||||
def get_layer(prefix: str):
|
||||
layer_idx = int(prefix.rsplit(".", 1)[1])
|
||||
layer_class = ALL_DECODER_LAYER_TYPES[
|
||||
config.layers_block_type[layer_idx]]
|
||||
return layer_class(
|
||||
config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
self.final_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
@ -304,26 +316,34 @@ class JambaModel(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
kv_cache_index = 0
|
||||
mamba_cache_index = 0
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
kv_cache = None
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, JambaAttentionDecoderLayer):
|
||||
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
|
||||
self.config.attn_layer_period]
|
||||
kv_cache = kv_caches[kv_cache_index]
|
||||
kv_cache_index += 1
|
||||
if isinstance(layer, JambaMambaDecoderLayer):
|
||||
current_state_layer = i - (1 +
|
||||
(i - self.config.attn_layer_offset)
|
||||
// self.config.attn_layer_period)
|
||||
current_state_layer = mamba_cache_index
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
current_state_layer)
|
||||
mamba_cache_index += 1
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
@ -332,11 +352,17 @@ class JambaModel(nn.Module):
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -368,6 +394,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model = JambaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
@ -390,6 +418,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
@ -406,10 +437,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
|
||||
layers_type = self.config.layers_block_type
|
||||
num_mamba_layers = sum(
|
||||
[layer_type == "mamba" for layer_type in layers_type])
|
||||
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
|
||||
*self._get_mamba_cache_shape())
|
||||
@ -423,7 +452,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
state_indices_tensor)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_params,
|
||||
inputs_embeds)
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
@ -504,8 +533,12 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -520,6 +553,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
@ -533,6 +568,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
@ -8,6 +8,7 @@ from transformers import MambaConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
@ -18,13 +19,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree)
|
||||
IsAttentionFree, SupportsPP)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .utils import maybe_prefix
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -95,15 +99,17 @@ class MambaModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
decoder_layers = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
decoder_layers.append(
|
||||
MambaDecoderLayer(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config))
|
||||
self.layers = nn.ModuleList(decoder_layers)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MambaDecoderLayer(
|
||||
config, cache_config=cache_config, quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
@ -114,29 +120,40 @@ class MambaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(i))
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer))
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -148,7 +165,9 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.backbone = MambaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "backbone"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
@ -174,6 +193,9 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.backbone.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.get_input_embeddings(input_ids)
|
||||
|
||||
@ -189,9 +211,12 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
max_batch_size = (VllmConfig.get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, self.config.num_hidden_layers,
|
||||
max_batch_size, *self._get_mamba_cache_shape())
|
||||
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
@ -204,7 +229,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
state_indices_tensor)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params, inputs_embeds)
|
||||
mamba_cache_params, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -252,6 +278,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
@ -21,7 +21,7 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .adapters import as_embedding_model
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp)
|
||||
from .interfaces_base import is_pooling_model, is_text_generation_model
|
||||
@ -218,6 +218,7 @@ class _ModelInfo:
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
is_attention_free: bool
|
||||
is_hybrid: bool
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
@ -239,6 +240,7 @@ class _ModelInfo:
|
||||
supports_pp=supports_pp(model),
|
||||
has_inner_state=has_inner_state(model),
|
||||
is_attention_free=is_attention_free(model),
|
||||
is_hybrid=is_hybrid(model),
|
||||
)
|
||||
|
||||
|
||||
@ -484,6 +486,13 @@ class _ModelRegistry:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_attention_free
|
||||
|
||||
def is_hybrid_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_hybrid
|
||||
|
||||
|
||||
ModelRegistry = _ModelRegistry({
|
||||
model_arch: _LazyRegisteredModel(
|
||||
|
@ -170,6 +170,11 @@ class Device(enum.Enum):
|
||||
CPU = enum.auto()
|
||||
|
||||
|
||||
class LayerBlockType(enum.Enum):
|
||||
attention = "attention"
|
||||
mamba = "mamba"
|
||||
|
||||
|
||||
class Counter:
|
||||
|
||||
def __init__(self, start: int = 0) -> None:
|
||||
|
@ -15,8 +15,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
|
||||
is_pin_memory_available)
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
@ -68,8 +68,8 @@ class GPUModelRunner:
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_attention_layers(
|
||||
parallel_config)
|
||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
|
@ -14,7 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
@ -260,8 +260,8 @@ def _get_cache_block_size(
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_attention_layers = model_config.get_num_attention_layers(
|
||||
parallel_config)
|
||||
num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
|
||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
|
@ -6,8 +6,8 @@ import torch
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
|
||||
is_pin_memory_available)
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
|
||||
get_dtype_size, is_pin_memory_available)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -34,8 +34,8 @@ class CacheEngine:
|
||||
|
||||
self.head_size = model_config.get_head_size()
|
||||
# Models like Jamba, have mixed typed layers, E.g Mamba
|
||||
self.num_attention_layers = model_config.get_num_attention_layers(
|
||||
parallel_config)
|
||||
self.num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
@ -105,8 +105,8 @@ class CacheEngine:
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_attention_layers = model_config.get_num_attention_layers(
|
||||
parallel_config)
|
||||
num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
|
||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
|
Reference in New Issue
Block a user