mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Perf] Optimize reshape_and_cache_flash
CUDA Kernel (#22036)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
156
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
Normal file
156
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
Normal file
@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
FlexibleArgumentParser,
|
||||
create_kv_caches_with_random_flash,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_benchmark(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_layout: str,
|
||||
num_iters: int,
|
||||
device: str = "cuda",
|
||||
) -> float:
|
||||
"""Return latency (seconds) for given num_tokens."""
|
||||
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
||||
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# create random key / value tensors [T, H, D].
|
||||
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
|
||||
value = torch.randn_like(key)
|
||||
|
||||
# prepare the slot mapping.
|
||||
# each token is assigned a unique slot in the KV-cache.
|
||||
num_slots = block_size * num_blocks
|
||||
if num_tokens > num_slots:
|
||||
raise ValueError("num_tokens cannot exceed the total number of cache slots")
|
||||
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||
|
||||
key_caches, value_caches = create_kv_caches_with_random_flash(
|
||||
num_blocks,
|
||||
block_size,
|
||||
1, # num_layers
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device,
|
||||
cache_layout=kv_cache_layout,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# compute per-kernel scaling factors for fp8 conversion (if used).
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
|
||||
def run_cuda_benchmark(n_iters: int) -> float:
|
||||
nonlocal key, value, key_cache, value_cache, slot_mapping
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(n_iters):
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
return (end - start) / n_iters
|
||||
|
||||
# warm-up
|
||||
run_cuda_benchmark(3)
|
||||
|
||||
lat = run_cuda_benchmark(num_iters)
|
||||
|
||||
# free tensors to mitigate OOM when sweeping
|
||||
del key, value, key_cache, value_cache, slot_mapping
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return lat
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for layout in ["NHD", "HND"]:
|
||||
for exp in range(1, 17):
|
||||
n_tok = 2**exp
|
||||
lat = run_benchmark(
|
||||
num_tokens=n_tok,
|
||||
num_heads=args.num_heads,
|
||||
head_size=args.head_size,
|
||||
block_size=args.block_size,
|
||||
num_blocks=args.num_blocks,
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
kv_cache_layout=layout,
|
||||
num_iters=args.iters,
|
||||
device="cuda",
|
||||
)
|
||||
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
|
||||
|
||||
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
|
||||
parser.add_argument("--num-heads", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--num-blocks", type=int, default=128 * 512)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="bfloat16",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
default="auto",
|
||||
)
|
||||
|
||||
parser.add_argument("--iters", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
@ -5,6 +5,7 @@
|
||||
#include "cuda_utils.h"
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include "quantization/fp8/amd/quant_utils.cuh"
|
||||
@ -261,14 +262,26 @@ __global__ void reshape_and_cache_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Used by vectorization_utils to copy/convert one element
|
||||
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
|
||||
struct CopyWithScaleOp {
|
||||
float scale;
|
||||
|
||||
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
dst = static_cast<OutT>(src);
|
||||
} else {
|
||||
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void reshape_and_cache_flash_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
cache_t* __restrict__ key_cache, // NHD or HND, shape see comments below
|
||||
cache_t* __restrict__ value_cache, // same above
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int64_t block_stride, const int64_t page_stride,
|
||||
const int64_t head_stride, const int64_t key_stride,
|
||||
@ -282,25 +295,58 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
const int n = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int64_t src_key_idx = token_idx * key_stride + i;
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
||||
block_offset * page_stride +
|
||||
head_idx * head_stride + head_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
key_cache[tgt_key_value_idx] = tgt_key;
|
||||
value_cache[tgt_key_value_idx] = tgt_value;
|
||||
const int n_elems = num_heads * head_size;
|
||||
|
||||
// pointers to the beginning of the source row for this token.
|
||||
const scalar_t* __restrict__ key_src = key + token_idx * key_stride;
|
||||
const scalar_t* __restrict__ value_src = value + token_idx * value_stride;
|
||||
|
||||
// find the start position inside the kv-cache for this token.
|
||||
cache_t* __restrict__ key_dst =
|
||||
key_cache + block_idx * block_stride + block_offset * page_stride;
|
||||
cache_t* __restrict__ value_dst =
|
||||
value_cache + block_idx * block_stride + block_offset * page_stride;
|
||||
|
||||
// this is true for the NHD layout where `head_stride == head_size`
|
||||
const bool is_contiguous_heads = (head_stride == head_size);
|
||||
|
||||
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
|
||||
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
|
||||
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
|
||||
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
|
||||
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
|
||||
if (is_contiguous_heads) {
|
||||
// NHD layout
|
||||
// kv cache: [num_blocks, block_size, num_heads, head_size]
|
||||
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
|
||||
blockDim.x, k_op);
|
||||
|
||||
vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
|
||||
threadIdx.x, blockDim.x, v_op);
|
||||
|
||||
} else {
|
||||
key_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
||||
value_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
||||
// HND layout: heads are strided, but each head_size segment is contiguous
|
||||
// kv cache: [num_blocks, num_heads, block_size, head_size]
|
||||
const int lane = threadIdx.x & 31; // 0..31 within warp
|
||||
const int warp_id = threadIdx.x >> 5; // warp index within block
|
||||
const int warps_per_block = blockDim.x >> 5;
|
||||
|
||||
for (int head = warp_id; head < num_heads; head += warps_per_block) {
|
||||
const scalar_t* __restrict__ k_src_h = key_src + head * head_size;
|
||||
const scalar_t* __restrict__ v_src_h = value_src + head * head_size;
|
||||
|
||||
cache_t* __restrict__ k_dst_h =
|
||||
key_dst + static_cast<int64_t>(head) * head_stride;
|
||||
cache_t* __restrict__ v_dst_h =
|
||||
value_dst + static_cast<int64_t>(head) * head_stride;
|
||||
|
||||
// within each head, let the 32 threads of the warp perform the vector
|
||||
// copy
|
||||
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
|
||||
k_op);
|
||||
|
||||
vectorize_with_alignment<VEC_SIZE>(v_src_h, v_dst_h, head_size, lane, 32,
|
||||
v_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user