mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Support MiMo-7B inference with MTP (#17433)
Signed-off-by: wp-alpha <wangpeng66@xiaomi.com> Co-authored-by: wangpeng66 <wangpeng66@xiaomi.com>
This commit is contained in:
@ -592,6 +592,11 @@ Specified using `--task generate`.
|
||||
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
|
||||
*
|
||||
*
|
||||
- * `MiMoForCausalLM`
|
||||
* MiMo
|
||||
* `XiaomiMiMo/MiMo-7B-RL`, etc.
|
||||
*
|
||||
*
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
|
@ -242,6 +242,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
is_available_online=False,
|
||||
trust_remote_code=True),
|
||||
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
|
||||
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
"BartModel": _HfExamplesInfo("facebook/bart-base"),
|
||||
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
|
||||
@ -403,6 +405,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
||||
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||
trust_remote_code=True,
|
||||
speculative_model="XiaomiMiMo/MiMo-7B-RL")
|
||||
}
|
||||
|
||||
_TRANSFORMERS_MODELS = {
|
||||
|
@ -1139,7 +1139,8 @@ class ModelConfig:
|
||||
def get_layers_start_end_indices(
|
||||
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
if self.hf_text_config.model_type == "deepseek_mtp":
|
||||
if (self.hf_text_config.model_type == "deepseek_mtp"
|
||||
or self.hf_config.model_type == "mimo_mtp"):
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_nextn_predict_layers", 0)
|
||||
else:
|
||||
@ -2357,6 +2358,17 @@ class SpeculativeConfig:
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["DeepSeekMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.architectures[0] == "MiMoForCausalLM":
|
||||
hf_config.model_type = "mimo_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["MiMoMTPModel"]
|
||||
})
|
||||
return hf_config
|
||||
|
||||
return hf_config
|
||||
|
||||
def __post_init__(self):
|
||||
@ -2373,8 +2385,10 @@ class SpeculativeConfig:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
# mtp acceleration for more models besides deepseek_v3
|
||||
if self.target_model_config and \
|
||||
self.target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v3":
|
||||
(self.target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v3" or
|
||||
self.target_model_config.hf_text_config.model_type \
|
||||
== "mimo"):
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
|
190
vllm/model_executor/models/mimo.py
Normal file
190
vllm/model_executor/models/mimo.py
Normal file
@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||||
# Copyright 2025 Xiaomi Corporation.
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiMo model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class MiMoModel(Qwen2Model):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "mtp_layers" in name:
|
||||
continue
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = MiMoModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = self.model.norm(hidden_states)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
283
vllm/model_executor/models/mimo_mtp.py
Normal file
283
vllm/model_executor/models/mimo_mtp.py
Normal file
@ -0,0 +1,283 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py
|
||||
# Copyright 2025 Xiaomi Corporation.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2024 DeepSeek-AI team.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiMo-MTP model."""
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class MiMoMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.token_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.hidden_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.mtp_block = Qwen2DecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
self.final_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds[positions == 0] = 0
|
||||
inputs_embeds = self.token_layernorm(inputs_embeds)
|
||||
previous_hidden_states = self.hidden_layernorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.input_proj(
|
||||
torch.cat([previous_hidden_states, inputs_embeds], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return self.final_layernorm(hidden_states)
|
||||
|
||||
|
||||
class MiMoMultiTokenPredictor(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.mtp_layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
MiMoMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
return self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)](
|
||||
inputs_embeds,
|
||||
positions,
|
||||
previous_hidden_states,
|
||||
spec_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: ParallelLMHead,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)]
|
||||
logits = self.logits_processor(lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
class MiMoMTP(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = MiMoMultiTokenPredictor(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert spec_step_idx == 0, "mimo_mtp only support predict one token now"
|
||||
hidden_states = self.model(input_ids, positions,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.model.compute_logits(hidden_states, self.lm_head,
|
||||
sampling_metadata, spec_step_idx)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
name = self.map_model_name_to_mtp_param_name(name)
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if "mtp_layers" not in name:
|
||||
break
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if "mtp_layers" not in name and ("embed_tokens" not in name
|
||||
and "lm_head" not in name):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def map_model_name_to_mtp_param_name(self, name: str) -> str:
|
||||
import re
|
||||
name_without_prefix = [
|
||||
"token_layernorm", "hidden_layernorm", "input_proj",
|
||||
"final_layernorm"
|
||||
]
|
||||
for sub_name in name_without_prefix:
|
||||
if sub_name in name:
|
||||
return name
|
||||
pattern = r"model.mtp_layers.(\d+)."
|
||||
group = re.match(pattern, name)
|
||||
if group is not None:
|
||||
name = name.replace(group.group(), group.group() + "mtp_block.")
|
||||
return name
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
Add .mtp_block for modules in transformer layer block for spec layer
|
||||
"""
|
||||
spec_layer_weight_names = [
|
||||
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
|
||||
]
|
||||
spec_layer_weight = False
|
||||
for weight_name in spec_layer_weight_names:
|
||||
if weight_name in name:
|
||||
spec_layer_weight = True
|
||||
break
|
||||
if not spec_layer_weight:
|
||||
# treat rest weights as weights for transformer layer block
|
||||
name = name.replace(f"model.layers.{spec_layer}.",
|
||||
f"model.layers.{spec_layer}.mtp_block.")
|
||||
return name
|
@ -88,6 +88,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
|
||||
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
|
||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||
@ -215,6 +216,7 @@ _MULTIMODAL_MODELS = {
|
||||
}
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
|
@ -71,7 +71,11 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
or (speculative_config.draft_model_config.hf_config.model_type ==
|
||||
model_config.hf_config.model_type) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \
|
||||
not in ("medusa",
|
||||
"mlp_speculator",
|
||||
"eagle",
|
||||
"deepseek_mtp",
|
||||
"mimo_mtp")) \
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
|
Reference in New Issue
Block a user