[Perf] Optimize reshape_and_cache CUDA Kernel (#25955)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Co-authored-by: Liu-congo <1502632128@qq.com>
This commit is contained in:
Jiangyun Zhu
2025-10-03 16:33:46 +08:00
committed by GitHub
parent 0ad9951c41
commit eb0fa43868
2 changed files with 225 additions and 45 deletions

View File

@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random
import time
import torch
from tabulate import tabulate
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)
logger = init_logger(__name__)
@torch.inference_mode()
def run_benchmark(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
kv_cache_dtype: str,
num_iters: int,
benchmark_mode: str,
device: str = "cuda",
) -> float:
"""Return latency (seconds) for given num_tokens."""
if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
current_platform.seed_everything(42)
torch.set_default_device(device)
# create random key / value tensors [T, H, D].
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
value = torch.randn_like(key)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots = block_size * num_blocks
if num_tokens > num_slots:
raise ValueError("num_tokens cannot exceed the total number of cache slots")
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
key_caches, value_caches = create_kv_caches_with_random(
num_blocks,
block_size,
1, # num_layers
num_heads,
head_size,
kv_cache_dtype,
dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# to free unused memory
del key_caches, value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
function_under_test = lambda: ops.reshape_and_cache(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
if benchmark_mode == "cudagraph":
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
function_under_test()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / n_iters
# warm-up
run_cuda_benchmark(3)
lat = run_cuda_benchmark(num_iters)
# free tensors to mitigate OOM when sweeping
del key, value, key_cache, value_cache, slot_mapping
torch.cuda.empty_cache()
return lat
def main(args):
rows = []
for exp in range(1, 17):
n_tok = 2**exp
lat = run_benchmark(
num_tokens=n_tok,
num_heads=args.num_heads,
head_size=args.head_size,
block_size=args.block_size,
num_blocks=args.num_blocks,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
kv_cache_dtype=args.kv_cache_dtype,
num_iters=args.iters,
benchmark_mode=args.mode,
device="cuda",
)
rows.append([n_tok, lat * 1e6]) # convert to microseconds
print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f"))
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--num-heads", type=int, default=128)
parser.add_argument(
"--head-size",
type=int,
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--num-blocks", type=int, default=128 * 128)
parser.add_argument(
"--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="bfloat16",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
default="auto",
)
parser.add_argument("--iters", type=int, default=200)
parser.add_argument(
"--mode",
type=str,
choices=["cudagraph", "no_graph"],
default="cudagraph",
)
args = parser.parse_args()
main(args)

View File

@ -16,8 +16,7 @@
#include <algorithm>
#include <cassert>
#include <map>
#include <vector>
#include <cfloat> // FLT_MIN
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
@ -209,6 +208,20 @@ void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
namespace vllm {
// Used to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
float scale;
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst = static_cast<OutT>(src);
} else {
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
}
}
};
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
@ -224,59 +237,51 @@ __global__ void reshape_and_cache_kernel(
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int h_block_count = head_size / x; // head_size//x
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int h_block_idx = threadIdx.x;
if (h_block_idx >= num_heads * h_block_count) {
return;
}
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;
const int head_idx = h_block_idx / h_block_count;
const int h_block = h_block_idx % h_block_count;
const int64_t tgt_key_idx =
block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
block_offset * x + x_offset;
const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
const scalar_t* __restrict__ key_src =
key + token_idx * key_stride + head_idx * head_size + h_block * x;
const int64_t src_value_start =
token_idx * value_stride + head_idx * head_size + h_block * x;
cache_t* __restrict__ key_dst =
key_cache + block_idx * num_heads * h_block_count * block_size * x +
head_idx * h_block_count * block_size * x + h_block * block_size * x +
block_offset * x;
const int64_t tgt_value_start =
block_idx * num_heads * h_block_count * x * block_size +
head_idx * h_block_count * x * block_size + h_block * x * block_size +
block_offset;
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, x, 0, 1, k_op);
const scalar_t* __restrict__ value_src = value + src_value_start;
cache_t* __restrict__ value_dst = value_cache + tgt_value_start;
#pragma unroll
for (int i = 0; i < x; i++) {
v_op(value_dst[i * block_size], value_src[i]);
}
}
// Used by vectorization_utils to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
float scale;
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst = static_cast<OutT>(src);
} else {
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
}
}
};
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
@ -601,9 +606,10 @@ void reshape_and_cache(
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int head_div_x = head_size / x;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
dim3 block(std::min(num_heads * head_div_x, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();