[spec decode] Consolidate speculative decode method name for MTP (#25232)

Signed-off-by: zixi-qi <qizixi@meta.com>
This commit is contained in:
qizixi
2025-09-26 15:27:05 -07:00
committed by GitHub
parent cf89202855
commit c70ac4b8ff
6 changed files with 287 additions and 40 deletions

View File

@ -54,6 +54,7 @@ def parse_args():
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
@ -118,9 +119,9 @@ def main(args):
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method.endswith("mtp"):
elif args.method == "mtp":
speculative_config = {
"method": args.method,
"method": "mtp",
"num_speculative_tokens": args.num_spec_tokens,
}
else:

View File

@ -15,6 +15,8 @@ from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
MTP_SIMILARITY_RATE = 0.8
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
@ -222,3 +224,66 @@ def test_eagle_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
],
ids=["mimo", "deepseek"])
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
model_setup: (method, model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_MLA_DISABLE", "1")
method, model_name, tp_size = model_setup
ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size,
trust_remote_code=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"num_speculative_tokens": 1,
"max_model_len": 2048,
},
max_model_len=2048,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 80% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

View File

@ -0,0 +1,195 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
"""Create an MTP proposer with unified model configuration."""
model_config = ModelConfig(model=mimo_7b_dir,
runner="generate",
max_model_len=100,
trust_remote_code=True)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=mimo_7b_dir,
method="mtp",
num_speculative_tokens=num_speculative_tokens,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
mock_get_pp_group):
"""Test MTP-specific model loading with unified model approach."""
# Setup mocks
mock_model = mock.MagicMock()
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_get_model.return_value = mock_model
target_attn_layers = {"target_attn_1": mock.MagicMock()}
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
mock_get_pp_group.return_value = mock_pp_group
# Create target model
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)
target_model.lm_head = mock.MagicMock()
# Create MTP proposer
proposer = _create_mtp_proposer(num_speculative_tokens=4)
proposer.load_model(target_model)
# Verify MTP-specific behavior:
# Model is loaded
mock_get_model.assert_called_once()
# MTP shares lm_head with target model
assert proposer.model.lm_head == target_model.lm_head
# MTP shares embed_tokens with target model
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
@pytest.mark.parametrize("num_speculative_tokens", [1])
def test_mtp_propose(num_speculative_tokens, monkeypatch):
"""Test that MTP's forward method returns hidden states directly"""
device = torch.device(current_platform.device_type)
batch_size = 2
seq_lens = [5, 3]
total_tokens = sum(seq_lens)
vocab_size = 100
proposer = _create_mtp_proposer(num_speculative_tokens)
hidden_size = proposer.hidden_size
# Mock the MTP model to verify it returns hidden states directly
model_mock = mock.MagicMock()
# MTP returns hidden states directly
if num_speculative_tokens == 1:
model_mock.return_value = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
# Multiple forward passes for multi-token speculation
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
h_states = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append(h_states)
model_mock.side_effect = forward_returns
# Mock compute_logits
def create_deterministic_logits(batch_size, vocab_size, token_offset):
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
logits[:, token_offset] = 100.0
return logits
if num_speculative_tokens == 1:
model_mock.compute_logits.return_value = create_deterministic_logits(
batch_size, vocab_size, 42)
else:
logits_returns = [
create_deterministic_logits(batch_size, vocab_size, 42 + i)
for i in range(num_speculative_tokens)
]
model_mock.compute_logits.side_effect = logits_returns
proposer.model = model_mock
proposer.attn_layer_names = ["layer.0"]
# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
common_attn_metadata = create_common_attn_metadata(batch_spec,
block_size=16,
device=device)
target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
sampling_metadata = mock.MagicMock()
# Setup attention metadata
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
proposer.runner = mock.MagicMock()
proposer.attn_metadata_builder = attn_metadata_builder
# Run propose
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
# Verify the model was called correctly
assert model_mock.called
# Verify output shape
assert result.shape == (batch_size, num_speculative_tokens)

View File

@ -32,7 +32,9 @@ logger = init_logger(__name__)
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
"longcat_flash_mtp"]
"longcat_flash_mtp", "mtp"]
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")
@config
@ -207,11 +209,16 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.method in MTP_MODEL_TYPES:
logger.warning("method `%s` is deprecated and replaced with mtp.",
self.method)
self.method = "mtp"
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
if (self.target_model_config
and self.target_model_config.hf_text_config.model_type
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
if self.method == "mtp":
assert (
self.target_model_config
is not None), "target_model_config must be present for mtp"
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
@ -312,31 +319,13 @@ class SpeculativeConfig:
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
in MTP_MODEL_TYPES):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
"Enabling num_speculative_tokens > 1 will run" \
"multiple times of forward on same MTP layer" \
",which may result in lower acceptance rate" \
)
elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")):
@ -353,7 +342,7 @@ class SpeculativeConfig:
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
"eagle, or mtp.")
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
@ -562,8 +551,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")
return self.method in ("eagle", "eagle3", "mtp")
def __repr__(self) -> str:
method = self.method

View File

@ -1481,7 +1481,7 @@ class EngineArgs:
raise NotImplementedError(
"Draft model speculative decoding is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.")
"such as ngram, medusa, eagle, or mtp.")
V1_BACKENDS = [
"FLASH_ATTN",

View File

@ -222,8 +222,7 @@ class EagleProposer:
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp",
"longcat_flash_mtp"):
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
@ -352,8 +351,7 @@ class EagleProposer:
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
if self.method in ("deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp"):
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
@ -888,10 +886,10 @@ class EagleProposer:
def _get_attention_metadata_builder(
self) -> list[AttentionMetadataBuilder]:
"""Find and return the attention metadata builders for EAGLE layers.
Returns:
The metadata builders for EAGLE layers.
Raises:
AssertionError: If no metadata builders are found for EAGLE layers.
"""