[Model] Support MiMo-7B inference with MTP (#17433)

Signed-off-by: wp-alpha <wangpeng66@xiaomi.com>
Co-authored-by: wangpeng66 <wangpeng66@xiaomi.com>
This commit is contained in:
bwshen-mi
2025-05-13 07:25:33 +08:00
committed by GitHub
parent f065de4e88
commit acee8f48aa
7 changed files with 507 additions and 4 deletions

View File

@ -592,6 +592,11 @@ Specified using `--task generate`.
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
*
*
- * `MiMoForCausalLM`
* MiMo
* `XiaomiMiMo/MiMo-7B-RL`, etc.
*
*
:::
:::{note}

View File

@ -242,6 +242,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
is_available_online=False,
trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
@ -403,6 +405,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL")
}
_TRANSFORMERS_MODELS = {

View File

@ -1139,7 +1139,8 @@ class ModelConfig:
def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
if self.hf_text_config.model_type == "deepseek_mtp":
if (self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
@ -2357,6 +2358,17 @@ class SpeculativeConfig:
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
if hf_config.architectures[0] == "MiMoForCausalLM":
hf_config.model_type = "mimo_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"]
})
return hf_config
return hf_config
def __post_init__(self):
@ -2373,8 +2385,10 @@ class SpeculativeConfig:
# TODO(Shangming): Refactor mtp configuration logic when supporting
# mtp acceleration for more models besides deepseek_v3
if self.target_model_config and \
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type \
== "deepseek_v3":
== "mimo"):
# use the draft model from the same model:
self.model = self.target_model_config.model
elif self.method in ("ngram", "[ngram]"):

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2025 Xiaomi Corporation.
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiMo model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
logger = init_logger(__name__)
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class MiMoModel(Qwen2Model):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = hidden_states + residual
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "mtp_layers" in name:
continue
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = MiMoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
hidden_states = self.model.norm(hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -0,0 +1,283 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py
# Copyright 2025 Xiaomi Corporation.
# Copyright 2023 The vLLM team.
# Copyright 2024 DeepSeek-AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiMo-MTP model."""
from typing import Iterable, Optional, Set, Tuple
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
class MiMoMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.token_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.hidden_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.mtp_block = Qwen2DecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
inputs_embeds: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
inputs_embeds = self.token_layernorm(inputs_embeds)
previous_hidden_states = self.hidden_layernorm(previous_hidden_states)
hidden_states = self.input_proj(
torch.cat([previous_hidden_states, inputs_embeds], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return self.final_layernorm(hidden_states)
class MiMoMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.mtp_layers = torch.nn.ModuleDict({
str(idx):
MiMoMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
return self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)](
inputs_embeds,
positions,
previous_hidden_states,
spec_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
lm_head: ParallelLMHead,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)]
logits = self.logits_processor(lm_head, hidden_states,
sampling_metadata)
return logits
class MiMoMTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.model = MiMoMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
assert spec_step_idx == 0, "mimo_mtp only support predict one token now"
hidden_states = self.model(input_ids, positions,
previous_hidden_states, inputs_embeds,
spec_step_idx)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states, self.lm_head,
sampling_metadata, spec_step_idx)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
name = self.map_model_name_to_mtp_param_name(name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
if "mtp_layers" not in name:
break
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if "mtp_layers" not in name and ("embed_tokens" not in name
and "lm_head" not in name):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def map_model_name_to_mtp_param_name(self, name: str) -> str:
import re
name_without_prefix = [
"token_layernorm", "hidden_layernorm", "input_proj",
"final_layernorm"
]
for sub_name in name_without_prefix:
if sub_name in name:
return name
pattern = r"model.mtp_layers.(\d+)."
group = re.match(pattern, name)
if group is not None:
name = name.replace(group.group(), group.group() + "mtp_block.")
return name
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
"""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
spec_layer_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
return name

View File

@ -88,6 +88,7 @@ _TEXT_GENERATION_MODELS = {
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
@ -215,6 +216,7 @@ _MULTIMODAL_MODELS = {
}
_SPECULATIVE_DECODING_MODELS = {
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
"EAGLEModel": ("eagle", "EAGLE"),
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),

View File

@ -71,7 +71,11 @@ class Worker(LocalOrDistributedWorkerBase):
or (speculative_config.draft_model_config.hf_config.model_type ==
model_config.hf_config.model_type) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \
not in ("medusa",
"mlp_speculator",
"eagle",
"deepseek_mtp",
"mimo_mtp")) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner