Add tree attention backend for v1 (part 1) (#20401)

Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
This commit is contained in:
Giancarlo Delfin
2025-08-03 22:13:26 -07:00
committed by GitHub
parent c2e75b3c11
commit aa7012eb6d
12 changed files with 1098 additions and 25 deletions

View File

@ -17,7 +17,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1 _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
] ]
# Remove flashinfer from the list if it's not available # Remove flashinfer from the list if it's not available

View File

@ -109,11 +109,11 @@ def create_common_attn_metadata(
def get_attention_backend(backend_name: _Backend): def get_attention_backend(backend_name: _Backend):
"""Set up attention backend classes for testing. """Set up attention backend classes for testing.
Args: Args:
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.) backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
vllm_config: VllmConfig instance vllm_config: VllmConfig instance
Returns: Returns:
Tuple of (backend_builder_class, backend_impl_class) Tuple of (backend_builder_class, backend_impl_class)
""" """
@ -126,6 +126,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
_Backend.TRITON_ATTN_VLLM_V1: _Backend.TRITON_ATTN_VLLM_V1:
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
_Backend.TREE_ATTN:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
} }
if backend_name not in backend_map: if backend_name not in backend_map:

View File

@ -202,7 +202,9 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens): @pytest.mark.parametrize("backend",
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
def test_propose(num_speculative_tokens, backend):
# Use GPU device # Use GPU device
device = torch.device(current_platform.device_type) device = torch.device(current_platform.device_type)
@ -301,8 +303,7 @@ def test_propose(num_speculative_tokens):
device=device) device=device)
sampling_metadata = mock.MagicMock() sampling_metadata = mock.MagicMock()
attn_metadata_builder_cls, _ = get_attention_backend( attn_metadata_builder_cls, _ = get_attention_backend(backend)
_Backend.FLASH_ATTN_VLLM_V1)
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names, layer_names=proposer.attn_layer_names,

View File

@ -0,0 +1,299 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import Optional
import torch
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
def __init__(self):
super().__init__()
def forward(self, x):
return x
def forward_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
slot_mapping: torch.Tensor,
seqlen_k: int,
backend: _Backend,
spec_token_tree: Optional[str] = None,
num_spec_tokens: int = 0,
) -> torch.Tensor:
batch_size, q_len, num_heads, dim_per_head = q.shape
num_kv_heads = k.shape[-2]
# Initialize the query and KV sequence lengths.
query_start_loc = q_len * torch.arange(
batch_size + 1, device=q.device, dtype=torch.int32)
query_lens = torch.diff(query_start_loc)
seq_lens = torch.full(
(batch_size, ),
seqlen_k,
device=q.device,
dtype=torch.int32,
)
context_lens = seq_lens - query_lens
max_query_len = q_len
num_actual_tokens = query_start_loc[-1]
softmax_scale = q.shape[-1]**(-0.5)
layer = MockAttentionLayer()
# Build common metadata.
model_name = "meta-llama/Meta-Llama-3-8B"
builder_cls, impl_cls = get_attention_backend(backend)
vllm_config = create_vllm_config(model_name=model_name,
max_model_len=max(seq_lens))
if spec_token_tree is not None:
# Create speculative config if token tree is specified.
vllm_config.speculative_config = SpeculativeConfig(
target_model_config=vllm_config.model_config,
target_parallel_config=ParallelConfig(),
model=model_name,
method="eagle",
num_speculative_tokens=num_spec_tokens,
speculative_token_tree=spec_token_tree)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
)
# Build attention metadata.
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Initialize the backend implementation.
instance = impl_cls(
num_heads=num_heads,
head_size=dim_per_head,
scale=softmax_scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
)
# Run forward pass and return output.
query = q.view(-1, num_heads, dim_per_head)
key = k.view(-1, num_kv_heads, dim_per_head)
value = v.view(-1, num_kv_heads, dim_per_head)
output = torch.empty_like(query)
return instance.forward(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache.clone(),
attn_metadata=attn_metadata,
output=output,
)
def test_tree_attn_correctness() -> None:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
device = "cuda"
tree_attn_masks = {
# Chain.
"[(0,), (0, 0), (0, 0, 0)]":
torch.tensor(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
],
device=device,
dtype=torch.int32,
),
# Tree.
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 0, 1, 0],
[1, 0, 1, 0, 0, 0, 1],
],
device=device,
dtype=torch.int32,
),
}
dim_per_head = 128
num_kv_heads = 2
block_size = 128
max_sequence_length = 8192
randomize_blocks = True
for batch_size in [1, 16, 32]:
for num_heads in [2, 4]:
for sequence_position in [16, 1024, 2048]:
for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
# Assert that the number of heads is divisible
# by the number of KV heads.
assert num_heads % num_kv_heads == 0
# Initialize q, k, and v.
tree_size_q = tree_attn_mask.shape[0]
seqlen_k = sequence_position + tree_size_q
q = torch.randn(
(batch_size, tree_size_q, num_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
k = torch.randn(
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
v = torch.randn(
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
# Setup the block table and KV cache for paged KV.
assert max_sequence_length % block_size == 0
max_blocks_per_batch = max_sequence_length // block_size
kv_cache = torch.randn(
(
2,
batch_size * max_blocks_per_batch,
block_size,
num_kv_heads,
dim_per_head,
),
device=q.device,
dtype=torch.bfloat16,
)
num_alloc_blocks_per_batch = math.ceil(seqlen_k /
block_size)
block_table = torch.zeros(
(batch_size, max_blocks_per_batch),
device=q.device,
dtype=torch.int32,
)
block_ids = torch.arange(
0,
batch_size * num_alloc_blocks_per_batch,
device=q.device,
dtype=torch.int32,
)
if randomize_blocks:
# Randomize the block ids.
block_ids = block_ids[torch.randperm(
block_ids.numel())]
block_table[:, :
num_alloc_blocks_per_batch] = block_ids.view(
-1, num_alloc_blocks_per_batch)
# Setup the slot mapping for the input KVs.
tree_positions = sequence_position + torch.arange(
0,
tree_size_q,
device=q.device,
dtype=torch.int64,
).repeat(batch_size, 1)
tree_slot_mapping = _gen_slot_mapping(
tree_positions, block_table, block_size)
# Compute attention for the tree.
tree_attn_output = forward_attention(
q=q,
k=k,
v=v,
kv_cache=kv_cache,
block_table=block_table,
slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k,
backend=_Backend.TREE_ATTN,
spec_token_tree=spec_token_tree,
num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head)
# Verify that the chain attention output for each
# branch of the tree (computed using FA3) matches
# the tree attention output.
for q_index in range(tree_size_q):
# Get the q, k, and v for the branch.
branch_mask = tree_attn_mask[q_index, :]
branch_indices = torch.nonzero(branch_mask,
as_tuple=True)[0]
q_len = branch_indices.shape[0]
q_branch = q[:, branch_indices]
k_branch = k[:, branch_indices]
v_branch = v[:, branch_indices]
# Setup slot mapping for the branch.
branch_positions = sequence_position + torch.arange(
0,
q_len,
device=q.device,
dtype=torch.int64,
).repeat(batch_size, 1)
branch_slot_mapping = _gen_slot_mapping(
branch_positions, block_table, block_size)
# Compute flash attention for the branch.
flash_attn_output = forward_attention(
q=q_branch,
k=k_branch,
v=v_branch,
kv_cache=kv_cache,
block_table=block_table,
slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len,
backend=_Backend.FLASH_ATTN_VLLM_V1,
).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs.
assert torch.allclose(
tree_attn_output[:, branch_indices],
flash_attn_output,
atol=7.81e-3,
), (f"outputs are not close for "
f"batch_size: {batch_size}, "
f"num_heads: {num_heads}, "
f"sequence_position: {sequence_position}, "
f"tree_attn_mask: {tree_attn_mask}, "
f"q_index: {q_index}.")
def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor,
block_size: int):
block_indices = positions // block_size
blocks = block_table.gather(dim=1, index=block_indices)
return (blocks * block_size + positions % block_size).view(-1)

View File

@ -55,6 +55,7 @@ def kernel_unified_attention_2d(
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs] seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads] alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32 scale, # float32
k_scale, # float32 k_scale, # float32
v_scale, # float32 v_scale, # float32
@ -66,10 +67,12 @@ def kernel_unified_attention_2d(
query_stride_1: tl.int64, # int, should be equal to head_size query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size output_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int stride_k_cache_0: tl.int64, # int
@ -144,6 +147,11 @@ def kernel_unified_attention_2d(
mask=query_mask_1, mask=query_mask_1,
other=0.0) other=0.0)
# query-query attention bias
if USE_QQ_BIAS:
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M]
# compute the length of the longest sequence prefix spanned by any # compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx) # query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
@ -223,6 +231,18 @@ def kernel_unified_attention_2d(
if USE_ALIBI_SLOPES: if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len) S += alibi_slope[:, None] * (seq_offset - context_len)
if USE_QQ_BIAS:
# compute key positions relative to query section
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
# load bias only for keys that correspond to queries
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
qq_bias = tl.load(
qq_bias_row_ptrs + key_rel_pos[None, :],
mask=is_query_key[None, :], # avoid OOB for context keys
other=0.0,
)
S += qq_bias
# compute running maximum # compute running maximum
# m_j : (BLOCK_M,) # m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1)) m_j = tl.maximum(M, tl.max(S, axis=1))
@ -275,6 +295,7 @@ def kernel_unified_attention_3d(
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs] seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads] alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32 scale, # float32
k_scale, # float32 k_scale, # float32
v_scale, # float32 v_scale, # float32
@ -284,10 +305,12 @@ def kernel_unified_attention_3d(
block_table_stride: tl.int64, # int block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size query_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int stride_k_cache_0: tl.int64, # int
@ -373,6 +396,11 @@ def kernel_unified_attention_3d(
mask=query_mask_1, mask=query_mask_1,
other=0.0) other=0.0)
# query-query attention bias
if USE_QQ_BIAS:
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M]
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# iterate through tiles within current segment # iterate through tiles within current segment
@ -442,6 +470,18 @@ def kernel_unified_attention_3d(
if USE_ALIBI_SLOPES: if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len) S += alibi_slope[:, None] * (seq_offset - context_len)
if USE_QQ_BIAS:
# compute key positions relative to query section
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
# load bias only for keys that correspond to queries
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
qq_bias = tl.load(
qq_bias_row_ptrs + key_rel_pos[None, :],
mask=is_query_key[None, :], # avoid OOB for context keys
other=0.0,
)
S += qq_bias
# compute running maximum # compute running maximum
# m_j : (BLOCK_M,) # m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1)) m_j = tl.maximum(M, tl.max(S, axis=1))
@ -586,6 +626,7 @@ def unified_attention(
k_descale, k_descale,
v_descale, v_descale,
alibi_slopes=None, alibi_slopes=None,
qq_bias=None,
): ):
assert causal, "Only causal attention is supported" assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported" assert q_descale is None, "Q scales not supported"
@ -595,6 +636,7 @@ def unified_attention(
"Block size must be at least 32 for fp8" "Block size must be at least 32 for fp8"
use_alibi_slopes = alibi_slopes is not None use_alibi_slopes = alibi_slopes is not None
use_qq_bias = qq_bias is not None
block_size = v.shape[1] block_size = v.shape[1]
num_seqs = len(seqused_k) num_seqs = len(seqused_k)
@ -630,6 +672,7 @@ def unified_attention(
block_tables_ptr=block_table, block_tables_ptr=block_table,
seq_lens_ptr=seqused_k, seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes, alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale, scale=softmax_scale,
k_scale=k_descale, k_scale=k_descale,
v_scale=v_descale, v_scale=v_descale,
@ -641,10 +684,12 @@ def unified_attention(
query_stride_1=q.stride(1), query_stride_1=q.stride(1),
output_stride_0=out.stride(0), output_stride_0=out.stride(0),
output_stride_1=out.stride(1), output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size, BLOCK_SIZE=block_size,
HEAD_SIZE=head_size, HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes, USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0), USE_SOFTCAP=(softcap > 0),
SLIDING_WINDOW=(1 + window_size[0]), SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0), stride_k_cache_0=k.stride(0),
@ -699,6 +744,7 @@ def unified_attention(
block_tables_ptr=block_table, block_tables_ptr=block_table,
seq_lens_ptr=seqused_k, seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes, alibi_slopes_ptr=alibi_slopes,
qq_bias_ptr=qq_bias,
scale=softmax_scale, scale=softmax_scale,
k_scale=k_descale, k_scale=k_descale,
v_scale=v_descale, v_scale=v_descale,
@ -708,10 +754,12 @@ def unified_attention(
block_table_stride=block_table.stride(0), block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0), query_stride_0=q.stride(0),
query_stride_1=q.stride(1), query_stride_1=q.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size, BLOCK_SIZE=block_size,
HEAD_SIZE=head_size, HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes, USE_ALIBI_SLOPES=use_alibi_slopes,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0), USE_SOFTCAP=(softcap > 0),
SLIDING_WINDOW=(1 + window_size[0]), SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0), stride_k_cache_0=k.stride(0),

View File

@ -3049,6 +3049,19 @@ class SpeculativeConfig:
f"num_speculative_tokens:{self.num_speculative_tokens}" f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}") f" must be divisible by {n_predict=}")
if self.speculative_token_tree is None:
# Generate chain of tokens.
self.speculative_token_tree = str([
(i + 1) * (0, )
for i in range(self.num_speculative_tokens)
])
else:
# Sort the token tree breadth-first.
tree_choices = ast.literal_eval(
self.speculative_token_tree)
self.speculative_token_tree = str(
sorted(tree_choices, key=lambda t: (len(t), t)))
self.draft_tensor_parallel_size = \ self.draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_tp( SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config, self.target_parallel_config,

View File

@ -1454,7 +1454,6 @@ class EngineArgs:
"Please consider using other speculative decoding methods " "Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.") "such as ngram, medusa, eagle, or deepseek_mtp.")
# No XFormers so far.
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN_VLLM_V1",
"FLASH_ATTN", "FLASH_ATTN",
@ -1469,6 +1468,7 @@ class EngineArgs:
"ROCM_AITER_MLA", "ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1", "TORCH_SDPA_VLLM_V1",
"FLEX_ATTENTION", "FLEX_ATTENTION",
"TREE_ATTN",
] ]
if (envs.is_set("VLLM_ATTENTION_BACKEND") if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

View File

@ -270,6 +270,7 @@ class CudaPlatformBase(Platform):
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
if selected_backend == _Backend.FLASHINFER: if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.") logger.info_once("Using FlashInfer backend on V1 engine.")
@ -287,6 +288,9 @@ class CudaPlatformBase(Platform):
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.") logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1 return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend on V1 engine.")
return TREE_ATTN_V1
from vllm.attention.selector import is_attn_backend_supported from vllm.attention.selector import is_attn_backend_supported

View File

@ -62,6 +62,7 @@ class _Backend(enum.Enum):
DIFFERENTIAL_FLASH_ATTN = enum.auto() DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto() NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
class PlatformEnum(enum.Enum): class PlatformEnum(enum.Enum):

View File

@ -0,0 +1,452 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with TreeAttention."""
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TREE_ATTN_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["TreeAttentionImpl"]:
return TreeAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TreeAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
return TreeAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class TreeAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
tree_attn_bias: Optional[torch.Tensor] = None
# Cached Prefill/decode metadata.
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
q_start_loc = self.query_start_loc[self.num_decodes:]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[self.num_decodes:]
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_prefill_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc - q_start_loc[0],
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[self.num_decodes:],
slot_mapping=self.slot_mapping[self.num_decode_tokens:],
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
q_start_loc = self.query_start_loc[:self.num_decodes + 1]
q_seqlens = torch.diff(q_start_loc)
kv_seqlens = self.seq_lens[:self.num_decodes]
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = TreeAttentionMetadata(
num_actual_tokens=self.num_decode_tokens,
max_query_len=int(q_seqlens.max().item()),
query_start_loc=q_start_loc,
max_seq_len=int(kv_seqlens.max().item()),
seq_lens=kv_seqlens,
block_table=self.block_table[:self.num_decodes],
slot_mapping=self.slot_mapping[:self.num_decode_tokens],
tree_attn_bias=self.tree_attn_bias,
)
return self._cached_decode_metadata
class TreeAttentionMetadataBuilder(
AttentionMetadataBuilder[TreeAttentionMetadata]):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
tree_choices: list[tuple[int,
...]] = (ast.literal_eval(spec_token_tree)
if spec_token_tree is not None else
[(0, )])
# Construct the tree attention bias.
depth_counts = _get_depth_counts(tree_choices)
self.tree_attn_bias = _prepare_tree_attn_bias(
tree_choices,
depth_counts,
dtype=torch.float32,
device=device,
)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.tree_attn_bias.shape[0])
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> TreeAttentionMetadata:
decode_threshold = self.tree_attn_bias.shape[0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=decode_threshold))
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
return TreeAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=num_decodes,
max_query_len=max_query_len,
query_start_loc=q_start_loc,
max_seq_len=max_seq_len,
seq_lens=kv_seqlens,
block_table=block_table,
slot_mapping=slot_mapping,
tree_attn_bias=self.tree_attn_bias,
)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> TreeAttentionMetadata:
# Cache the original tree attention bias.
orig_tree_attn_bias = self.tree_attn_bias
if draft_index == 0:
# Use prefill for drafting at the root level.
self.tree_attn_bias = torch.empty(0)
else:
# Slice the tree attention bias for drafting.
query_len = common_attn_metadata.max_query_len
start, end = draft_index, draft_index + query_len
self.tree_attn_bias = self.tree_attn_bias[start:end,
start:end].contiguous()
# Build attention bias.
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
# Reset the tree attention bias to the original value.
self.tree_attn_bias = orig_tree_attn_bias
return attn_metadata
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
# Count the number of choices at each depth of the tree.
depth_counts = []
prev_depth = 0
for path in sorted_tree_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
prev_depth = depth
return depth_counts
def _prepare_tree_attn_bias(
sorted_tree_choices: list[tuple[int, ...]],
depth_counts: list[int],
dtype: Optional[torch.dtype],
device: Optional[torch.device],
) -> torch.Tensor:
# +1 comes from the additional root node.
tree_len = len(sorted_tree_choices) + 1
tree_attn_mask = torch.full((tree_len, tree_len),
-torch.inf,
device=device,
dtype=dtype)
# Set diagonal to all zeros. Each token should
# attend to itself.
mask_val = 0
for i in range(tree_len):
tree_attn_mask[i, i] = mask_val
# Set root to all zeros. All tokens attend to it.
tree_attn_mask[:, 0] = mask_val
# Set all ancestors to zeros.
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_tree_choice = sorted_tree_choices[start + j]
# Retrieve ancestor position.
if len(cur_tree_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_tree_choice) - 1):
ancestor_idx.append(
sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
start += depth_counts[i]
return tree_attn_mask
class TreeAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"TreeAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if logits_soft_cap is None:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
TreeAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TreeAttentionImpl.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TreeAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with TreeAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TreeAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])
if prefill_meta := attn_metadata.prefill_metadata:
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_query_len,
seqused_k=decode_meta.seq_lens,
max_seqlen_k=decode_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
qq_bias=decode_meta.tree_attn_bias,
window_size=self.sliding_window,
block_table=decode_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output

View File

@ -214,6 +214,26 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
return self.build(common_prefix_len=0, return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata) common_attn_metadata=common_attn_metadata)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True)
def use_cascade_attention( def use_cascade_attention(
self, self,
common_prefix_len: int, common_prefix_len: int,

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from dataclasses import replace
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@ -17,6 +19,8 @@ from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
@ -74,18 +78,52 @@ class EagleProposer:
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
device=device) device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + self.arange = torch.arange(
1, # We need +1 here because the arange is used to set query_start_loc,
device=device, # which has one more element than batch_size.
dtype=torch.int32) max_batch_size + 1,
device=device,
dtype=torch.int32,
)
self.inputs_embeds = torch.zeros( self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
device=device) device=device)
# Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree
self.tree_choices: list[tuple[int,
...]] = ast.literal_eval(spec_token_tree)
tree_depth = len(self.tree_choices[-1])
# Precompute per-level properties of the tree.
num_drafts_per_level = [0] * tree_depth
for node in self.tree_choices:
num_drafts_per_level[len(node) - 1] += 1
self.cu_drafts_per_level = [num_drafts_per_level[0]]
self.child_drafts_per_level = [num_drafts_per_level[0]]
for level in range(1, tree_depth):
self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
num_drafts_per_level[level])
self.child_drafts_per_level.append(num_drafts_per_level[level] //
num_drafts_per_level[level - 1])
# Find the first level where the tree branches off into one or more
# children.
self.first_branching_level = None
for level in range(tree_depth):
if self.cu_drafts_per_level[level] > level + 1:
self.first_branching_level = level
break
# Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange(
1,
len(self.tree_choices) + 1,
device=device,
dtype=torch.int32,
).repeat(max_batch_size, 1)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
@ -120,11 +158,9 @@ class EagleProposer:
assert self.runner is not None assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups # FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build( attn_metadata = self.runner.attn_metadata_builders[
common_prefix_len=0, 0].build_for_drafting(common_attn_metadata=common_attn_metadata,
common_attn_metadata=common_attn_metadata, draft_index=0)
fast_build=True,
)
# At this moment, we assume all eagle layers belong to the same KV # At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata. # cache group, thus using the same attention metadata.
@ -167,6 +203,22 @@ class EagleProposer:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.first_branching_level == 0:
# Branching has occurred at the root level. Draft using tree
# attention.
draft_token_ids_list = self.propose_tree(
tree_root_level=0,
batch_size=batch_size,
logits=logits,
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
@ -178,16 +230,15 @@ class EagleProposer:
# one layer. Adapt this code to support multiple layers once # one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module. # there's a multi-layer MTP module.
# Currently FlashAttention is the only backend that supports # Currently, only FlashAttention and TreeAttention support multi-token
# multi-token eagle spec decode. This is because the code below # eagle spec decode. This is because the code below
# makes assumptions about attn_metadata attributes available. # makes assumptions about attn_metadata attributes available.
assert isinstance(attn_metadata, FlashAttentionMetadata) assert isinstance(attn_metadata,
(FlashAttentionMetadata, TreeAttentionMetadata))
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \ if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]: batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
@ -196,7 +247,7 @@ class EagleProposer:
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1): for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default. # tensor.argmax() returns int64 by default.
@ -265,7 +316,20 @@ class EagleProposer:
logits = self.model.compute_logits(last_hidden_states[:batch_size], logits = self.model.compute_logits(last_hidden_states[:batch_size],
None) None)
# TODO(wenlong): get more than one token for tree attention if self.first_branching_level == token_index + 1:
# Branching has occurred. The remaining tokens are drafted
# using tree attention.
draft_token_ids_list += self.propose_tree(
tree_root_level=token_index + 1,
batch_size=batch_size,
logits=logits,
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids) draft_token_ids_list.append(draft_token_ids)
@ -273,6 +337,175 @@ class EagleProposer:
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids return draft_token_ids
def propose_tree(
self,
tree_root_level: int,
batch_size: int,
# [num_tokens, vocab_size]
logits: torch.Tensor,
# [num_tokens]
positions: torch.Tensor,
# [num_tokens, hidden_size]
hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0]
assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder)
total_num_drafts = self.cu_drafts_per_level[tree_root_level]
level_num_drafts = total_num_drafts
# Sample a draft token for each child at the tree root level.
num_children = self.child_drafts_per_level[tree_root_level]
if num_children == 1:
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
else:
draft_token_ids = torch.topk(logits, num_children,
dim=-1).indices.view(batch_size, -1)
draft_token_ids_list = [draft_token_ids]
draft_hidden_states = hidden_states.view(batch_size, 1, -1)
# Initialize empty tensors for concatenation with the level outputs.
tree_input_ids = torch.empty(0,
device=self.input_ids.device,
dtype=self.input_ids.dtype)
tree_positions = torch.empty(0,
device=self.positions.device,
dtype=self.positions.dtype)
tree_hidden_states = torch.empty(0,
device=self.hidden_states.device,
dtype=self.hidden_states.dtype)
# Precompute the draft token positions.
flattened_draft_positions = (
positions.view(batch_size, -1) +
self.tree_draft_pos_offsets[:batch_size, :])
tree_depth = len(self.cu_drafts_per_level)
for level in range(tree_root_level, tree_depth - 1):
# Get draft positions for RoPE.
draft_positions = positions + (level + 1)
exceeds_max_model_len = (positions +
total_num_drafts) >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_draft_positions = torch.where(
exceeds_max_model_len,
0,
draft_positions,
)
if level_num_drafts > 1:
# Repeat the positions for each draft at this level.
draft_positions = clamped_draft_positions.repeat_interleave(
level_num_drafts).reshape(batch_size, -1)
if num_children > 1:
# Repeat draft hidden states for each child.
draft_hidden_states = draft_hidden_states.repeat_interleave(
num_children, dim=1)
# Concatenate the draft tokens, positions, and hidden states.
tree_input_ids = torch.cat([tree_input_ids, draft_token_ids],
dim=1)
tree_positions = torch.cat([tree_positions, draft_positions],
dim=1)
tree_hidden_states = torch.cat(
[tree_hidden_states, draft_hidden_states], dim=1)
# Build new attention metadata for the next level of drafts.
# This is necessary to support tree attention.
query_len = total_num_drafts - tree_root_level
common_attn_metadata = replace(
common_attn_metadata,
query_start_loc=query_len * self.arange[:batch_size + 1],
seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
num_actual_tokens=batch_size * query_len,
max_query_len=query_len,
)
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=tree_root_level + 1,
)
# Apply new attention metadata to all layers.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping.
query_positions = flattened_draft_positions[:, level:level +
query_len]
block_numbers = query_positions // self.block_size
block_ids = attn_metadata.block_table.gather(dim=1,
index=block_numbers)
slot_mapping = (block_ids * self.block_size +
query_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
attn_metadata.slot_mapping = slot_mapping.view(-1)
# Copy inputs to buffer for cudagraph.
num_tokens = attn_metadata.num_actual_tokens
input_ids = tree_input_ids.view(-1)
self.input_ids[:num_tokens] = input_ids
self.positions[:num_tokens] = tree_positions.view(-1)
self.hidden_states[:num_tokens] = tree_hidden_states.view(
num_tokens, -1)
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_tokens)
else:
num_input_tokens = num_tokens
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=None,
)
# Get the output hidden states for the draft tokens.
draft_hidden_states = hidden_states[:num_tokens].view(
batch_size, query_len, -1)[:, -level_num_drafts:]
draft_last_hidden_states = last_hidden_states[:num_tokens].view(
batch_size, query_len, -1)[:, -level_num_drafts:]
# Get the output logits for the draft tokens.
logits = self.model.compute_logits(
draft_last_hidden_states.reshape(batch_size * level_num_drafts,
-1),
None,
)
# Sample a draft token for each child at the next tree level.
num_children = self.child_drafts_per_level[level + 1]
if num_children == 1:
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
else:
draft_token_ids = torch.topk(logits, num_children,
dim=-1).indices.view(
batch_size, -1)
draft_token_ids_list.append(draft_token_ids)
# Update the # drafts counters for the next tree level.
level_num_drafts = self.cu_drafts_per_level[level +
1] - total_num_drafts
total_num_drafts = self.cu_drafts_per_level[level + 1]
return draft_token_ids_list
def prepare_inputs( def prepare_inputs(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,