mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[neuron] add reshape_and_cache (#14391)
This commit is contained in:
83
tests/neuron/test_cache.py
Normal file
83
tests/neuron/test_cache.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.ops.nki_flash_attn import reshape_and_cache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_tokens, n_kv_head, d_head, num_blocks, block_size",
|
||||||
|
[
|
||||||
|
# Small model configuration (e.g., GPT-2 small)
|
||||||
|
(32, 12, 64, 4, 128), # Typical sequence processing
|
||||||
|
(1, 12, 64, 4, 128), # Single token update
|
||||||
|
(128, 12, 64, 4, 128), # Longer sequence
|
||||||
|
|
||||||
|
# Medium model configuration (e.g., GPT-2 medium)
|
||||||
|
(64, 16, 96, 8, 256), # Standard batch
|
||||||
|
(256, 16, 96, 8, 256), # Large batch
|
||||||
|
|
||||||
|
# Large model configuration (e.g., GPT-3 style)
|
||||||
|
(48, 32, 128, 16, 512), # Typical processing window
|
||||||
|
(512, 32, 128, 16, 512), # Full context window
|
||||||
|
|
||||||
|
# Edge cases and stress tests
|
||||||
|
(1024, 8, 32, 32, 32), # Many tokens, small heads
|
||||||
|
(16, 64, 256, 4, 64), # Few tokens, many heads
|
||||||
|
(2048, 24, 128, 64, 128), # Large scale test
|
||||||
|
|
||||||
|
# Minimal configurations for debugging
|
||||||
|
(4, 2, 16, 2, 16), # Tiny test case
|
||||||
|
(1, 1, 8, 1, 8), # Minimal possible
|
||||||
|
])
|
||||||
|
def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
|
||||||
|
block_size):
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Create CPU tensors for reference implementation
|
||||||
|
key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
|
||||||
|
torch.tensor(d_head))
|
||||||
|
value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
|
||||||
|
torch.tensor(d_head))
|
||||||
|
key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
|
||||||
|
value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
|
||||||
|
slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens]
|
||||||
|
|
||||||
|
# Run reference implementation on CPU
|
||||||
|
block_indices = torch.div(slot_mapping_cpu,
|
||||||
|
block_size,
|
||||||
|
rounding_mode="floor")
|
||||||
|
block_offsets = slot_mapping_cpu % block_size
|
||||||
|
|
||||||
|
for i in range(num_tokens):
|
||||||
|
block_idx = block_indices[i]
|
||||||
|
block_offset = block_offsets[i]
|
||||||
|
key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i]
|
||||||
|
value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i]
|
||||||
|
|
||||||
|
# Create XLA device tensors
|
||||||
|
device = torch.device('xla')
|
||||||
|
key = key_cpu.to(device)
|
||||||
|
value = value_cpu.to(device)
|
||||||
|
key_cache = torch.zeros_like(key_cache_cpu, device=device)
|
||||||
|
value_cache = torch.zeros_like(value_cache_cpu, device=device)
|
||||||
|
slot_mapping = slot_mapping_cpu.to(device)
|
||||||
|
|
||||||
|
# Run vectorized implementation on XLA device
|
||||||
|
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||||
|
|
||||||
|
# Move results back to CPU for comparison
|
||||||
|
key_cache_result = key_cache.cpu()
|
||||||
|
value_cache_result = value_cache.cpu()
|
||||||
|
|
||||||
|
# Assert results match
|
||||||
|
torch.testing.assert_close(key_cache_result,
|
||||||
|
key_cache_cpu,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-5)
|
||||||
|
torch.testing.assert_close(value_cache_result,
|
||||||
|
value_cache_cpu,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-5)
|
@ -869,3 +869,46 @@ def flash_attn_varlen_nkifunc(
|
|||||||
|
|
||||||
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Writes key-value pairs to the KV cache at specified positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (torch.Tensor): Key tensor with shape
|
||||||
|
(num_tokens, n_kv_head, d_head)
|
||||||
|
value (torch.Tensor): Value tensor with shape
|
||||||
|
(num_tokens, n_kv_head, d_head)
|
||||||
|
key_cache (torch.Tensor): Key cache tensor with shape
|
||||||
|
(num_blocks, n_kv_head, block_size, d_head)
|
||||||
|
value_cache (torch.Tensor): Value cache tensor with shape
|
||||||
|
(num_blocks, n_kv_head, block_size, d_head)
|
||||||
|
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
|
||||||
|
with shape (num_tokens)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: Updates the key_cache and value_cache tensors in-place
|
||||||
|
"""
|
||||||
|
block_size = key_cache.size(2)
|
||||||
|
|
||||||
|
# Calculate indices with explicit floor division
|
||||||
|
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||||
|
block_offsets = slot_mapping % block_size
|
||||||
|
|
||||||
|
# Update caches using index_put_
|
||||||
|
key_cache.index_put_(
|
||||||
|
(block_indices.unsqueeze(1),
|
||||||
|
torch.arange(key_cache.size(1),
|
||||||
|
device=key.device), block_offsets.unsqueeze(1)), key)
|
||||||
|
|
||||||
|
value_cache.index_put_(
|
||||||
|
(block_indices.unsqueeze(1),
|
||||||
|
torch.arange(value_cache.size(1),
|
||||||
|
device=value.device), block_offsets.unsqueeze(1)), value)
|
||||||
|
Reference in New Issue
Block a user