# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from unittest import mock import pytest import torch from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, get_attention_backend) from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.config.load import LoadConfig from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base" def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: """Create an MTP proposer with unified model configuration.""" model_config = ModelConfig(model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True) speculative_config = SpeculativeConfig( target_model_config=model_config, target_parallel_config=ParallelConfig(), model=mimo_7b_dir, method="mtp", num_speculative_tokens=num_speculative_tokens, ) vllm_config = VllmConfig( model_config=model_config, cache_config=CacheConfig(), speculative_config=speculative_config, device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), scheduler_config=SchedulerConfig()) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) @mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') @mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') @mock.patch('vllm.v1.spec_decode.eagle.get_model') def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group): """Test MTP-specific model loading with unified model approach.""" # Setup mocks mock_model = mock.MagicMock() mock_model.model.embed_tokens.weight.shape = (131072, 4096) mock_get_model.return_value = mock_model target_attn_layers = {"target_attn_1": mock.MagicMock()} all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()} target_indexer_layers: dict = {} all_indexer_layers: dict = {} mock_get_layers.side_effect = [ target_attn_layers, target_indexer_layers, all_attn_layers, all_indexer_layers ] mock_pp_group = mock.MagicMock() mock_pp_group.world_size = 1 mock_get_pp_group.return_value = mock_pp_group # Create target model class _TargetModelStub(LlamaForCausalLM): model: mock.MagicMock lm_head: mock.MagicMock target_model = mock.create_autospec(_TargetModelStub, instance=True) target_model.model = mock.MagicMock() target_model.model.embed_tokens.weight.shape = (131072, 4096) target_model.lm_head = mock.MagicMock() # Create MTP proposer proposer = _create_mtp_proposer(num_speculative_tokens=4) proposer.load_model(target_model) # Verify MTP-specific behavior: # Model is loaded mock_get_model.assert_called_once() # MTP shares lm_head with target model assert proposer.model.lm_head == target_model.lm_head # MTP shares embed_tokens with target model assert proposer.model.model.embed_tokens == target_model.model.embed_tokens @pytest.mark.parametrize("num_speculative_tokens", [1]) def test_mtp_propose(num_speculative_tokens, monkeypatch): """Test that MTP's forward method returns hidden states directly""" device = torch.device(current_platform.device_type) batch_size = 2 seq_lens = [5, 3] total_tokens = sum(seq_lens) vocab_size = 100 proposer = _create_mtp_proposer(num_speculative_tokens) hidden_size = proposer.hidden_size # Mock the MTP model to verify it returns hidden states directly model_mock = mock.MagicMock() # MTP returns hidden states directly if num_speculative_tokens == 1: model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device) else: # Multiple forward passes for multi-token speculation forward_returns = [] for i in range(num_speculative_tokens): if i == 0: h_states = torch.zeros(total_tokens, hidden_size, device=device) else: h_states = torch.zeros(batch_size, hidden_size, device=device) forward_returns.append(h_states) model_mock.side_effect = forward_returns # Mock compute_logits def create_deterministic_logits(batch_size, vocab_size, token_offset): logits = torch.full((batch_size, vocab_size), -100.0, device=device) logits[:, token_offset] = 100.0 return logits if num_speculative_tokens == 1: model_mock.compute_logits.return_value = create_deterministic_logits( batch_size, vocab_size, 42) else: logits_returns = [ create_deterministic_logits(batch_size, vocab_size, 42 + i) for i in range(num_speculative_tokens) ] model_mock.compute_logits.side_effect = logits_returns proposer.model = model_mock proposer.attn_layer_names = ["layer.0"] # Prepare inputs batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) common_attn_metadata = create_common_attn_metadata(batch_spec, block_size=16, device=device) target_token_ids = torch.randint(0, vocab_size, (total_tokens, ), device=device) target_positions = torch.cat([ torch.arange(seq_lens[0], device=device), torch.arange(seq_lens[1], device=device) ]) target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) next_token_ids = torch.randint(0, vocab_size, (batch_size, ), dtype=torch.int32, device=device) sampling_metadata = mock.MagicMock() # Setup attention metadata attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, vllm_config=proposer.vllm_config, device=device, ) proposer.runner = mock.MagicMock() proposer.attn_metadata_builder = attn_metadata_builder # Run propose result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, last_token_indices=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) # Verify the model was called correctly assert model_mock.called # Verify output shape assert result.shape == (batch_size, num_speculative_tokens)