mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Add tree attention backend for v1 (part 1) (#20401)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
This commit is contained in:
@ -17,7 +17,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
|
|||||||
|
|
||||||
BACKENDS_TO_TEST = [
|
BACKENDS_TO_TEST = [
|
||||||
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
|
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
|
||||||
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1
|
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
|
||||||
]
|
]
|
||||||
|
|
||||||
# Remove flashinfer from the list if it's not available
|
# Remove flashinfer from the list if it's not available
|
||||||
|
@ -109,11 +109,11 @@ def create_common_attn_metadata(
|
|||||||
|
|
||||||
def get_attention_backend(backend_name: _Backend):
|
def get_attention_backend(backend_name: _Backend):
|
||||||
"""Set up attention backend classes for testing.
|
"""Set up attention backend classes for testing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
|
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
|
||||||
vllm_config: VllmConfig instance
|
vllm_config: VllmConfig instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (backend_builder_class, backend_impl_class)
|
Tuple of (backend_builder_class, backend_impl_class)
|
||||||
"""
|
"""
|
||||||
@ -126,6 +126,8 @@ def get_attention_backend(backend_name: _Backend):
|
|||||||
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
|
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
|
||||||
_Backend.TRITON_ATTN_VLLM_V1:
|
_Backend.TRITON_ATTN_VLLM_V1:
|
||||||
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
|
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
|
||||||
|
_Backend.TREE_ATTN:
|
||||||
|
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
|
||||||
}
|
}
|
||||||
|
|
||||||
if backend_name not in backend_map:
|
if backend_name not in backend_map:
|
||||||
|
@ -202,7 +202,9 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||||
def test_propose(num_speculative_tokens):
|
@pytest.mark.parametrize("backend",
|
||||||
|
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
|
||||||
|
def test_propose(num_speculative_tokens, backend):
|
||||||
# Use GPU device
|
# Use GPU device
|
||||||
device = torch.device(current_platform.device_type)
|
device = torch.device(current_platform.device_type)
|
||||||
|
|
||||||
@ -301,8 +303,7 @@ def test_propose(num_speculative_tokens):
|
|||||||
device=device)
|
device=device)
|
||||||
sampling_metadata = mock.MagicMock()
|
sampling_metadata = mock.MagicMock()
|
||||||
|
|
||||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
attn_metadata_builder_cls, _ = get_attention_backend(backend)
|
||||||
_Backend.FLASH_ATTN_VLLM_V1)
|
|
||||||
attn_metadata_builder = attn_metadata_builder_cls(
|
attn_metadata_builder = attn_metadata_builder_cls(
|
||||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||||
layer_names=proposer.attn_layer_names,
|
layer_names=proposer.attn_layer_names,
|
||||||
|
299
tests/v1/spec_decode/test_tree_attention.py
Normal file
299
tests/v1/spec_decode/test_tree_attention.py
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
|
||||||
|
create_vllm_config,
|
||||||
|
get_attention_backend)
|
||||||
|
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||||
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class MockAttentionLayer(torch.nn.Module):
|
||||||
|
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||||
|
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||||
|
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def forward_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
seqlen_k: int,
|
||||||
|
backend: _Backend,
|
||||||
|
spec_token_tree: Optional[str] = None,
|
||||||
|
num_spec_tokens: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, q_len, num_heads, dim_per_head = q.shape
|
||||||
|
num_kv_heads = k.shape[-2]
|
||||||
|
# Initialize the query and KV sequence lengths.
|
||||||
|
query_start_loc = q_len * torch.arange(
|
||||||
|
batch_size + 1, device=q.device, dtype=torch.int32)
|
||||||
|
query_lens = torch.diff(query_start_loc)
|
||||||
|
seq_lens = torch.full(
|
||||||
|
(batch_size, ),
|
||||||
|
seqlen_k,
|
||||||
|
device=q.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
context_lens = seq_lens - query_lens
|
||||||
|
max_query_len = q_len
|
||||||
|
num_actual_tokens = query_start_loc[-1]
|
||||||
|
|
||||||
|
softmax_scale = q.shape[-1]**(-0.5)
|
||||||
|
layer = MockAttentionLayer()
|
||||||
|
|
||||||
|
# Build common metadata.
|
||||||
|
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||||
|
builder_cls, impl_cls = get_attention_backend(backend)
|
||||||
|
vllm_config = create_vllm_config(model_name=model_name,
|
||||||
|
max_model_len=max(seq_lens))
|
||||||
|
if spec_token_tree is not None:
|
||||||
|
# Create speculative config if token tree is specified.
|
||||||
|
vllm_config.speculative_config = SpeculativeConfig(
|
||||||
|
target_model_config=vllm_config.model_config,
|
||||||
|
target_parallel_config=ParallelConfig(),
|
||||||
|
model=model_name,
|
||||||
|
method="eagle",
|
||||||
|
num_speculative_tokens=num_spec_tokens,
|
||||||
|
speculative_token_tree=spec_token_tree)
|
||||||
|
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||||
|
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
|
||||||
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
query_start_loc_cpu=query_start_loc.cpu(),
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_cpu=seq_lens.cpu(),
|
||||||
|
num_computed_tokens_cpu=context_lens.cpu(),
|
||||||
|
num_reqs=batch_size,
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
block_table_tensor=block_table,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build attention metadata.
|
||||||
|
attn_metadata = builder.build(
|
||||||
|
common_prefix_len=0,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the backend implementation.
|
||||||
|
instance = impl_cls(
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=dim_per_head,
|
||||||
|
scale=softmax_scale,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
alibi_slopes=None,
|
||||||
|
sliding_window=None,
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run forward pass and return output.
|
||||||
|
query = q.view(-1, num_heads, dim_per_head)
|
||||||
|
key = k.view(-1, num_kv_heads, dim_per_head)
|
||||||
|
value = v.view(-1, num_kv_heads, dim_per_head)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
return instance.forward(
|
||||||
|
layer=layer,
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache.clone(),
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
output=output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tree_attn_correctness() -> None:
|
||||||
|
torch.manual_seed(42)
|
||||||
|
torch.cuda.manual_seed_all(42)
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
tree_attn_masks = {
|
||||||
|
# Chain.
|
||||||
|
"[(0,), (0, 0), (0, 0, 0)]":
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
[1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
# Tree.
|
||||||
|
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0, 0],
|
||||||
|
[1, 0, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 1, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 1, 0, 0],
|
||||||
|
[1, 0, 1, 0, 0, 1, 0],
|
||||||
|
[1, 0, 1, 0, 0, 0, 1],
|
||||||
|
],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_per_head = 128
|
||||||
|
num_kv_heads = 2
|
||||||
|
block_size = 128
|
||||||
|
max_sequence_length = 8192
|
||||||
|
randomize_blocks = True
|
||||||
|
for batch_size in [1, 16, 32]:
|
||||||
|
for num_heads in [2, 4]:
|
||||||
|
for sequence_position in [16, 1024, 2048]:
|
||||||
|
for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
|
||||||
|
# Assert that the number of heads is divisible
|
||||||
|
# by the number of KV heads.
|
||||||
|
assert num_heads % num_kv_heads == 0
|
||||||
|
|
||||||
|
# Initialize q, k, and v.
|
||||||
|
tree_size_q = tree_attn_mask.shape[0]
|
||||||
|
seqlen_k = sequence_position + tree_size_q
|
||||||
|
q = torch.randn(
|
||||||
|
(batch_size, tree_size_q, num_heads, dim_per_head),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
k = torch.randn(
|
||||||
|
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
v = torch.randn(
|
||||||
|
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup the block table and KV cache for paged KV.
|
||||||
|
assert max_sequence_length % block_size == 0
|
||||||
|
max_blocks_per_batch = max_sequence_length // block_size
|
||||||
|
kv_cache = torch.randn(
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
batch_size * max_blocks_per_batch,
|
||||||
|
block_size,
|
||||||
|
num_kv_heads,
|
||||||
|
dim_per_head,
|
||||||
|
),
|
||||||
|
device=q.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
num_alloc_blocks_per_batch = math.ceil(seqlen_k /
|
||||||
|
block_size)
|
||||||
|
block_table = torch.zeros(
|
||||||
|
(batch_size, max_blocks_per_batch),
|
||||||
|
device=q.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
block_ids = torch.arange(
|
||||||
|
0,
|
||||||
|
batch_size * num_alloc_blocks_per_batch,
|
||||||
|
device=q.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
if randomize_blocks:
|
||||||
|
# Randomize the block ids.
|
||||||
|
block_ids = block_ids[torch.randperm(
|
||||||
|
block_ids.numel())]
|
||||||
|
block_table[:, :
|
||||||
|
num_alloc_blocks_per_batch] = block_ids.view(
|
||||||
|
-1, num_alloc_blocks_per_batch)
|
||||||
|
|
||||||
|
# Setup the slot mapping for the input KVs.
|
||||||
|
tree_positions = sequence_position + torch.arange(
|
||||||
|
0,
|
||||||
|
tree_size_q,
|
||||||
|
device=q.device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
).repeat(batch_size, 1)
|
||||||
|
tree_slot_mapping = _gen_slot_mapping(
|
||||||
|
tree_positions, block_table, block_size)
|
||||||
|
|
||||||
|
# Compute attention for the tree.
|
||||||
|
tree_attn_output = forward_attention(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_table=block_table,
|
||||||
|
slot_mapping=tree_slot_mapping,
|
||||||
|
seqlen_k=seqlen_k,
|
||||||
|
backend=_Backend.TREE_ATTN,
|
||||||
|
spec_token_tree=spec_token_tree,
|
||||||
|
num_spec_tokens=tree_size_q - 1,
|
||||||
|
).view(batch_size, -1, num_heads, dim_per_head)
|
||||||
|
|
||||||
|
# Verify that the chain attention output for each
|
||||||
|
# branch of the tree (computed using FA3) matches
|
||||||
|
# the tree attention output.
|
||||||
|
for q_index in range(tree_size_q):
|
||||||
|
# Get the q, k, and v for the branch.
|
||||||
|
branch_mask = tree_attn_mask[q_index, :]
|
||||||
|
branch_indices = torch.nonzero(branch_mask,
|
||||||
|
as_tuple=True)[0]
|
||||||
|
q_len = branch_indices.shape[0]
|
||||||
|
q_branch = q[:, branch_indices]
|
||||||
|
k_branch = k[:, branch_indices]
|
||||||
|
v_branch = v[:, branch_indices]
|
||||||
|
|
||||||
|
# Setup slot mapping for the branch.
|
||||||
|
branch_positions = sequence_position + torch.arange(
|
||||||
|
0,
|
||||||
|
q_len,
|
||||||
|
device=q.device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
).repeat(batch_size, 1)
|
||||||
|
branch_slot_mapping = _gen_slot_mapping(
|
||||||
|
branch_positions, block_table, block_size)
|
||||||
|
|
||||||
|
# Compute flash attention for the branch.
|
||||||
|
flash_attn_output = forward_attention(
|
||||||
|
q=q_branch,
|
||||||
|
k=k_branch,
|
||||||
|
v=v_branch,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_table=block_table,
|
||||||
|
slot_mapping=branch_slot_mapping,
|
||||||
|
seqlen_k=sequence_position + q_len,
|
||||||
|
backend=_Backend.FLASH_ATTN_VLLM_V1,
|
||||||
|
).view(batch_size, -1, num_heads, dim_per_head)
|
||||||
|
|
||||||
|
# Compare the outputs.
|
||||||
|
assert torch.allclose(
|
||||||
|
tree_attn_output[:, branch_indices],
|
||||||
|
flash_attn_output,
|
||||||
|
atol=7.81e-3,
|
||||||
|
), (f"outputs are not close for "
|
||||||
|
f"batch_size: {batch_size}, "
|
||||||
|
f"num_heads: {num_heads}, "
|
||||||
|
f"sequence_position: {sequence_position}, "
|
||||||
|
f"tree_attn_mask: {tree_attn_mask}, "
|
||||||
|
f"q_index: {q_index}.")
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor,
|
||||||
|
block_size: int):
|
||||||
|
block_indices = positions // block_size
|
||||||
|
blocks = block_table.gather(dim=1, index=block_indices)
|
||||||
|
return (blocks * block_size + positions % block_size).view(-1)
|
@ -55,6 +55,7 @@ def kernel_unified_attention_2d(
|
|||||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
seq_lens_ptr, # [num_seqs]
|
seq_lens_ptr, # [num_seqs]
|
||||||
alibi_slopes_ptr, # [num_query_heads]
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
|
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||||
scale, # float32
|
scale, # float32
|
||||||
k_scale, # float32
|
k_scale, # float32
|
||||||
v_scale, # float32
|
v_scale, # float32
|
||||||
@ -66,10 +67,12 @@ def kernel_unified_attention_2d(
|
|||||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
output_stride_0: tl.int64, # int
|
output_stride_0: tl.int64, # int
|
||||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
qq_bias_stride_0: tl.int64, # int
|
||||||
BLOCK_SIZE: tl.constexpr, # int
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
HEAD_SIZE: tl.constexpr, # int
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
|
USE_QQ_BIAS: tl.constexpr, # bool
|
||||||
USE_SOFTCAP: tl.constexpr, # bool
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
SLIDING_WINDOW: tl.constexpr, # int
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
stride_k_cache_0: tl.int64, # int
|
stride_k_cache_0: tl.int64, # int
|
||||||
@ -144,6 +147,11 @@ def kernel_unified_attention_2d(
|
|||||||
mask=query_mask_1,
|
mask=query_mask_1,
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
|
# query-query attention bias
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||||
|
) # shape: [BLOCK_M]
|
||||||
|
|
||||||
# compute the length of the longest sequence prefix spanned by any
|
# compute the length of the longest sequence prefix spanned by any
|
||||||
# query token in the current q_block (q_block_local_idx)
|
# query token in the current q_block (q_block_local_idx)
|
||||||
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||||
@ -223,6 +231,18 @@ def kernel_unified_attention_2d(
|
|||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
# compute key positions relative to query section
|
||||||
|
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||||
|
# load bias only for keys that correspond to queries
|
||||||
|
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||||
|
qq_bias = tl.load(
|
||||||
|
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||||
|
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
S += qq_bias
|
||||||
|
|
||||||
# compute running maximum
|
# compute running maximum
|
||||||
# m_j : (BLOCK_M,)
|
# m_j : (BLOCK_M,)
|
||||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
@ -275,6 +295,7 @@ def kernel_unified_attention_3d(
|
|||||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
seq_lens_ptr, # [num_seqs]
|
seq_lens_ptr, # [num_seqs]
|
||||||
alibi_slopes_ptr, # [num_query_heads]
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
|
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||||
scale, # float32
|
scale, # float32
|
||||||
k_scale, # float32
|
k_scale, # float32
|
||||||
v_scale, # float32
|
v_scale, # float32
|
||||||
@ -284,10 +305,12 @@ def kernel_unified_attention_3d(
|
|||||||
block_table_stride: tl.int64, # int
|
block_table_stride: tl.int64, # int
|
||||||
query_stride_0: tl.int64, # int
|
query_stride_0: tl.int64, # int
|
||||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
|
qq_bias_stride_0: tl.int64, # int
|
||||||
BLOCK_SIZE: tl.constexpr, # int
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
HEAD_SIZE: tl.constexpr, # int
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
|
USE_QQ_BIAS: tl.constexpr, # bool
|
||||||
USE_SOFTCAP: tl.constexpr, # bool
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
SLIDING_WINDOW: tl.constexpr, # int
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
stride_k_cache_0: tl.int64, # int
|
stride_k_cache_0: tl.int64, # int
|
||||||
@ -373,6 +396,11 @@ def kernel_unified_attention_3d(
|
|||||||
mask=query_mask_1,
|
mask=query_mask_1,
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
|
# query-query attention bias
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||||
|
) # shape: [BLOCK_M]
|
||||||
|
|
||||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||||
|
|
||||||
# iterate through tiles within current segment
|
# iterate through tiles within current segment
|
||||||
@ -442,6 +470,18 @@ def kernel_unified_attention_3d(
|
|||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
|
if USE_QQ_BIAS:
|
||||||
|
# compute key positions relative to query section
|
||||||
|
key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE]
|
||||||
|
# load bias only for keys that correspond to queries
|
||||||
|
is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0
|
||||||
|
qq_bias = tl.load(
|
||||||
|
qq_bias_row_ptrs + key_rel_pos[None, :],
|
||||||
|
mask=is_query_key[None, :], # avoid OOB for context keys
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
S += qq_bias
|
||||||
|
|
||||||
# compute running maximum
|
# compute running maximum
|
||||||
# m_j : (BLOCK_M,)
|
# m_j : (BLOCK_M,)
|
||||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
@ -586,6 +626,7 @@ def unified_attention(
|
|||||||
k_descale,
|
k_descale,
|
||||||
v_descale,
|
v_descale,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
|
qq_bias=None,
|
||||||
):
|
):
|
||||||
assert causal, "Only causal attention is supported"
|
assert causal, "Only causal attention is supported"
|
||||||
assert q_descale is None, "Q scales not supported"
|
assert q_descale is None, "Q scales not supported"
|
||||||
@ -595,6 +636,7 @@ def unified_attention(
|
|||||||
"Block size must be at least 32 for fp8"
|
"Block size must be at least 32 for fp8"
|
||||||
|
|
||||||
use_alibi_slopes = alibi_slopes is not None
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
|
use_qq_bias = qq_bias is not None
|
||||||
|
|
||||||
block_size = v.shape[1]
|
block_size = v.shape[1]
|
||||||
num_seqs = len(seqused_k)
|
num_seqs = len(seqused_k)
|
||||||
@ -630,6 +672,7 @@ def unified_attention(
|
|||||||
block_tables_ptr=block_table,
|
block_tables_ptr=block_table,
|
||||||
seq_lens_ptr=seqused_k,
|
seq_lens_ptr=seqused_k,
|
||||||
alibi_slopes_ptr=alibi_slopes,
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
qq_bias_ptr=qq_bias,
|
||||||
scale=softmax_scale,
|
scale=softmax_scale,
|
||||||
k_scale=k_descale,
|
k_scale=k_descale,
|
||||||
v_scale=v_descale,
|
v_scale=v_descale,
|
||||||
@ -641,10 +684,12 @@ def unified_attention(
|
|||||||
query_stride_1=q.stride(1),
|
query_stride_1=q.stride(1),
|
||||||
output_stride_0=out.stride(0),
|
output_stride_0=out.stride(0),
|
||||||
output_stride_1=out.stride(1),
|
output_stride_1=out.stride(1),
|
||||||
|
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||||
BLOCK_SIZE=block_size,
|
BLOCK_SIZE=block_size,
|
||||||
HEAD_SIZE=head_size,
|
HEAD_SIZE=head_size,
|
||||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
USE_SOFTCAP=(softcap > 0),
|
USE_SOFTCAP=(softcap > 0),
|
||||||
SLIDING_WINDOW=(1 + window_size[0]),
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
stride_k_cache_0=k.stride(0),
|
stride_k_cache_0=k.stride(0),
|
||||||
@ -699,6 +744,7 @@ def unified_attention(
|
|||||||
block_tables_ptr=block_table,
|
block_tables_ptr=block_table,
|
||||||
seq_lens_ptr=seqused_k,
|
seq_lens_ptr=seqused_k,
|
||||||
alibi_slopes_ptr=alibi_slopes,
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
qq_bias_ptr=qq_bias,
|
||||||
scale=softmax_scale,
|
scale=softmax_scale,
|
||||||
k_scale=k_descale,
|
k_scale=k_descale,
|
||||||
v_scale=v_descale,
|
v_scale=v_descale,
|
||||||
@ -708,10 +754,12 @@ def unified_attention(
|
|||||||
block_table_stride=block_table.stride(0),
|
block_table_stride=block_table.stride(0),
|
||||||
query_stride_0=q.stride(0),
|
query_stride_0=q.stride(0),
|
||||||
query_stride_1=q.stride(1),
|
query_stride_1=q.stride(1),
|
||||||
|
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||||
BLOCK_SIZE=block_size,
|
BLOCK_SIZE=block_size,
|
||||||
HEAD_SIZE=head_size,
|
HEAD_SIZE=head_size,
|
||||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
USE_SOFTCAP=(softcap > 0),
|
USE_SOFTCAP=(softcap > 0),
|
||||||
SLIDING_WINDOW=(1 + window_size[0]),
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
stride_k_cache_0=k.stride(0),
|
stride_k_cache_0=k.stride(0),
|
||||||
|
@ -3049,6 +3049,19 @@ class SpeculativeConfig:
|
|||||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||||
f" must be divisible by {n_predict=}")
|
f" must be divisible by {n_predict=}")
|
||||||
|
|
||||||
|
if self.speculative_token_tree is None:
|
||||||
|
# Generate chain of tokens.
|
||||||
|
self.speculative_token_tree = str([
|
||||||
|
(i + 1) * (0, )
|
||||||
|
for i in range(self.num_speculative_tokens)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Sort the token tree breadth-first.
|
||||||
|
tree_choices = ast.literal_eval(
|
||||||
|
self.speculative_token_tree)
|
||||||
|
self.speculative_token_tree = str(
|
||||||
|
sorted(tree_choices, key=lambda t: (len(t), t)))
|
||||||
|
|
||||||
self.draft_tensor_parallel_size = \
|
self.draft_tensor_parallel_size = \
|
||||||
SpeculativeConfig._verify_and_get_draft_tp(
|
SpeculativeConfig._verify_and_get_draft_tp(
|
||||||
self.target_parallel_config,
|
self.target_parallel_config,
|
||||||
|
@ -1454,7 +1454,6 @@ class EngineArgs:
|
|||||||
"Please consider using other speculative decoding methods "
|
"Please consider using other speculative decoding methods "
|
||||||
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
||||||
|
|
||||||
# No XFormers so far.
|
|
||||||
V1_BACKENDS = [
|
V1_BACKENDS = [
|
||||||
"FLASH_ATTN_VLLM_V1",
|
"FLASH_ATTN_VLLM_V1",
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
@ -1469,6 +1468,7 @@ class EngineArgs:
|
|||||||
"ROCM_AITER_MLA",
|
"ROCM_AITER_MLA",
|
||||||
"TORCH_SDPA_VLLM_V1",
|
"TORCH_SDPA_VLLM_V1",
|
||||||
"FLEX_ATTENTION",
|
"FLEX_ATTENTION",
|
||||||
|
"TREE_ATTN",
|
||||||
]
|
]
|
||||||
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||||
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
||||||
|
@ -270,6 +270,7 @@ class CudaPlatformBase(Platform):
|
|||||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||||
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||||
|
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
||||||
|
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
logger.info_once("Using FlashInfer backend on V1 engine.")
|
logger.info_once("Using FlashInfer backend on V1 engine.")
|
||||||
@ -287,6 +288,9 @@ class CudaPlatformBase(Platform):
|
|||||||
elif selected_backend == _Backend.FLASH_ATTN:
|
elif selected_backend == _Backend.FLASH_ATTN:
|
||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
return FLASH_ATTN_V1
|
return FLASH_ATTN_V1
|
||||||
|
elif selected_backend == _Backend.TREE_ATTN:
|
||||||
|
logger.info_once("Using Tree Attention backend on V1 engine.")
|
||||||
|
return TREE_ATTN_V1
|
||||||
|
|
||||||
from vllm.attention.selector import is_attn_backend_supported
|
from vllm.attention.selector import is_attn_backend_supported
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ class _Backend(enum.Enum):
|
|||||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
||||||
NO_ATTENTION = enum.auto()
|
NO_ATTENTION = enum.auto()
|
||||||
FLEX_ATTENTION = enum.auto()
|
FLEX_ATTENTION = enum.auto()
|
||||||
|
TREE_ATTN = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class PlatformEnum(enum.Enum):
|
class PlatformEnum(enum.Enum):
|
||||||
|
452
vllm/v1/attention/backends/tree_attn.py
Normal file
452
vllm/v1/attention/backends/tree_attn.py
Normal file
@ -0,0 +1,452 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Attention layer with TreeAttention."""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
|
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||||
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TreeAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
|
supported_head_sizes = cls.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
|
attn_type = cls.__name__.removesuffix("Backend")
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by {attn_type}. "
|
||||||
|
f"Supported head sizes are: {supported_head_sizes}. "
|
||||||
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||||
|
"FlexAttention backend which supports all head sizes.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "TREE_ATTN_VLLM_V1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
||||||
|
return TreeAttentionImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||||
|
return TreeAttentionMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> tuple[int, ...]:
|
||||||
|
if block_size % 16 != 0:
|
||||||
|
raise ValueError("Block size must be a multiple of 16.")
|
||||||
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
|
||||||
|
return TreeAttentionMetadataBuilder
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TreeAttentionMetadata:
|
||||||
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
|
max_query_len: int
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
max_seq_len: int
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
block_table: torch.Tensor
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
|
num_prefill_tokens: int = 0
|
||||||
|
num_decode_tokens: int = 0
|
||||||
|
num_prefills: int = 0
|
||||||
|
num_decodes: int = 0
|
||||||
|
|
||||||
|
tree_attn_bias: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Cached Prefill/decode metadata.
|
||||||
|
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
|
||||||
|
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||||
|
if self.num_prefills == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cached_prefill_metadata is not None:
|
||||||
|
# Recover cached prefill-phase attention
|
||||||
|
# metadata structure
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
q_start_loc = self.query_start_loc[self.num_decodes:]
|
||||||
|
q_seqlens = torch.diff(q_start_loc)
|
||||||
|
kv_seqlens = self.seq_lens[self.num_decodes:]
|
||||||
|
# Construct & cache prefill-phase attention metadata structure
|
||||||
|
self._cached_prefill_metadata = TreeAttentionMetadata(
|
||||||
|
num_actual_tokens=self.num_prefill_tokens,
|
||||||
|
max_query_len=int(q_seqlens.max().item()),
|
||||||
|
query_start_loc=q_start_loc - q_start_loc[0],
|
||||||
|
max_seq_len=int(kv_seqlens.max().item()),
|
||||||
|
seq_lens=kv_seqlens,
|
||||||
|
block_table=self.block_table[self.num_decodes:],
|
||||||
|
slot_mapping=self.slot_mapping[self.num_decode_tokens:],
|
||||||
|
)
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cached_decode_metadata is not None:
|
||||||
|
# Recover cached decode-phase attention
|
||||||
|
# metadata structure
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
q_start_loc = self.query_start_loc[:self.num_decodes + 1]
|
||||||
|
q_seqlens = torch.diff(q_start_loc)
|
||||||
|
kv_seqlens = self.seq_lens[:self.num_decodes]
|
||||||
|
# Construct & cache decode-phase attention metadata structure
|
||||||
|
self._cached_decode_metadata = TreeAttentionMetadata(
|
||||||
|
num_actual_tokens=self.num_decode_tokens,
|
||||||
|
max_query_len=int(q_seqlens.max().item()),
|
||||||
|
query_start_loc=q_start_loc,
|
||||||
|
max_seq_len=int(kv_seqlens.max().item()),
|
||||||
|
seq_lens=kv_seqlens,
|
||||||
|
block_table=self.block_table[:self.num_decodes],
|
||||||
|
slot_mapping=self.slot_mapping[:self.num_decode_tokens],
|
||||||
|
tree_attn_bias=self.tree_attn_bias,
|
||||||
|
)
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
|
||||||
|
class TreeAttentionMetadataBuilder(
|
||||||
|
AttentionMetadataBuilder[TreeAttentionMetadata]):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
layer_names: list[str],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
self.block_size = kv_cache_spec.block_size
|
||||||
|
|
||||||
|
spec_config = vllm_config.speculative_config
|
||||||
|
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
|
||||||
|
tree_choices: list[tuple[int,
|
||||||
|
...]] = (ast.literal_eval(spec_token_tree)
|
||||||
|
if spec_token_tree is not None else
|
||||||
|
[(0, )])
|
||||||
|
# Construct the tree attention bias.
|
||||||
|
depth_counts = _get_depth_counts(tree_choices)
|
||||||
|
self.tree_attn_bias = _prepare_tree_attn_bias(
|
||||||
|
tree_choices,
|
||||||
|
depth_counts,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
|
return reorder_batch_to_split_decodes_and_prefills(
|
||||||
|
input_batch,
|
||||||
|
scheduler_output,
|
||||||
|
decode_threshold=self.tree_attn_bias.shape[0])
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> TreeAttentionMetadata:
|
||||||
|
decode_threshold = self.tree_attn_bias.shape[0]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
split_decodes_and_prefills(common_attn_metadata,
|
||||||
|
decode_threshold=decode_threshold))
|
||||||
|
|
||||||
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
|
q_start_loc = common_attn_metadata.query_start_loc
|
||||||
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
|
kv_seqlens = common_attn_metadata.seq_lens
|
||||||
|
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||||
|
block_table = common_attn_metadata.block_table_tensor
|
||||||
|
slot_mapping = common_attn_metadata.slot_mapping
|
||||||
|
|
||||||
|
return TreeAttentionMetadata(
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
num_decodes=num_decodes,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
query_start_loc=q_start_loc,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
seq_lens=kv_seqlens,
|
||||||
|
block_table=block_table,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
tree_attn_bias=self.tree_attn_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_for_drafting(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
draft_index: int,
|
||||||
|
) -> TreeAttentionMetadata:
|
||||||
|
# Cache the original tree attention bias.
|
||||||
|
orig_tree_attn_bias = self.tree_attn_bias
|
||||||
|
|
||||||
|
if draft_index == 0:
|
||||||
|
# Use prefill for drafting at the root level.
|
||||||
|
self.tree_attn_bias = torch.empty(0)
|
||||||
|
else:
|
||||||
|
# Slice the tree attention bias for drafting.
|
||||||
|
query_len = common_attn_metadata.max_query_len
|
||||||
|
start, end = draft_index, draft_index + query_len
|
||||||
|
self.tree_attn_bias = self.tree_attn_bias[start:end,
|
||||||
|
start:end].contiguous()
|
||||||
|
|
||||||
|
# Build attention bias.
|
||||||
|
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
|
||||||
|
|
||||||
|
# Reset the tree attention bias to the original value.
|
||||||
|
self.tree_attn_bias = orig_tree_attn_bias
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
|
||||||
|
# Count the number of choices at each depth of the tree.
|
||||||
|
depth_counts = []
|
||||||
|
prev_depth = 0
|
||||||
|
for path in sorted_tree_choices:
|
||||||
|
depth = len(path)
|
||||||
|
if depth != prev_depth:
|
||||||
|
depth_counts.append(0)
|
||||||
|
depth_counts[depth - 1] += 1
|
||||||
|
prev_depth = depth
|
||||||
|
return depth_counts
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_tree_attn_bias(
|
||||||
|
sorted_tree_choices: list[tuple[int, ...]],
|
||||||
|
depth_counts: list[int],
|
||||||
|
dtype: Optional[torch.dtype],
|
||||||
|
device: Optional[torch.device],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# +1 comes from the additional root node.
|
||||||
|
tree_len = len(sorted_tree_choices) + 1
|
||||||
|
tree_attn_mask = torch.full((tree_len, tree_len),
|
||||||
|
-torch.inf,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
# Set diagonal to all zeros. Each token should
|
||||||
|
# attend to itself.
|
||||||
|
mask_val = 0
|
||||||
|
for i in range(tree_len):
|
||||||
|
tree_attn_mask[i, i] = mask_val
|
||||||
|
|
||||||
|
# Set root to all zeros. All tokens attend to it.
|
||||||
|
tree_attn_mask[:, 0] = mask_val
|
||||||
|
|
||||||
|
# Set all ancestors to zeros.
|
||||||
|
start = 0
|
||||||
|
for i in range(len(depth_counts)):
|
||||||
|
for j in range(depth_counts[i]):
|
||||||
|
cur_tree_choice = sorted_tree_choices[start + j]
|
||||||
|
# Retrieve ancestor position.
|
||||||
|
if len(cur_tree_choice) == 1:
|
||||||
|
continue
|
||||||
|
ancestor_idx = []
|
||||||
|
for c in range(len(cur_tree_choice) - 1):
|
||||||
|
ancestor_idx.append(
|
||||||
|
sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
|
||||||
|
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
|
||||||
|
start += depth_counts[i]
|
||||||
|
return tree_attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
class TreeAttentionImpl(AttentionImpl):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[list[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
use_irope: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if blocksparse_params is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"TreeAttention does not support block-sparse attention.")
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
|
self.alibi_slopes = alibi_slopes
|
||||||
|
if logits_soft_cap is None:
|
||||||
|
# Setting logits_soft_cap to 0 means no soft cap.
|
||||||
|
logits_soft_cap = 0
|
||||||
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
if sliding_window is None:
|
||||||
|
self.sliding_window = (-1, -1)
|
||||||
|
else:
|
||||||
|
self.sliding_window = (sliding_window - 1, 0)
|
||||||
|
|
||||||
|
TreeAttentionBackend.validate_head_size(head_size)
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"TreeAttentionImpl.")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: TreeAttentionMetadata,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with TreeAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [num_tokens, num_heads, head_size]
|
||||||
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
|
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
|
if output_scale is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"fused output quantization is not yet supported"
|
||||||
|
" for TreeAttentionImpl")
|
||||||
|
|
||||||
|
if attn_metadata is None:
|
||||||
|
# Profiling run.
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Cache the input KVs.
|
||||||
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
if self.kv_sharing_target_layer_name is None:
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||||
|
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||||
|
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||||
|
# op uses the slot_mapping's shape to determine the number of
|
||||||
|
# actual tokens.
|
||||||
|
ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
|
||||||
|
key.shape[1])
|
||||||
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
|
unified_attention(
|
||||||
|
q=query[num_decode_tokens:num_actual_tokens],
|
||||||
|
k=key_cache,
|
||||||
|
v=value_cache,
|
||||||
|
out=output[num_decode_tokens:num_actual_tokens],
|
||||||
|
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||||
|
max_seqlen_q=prefill_meta.max_query_len,
|
||||||
|
seqused_k=prefill_meta.seq_lens,
|
||||||
|
max_seqlen_k=prefill_meta.max_seq_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
block_table=prefill_meta.block_table,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
q_descale=None, # Not supported
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
unified_attention(
|
||||||
|
q=query[:num_decode_tokens],
|
||||||
|
k=key_cache,
|
||||||
|
v=value_cache,
|
||||||
|
out=output[:num_decode_tokens],
|
||||||
|
cu_seqlens_q=decode_meta.query_start_loc,
|
||||||
|
max_seqlen_q=decode_meta.max_query_len,
|
||||||
|
seqused_k=decode_meta.seq_lens,
|
||||||
|
max_seqlen_k=decode_meta.max_seq_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
qq_bias=decode_meta.tree_attn_bias,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
block_table=decode_meta.block_table,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
q_descale=None, # Not supported
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
)
|
||||||
|
return output
|
@ -214,6 +214,26 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
return self.build(common_prefix_len=0,
|
return self.build(common_prefix_len=0,
|
||||||
common_attn_metadata=common_attn_metadata)
|
common_attn_metadata=common_attn_metadata)
|
||||||
|
|
||||||
|
def build_for_drafting(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
draft_index: int,
|
||||||
|
) -> M:
|
||||||
|
"""
|
||||||
|
Build attention metadata for draft model. Uses build by default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
common_attn_metadata: The common attention metadata.
|
||||||
|
draft_index: The index of the current draft operation.
|
||||||
|
When speculating a chain of tokens, this index refers to the
|
||||||
|
draft attempt for the i-th token.
|
||||||
|
For tree-based attention, this index instead refers to the
|
||||||
|
draft attempt for the i-th level in the tree of tokens.
|
||||||
|
"""
|
||||||
|
return self.build(common_prefix_len=0,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
fast_build=True)
|
||||||
|
|
||||||
def use_cascade_attention(
|
def use_cascade_attention(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import ast
|
||||||
|
from dataclasses import replace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -17,6 +19,8 @@ from vllm.model_executor.models import supports_multimodal
|
|||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
|
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||||||
|
TreeAttentionMetadataBuilder)
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
@ -74,18 +78,52 @@ class EagleProposer:
|
|||||||
(self.max_num_tokens, self.hidden_size),
|
(self.max_num_tokens, self.hidden_size),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
# We need +1 here because the arange is used to set query_start_loc,
|
|
||||||
# which has one more element than batch_size.
|
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
self.arange = torch.arange(
|
||||||
1,
|
# We need +1 here because the arange is used to set query_start_loc,
|
||||||
device=device,
|
# which has one more element than batch_size.
|
||||||
dtype=torch.int32)
|
max_batch_size + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
self.inputs_embeds = torch.zeros(
|
self.inputs_embeds = torch.zeros(
|
||||||
(self.max_num_tokens, self.hidden_size),
|
(self.max_num_tokens, self.hidden_size),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
# Parse the speculative token tree.
|
||||||
|
spec_token_tree = self.speculative_config.speculative_token_tree
|
||||||
|
self.tree_choices: list[tuple[int,
|
||||||
|
...]] = ast.literal_eval(spec_token_tree)
|
||||||
|
tree_depth = len(self.tree_choices[-1])
|
||||||
|
# Precompute per-level properties of the tree.
|
||||||
|
num_drafts_per_level = [0] * tree_depth
|
||||||
|
for node in self.tree_choices:
|
||||||
|
num_drafts_per_level[len(node) - 1] += 1
|
||||||
|
self.cu_drafts_per_level = [num_drafts_per_level[0]]
|
||||||
|
self.child_drafts_per_level = [num_drafts_per_level[0]]
|
||||||
|
for level in range(1, tree_depth):
|
||||||
|
self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
|
||||||
|
num_drafts_per_level[level])
|
||||||
|
self.child_drafts_per_level.append(num_drafts_per_level[level] //
|
||||||
|
num_drafts_per_level[level - 1])
|
||||||
|
# Find the first level where the tree branches off into one or more
|
||||||
|
# children.
|
||||||
|
self.first_branching_level = None
|
||||||
|
for level in range(tree_depth):
|
||||||
|
if self.cu_drafts_per_level[level] > level + 1:
|
||||||
|
self.first_branching_level = level
|
||||||
|
break
|
||||||
|
# Precompute draft position offsets in flattened tree.
|
||||||
|
self.tree_draft_pos_offsets = torch.arange(
|
||||||
|
1,
|
||||||
|
len(self.tree_choices) + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
).repeat(max_batch_size, 1)
|
||||||
|
|
||||||
def propose(
|
def propose(
|
||||||
self,
|
self,
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
@ -120,11 +158,9 @@ class EagleProposer:
|
|||||||
assert self.runner is not None
|
assert self.runner is not None
|
||||||
|
|
||||||
# FIXME: need to consider multiple kv_cache_groups
|
# FIXME: need to consider multiple kv_cache_groups
|
||||||
attn_metadata = self.runner.attn_metadata_builders[0].build(
|
attn_metadata = self.runner.attn_metadata_builders[
|
||||||
common_prefix_len=0,
|
0].build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||||
common_attn_metadata=common_attn_metadata,
|
draft_index=0)
|
||||||
fast_build=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# At this moment, we assume all eagle layers belong to the same KV
|
# At this moment, we assume all eagle layers belong to the same KV
|
||||||
# cache group, thus using the same attention metadata.
|
# cache group, thus using the same attention metadata.
|
||||||
@ -167,6 +203,22 @@ class EagleProposer:
|
|||||||
last_hidden_states, hidden_states = ret_hidden_states
|
last_hidden_states, hidden_states = ret_hidden_states
|
||||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
|
positions = target_positions[last_token_indices]
|
||||||
|
hidden_states = hidden_states[last_token_indices]
|
||||||
|
if self.first_branching_level == 0:
|
||||||
|
# Branching has occurred at the root level. Draft using tree
|
||||||
|
# attention.
|
||||||
|
draft_token_ids_list = self.propose_tree(
|
||||||
|
tree_root_level=0,
|
||||||
|
batch_size=batch_size,
|
||||||
|
logits=logits,
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
)
|
||||||
|
# [batch_size, num_tree_tokens]
|
||||||
|
return torch.cat(draft_token_ids_list, dim=1)
|
||||||
|
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
|
|
||||||
# Early exit if there is only one draft token to be generated.
|
# Early exit if there is only one draft token to be generated.
|
||||||
@ -178,16 +230,15 @@ class EagleProposer:
|
|||||||
# one layer. Adapt this code to support multiple layers once
|
# one layer. Adapt this code to support multiple layers once
|
||||||
# there's a multi-layer MTP module.
|
# there's a multi-layer MTP module.
|
||||||
|
|
||||||
# Currently FlashAttention is the only backend that supports
|
# Currently, only FlashAttention and TreeAttention support multi-token
|
||||||
# multi-token eagle spec decode. This is because the code below
|
# eagle spec decode. This is because the code below
|
||||||
# makes assumptions about attn_metadata attributes available.
|
# makes assumptions about attn_metadata attributes available.
|
||||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
assert isinstance(attn_metadata,
|
||||||
|
(FlashAttentionMetadata, TreeAttentionMetadata))
|
||||||
|
|
||||||
# Generate the remaining draft tokens.
|
# Generate the remaining draft tokens.
|
||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
|
|
||||||
positions = target_positions[last_token_indices]
|
|
||||||
hidden_states = hidden_states[last_token_indices]
|
|
||||||
if self.use_cuda_graph and \
|
if self.use_cuda_graph and \
|
||||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||||
@ -196,7 +247,7 @@ class EagleProposer:
|
|||||||
attn_metadata.num_actual_tokens = batch_size
|
attn_metadata.num_actual_tokens = batch_size
|
||||||
attn_metadata.max_query_len = 1
|
attn_metadata.max_query_len = 1
|
||||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||||
for _ in range(self.num_speculative_tokens - 1):
|
for token_index in range(self.num_speculative_tokens - 1):
|
||||||
# Update the inputs.
|
# Update the inputs.
|
||||||
# cast to int32 is crucial when eagle model is compiled.
|
# cast to int32 is crucial when eagle model is compiled.
|
||||||
# tensor.argmax() returns int64 by default.
|
# tensor.argmax() returns int64 by default.
|
||||||
@ -265,7 +316,20 @@ class EagleProposer:
|
|||||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||||
None)
|
None)
|
||||||
|
|
||||||
# TODO(wenlong): get more than one token for tree attention
|
if self.first_branching_level == token_index + 1:
|
||||||
|
# Branching has occurred. The remaining tokens are drafted
|
||||||
|
# using tree attention.
|
||||||
|
draft_token_ids_list += self.propose_tree(
|
||||||
|
tree_root_level=token_index + 1,
|
||||||
|
batch_size=batch_size,
|
||||||
|
logits=logits,
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
)
|
||||||
|
# [batch_size, num_tree_tokens]
|
||||||
|
return torch.cat(draft_token_ids_list, dim=1)
|
||||||
|
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
draft_token_ids_list.append(draft_token_ids)
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
|
||||||
@ -273,6 +337,175 @@ class EagleProposer:
|
|||||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
|
def propose_tree(
|
||||||
|
self,
|
||||||
|
tree_root_level: int,
|
||||||
|
batch_size: int,
|
||||||
|
# [num_tokens, vocab_size]
|
||||||
|
logits: torch.Tensor,
|
||||||
|
# [num_tokens]
|
||||||
|
positions: torch.Tensor,
|
||||||
|
# [num_tokens, hidden_size]
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0]
|
||||||
|
assert isinstance(tree_attn_metadata_builder,
|
||||||
|
TreeAttentionMetadataBuilder)
|
||||||
|
|
||||||
|
total_num_drafts = self.cu_drafts_per_level[tree_root_level]
|
||||||
|
level_num_drafts = total_num_drafts
|
||||||
|
# Sample a draft token for each child at the tree root level.
|
||||||
|
num_children = self.child_drafts_per_level[tree_root_level]
|
||||||
|
if num_children == 1:
|
||||||
|
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
|
||||||
|
else:
|
||||||
|
draft_token_ids = torch.topk(logits, num_children,
|
||||||
|
dim=-1).indices.view(batch_size, -1)
|
||||||
|
draft_token_ids_list = [draft_token_ids]
|
||||||
|
draft_hidden_states = hidden_states.view(batch_size, 1, -1)
|
||||||
|
|
||||||
|
# Initialize empty tensors for concatenation with the level outputs.
|
||||||
|
tree_input_ids = torch.empty(0,
|
||||||
|
device=self.input_ids.device,
|
||||||
|
dtype=self.input_ids.dtype)
|
||||||
|
tree_positions = torch.empty(0,
|
||||||
|
device=self.positions.device,
|
||||||
|
dtype=self.positions.dtype)
|
||||||
|
tree_hidden_states = torch.empty(0,
|
||||||
|
device=self.hidden_states.device,
|
||||||
|
dtype=self.hidden_states.dtype)
|
||||||
|
# Precompute the draft token positions.
|
||||||
|
flattened_draft_positions = (
|
||||||
|
positions.view(batch_size, -1) +
|
||||||
|
self.tree_draft_pos_offsets[:batch_size, :])
|
||||||
|
tree_depth = len(self.cu_drafts_per_level)
|
||||||
|
for level in range(tree_root_level, tree_depth - 1):
|
||||||
|
# Get draft positions for RoPE.
|
||||||
|
draft_positions = positions + (level + 1)
|
||||||
|
exceeds_max_model_len = (positions +
|
||||||
|
total_num_drafts) >= self.max_model_len
|
||||||
|
# Mask out the position ids that exceed the max model length.
|
||||||
|
# Otherwise, we may get out-of-range error in RoPE.
|
||||||
|
clamped_draft_positions = torch.where(
|
||||||
|
exceeds_max_model_len,
|
||||||
|
0,
|
||||||
|
draft_positions,
|
||||||
|
)
|
||||||
|
if level_num_drafts > 1:
|
||||||
|
# Repeat the positions for each draft at this level.
|
||||||
|
draft_positions = clamped_draft_positions.repeat_interleave(
|
||||||
|
level_num_drafts).reshape(batch_size, -1)
|
||||||
|
|
||||||
|
if num_children > 1:
|
||||||
|
# Repeat draft hidden states for each child.
|
||||||
|
draft_hidden_states = draft_hidden_states.repeat_interleave(
|
||||||
|
num_children, dim=1)
|
||||||
|
|
||||||
|
# Concatenate the draft tokens, positions, and hidden states.
|
||||||
|
tree_input_ids = torch.cat([tree_input_ids, draft_token_ids],
|
||||||
|
dim=1)
|
||||||
|
tree_positions = torch.cat([tree_positions, draft_positions],
|
||||||
|
dim=1)
|
||||||
|
tree_hidden_states = torch.cat(
|
||||||
|
[tree_hidden_states, draft_hidden_states], dim=1)
|
||||||
|
|
||||||
|
# Build new attention metadata for the next level of drafts.
|
||||||
|
# This is necessary to support tree attention.
|
||||||
|
query_len = total_num_drafts - tree_root_level
|
||||||
|
common_attn_metadata = replace(
|
||||||
|
common_attn_metadata,
|
||||||
|
query_start_loc=query_len * self.arange[:batch_size + 1],
|
||||||
|
seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
|
||||||
|
num_actual_tokens=batch_size * query_len,
|
||||||
|
max_query_len=query_len,
|
||||||
|
)
|
||||||
|
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
draft_index=tree_root_level + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply new attention metadata to all layers.
|
||||||
|
per_layer_attn_metadata = {}
|
||||||
|
for layer_name in self.attn_layer_names:
|
||||||
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
|
|
||||||
|
# Consider max model length.
|
||||||
|
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
||||||
|
self.max_model_len)
|
||||||
|
# For the requests that exceed the max model length, we set the
|
||||||
|
# sequence length to 1 to minimize their overheads in attention.
|
||||||
|
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||||
|
|
||||||
|
# Compute the slot mapping.
|
||||||
|
query_positions = flattened_draft_positions[:, level:level +
|
||||||
|
query_len]
|
||||||
|
block_numbers = query_positions // self.block_size
|
||||||
|
block_ids = attn_metadata.block_table.gather(dim=1,
|
||||||
|
index=block_numbers)
|
||||||
|
slot_mapping = (block_ids * self.block_size +
|
||||||
|
query_positions % self.block_size)
|
||||||
|
# Mask out the slot mappings that exceed the max model length.
|
||||||
|
# Otherwise, the KV cache will be inadvertently updated with the
|
||||||
|
# padding tokens.
|
||||||
|
slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
|
||||||
|
attn_metadata.slot_mapping = slot_mapping.view(-1)
|
||||||
|
|
||||||
|
# Copy inputs to buffer for cudagraph.
|
||||||
|
num_tokens = attn_metadata.num_actual_tokens
|
||||||
|
input_ids = tree_input_ids.view(-1)
|
||||||
|
self.input_ids[:num_tokens] = input_ids
|
||||||
|
self.positions[:num_tokens] = tree_positions.view(-1)
|
||||||
|
self.hidden_states[:num_tokens] = tree_hidden_states.view(
|
||||||
|
num_tokens, -1)
|
||||||
|
|
||||||
|
if self.use_cuda_graph and \
|
||||||
|
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||||
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||||
|
num_tokens)
|
||||||
|
else:
|
||||||
|
num_input_tokens = num_tokens
|
||||||
|
# Run the model.
|
||||||
|
with set_forward_context(per_layer_attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens):
|
||||||
|
last_hidden_states, hidden_states = self.model(
|
||||||
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
|
positions=self.positions[:num_input_tokens],
|
||||||
|
hidden_states=self.hidden_states[:num_input_tokens],
|
||||||
|
inputs_embeds=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the output hidden states for the draft tokens.
|
||||||
|
draft_hidden_states = hidden_states[:num_tokens].view(
|
||||||
|
batch_size, query_len, -1)[:, -level_num_drafts:]
|
||||||
|
draft_last_hidden_states = last_hidden_states[:num_tokens].view(
|
||||||
|
batch_size, query_len, -1)[:, -level_num_drafts:]
|
||||||
|
|
||||||
|
# Get the output logits for the draft tokens.
|
||||||
|
logits = self.model.compute_logits(
|
||||||
|
draft_last_hidden_states.reshape(batch_size * level_num_drafts,
|
||||||
|
-1),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample a draft token for each child at the next tree level.
|
||||||
|
num_children = self.child_drafts_per_level[level + 1]
|
||||||
|
if num_children == 1:
|
||||||
|
draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
|
||||||
|
else:
|
||||||
|
draft_token_ids = torch.topk(logits, num_children,
|
||||||
|
dim=-1).indices.view(
|
||||||
|
batch_size, -1)
|
||||||
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
|
||||||
|
# Update the # drafts counters for the next tree level.
|
||||||
|
level_num_drafts = self.cu_drafts_per_level[level +
|
||||||
|
1] - total_num_drafts
|
||||||
|
total_num_drafts = self.cu_drafts_per_level[level + 1]
|
||||||
|
|
||||||
|
return draft_token_ids_list
|
||||||
|
|
||||||
def prepare_inputs(
|
def prepare_inputs(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
Reference in New Issue
Block a user