mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI Perf] Prune tests in tests/kernels/attention/
(#22936)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -9,10 +9,10 @@ import torch
|
||||
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
BLOCK_SIZES = [16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
QDTYPES = [None]
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
|
@ -29,17 +29,14 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = [
|
||||
torch.half, torch.bfloat16, torch.float
|
||||
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
|
||||
# This should be sync with get_supported_head_sizes() in
|
||||
# vllm.attention.ops.paged_attn.PagedAttention
|
||||
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
||||
HEAD_SIZES = [32, 80, 128, 256]
|
||||
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
|
@ -11,11 +11,11 @@ from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [42] # Arbitrary values for testing
|
||||
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||
NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 120, 256]
|
||||
HEAD_SIZES = [64, 80, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
CACHE_LAYOUTS = ["NHD", "HND"]
|
||||
|
||||
|
@ -12,14 +12,16 @@ from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||
flash_attn_with_kvcache,
|
||||
is_fa_version_supported)
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
BLOCK_SIZES = [16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
QDTYPES = [None, torch.float8_e4m3fn]
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
SOFT_CAPS = [None, 50.0]
|
||||
SLIDING_WINDOWS = [None, 256]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
@ -83,9 +85,9 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
@ -198,9 +200,9 @@ def test_flash_attn_with_paged_kv(
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
|
@ -9,11 +9,13 @@ import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
|
||||
NUM_HEADS = [(32, 8), (6, 1)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
SOFT_CAPS = [None, 30.0]
|
||||
SLIDING_WINDOWS = [None, 64]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
@ -76,8 +78,8 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@pytest.mark.parametrize("sliding_window", [None, 64])
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_decode_with_paged_kv(
|
||||
kv_lens: list[int],
|
||||
@ -173,8 +175,8 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@pytest.mark.parametrize("sliding_window", [None, 64])
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_prefill_with_paged_kv(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
@ -278,11 +280,11 @@ def test_flashinfer_prefill_with_paged_kv(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
|
||||
@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
|
||||
head_size: int, dtype: torch.dtype, block_size: int,
|
||||
@ -385,11 +387,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.skip(reason="TODO: fix the accuracy issue")
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
kv_lens: list[int],
|
||||
@ -399,7 +402,6 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
pytest.skip("TODO: fix the accuracy issue")
|
||||
# test doesn't work for num_heads = (16,16)
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
|
@ -20,11 +20,11 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
MAX_Q_LEN = 1024
|
||||
MAX_KV_LEN = 4096
|
||||
BATCH_SIZES = [4, 12]
|
||||
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
|
||||
NUM_HEADS = [(16, 16), (40, 8)]
|
||||
HEAD_SIZES = [128]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
BLOCK_SIZES = [16]
|
||||
KV_LAYOUTS = ["HND"]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
SOFT_CAPS = [None, 50.0]
|
||||
|
@ -19,13 +19,13 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
NUM_HEADS = [64]
|
||||
NUM_QUERIES_PER_KV = [1, 8, 64]
|
||||
HEAD_SIZES = [128, 96, 24]
|
||||
NUM_QUERIES_PER_KV = [1, 64]
|
||||
HEAD_SIZES = [24, 128]
|
||||
DTYPES = [torch.float16]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
|
||||
SLIDING_WINDOW = [0, 16, 2048]
|
||||
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
||||
|
||||
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
|
||||
|
@ -9,11 +9,11 @@ import torch
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
BLOCK_SIZES = [16]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
|
||||
None, torch.float8_e4m3fnuz
|
||||
]
|
||||
@ -85,7 +85,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
|
Reference in New Issue
Block a user