[Perf] Optimize reshape_and_cache_flash CUDA Kernel (#22036)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-08-01 19:18:51 -04:00
committed by GitHub
parent 88faa466d7
commit eefbf4a68b
2 changed files with 225 additions and 23 deletions

View File

@ -0,0 +1,156 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random
import time
import torch
from tabulate import tabulate
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random_flash,
)
logger = init_logger(__name__)
@torch.inference_mode()
def run_benchmark(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
kv_cache_dtype: str,
kv_cache_layout: str,
num_iters: int,
device: str = "cuda",
) -> float:
"""Return latency (seconds) for given num_tokens."""
if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
current_platform.seed_everything(42)
torch.set_default_device(device)
# create random key / value tensors [T, H, D].
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
value = torch.randn_like(key)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots = block_size * num_blocks
if num_tokens > num_slots:
raise ValueError("num_tokens cannot exceed the total number of cache slots")
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
key_caches, value_caches = create_kv_caches_with_random_flash(
num_blocks,
block_size,
1, # num_layers
num_heads,
head_size,
kv_cache_dtype,
dtype,
device=device,
cache_layout=kv_cache_layout,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / n_iters
# warm-up
run_cuda_benchmark(3)
lat = run_cuda_benchmark(num_iters)
# free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache()
return lat
def main(args):
rows = []
for layout in ["NHD", "HND"]:
for exp in range(1, 17):
n_tok = 2**exp
lat = run_benchmark(
num_tokens=n_tok,
num_heads=args.num_heads,
head_size=args.head_size,
block_size=args.block_size,
num_blocks=args.num_blocks,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
kv_cache_dtype=args.kv_cache_dtype,
kv_cache_layout=layout,
num_iters=args.iters,
device="cuda",
)
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--num-heads", type=int, default=128)
parser.add_argument(
"--head-size",
type=int,
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--num-blocks", type=int, default=128 * 512)
parser.add_argument(
"--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="bfloat16",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
default="auto",
)
parser.add_argument("--iters", type=int, default=100)
args = parser.parse_args()
main(args)

View File

@ -5,6 +5,7 @@
#include "cuda_utils.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh"
#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
@ -261,14 +262,26 @@ __global__ void reshape_and_cache_kernel(
}
}
// Used by vectorization_utils to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
float scale;
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst = static_cast<OutT>(src);
} else {
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
}
}
};
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
// head_size]
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
// head_size]
cache_t* __restrict__ key_cache, // NHD or HND, shape see comments below
cache_t* __restrict__ value_cache, // same above
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
@ -282,25 +295,58 @@ __global__ void reshape_and_cache_flash_kernel(
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_key_value_idx = block_idx * block_stride +
block_offset * page_stride +
head_idx * head_stride + head_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_value_idx] = tgt_key;
value_cache[tgt_key_value_idx] = tgt_value;
} else {
key_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
const int n_elems = num_heads * head_size;
// pointers to the beginning of the source row for this token.
const scalar_t* __restrict__ key_src = key + token_idx * key_stride;
const scalar_t* __restrict__ value_src = value + token_idx * value_stride;
// find the start position inside the kv-cache for this token.
cache_t* __restrict__ key_dst =
key_cache + block_idx * block_stride + block_offset * page_stride;
cache_t* __restrict__ value_dst =
value_cache + block_idx * block_stride + block_offset * page_stride;
// this is true for the NHD layout where `head_stride == head_size`
const bool is_contiguous_heads = (head_stride == head_size);
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
if (is_contiguous_heads) {
// NHD layout
// kv cache: [num_blocks, block_size, num_heads, head_size]
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
blockDim.x, k_op);
vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
threadIdx.x, blockDim.x, v_op);
} else {
// HND layout: heads are strided, but each head_size segment is contiguous
// kv cache: [num_blocks, num_heads, block_size, head_size]
const int lane = threadIdx.x & 31; // 0..31 within warp
const int warp_id = threadIdx.x >> 5; // warp index within block
const int warps_per_block = blockDim.x >> 5;
for (int head = warp_id; head < num_heads; head += warps_per_block) {
const scalar_t* __restrict__ k_src_h = key_src + head * head_size;
const scalar_t* __restrict__ v_src_h = value_src + head * head_size;
cache_t* __restrict__ k_dst_h =
key_dst + static_cast<int64_t>(head) * head_stride;
cache_t* __restrict__ v_dst_h =
value_dst + static_cast<int64_t>(head) * head_stride;
// within each head, let the 32 threads of the warp perform the vector
// copy
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
k_op);
vectorize_with_alignment<VEC_SIZE>(v_src_h, v_dst_h, head_size, lane, 32,
v_op);
}
}
}