Files
vllm/tests/kernels/attention/test_flashmla_sparse.py
Yongye Zhu b3230e1ac0 [New Model] DeepSeek-V3.2 (Rebased to Main) (#25896)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Signed-off-by: Lucia Fang <fanglu@meta.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: Lucia Fang <fanglu@meta.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Xiaozhu Meng <mxz297@gmail.com>
Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00

120 lines
4.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
def _cuda_sm90_available() -> bool:
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major == 9
def test_sparse_flashmla_metadata_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 128
num_heads_k = 1
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
topk = 128
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
assert tile_md.dtype == torch.int32
assert num_splits.dtype == torch.int32
def test_sparse_flashmla_decode_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 1
head_dim_k = 576
head_dim_v = 512
num_heads_k = 1
page_block_size = 64
bytes_per_token = 656
topk = 128
# Metadata
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
# q_heads_per_hk = num_heads_q // num_heads_k
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
# Inputs
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
dtype=torch.bfloat16,
device=device)
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
dtype=torch.uint8,
device=device)
indices = torch.zeros((batch_size, seqlen_q, topk),
dtype=torch.int32,
device=device)
block_table = torch.zeros((batch_size, 128),
dtype=torch.int32,
device=device)
out, lse = fm.flash_mla_with_kvcache(q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_md,
num_splits,
indices=indices,
is_fp8_kvcache=True)
assert out.shape[0] == batch_size
assert out.shape[-1] == head_dim_v
assert lse.shape[0] == batch_size
def test_sparse_flashmla_prefill_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
s_q = 1
s_kv = 1
h_q = 64 # kernel expects multiple of 64
h_kv = 1
d_qk = 576
d_v = 512
topk = 128
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
d_v)
assert out.shape == (s_q, h_q, d_v)
assert max_logits.shape == (s_q, h_q)
assert lse.shape == (s_q, h_q)