mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@meta.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Lucia Fang <fanglu@meta.com> Co-authored-by: NickLucche <nlucches@redhat.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Xiaozhu Meng <mxz297@gmail.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: simon-mo <simon.mo@hey.com>
120 lines
4.3 KiB
Python
120 lines
4.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import pytest
|
|
import torch
|
|
|
|
|
|
def _cuda_sm90_available() -> bool:
|
|
if not torch.cuda.is_available():
|
|
return False
|
|
major, _ = torch.cuda.get_device_capability()
|
|
return major == 9
|
|
|
|
|
|
def test_sparse_flashmla_metadata_smoke():
|
|
import vllm.attention.ops.flashmla as fm
|
|
ok, reason = fm.is_flashmla_supported()
|
|
if not ok or not _cuda_sm90_available():
|
|
pytest.skip(reason or "SM90 not available")
|
|
|
|
device = torch.device("cuda")
|
|
batch_size = 1
|
|
seqlen_q = 1
|
|
num_heads_q = 128
|
|
num_heads_k = 1
|
|
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
|
topk = 128
|
|
|
|
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
|
|
|
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
|
q_seq_per_hk,
|
|
num_heads_k,
|
|
num_heads_q=num_heads_q,
|
|
topk=topk,
|
|
is_fp8_kvcache=True)
|
|
assert tile_md.dtype == torch.int32
|
|
assert num_splits.dtype == torch.int32
|
|
|
|
|
|
def test_sparse_flashmla_decode_smoke():
|
|
import vllm.attention.ops.flashmla as fm
|
|
ok, reason = fm.is_flashmla_supported()
|
|
if not ok or not _cuda_sm90_available():
|
|
pytest.skip(reason or "SM90 not available")
|
|
|
|
device = torch.device("cuda")
|
|
batch_size = 1
|
|
seqlen_q = 1
|
|
num_heads_q = 1
|
|
head_dim_k = 576
|
|
head_dim_v = 512
|
|
num_heads_k = 1
|
|
page_block_size = 64
|
|
bytes_per_token = 656
|
|
topk = 128
|
|
|
|
# Metadata
|
|
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
|
# q_heads_per_hk = num_heads_q // num_heads_k
|
|
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
|
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
|
q_seq_per_hk,
|
|
num_heads_k,
|
|
num_heads_q=num_heads_q,
|
|
topk=topk,
|
|
is_fp8_kvcache=True)
|
|
|
|
# Inputs
|
|
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
|
|
dtype=torch.bfloat16,
|
|
device=device)
|
|
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
|
|
dtype=torch.uint8,
|
|
device=device)
|
|
indices = torch.zeros((batch_size, seqlen_q, topk),
|
|
dtype=torch.int32,
|
|
device=device)
|
|
|
|
block_table = torch.zeros((batch_size, 128),
|
|
dtype=torch.int32,
|
|
device=device)
|
|
out, lse = fm.flash_mla_with_kvcache(q,
|
|
k_cache,
|
|
block_table,
|
|
cache_seqlens,
|
|
head_dim_v,
|
|
tile_md,
|
|
num_splits,
|
|
indices=indices,
|
|
is_fp8_kvcache=True)
|
|
assert out.shape[0] == batch_size
|
|
assert out.shape[-1] == head_dim_v
|
|
assert lse.shape[0] == batch_size
|
|
|
|
|
|
def test_sparse_flashmla_prefill_smoke():
|
|
import vllm.attention.ops.flashmla as fm
|
|
ok, reason = fm.is_flashmla_supported()
|
|
if not ok or not _cuda_sm90_available():
|
|
pytest.skip(reason or "SM90 not available")
|
|
|
|
device = torch.device("cuda")
|
|
s_q = 1
|
|
s_kv = 1
|
|
h_q = 64 # kernel expects multiple of 64
|
|
h_kv = 1
|
|
d_qk = 576
|
|
d_v = 512
|
|
topk = 128
|
|
|
|
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
|
|
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
|
|
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
|
|
|
|
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
|
|
d_v)
|
|
assert out.shape == (s_q, h_q, d_v)
|
|
assert max_logits.shape == (s_q, h_q)
|
|
assert lse.shape == (s_q, h_q)
|