[CI/Build] tests(v1): feed Triton attention the (num_blocks, 2, …) KV cache layout in backend-correctness tests (#26663)

Signed-off-by: Huamin Li <3ericli@gmail.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Huamin Li
2025-10-17 21:11:26 -07:00
committed by GitHub
parent c981f0ea78
commit c312320764

View File

@ -423,13 +423,14 @@ def _test_backend_correctness(
for backend_name in backend_to_test:
# FlashAttentionm + FlexAttention:
# [2, num_blocks, block_size, num_kv_heads, head_size]
# FlashInfer:
# FlashInfer + Triton:
# [num_blocks, 2, block_size, num_kv_heads, head_size]
# Select the appropriate KV cache format for each backend
kv_cache_for_backend = kv_cache
if backend_name == _Backend.FLASHINFER:
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
kv_cache_for_backend = kv_cache.transpose(0, 1)
if backend_name == _Backend.FLASHINFER:
# For FlashInfer default to HND layout and
kv_cache_for_backend = (
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)