Files
vllm/tests/v1/attention/utils.py
Yongye Zhu b3230e1ac0 [New Model] DeepSeek-V3.2 (Rebased to Main) (#25896)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Signed-off-by: Lucia Fang <fanglu@meta.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: Lucia Fang <fanglu@meta.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Xiaozhu Meng <mxz297@gmail.com>
Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00

262 lines
9.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for attention-related v1 tests."""
from dataclasses import dataclass
from typing import Union
import pytest
import torch
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.platforms import _Backend, current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
@dataclass
class BatchSpec:
"""Specification for a batch configuration (workload shape only)."""
seq_lens: list[int]
query_lens: list[int]
name: str = "unnamed"
@property
def batch_size(self):
return len(self.seq_lens)
def __post_init__(self):
assert len(self.seq_lens) == len(self.query_lens)
def compute_num_tokens(self):
return sum(self.query_lens)
def create_common_attn_metadata(
batch_spec: BatchSpec,
block_size: int,
device: torch.device,
max_block_idx: int = 1000,
arange_block_indices: bool = False) -> CommonAttentionMetadata:
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
# Create query start locations
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
dtype=torch.int32,
device=device)
query_start_loc[1:] = torch.tensor(batch_spec.query_lens,
dtype=torch.int32,
device=device).cumsum(0)
query_start_loc_cpu = query_start_loc.cpu()
num_tokens = batch_spec.compute_num_tokens()
# Create sequence lengths
seq_lens = torch.tensor(batch_spec.seq_lens,
dtype=torch.int32,
device=device)
seq_lens_cpu = seq_lens.cpu()
max_seq_len = int(seq_lens_cpu.max())
# Create computed tokens (context length for each sequence)
context_lens = [
batch_spec.seq_lens[i] - batch_spec.query_lens[i]
for i in range(batch_spec.batch_size)
]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
# Create block table and slot mapping
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
if arange_block_indices:
num_blocks = batch_spec.batch_size * max_blocks
block_table_tensor = torch.arange(num_blocks,
dtype=torch.int32,
device=device).view(
batch_spec.batch_size,
max_blocks)
slot_mapping = torch.arange(num_tokens,
dtype=torch.int64,
device=device).view(num_tokens)
else:
block_table_tensor = torch.randint(0,
max_block_idx,
(batch_spec.batch_size, max_blocks),
dtype=torch.int32,
device=device)
slot_mapping = torch.randint(0,
max_block_idx, (num_tokens, ),
dtype=torch.int64,
device=device)
# Calculate max query length
max_query_len = max(batch_spec.query_lens)
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)
def get_attention_backend(backend_name: _Backend):
"""Set up attention backend classes for testing.
Args:
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
vllm_config: VllmConfig instance
Returns:
Tuple of (backend_builder_class, backend_impl_class)
"""
backend_map = {
_Backend.FLASH_ATTN:
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if current_platform.is_cuda() else
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
),
_Backend.FLASHINFER:
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
_Backend.FLEX_ATTENTION:
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
_Backend.TRITON_ATTN:
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
_Backend.TREE_ATTN:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
_Backend.XFORMERS:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
_Backend.CUTLASS_MLA:
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
_Backend.FLASHINFER_MLA:
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
_Backend.TRITON_MLA:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
}
if backend_name not in backend_map:
raise ValueError(f"Unknown backend: {backend_name}")
backend_class_name = backend_map[backend_name]
try:
backend_class = resolve_obj_by_qualname(backend_class_name)
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
except ImportError as e:
pytest.skip(f"{backend_name} not available: {e}")
def create_standard_kv_cache_spec(
vllm_config: VllmConfig) -> FullAttentionSpec:
"""Create a FullAttentionSpec from ModelParams only."""
return FullAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
)
def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
tensor_parallel_size: int = 1,
max_model_len: int = 1024,
dtype: Union[ModelDType, torch.dtype] = "auto",
num_gpu_blocks: int = 1000,
block_size: int = 16,
max_num_seqs: int = 256,
max_num_batched_tokens: int = 8192,
enable_chunked_prefill: bool = True,
add_mock_model_methods: bool = True) -> VllmConfig:
"""Create a VllmConfig for testing with reasonable defaults."""
model_config = ModelConfig(
model=model_name,
tokenizer=model_name,
trust_remote_code=False,
dtype=dtype,
seed=0,
max_model_len=max_model_len,
)
cache_config = CacheConfig(
block_size=block_size,
cache_dtype="auto",
swap_space=0,
)
# Set cache blocks for testing
# (these may be set during initialization normally)
cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0
parallel_config = ParallelConfig(
tensor_parallel_size=tensor_parallel_size, )
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
)
device_config = DeviceConfig()
load_config = LoadConfig()
compilation_config = CompilationConfig()
if add_mock_model_methods:
# Add mock methods to satisfy backends that need them
# This is a workaround because tests don't build full, real models,
# but some backends expect to query the model for layer-specific
# parameters
import types
model_config.get_num_layers = types.MethodType(lambda self: 1,
model_config)
model_config.get_sliding_window_for_layer = types.MethodType(
lambda self, i: None, model_config)
model_config.get_logits_soft_cap_for_layer = types.MethodType(
lambda self, i: 0.0, model_config)
model_config.get_sm_scale_for_layer = types.MethodType(
lambda self, i: 1.0 / model_config.get_head_size()**0.5,
model_config)
return VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
load_config=load_config,
compilation_config=compilation_config,
)
def create_dummy_kv_cache(block_size: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
num_blocks,
2, # K and V
block_size,
num_kv_heads,
head_size,
dtype=dtype,
device=device)
return kv_cache