mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[spec decode] Consolidate speculative decode method name for MTP (#25232)
Signed-off-by: zixi-qi <qizixi@meta.com>
This commit is contained in:
@ -54,6 +54,7 @@ def parse_args():
|
||||
"--method",
|
||||
type=str,
|
||||
default="eagle",
|
||||
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||
)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
@ -118,9 +119,9 @@ def main(args):
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
}
|
||||
elif args.method.endswith("mtp"):
|
||||
elif args.method == "mtp":
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"method": "mtp",
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
}
|
||||
else:
|
||||
|
@ -15,6 +15,8 @@ from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
@ -222,3 +224,66 @@ def test_eagle_correctness(
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||
],
|
||||
ids=["mimo", "deepseek"])
|
||||
def test_mtp_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, int],
|
||||
mm_enabled: bool,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using MTP speculative decoding.
|
||||
model_setup: (method, model_name, tp_size)
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
method, model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"num_speculative_tokens": 1,
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_model_len=2048,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 80% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
195
tests/v1/spec_decode/test_mtp.py
Normal file
195
tests/v1/spec_decode/test_mtp.py
Normal file
@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
get_attention_backend)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
|
||||
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
|
||||
|
||||
|
||||
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
|
||||
"""Create an MTP proposer with unified model configuration."""
|
||||
model_config = ModelConfig(model=mimo_7b_dir,
|
||||
runner="generate",
|
||||
max_model_len=100,
|
||||
trust_remote_code=True)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
model=mimo_7b_dir,
|
||||
method="mtp",
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
speculative_config=speculative_config,
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig())
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config,
|
||||
device=current_platform.device_type)
|
||||
|
||||
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
|
||||
mock_get_pp_group):
|
||||
"""Test MTP-specific model loading with unified model approach."""
|
||||
|
||||
# Setup mocks
|
||||
mock_model = mock.MagicMock()
|
||||
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
target_attn_layers = {"target_attn_1": mock.MagicMock()}
|
||||
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
|
||||
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
|
||||
|
||||
mock_pp_group = mock.MagicMock()
|
||||
mock_pp_group.world_size = 1
|
||||
mock_get_pp_group.return_value = mock_pp_group
|
||||
|
||||
# Create target model
|
||||
class _TargetModelStub(LlamaForCausalLM):
|
||||
model: mock.MagicMock
|
||||
lm_head: mock.MagicMock
|
||||
|
||||
target_model = mock.create_autospec(_TargetModelStub, instance=True)
|
||||
target_model.model = mock.MagicMock()
|
||||
target_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create MTP proposer
|
||||
proposer = _create_mtp_proposer(num_speculative_tokens=4)
|
||||
proposer.load_model(target_model)
|
||||
|
||||
# Verify MTP-specific behavior:
|
||||
# Model is loaded
|
||||
mock_get_model.assert_called_once()
|
||||
# MTP shares lm_head with target model
|
||||
assert proposer.model.lm_head == target_model.lm_head
|
||||
# MTP shares embed_tokens with target model
|
||||
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||
def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
"""Test that MTP's forward method returns hidden states directly"""
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
batch_size = 2
|
||||
seq_lens = [5, 3]
|
||||
total_tokens = sum(seq_lens)
|
||||
vocab_size = 100
|
||||
|
||||
proposer = _create_mtp_proposer(num_speculative_tokens)
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Mock the MTP model to verify it returns hidden states directly
|
||||
model_mock = mock.MagicMock()
|
||||
|
||||
# MTP returns hidden states directly
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.return_value = torch.zeros(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
else:
|
||||
# Multiple forward passes for multi-token speculation
|
||||
forward_returns = []
|
||||
for i in range(num_speculative_tokens):
|
||||
if i == 0:
|
||||
h_states = torch.zeros(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
else:
|
||||
h_states = torch.zeros(batch_size, hidden_size, device=device)
|
||||
forward_returns.append(h_states)
|
||||
model_mock.side_effect = forward_returns
|
||||
|
||||
# Mock compute_logits
|
||||
def create_deterministic_logits(batch_size, vocab_size, token_offset):
|
||||
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||
logits[:, token_offset] = 100.0
|
||||
return logits
|
||||
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.compute_logits.return_value = create_deterministic_logits(
|
||||
batch_size, vocab_size, 42)
|
||||
else:
|
||||
logits_returns = [
|
||||
create_deterministic_logits(batch_size, vocab_size, 42 + i)
|
||||
for i in range(num_speculative_tokens)
|
||||
]
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
proposer.model = model_mock
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Prepare inputs
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
|
||||
common_attn_metadata = create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
|
||||
target_token_ids = torch.randint(0,
|
||||
vocab_size, (total_tokens, ),
|
||||
device=device)
|
||||
target_positions = torch.cat([
|
||||
torch.arange(seq_lens[0], device=device),
|
||||
torch.arange(seq_lens[1], device=device)
|
||||
])
|
||||
target_hidden_states = torch.randn(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
next_token_ids = torch.randint(0,
|
||||
vocab_size, (batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Setup attention metadata
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.attn_metadata_builder = attn_metadata_builder
|
||||
|
||||
# Run propose
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
# Verify the model was called correctly
|
||||
assert model_mock.called
|
||||
# Verify output shape
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
@ -32,7 +32,9 @@ logger = init_logger(__name__)
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
|
||||
"longcat_flash_mtp"]
|
||||
"longcat_flash_mtp", "mtp"]
|
||||
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||
|
||||
|
||||
@config
|
||||
@ -207,11 +209,16 @@ class SpeculativeConfig:
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.method in MTP_MODEL_TYPES:
|
||||
logger.warning("method `%s` is deprecated and replaced with mtp.",
|
||||
self.method)
|
||||
self.method = "mtp"
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
if (self.target_model_config
|
||||
and self.target_model_config.hf_text_config.model_type
|
||||
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
|
||||
if self.method == "mtp":
|
||||
assert (
|
||||
self.target_model_config
|
||||
is not None), "target_model_config must be present for mtp"
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
@ -312,31 +319,13 @@ class SpeculativeConfig:
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
in MTP_MODEL_TYPES):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Deepseek MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"ernie_mtp"):
|
||||
self.method = "ernie_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Ernie MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"qwen3_next_mtp"):
|
||||
self.method = "qwen3_next_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Qwen3Next MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
"Enabling num_speculative_tokens > 1 will run" \
|
||||
"multiple times of forward on same MTP layer" \
|
||||
",which may result in lower acceptance rate" \
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("longcat_flash_mtp")):
|
||||
@ -353,7 +342,7 @@ class SpeculativeConfig:
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or deepseek_mtp.")
|
||||
"eagle, or mtp.")
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
@ -562,8 +551,7 @@ class SpeculativeConfig:
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||
return self.method in ("eagle", "eagle3", "mtp")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
|
@ -1481,7 +1481,7 @@ class EngineArgs:
|
||||
raise NotImplementedError(
|
||||
"Draft model speculative decoding is not supported yet. "
|
||||
"Please consider using other speculative decoding methods "
|
||||
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
||||
"such as ngram, medusa, eagle, or mtp.")
|
||||
|
||||
V1_BACKENDS = [
|
||||
"FLASH_ATTN",
|
||||
|
@ -222,8 +222,7 @@ class EagleProposer:
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp",
|
||||
"longcat_flash_mtp"):
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
@ -352,8 +351,7 @@ class EagleProposer:
|
||||
hidden_states=self.hidden_states[:input_batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp"):
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = ret_hidden_states
|
||||
else:
|
||||
@ -888,10 +886,10 @@ class EagleProposer:
|
||||
def _get_attention_metadata_builder(
|
||||
self) -> list[AttentionMetadataBuilder]:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
|
||||
|
||||
Returns:
|
||||
The metadata builders for EAGLE layers.
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If no metadata builders are found for EAGLE layers.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user