Files
vllm/tests/v1/spec_decode/test_tree_attention.py
Matthew Bonanni 3468f17ebe [V0 deprecation] Remove _VLLM_V1 suffixes from attention backend names (#25489)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
2025-09-25 17:37:50 +00:00

302 lines
11 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import Optional
import torch
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
def __init__(self):
super().__init__()
def forward(self, x):
return x
def forward_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
slot_mapping: torch.Tensor,
seqlen_k: int,
backend: _Backend,
spec_token_tree: Optional[str] = None,
num_spec_tokens: int = 0,
) -> torch.Tensor:
batch_size, q_len, num_heads, dim_per_head = q.shape
num_kv_heads = k.shape[-2]
# Initialize the query and KV sequence lengths.
query_start_loc = q_len * torch.arange(
batch_size + 1, device=q.device, dtype=torch.int32)
query_lens = torch.diff(query_start_loc)
seq_lens = torch.full(
(batch_size, ),
seqlen_k,
device=q.device,
dtype=torch.int32,
)
context_lens = seq_lens - query_lens
max_seq_len = int(seq_lens.max())
max_query_len = q_len
num_actual_tokens = query_start_loc[-1]
softmax_scale = q.shape[-1]**(-0.5)
layer = MockAttentionLayer()
# Build common metadata.
model_name = "meta-llama/Meta-Llama-3-8B"
builder_cls, impl_cls = get_attention_backend(backend)
vllm_config = create_vllm_config(model_name=model_name,
max_model_len=max(seq_lens))
if spec_token_tree is not None:
# Create speculative config if token tree is specified.
vllm_config.speculative_config = SpeculativeConfig(
target_model_config=vllm_config.model_config,
target_parallel_config=ParallelConfig(),
model=model_name,
method="eagle",
num_speculative_tokens=num_spec_tokens,
speculative_token_tree=spec_token_tree)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
)
# Build attention metadata.
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Initialize the backend implementation.
instance = impl_cls(
num_heads=num_heads,
head_size=dim_per_head,
scale=softmax_scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
)
# Run forward pass and return output.
query = q.view(-1, num_heads, dim_per_head)
key = k.view(-1, num_kv_heads, dim_per_head)
value = v.view(-1, num_kv_heads, dim_per_head)
output = torch.empty_like(query)
return instance.forward(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache.clone(),
attn_metadata=attn_metadata,
output=output,
)
def test_tree_attn_correctness() -> None:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
device = "cuda"
tree_attn_masks = {
# Chain.
"[(0,), (0, 0), (0, 0, 0)]":
torch.tensor(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
],
device=device,
dtype=torch.int32,
),
# Tree.
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0],
[1, 1, 0, 1, 0, 0, 0],
[1, 1, 0, 0, 1, 0, 0],
[1, 0, 1, 0, 0, 1, 0],
[1, 0, 1, 0, 0, 0, 1],
],
device=device,
dtype=torch.int32,
),
}
dim_per_head = 128
num_kv_heads = 2
block_size = 32
max_sequence_length = 8192
randomize_blocks = True
for batch_size in [1, 16, 32]:
for num_heads in [2, 4]:
for sequence_position in [16, 1024, 2048]:
for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
# Assert that the number of heads is divisible
# by the number of KV heads.
assert num_heads % num_kv_heads == 0
# Initialize q, k, and v.
tree_size_q = tree_attn_mask.shape[0]
seqlen_k = sequence_position + tree_size_q
q = torch.randn(
(batch_size, tree_size_q, num_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
k = torch.randn(
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
v = torch.randn(
(batch_size, tree_size_q, num_kv_heads, dim_per_head),
device=device,
dtype=torch.bfloat16,
)
# Set up the block table and KV cache for paged KV.
assert max_sequence_length % block_size == 0
max_blocks_per_batch = max_sequence_length // block_size
kv_cache = torch.randn(
(
2,
batch_size * max_blocks_per_batch,
block_size,
num_kv_heads,
dim_per_head,
),
device=q.device,
dtype=torch.bfloat16,
)
num_alloc_blocks_per_batch = math.ceil(seqlen_k /
block_size)
block_table = torch.zeros(
(batch_size, max_blocks_per_batch),
device=q.device,
dtype=torch.int32,
)
block_ids = torch.arange(
0,
batch_size * num_alloc_blocks_per_batch,
device=q.device,
dtype=torch.int32,
)
if randomize_blocks:
# Randomize the block ids.
block_ids = block_ids[torch.randperm(
block_ids.numel())]
block_table[:, :
num_alloc_blocks_per_batch] = block_ids.view(
-1, num_alloc_blocks_per_batch)
# Set up the slot mapping for the input KVs.
tree_positions = sequence_position + torch.arange(
0,
tree_size_q,
device=q.device,
dtype=torch.int64,
).repeat(batch_size, 1)
tree_slot_mapping = _gen_slot_mapping(
tree_positions, block_table, block_size)
# Compute attention for the tree.
tree_attn_output = forward_attention(
q=q,
k=k,
v=v,
kv_cache=kv_cache,
block_table=block_table,
slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k,
backend=_Backend.TREE_ATTN,
spec_token_tree=spec_token_tree,
num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head)
# Verify that the chain attention output for each
# branch of the tree (computed using FA3) matches
# the tree attention output.
for q_index in range(tree_size_q):
# Get the q, k, and v for the branch.
branch_mask = tree_attn_mask[q_index, :]
branch_indices = torch.nonzero(branch_mask,
as_tuple=True)[0]
q_len = branch_indices.shape[0]
q_branch = q[:, branch_indices]
k_branch = k[:, branch_indices]
v_branch = v[:, branch_indices]
# Setup slot mapping for the branch.
branch_positions = sequence_position + torch.arange(
0,
q_len,
device=q.device,
dtype=torch.int64,
).repeat(batch_size, 1)
branch_slot_mapping = _gen_slot_mapping(
branch_positions, block_table, block_size)
# Compute flash attention for the branch.
flash_attn_output = forward_attention(
q=q_branch,
k=k_branch,
v=v_branch,
kv_cache=kv_cache,
block_table=block_table,
slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len,
backend=_Backend.FLASH_ATTN,
).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs.
assert torch.allclose(
tree_attn_output[:, branch_indices],
flash_attn_output,
atol=7.81e-3,
), (f"outputs are not close for "
f"batch_size: {batch_size}, "
f"num_heads: {num_heads}, "
f"sequence_position: {sequence_position}, "
f"tree_attn_mask: {tree_attn_mask}, "
f"q_index: {q_index}.")
def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor,
block_size: int):
block_indices = positions // block_size
blocks = block_table.gather(dim=1, index=block_indices)
return (blocks * block_size + positions % block_size).view(-1)