[Spec decode] automatically disable mm for text-only draft models (#25667)

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
Jonas M. Kübler
2025-09-27 02:10:21 +02:00
committed by GitHub
parent 4e33a7ea85
commit 6f5c0931c1
2 changed files with 78 additions and 62 deletions

View File

@ -8,7 +8,7 @@ from typing import Any, Union
import pytest
import torch
from tests.utils import get_attn_backend_list_based_on_platform
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
@ -88,69 +88,66 @@ def test_ngram_correctness(
Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False)
test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches >= int(0.66 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches >= int(0.66 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
False,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
(("eagle", "eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
"llama4_eagle", "llama4_eagle_mm",
"deepseek_eagle"
])
@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
False,
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
(("eagle", "eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle",
"llama4_eagle_mm", "deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_eagle_correctness(
@ -174,9 +171,14 @@ def test_eagle_correctness(
model_setup: (method, model_name, eagle_model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
pass
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN does not support "

View File

@ -804,6 +804,20 @@ class EagleProposer:
self.attn_layer_names = list(draft_attn_layer_names)
if self.is_multimodal_model:
# Even if the target model is multimodal, we can also use
# text-only draft models
try:
dummy_input_ids = torch.tensor([[1]],
device=self.input_ids.device)
self.model.get_input_embeddings(dummy_input_ids,
multimodal_embeddings=None)
except (NotImplementedError, AttributeError, TypeError):
logger.warning(
"Draft model does not support multimodal inputs, "
"falling back to text-only mode")
self.is_multimodal_model = False
if supports_multimodal(target_model):
# handle multimodality
self.model.config.image_token_index = (