mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][Spec Decode] Support multi-layer eagle draft model (#18030)
Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
@ -246,6 +246,9 @@ def test_propose(num_speculative_tokens):
|
||||
# Assign the mock to the proposer
|
||||
proposer.model = model_mock
|
||||
|
||||
# Assign draft attn_layer_names since load_model is not invoked
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Create input tensors
|
||||
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
|
||||
dtype=torch.int32,
|
||||
|
@ -12,6 +12,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
|
||||
|
||||
@ -150,6 +151,11 @@ class EagleProposer:
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {self.method}")
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
@ -159,7 +165,7 @@ class EagleProposer:
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
with set_forward_context(attn_metadata,
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
ret_hidden_states = self.model(
|
||||
@ -245,7 +251,7 @@ class EagleProposer:
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(attn_metadata,
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
@ -318,8 +324,8 @@ class EagleProposer:
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
||||
target_attn_layer_names)
|
||||
assert len(draft_attn_layer_names) == 1
|
||||
self.attn_layer_name = next(iter(draft_attn_layer_names))
|
||||
|
||||
self.attn_layer_names = list(draft_attn_layer_names)
|
||||
|
||||
# share embed_tokens with the target model if needed
|
||||
if get_pp_group().world_size == 1:
|
||||
@ -355,6 +361,25 @@ class EagleProposer:
|
||||
self.hidden_states[:num_tokens],
|
||||
)
|
||||
|
||||
def validate_same_kv_cache_group(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Validate that all eagle layers belong to the same KVCacheGroup.
|
||||
Need this assumption to ensure all eagle layers can use the
|
||||
same AttentionMetadata.
|
||||
May extend to multiple AttentionMetadata in the future.
|
||||
"""
|
||||
kv_cache_groups: dict[str, int] = {}
|
||||
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
kv_cache_groups[layer_name] = id
|
||||
assert len(
|
||||
set([
|
||||
kv_cache_groups[layer_name]
|
||||
for layer_name in self.attn_layer_names
|
||||
])
|
||||
) == 1, "All eagle layers should belong to the same kv cache group"
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||
# to sample the draft tokens. We will use this after we find a way to manage
|
||||
|
@ -1360,11 +1360,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = async_tensor_h2d(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
target_device=self.device,
|
||||
pin_memory=True)
|
||||
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
eagle_attn_metadata = attn_metadata[
|
||||
self.drafter.attn_layer_names[0]]
|
||||
|
||||
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
|
||||
if hasattr(eagle_attn_metadata, "block_table"):
|
||||
@ -2018,6 +2020,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# KV cache specs.
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# validate all draft model layers belong to the same kv cache
|
||||
# group
|
||||
self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
|
Reference in New Issue
Block a user