mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) (#12676)
Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
@ -15,6 +15,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
|
@ -46,7 +46,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
char* src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
// We use the stride instead of numel in case the cache is padded for memory
|
||||
// alignment reasons, we assume the blocks data (inclusive of any padding)
|
||||
// is contiguous in memory
|
||||
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
src_device.is_cuda() ? src_device : dst_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -93,6 +96,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel for MLA, which works on a single joint kv_cache
|
||||
// Grid: (num_layers, num_pairs)
|
||||
template <typename scalar_t>
|
||||
__global__ void copy_blocks_mla_kernel(
|
||||
int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping,
|
||||
const int mem_footprint_per_block) {
|
||||
const int layer_idx = blockIdx.x;
|
||||
const int pair_idx = blockIdx.y;
|
||||
scalar_t* cache = reinterpret_cast<scalar_t*>(cache_ptrs[layer_idx]);
|
||||
int64_t src_block = block_mapping[2 * pair_idx];
|
||||
int64_t dst_block = block_mapping[2 * pair_idx + 1];
|
||||
int64_t src_offset = src_block * mem_footprint_per_block;
|
||||
int64_t dst_offset = dst_block * mem_footprint_per_block;
|
||||
for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) {
|
||||
cache[dst_offset + i] = cache[src_offset + i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
@ -147,6 +168,42 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
}));
|
||||
}
|
||||
|
||||
// copy blocks kernel for MLA (assumes a joint KV-cache)
|
||||
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
|
||||
const torch::Tensor& block_mapping) {
|
||||
int num_layers = kv_caches.size();
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
torch::Device cache_device = kv_caches[0].device();
|
||||
TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");
|
||||
|
||||
std::vector<int64_t> cache_ptrs(num_layers);
|
||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||
cache_ptrs[layer_idx] =
|
||||
reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
|
||||
}
|
||||
torch::Tensor cache_ptrs_tensor =
|
||||
torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
|
||||
.to(cache_device);
|
||||
|
||||
int num_pairs = block_mapping.size(0);
|
||||
// We use the stride instead of numel in case the cache is padded for memory
|
||||
// alignment reasons, we assume the blocks data (inclusive of any padding)
|
||||
// is contiguous in memory
|
||||
int mem_footprint_per_block = kv_caches[0].stride(0);
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, mem_footprint_per_block));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
|
||||
vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
|
||||
}));
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
@ -254,6 +311,7 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
@ -274,9 +332,8 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
int src_stride, int dst_stride, int size, int offset) {
|
||||
for (int i = threadIdx.x; i < size; i += blockDim.x) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx = block_idx * block_stride +
|
||||
block_offset * (kv_lora_rank + pe_dim) + i +
|
||||
offset;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
@ -391,14 +448,14 @@ void reshape_and_cache_flash(
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
|
||||
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
@ -428,6 +485,7 @@ void concat_and_cache_mla(
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
|
@ -450,6 +450,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
||||
|
||||
cache_ops.def(
|
||||
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks_mla", torch::kCUDA, ©_blocks_mla);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
cache_ops.def(
|
||||
"reshape_and_cache(Tensor key, Tensor value,"
|
||||
|
@ -9,6 +9,7 @@ import torch
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import align_to_256bytes
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -18,6 +19,13 @@ NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 120, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
|
||||
# Parameters for MLA tests.
|
||||
KV_LORA_RANKS = [512]
|
||||
QK_ROPE_HEAD_DIMS = [64]
|
||||
NUM_TOKENS_MLA = [42]
|
||||
BLOCK_SIZES_MLA = [16]
|
||||
NUM_BLOCKS_MLA = [8]
|
||||
|
||||
# Arbitrary values for testing
|
||||
# don't make it too large. e.g. [1024, 36000] will OOM
|
||||
NUM_BLOCKS = [1024, 10000]
|
||||
@ -432,3 +440,257 @@ def test_fp8_e4m3_conversion(
|
||||
ops.convert_fp8(converted_cache, cache_fp8)
|
||||
|
||||
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||
|
||||
|
||||
def _create_mla_cache(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
entry_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
align_cache: bool,
|
||||
) -> torch.Tensor:
|
||||
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
|
||||
|
||||
if align_cache:
|
||||
alloc_entry_size = align_to_256bytes(entry_size, cache_dtype)
|
||||
alloc_shape = (num_blocks, block_size, alloc_entry_size)
|
||||
cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device)
|
||||
cache = cache_full[..., :entry_size]
|
||||
else:
|
||||
cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=cache_dtype,
|
||||
device=device)
|
||||
return cache
|
||||
|
||||
|
||||
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
|
||||
rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype
|
||||
|
||||
vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
|
||||
if kv_cache_dtype == "fp8":
|
||||
temp = torch.zeros_like(cache)
|
||||
ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
|
||||
vals = temp
|
||||
cache.copy_(vals)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("align_cache", [False])
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
align_cache: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
|
||||
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
|
||||
ops.convert_fp8(ref_kv_cache,
|
||||
ref_temp,
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
else:
|
||||
ref_kv_cache = ref_temp
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla,
|
||||
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_temp,
|
||||
kv_cache.contiguous(),
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(expected_temp,
|
||||
ref_kv_cache,
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
torch.testing.assert_close(result_temp,
|
||||
expected_temp,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
else:
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("align_cache", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
align_cache: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
kv_caches = []
|
||||
for _ in range(num_layers):
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
kv_caches.append(kv_cache)
|
||||
|
||||
ref_caches = [kv_cache.clone() for kv_cache in kv_caches]
|
||||
|
||||
num_mappings = min(2, num_blocks // 2)
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remaining = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remaining, 2 * num_mappings)
|
||||
block_mapping = []
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping.append((src, dst1))
|
||||
block_mapping.append((src, dst2))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device=device).view(-1, 2)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
for ref_cache in ref_caches:
|
||||
ref_cache[dst].copy_(ref_cache[src])
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.copy_blocks_mla,
|
||||
(kv_caches, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
ops.copy_blocks_mla(kv_caches, block_mapping_tensor)
|
||||
|
||||
for kv_cache, ref_cache in zip(kv_caches, ref_caches):
|
||||
torch.testing.assert_close(kv_cache, ref_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("align_cache", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_swap_blocks_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
align_cache: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype)
|
||||
_fill_mla_cache(dst_cache, kv_cache_dtype)
|
||||
|
||||
src_cache_clone = src_cache.clone()
|
||||
|
||||
num_mappings = min(2, num_blocks // 2)
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remaining_blocks, num_mappings)
|
||||
block_mapping = list(zip(src_blocks, dst_blocks))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu").view(-1, 2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_cache, dst_cache, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
cond=(kv_lora_rank == KV_LORA_RANKS[0]
|
||||
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
|
||||
)
|
||||
|
||||
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
torch.testing.assert_close(
|
||||
src_cache_clone[src].cpu(),
|
||||
dst_cache[dst].cpu(),
|
||||
msg=f"Block {src} from src should have been swapped to block "
|
||||
f"{dst} in dst_cache.")
|
||||
|
@ -1037,6 +1037,11 @@ def copy_blocks(key_caches: List[torch.Tensor],
|
||||
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
|
||||
def copy_blocks_mla(kv_caches: List[torch.Tensor],
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping)
|
||||
|
||||
|
||||
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
|
||||
|
@ -26,7 +26,6 @@ from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
@ -72,14 +71,14 @@ class TritonMLABackend(AttentionBackend):
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
|
@ -204,10 +204,10 @@ def _decode_att_m_fwd(
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-2),
|
||||
k_buffer.stride(-1),
|
||||
v_buffer.stride(-2),
|
||||
v_buffer.stride(-1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
@ -438,10 +438,10 @@ def _decode_grouped_att_m_fwd(
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-2),
|
||||
k_buffer.stride(-1),
|
||||
v_buffer.stride(-2),
|
||||
v_buffer.stride(-1),
|
||||
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
|
10
vllm/envs.py
10
vllm/envs.py
@ -82,6 +82,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MLA_DISABLE: bool = False
|
||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
|
||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||
|
||||
|
||||
@ -539,6 +540,15 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
),
|
||||
|
||||
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
|
||||
# byte aligned for better performance, this increases the memory usage of
|
||||
# the cache. Currently this only affects MLA that results in non-256
|
||||
# byte aligned entries. This matches the alignment the CUDA runtime uses
|
||||
# for all allocations. Currently this primarily affects MLA, for most other
|
||||
# models the alignment is already naturally aligned to 256 bytes.
|
||||
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
|
||||
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -563,6 +563,10 @@ def cdiv(a: int, b: int) -> int:
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
|
||||
def _generate_random_fp8(
|
||||
tensor: torch.Tensor,
|
||||
low: float,
|
||||
@ -794,6 +798,12 @@ def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
|
||||
def align_to_256bytes(extent: int, dtype: torch.dtype) -> int:
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
eles_per_256bytes = 256 // dtype_size
|
||||
return round_up(extent, eles_per_256bytes)
|
||||
|
||||
|
||||
# `collections` helpers
|
||||
def is_list_of(
|
||||
value: object,
|
||||
|
@ -2,13 +2,17 @@
|
||||
"""CacheEngine class for managing the KV cache."""
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
|
||||
get_dtype_size, is_pin_memory_available)
|
||||
align_to_256bytes, get_dtype_size,
|
||||
is_pin_memory_available)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -38,6 +42,7 @@ class CacheEngine:
|
||||
self.num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.align_cache = self._align_cache(model_config)
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
@ -75,15 +80,39 @@ class CacheEngine:
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[torch.Tensor] = []
|
||||
|
||||
# Align entries so they are 256 byte aligned for better performance
|
||||
# Primarily targets MLA as this typically only ends up having entries
|
||||
# be 128 byte aligned.
|
||||
if self.align_cache:
|
||||
# We assume the cache shape is:
|
||||
# (TOTAL_PAGES, PAGE_SIZE, entry_shape...)
|
||||
# NOTE this assumption currently only holds for MLA so we only apply
|
||||
# this optimization when `use_mla` is true
|
||||
entry_shape = kv_cache_shape[2:]
|
||||
entry_size = np.prod(entry_shape)
|
||||
alloc_entry_size = align_to_256bytes(entry_size, self.dtype)
|
||||
alloc_shape = (*kv_cache_shape[:2], alloc_entry_size)
|
||||
else:
|
||||
alloc_shape = kv_cache_shape
|
||||
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
kv_cache.append(
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device))
|
||||
layer_kv_cache = torch.zeros(alloc_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
|
||||
# If we allocated with padding for alignment reasons truncate the
|
||||
# shape while preserving the aligned stride
|
||||
if self.align_cache:
|
||||
layer_kv_cache = layer_kv_cache[..., :entry_size]
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append(layer_kv_cache.view(kv_cache_shape))
|
||||
return kv_cache
|
||||
|
||||
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||
@ -99,6 +128,14 @@ class CacheEngine:
|
||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def _align_cache(model_config: ModelConfig):
|
||||
# Currently align_cache only applies to MLA models since the other
|
||||
# cache kernels haven't been updated yet to support non-continguous
|
||||
# tensors
|
||||
return model_config.use_mla and current_platform.is_cuda() \
|
||||
and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
cache_config: CacheConfig,
|
||||
@ -110,14 +147,21 @@ class CacheEngine:
|
||||
num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
|
||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||
# For MLA there is no value cache, since the latent vector
|
||||
# is joint keys and values.
|
||||
value_cache_block = key_cache_block if not model_config.use_mla else 0
|
||||
total = num_attention_layers * (key_cache_block + value_cache_block)
|
||||
if cache_config.cache_dtype == "auto":
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
key_cache_entry = num_heads * head_size
|
||||
if CacheEngine._align_cache(model_config):
|
||||
key_cache_entry = align_to_256bytes(key_cache_entry,
|
||||
model_config.dtype)
|
||||
|
||||
# For MLA there is no value cache, since the latent vector
|
||||
# is joint keys and values.
|
||||
value_cache_entry = key_cache_entry if not model_config.use_mla else 0
|
||||
total = num_attention_layers * cache_config.block_size * \
|
||||
(key_cache_entry + value_cache_entry)
|
||||
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
|
Reference in New Issue
Block a user