[V1][Kernel] Add triton implementation for reshape_and_cache_flash (#24503)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Burkhard Ringlein
2025-09-23 18:52:40 +02:00
committed by GitHub
parent 527821d191
commit 100b630a60
4 changed files with 276 additions and 20 deletions

View File

@ -9,6 +9,9 @@ import torch
from tabulate import tabulate
from vllm import _custom_ops as ops
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
@ -31,6 +34,8 @@ def run_benchmark(
kv_cache_dtype: str,
kv_cache_layout: str,
num_iters: int,
implementation: str,
benchmark_mode: str,
device: str = "cuda",
) -> float:
"""Return latency (seconds) for given num_tokens."""
@ -38,6 +43,14 @@ def run_benchmark(
if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
if implementation not in ("cuda", "triton"):
raise ValueError(
f"Unsupported implementation: {implementation}. "
"Only 'cuda' and 'triton' are supported."
)
if implementation == "triton" and kv_cache_layout == "HND":
return float("nan") # Triton does not support HND layout yet.
current_platform.seed_everything(42)
torch.set_default_device(device)
@ -65,27 +78,49 @@ def run_benchmark(
cache_layout=kv_cache_layout,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# to free unused memory
del key_caches, value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
if implementation == "cuda":
function_under_test = lambda: ops.reshape_and_cache_flash(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
else:
function_under_test = lambda: triton_reshape_and_cache_flash(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
if benchmark_mode == "cudagraph":
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
torch.cuda.synchronize()
function_under_test()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / n_iters
@ -116,10 +151,16 @@ def main(args):
kv_cache_dtype=args.kv_cache_dtype,
kv_cache_layout=layout,
num_iters=args.iters,
implementation=args.implementation,
benchmark_mode=args.mode,
device="cuda",
)
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
print(
f"Benchmark results for implementation {args.implementation}"
f" (measuring with {args.mode}):"
)
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
@ -151,6 +192,21 @@ if __name__ == "__main__":
)
parser.add_argument("--iters", type=int, default=100)
parser.add_argument(
"--implementation",
type=str,
choices=["cuda", "triton"],
default="cuda",
)
parser.add_argument(
"--mode",
type=str,
choices=["cudagraph", "no_graph"],
default="cudagraph",
)
args = parser.parse_args()
main(args)

View File

@ -39,6 +39,8 @@ CUDA_DEVICES = [
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@ -223,6 +225,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
@ -236,9 +239,13 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
kv_cache_layout: str,
implementation: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
assert implementation in ["cuda", "triton"]
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
@ -298,12 +305,20 @@ def test_reshape_and_cache_flash(
cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
if implementation == "cuda":
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)

View File

@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform
@triton.jit
def reshape_and_cache_kernel_flash(
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size]
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
slot_mapping_ptr, # [num_tokens]
k_scale, # float32
v_scale, # float32
# strides
key_stride: tl.int64,
value_stride: tl.int64,
block_stride: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size: tl.constexpr,
block_size: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
# tune parameters
TILE_SIZE: tl.constexpr,
):
token_idx = tl.program_id(axis=0)
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
# Padding token that should be ignored.
return
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
src_key_idx = token_idx * key_stride
src_value_idx = token_idx * value_stride
tgt_idx = block_idx * block_stride + block_offset * page_stride
# [TILE_SIZE]
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
mask=tile_pos < (num_heads * head_size))
if FP8_KV_CACHE:
if key_load.dtype.is_fp8():
key_tile = key_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty
key_tile = key_load / tl.load(k_scale)
else:
key_tile = key_load
# [TILE_SIZE]
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
mask=tile_pos < (num_heads * head_size))
if FP8_KV_CACHE:
if value_load.dtype.is_fp8():
value_tile = value_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the value_cache_ptr.dtype.element_ty
value_tile = value_load / tl.load(v_scale)
else:
value_tile = value_load
tl.store(
key_cache_ptr + tgt_idx + tile_pos,
key_tile,
mask=tile_pos < (num_heads * head_size),
)
tl.store(
value_cache_ptr + tgt_idx + tile_pos,
value_tile,
mask=tile_pos < (num_heads * head_size),
)
return
def triton_reshape_and_cache_flash(
key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size]
# [num_blocks, block_size, num_heads, head_size]
key_cache: torch.Tensor,
# [num_blocks, block_size, num_heads, head_size]
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, # [num_tokens]
kv_cache_dtype: str, # "auto", "fp8"
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
num_tokens = key.shape[0]
num_heads = key.shape[1]
head_size = key.shape[2]
block_size = key_cache.shape[1]
n = num_heads * head_size
key_stride = key.stride()[0]
value_stride = value.stride()[0]
block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1]
head_stride = key_cache.stride()[2]
assert head_stride == head_size, "only continous heads are supported"
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
kv_cache_dtype.startswith("fp8") else key_cache.dtype
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
"fp8"):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
"uint8 is not supported by triton reshape_and_cache_flash"
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
torch.float8_e4m3fnuz], \
"unsupported dtype of KV cache tensor, got "\
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
# heuristics instead of autotuning
TILE_SIZE = min(2048, triton.next_power_of_2(n))
if torch.version.hip:
num_stages = 4
num_warps = 8
else: # cuda
num_stages = 10
num_warps = 16
if torch.cuda.get_device_capability(key.device)[0] < 9:
TILE_SIZE = min(512, TILE_SIZE)
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
reshape_and_cache_kernel_flash[grid](
key_ptr=key,
value_ptr=value,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
slot_mapping_ptr=slot_mapping,
k_scale=k_scale,
v_scale=v_scale,
# strides
key_stride=key_stride,
value_stride=value_stride,
block_stride=block_stride,
page_stride=page_stride,
num_heads=num_heads,
head_size=head_size,
block_size=block_size,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
TILE_SIZE=TILE_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)

View File

@ -8,6 +8,8 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -291,7 +293,13 @@ class TritonAttentionImpl(AttentionImpl):
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
@ -303,8 +311,9 @@ class TritonAttentionImpl(AttentionImpl):
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, \
"A non 1.0 q_scale is not currently supported."