[V1][Spec Decode] Support multi-layer eagle draft model (#18030)

Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
qizixi
2025-05-24 02:45:34 -07:00
committed by GitHub
parent a859320575
commit c1e4a4052d
3 changed files with 45 additions and 9 deletions

View File

@ -246,6 +246,9 @@ def test_propose(num_speculative_tokens):
# Assign the mock to the proposer
proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]
# Create input tensors
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
dtype=torch.int32,

View File

@ -12,6 +12,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
@ -150,6 +151,11 @@ class EagleProposer:
else:
raise ValueError(f"Unsupported method: {self.method}")
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
@ -159,7 +165,7 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
with set_forward_context(attn_metadata,
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
@ -245,7 +251,7 @@ class EagleProposer:
self.hidden_states[:batch_size] = hidden_states
# Run the model.
with set_forward_context(attn_metadata,
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
@ -318,8 +324,8 @@ class EagleProposer:
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = next(iter(draft_attn_layer_names))
self.attn_layer_names = list(draft_attn_layer_names)
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
@ -355,6 +361,25 @@ class EagleProposer:
self.hidden_states[:num_tokens],
)
def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
Need this assumption to ensure all eagle layers can use the
same AttentionMetadata.
May extend to multiple AttentionMetadata in the future.
"""
kv_cache_groups: dict[str, int] = {}
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
kv_cache_groups[layer_name] = id
assert len(
set([
kv_cache_groups[layer_name]
for layer_name in self.attn_layer_names
])
) == 1, "All eagle layers should belong to the same kv cache group"
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage

View File

@ -1360,11 +1360,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = async_tensor_h2d(next_token_ids,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata = attn_metadata[
self.drafter.attn_layer_names[0]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
@ -2018,6 +2020,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV cache specs.
raise ValueError("Unknown KV cache spec type.")
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,