mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI/Build] tests(v1): feed Triton attention the (num_blocks, 2, …) KV cache layout in backend-correctness tests (#26663)
Signed-off-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
@ -423,13 +423,14 @@ def _test_backend_correctness(
|
||||
for backend_name in backend_to_test:
|
||||
# FlashAttentionm + FlexAttention:
|
||||
# [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
# FlashInfer:
|
||||
# FlashInfer + Triton:
|
||||
# [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
# Select the appropriate KV cache format for each backend
|
||||
kv_cache_for_backend = kv_cache
|
||||
if backend_name == _Backend.FLASHINFER:
|
||||
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
|
||||
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
||||
|
||||
if backend_name == _Backend.FLASHINFER:
|
||||
# For FlashInfer default to HND layout and
|
||||
kv_cache_for_backend = (
|
||||
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
|
||||
|
Reference in New Issue
Block a user