[NVIDIA] Support Flashinfer TRTLLM FP8-q/kv/out Attention Kernel (#21716)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-08-19 20:22:15 +08:00
committed by GitHub
parent 40f26734b9
commit 03752dba8f
9 changed files with 916 additions and 500 deletions

View File

@ -631,6 +631,7 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/fusion.py
- vllm/compilation/fusion_attn.py
commands:
- nvidia-smi
- python3 examples/offline_inference/basic/chat.py
@ -647,6 +648,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
# Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
##### 1 GPU test #####
##### multi gpus test #####

View File

@ -3,16 +3,14 @@
import csv
import os
import random
from datetime import datetime
from typing import Optional
import flashinfer
import torch
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
FP8_DTYPE = torch.float8_e4m3fn
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -26,65 +24,107 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad()
def benchmark_decode(
num_seqs,
max_seq_len,
page_size=16,
dtype=torch.bfloat16,
kv_layout="HND",
num_kv_heads=8,
kv_cache_dtype="auto",
head_dim=128,
warmup=10,
trials=20,
dtype: torch.dtype,
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int,
max_seq_len: int,
num_heads: tuple[int, int] = (64, 8),
head_size: int = 128,
kv_layout: str = "HND",
block_size: int = 16,
warmup: int = 10,
trials: int = 20,
):
torch.set_default_device("cuda")
device = "cuda"
torch.manual_seed(0)
HEAD_GRP_SIZE = 8
MAX_SEQ_LEN = max_seq_len
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
sm_scale = float(1.0 / (head_size**0.5))
# large number to reduce kv_cache reuse
NUM_BLOCKS = int(256000 / page_size)
NUM_BLOCKS = int(256000 / block_size)
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
# For decode, batch_size is num_decode_token
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
sm_scale = float(1.0 / (head_dim**0.5))
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
max_kv_len = max(kv_lens)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_seq_len
seq_lens = kv_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
k_scale = v_scale = 1.0
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
if kv_cache_dtype.startswith("fp8"):
kv_cache, _ = to_float8(kv_cache)
output_trtllm = torch.empty(q.shape, dtype=dtype)
# Benchmark TRT decode
def trt_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
q,
kv_cache,
workspace_buffer,
block_tables,
kv_lens_tensor,
max_kv_len,
bmm1_scale=k_scale * sm_scale,
bmm2_scale=v_scale,
out=output_trtllm,
)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
)
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
@ -101,74 +141,51 @@ def benchmark_decode(
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
# TRT Decode
trt_mean, trt_std = time_fn(trt_decode)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + page_size - 1) // page_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % page_size
if kv_last_page_len == 0:
kv_last_page_len = page_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
output_baseline = torch.empty(q.shape, dtype=dtype)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
"NONE",
q_data_type=dtype,
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
)
o_scale = 1.0
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def baseline_decode():
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
def trtllm_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
out=output_trtllm,
)
baseline_mean, baseline_std = time_fn(baseline_decode)
trtllm_mean, trtllm_std = time_fn(trtllm_decode)
# Calculate percentage speedup (positive means TRT is faster)
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
print(
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
)
# Return results for CSV writing
return {
"num_seqs": num_seqs,
"trt_mean": trt_mean,
"trt_std": trt_std.item(),
"batch_size": batch_size,
"trtllm_mean": trtllm_mean,
"trtllm_std": trtllm_std.item(),
"baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent,
"q_dtype": str(dtype),
"kv_cache_dtype": kv_cache_dtype,
"page_size": page_size,
"q_dtype": str(q_quant_dtype),
"kv_cache_dtype": str(kv_quant_dtype),
"output_dtype": str(o_quant_dtype),
"block_size": block_size,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"head_size": head_size,
"max_seq_len": max_seq_len,
}
@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None):
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [
"num_seqs",
"trt_mean",
"trt_std",
"batch_size",
"trtllm_mean",
"trtllm_std",
"baseline_mean",
"baseline_std",
"speedup_percent",
"q_dtype",
"kv_cache_dtype",
"page_size",
"output_dtype",
"block_size",
"num_kv_heads",
"head_dim",
"head_size",
"max_seq_len",
]
@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None):
if __name__ == "__main__":
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []
print(
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
)
all_results.append(result)
dtype = torch.bfloat16
quant_dtypes = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
]
print(
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="fp8",
)
all_results.append(result)
for quant_dtype in quant_dtypes:
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
print(
f"Running benchmark for q_dtype = {q_quant_dtype}, "
f"kv_cache_dtype: {kv_quant_dtype}, "
f"output_dtype: {o_quant_dtype}"
)
print(
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in batch_sizes:
result = benchmark_decode(
dtype=dtype,
quant_dtypes=quant_dtype,
batch_size=bs,
max_seq_len=max_seq_len,
)
all_results.append(result)
# Write all results to CSV
write_results_to_csv(all_results)

View File

@ -3,16 +3,14 @@
import csv
import os
import random
from datetime import datetime
from typing import Optional
import flashinfer
import torch
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
FP8_DTYPE = torch.float8_e4m3fn
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -26,84 +24,99 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad()
def benchmark_prefill(
num_seqs,
max_seq_len,
page_size=16,
dtype=torch.bfloat16,
kv_layout="HND",
num_kv_heads=8,
kv_cache_dtype="auto",
head_dim=128,
warmup=10,
trials=20,
dtype: torch.dtype,
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int,
max_seq_len: int,
num_heads: tuple[int, int] = (64, 8),
head_size: int = 128,
kv_layout: str = "HND",
block_size: int = 16,
warmup: int = 10,
trials: int = 20,
):
torch.set_default_device("cuda")
torch.manual_seed(0)
HEAD_GRP_SIZE = 8
MAX_SEQ_LEN = max_seq_len
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
max_q_len = max_kv_len = max_seq_len
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
sm_scale = float(1.0 / (head_size**0.5))
# large number to reduce kv_cache reuse
NUM_BLOCKS = int(256000 / page_size)
NUM_BLOCKS = int(256000 / block_size)
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8)
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
sm_scale = float(1.0 / (head_dim**0.5))
q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
q_lens[-1] = MAX_SEQ_LEN
max_q_len = max(q_lens)
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
q_lens[-1] = max_q_len
q_indptr = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
torch.cumsum(
torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32
),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
]
)
q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype)
kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)]
kv_lens[-1] = MAX_SEQ_LEN
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)]
max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len
max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype)
k_scale = v_scale = 1.0
if kv_cache_dtype.startswith("fp8"):
kv_cache, _ = to_float8(kv_cache)
output_trtllm = torch.empty(q.shape, dtype=dtype)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + page_size - 1) // page_size
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % page_size
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = page_size
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
output_baseline = torch.empty(q.shape, dtype=dtype)
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
@ -115,12 +128,12 @@ def benchmark_prefill(
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
head_size,
block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=kv_cache.dtype,
kv_data_type=dtype,
)
def time_fn(fn, warmup=10, trials=20):
@ -138,52 +151,55 @@ def benchmark_prefill(
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
def baseline_prefill():
return wrapper.run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline
)
o_scale = 1.0
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def trt_prefill():
def baseline_prefill():
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
def trtllm_prefill():
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens_tensor,
seq_lens=seq_lens,
max_q_len=max_q_len,
max_kv_len=max_seq_len,
bmm1_scale=k_scale * sm_scale,
bmm2_scale=v_scale,
batch_size=num_seqs,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
out=output_trtllm,
)
trt_mean, trt_std = time_fn(trt_prefill)
baseline_mean, baseline_std = time_fn(baseline_prefill)
trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
# Calculate percentage speedup (positive means TRT is faster)
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
print(
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}"
f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}"
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
)
# Return results for CSV writing
return {
"num_seqs": num_seqs,
"trt_mean": trt_mean,
"trt_std": trt_std.item(),
"batch_size": batch_size,
"trtllm_mean": trtllm_mean,
"trtllm_std": trtllm_std.item(),
"baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent,
"q_dtype": str(dtype),
"kv_cache_dtype": kv_cache_dtype,
"page_size": page_size,
"q_dtype": str(q_quant_dtype),
"kv_cache_dtype": str(kv_quant_dtype),
"output_dtype": str(o_quant_dtype),
"block_size": block_size,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"head_size": head_size,
"max_seq_len": max_seq_len,
}
@ -195,17 +211,18 @@ def write_results_to_csv(results, filename=None):
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [
"num_seqs",
"trt_mean",
"trt_std",
"batch_size",
"trtllm_mean",
"trtllm_std",
"baseline_mean",
"baseline_std",
"speedup_percent",
"q_dtype",
"kv_cache_dtype",
"page_size",
"output_dtype",
"block_size",
"num_kv_heads",
"head_dim",
"head_size",
"max_seq_len",
]
@ -224,27 +241,41 @@ def write_results_to_csv(results, filename=None):
if __name__ == "__main__":
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []
print(
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_prefill(
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
)
all_results.append(result)
dtype = torch.bfloat16
quant_dtypes = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
]
for quant_dtype in quant_dtypes:
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
print(
f"Running benchmark for q_dtype = {q_quant_dtype}, "
f"kv_cache_dtype: {kv_quant_dtype}, "
f"output_dtype: {o_quant_dtype}"
)
print(
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in batch_sizes:
result = benchmark_prefill(
dtype=dtype,
quant_dtypes=quant_dtype,
batch_size=bs,
max_seq_len=max_seq_len,
)
all_results.append(result)
# Write all results to CSV
write_results_to_csv(all_results)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from typing import Optional
import pytest
@ -7,13 +8,27 @@ import torch._dynamo
from tests.compile.backend import TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from vllm import LLM, SamplingParams
from vllm.attention import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype()
# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
@ -132,3 +147,235 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# Reset backend to make sure llm2 gets released
backend = None
class TestAttentionStaticQuantPatternModel(torch.nn.Module):
"""Test model for AttentionStaticQuantPattern fusion."""
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype, device: torch.device,
vllm_config: VllmConfig):
super().__init__()
self.num_qo_heads = num_qo_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.attn = Attention(
num_heads=self.num_qo_heads,
head_size=self.head_size,
scale=1.0 / (self.head_size**0.5),
num_kv_heads=self.num_kv_heads,
cache_config=vllm_config.cache_config,
prefix="model.layers.0.self_attn.attn",
)
self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
self.wscale = torch.tensor([1.0], dtype=torch.float32)
self.scale = torch.tensor([1.0], dtype=torch.float32)
self.block_size = 16
# Initialize attn MetadataBuilder
self.builder = self.attn.attn_backend.get_builder_cls()(
kv_cache_spec=AttentionSpec(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_dtype,
use_mla=False,
),
layer_names=[self.attn.layer_name],
vllm_config=self.vllm_config,
device=self.device,
)
def build_attn_metadata(self, batch_size: int):
"""Initialize attention metadata."""
# Create common attn metadata
batch_spec = BatchSpec(seq_lens=[1] * batch_size,
query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
self.block_size,
self.device,
arange_block_indices=True)
max_blocks = (max(batch_spec.seq_lens) + self.block_size -
1) // self.block_size
num_blocks = batch_size * max_blocks
# Create dummy KV cache for FlashInfer TRTLLM
# - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
# Create kv_cache in HND layout and permute to NHD layout
# (later will be permuted back to HND layout in forward pass)
kv_cache = torch.zeros(num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device)
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
self.attn.kv_cache = [kv_cache]
# Build attn metadata
self.attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata)
return self.attn_metadata
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(input=attn_output,
weight=w,
weight_scale=self.wscale,
input_scale=self.scale)
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size", [7, 256, 533])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize(
"model_name, quant_key",
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)])
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)),
reason="Only test on SM100(Blackwell)")
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str,
quant_key: QuantKey, backend: _Backend,
monkeypatch, dist_init):
"""Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_USE_V1", "1")
device = torch.device("cuda:0")
torch.manual_seed(42)
vllm_config = VllmConfig(
model_config=ModelConfig(
model=model_name,
max_model_len=2048,
),
scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+quant_fp8"],
),
cache_config=CacheConfig(cache_dtype="fp8"))
# Create test inputs
hidden_size = num_qo_heads * head_size
q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
k = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
v = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t()
# Mark first dimension as dynamic for realistic testing
torch._dynamo.mark_dynamic(q, 0)
torch._dynamo.mark_dynamic(k, 0)
torch._dynamo.mark_dynamic(v, 0)
# Run model directly without compilation and fusion
vllm_config_unfused = copy.deepcopy(vllm_config)
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
attn_metadata=None, vllm_config=vllm_config_unfused
), global_force_attn_backend_context_manager(backend):
model_unfused = TestAttentionStaticQuantPatternModel(
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
vllm_config_unfused)
model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v, linear_w)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
enable_attn_fusion=True, enable_noop=True)
with set_current_vllm_config(vllm_config), set_forward_context(
attn_metadata=None, vllm_config=vllm_config
), global_force_attn_backend_context_manager(backend):
model_fused = TestAttentionStaticQuantPatternModel(
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
vllm_config)
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
)
test_backend = TestBackend(noop_pass, attn_pass)
# Compile model with fusion enabled
model_compiled = torch.compile(model_fused,
backend=test_backend,
fullgraph=True)
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v, linear_w)
# After the 1st round of the forward pass, output quant scale should be
# loaded into the attn layer's _o_scale_float, the 2nd round should
# reuse the loaded _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v, linear_w)
assert model_compiled.attn._o_scale_float is not None
# Check attn fusion support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape) for key,
layer in vllm_config.compilation_config.static_forward_context.items()
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
test_backend.check_before_ops([QUANT_OPS[quant_key]],
fully_replaced=True)
# Check attention ops in the graph before and after fusion
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP,
test_backend.graph_post_pass))
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
assert len(attn_nodes_pre) == len(attn_nodes_post), \
"Should have same number of attention nodes before and after fusion"
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \
"Attention should not have output_scale before fusion"
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
"Attention should have output_scale after fusion"
# Check that results are closed
torch.testing.assert_close(result_unfused,
result_fused_1,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)

View File

@ -13,21 +13,7 @@ if not current_platform.is_device_capability(100):
allow_module_level=True)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
MAX_Q_LEN = 1024
MAX_KV_LEN = 4096
BATCH_SIZES = [4, 12]
NUM_HEADS = [(16, 16), (40, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16]
KV_LAYOUTS = ["HND"]
DTYPES = [torch.bfloat16]
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 50.0]
FP8_DTYPE = current_platform.fp8_dtype()
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -39,42 +25,59 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
return x_scl_sat.to(dtype), scale.float().reciprocal()
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
DTYPE = [torch.bfloat16]
QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
]
BATCH_SIZE = [4, 12]
MAX_SEQ_LENS = [(1024, 4096)]
NUM_HEADS = [(64, 8), (40, 8)]
HEAD_SIZE = [128]
KV_LAYOUT = ["HND"] # currently only HND is supported
BLOCK_SIZE = [16]
SOFT_CAP = [None, 50.0]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
block_size: int,
kv_layout: str,
dtype: torch.dtype,
kv_cache_dtype: Optional[torch.dtype],
block_size: int,
soft_cap: Optional[float],
) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
torch.set_default_device("cuda")
current_platform.seed_everything(0)
kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = MAX_KV_LEN
max_kv_len = torch.max(kv_lens).item()
num_seqs = len(kv_lens)
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
_, max_kv_len = max_seq_lens
scale = head_size**-0.5
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
sm_scale = float(1.0 / (head_size**0.5))
kv_cache_shape = None
if kv_layout == "NHD":
@ -83,156 +86,39 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
kv_scale = 1.0
if kv_cache_dtype is current_platform.fp8_dtype():
key_value_cache, kv_scale = to_float8(key_value_cache,
current_platform.fp8_dtype())
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
k_scale = v_scale = kv_scale
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = max_kv_len
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_query_heads // num_kv_heads) > 4))
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=scale,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query,
key_value_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output)
# TRTLLM Decode
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
output_trtllm = torch.empty(query.shape, dtype=dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query.contiguous(),
kv_cache=key_value_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=kv_lens_tensor,
max_seq_len=max_kv_len,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
out=output_trtllm,
)
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - output_trtllm))}"
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
batch_size: int,
num_heads: tuple[int, int],
head_size: int,
block_size: int,
kv_layout: str,
dtype: torch.dtype,
kv_cache_dtype: Optional[torch.dtype],
soft_cap: Optional[float],
) -> None:
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if dtype != kv_cache_dtype:
pytest.skip(f"Not supported dtype({dtype}) with "
"kv_cache_dtype({kv_cache_dtype})")
torch.set_default_device("cuda")
current_platform.seed_everything(0)
q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32)
q_lens[-1] = MAX_Q_LEN
max_q_len = torch.max(q_lens).item()
q_indptr = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = MAX_KV_LEN
seq_lens = kv_lens + q_lens
seq_lens = kv_lens
max_seq_len = torch.max(seq_lens).item()
num_seqs = len(seq_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
scale = head_size**-0.5
query = torch.randn(torch.sum(q_lens).item(),
num_query_heads,
head_size,
dtype=dtype)
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
kv_scale = 1.0
if kv_cache_dtype is current_platform.fp8_dtype():
key_value_cache, kv_scale = to_float8(key_value_cache,
current_platform.fp8_dtype())
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32)
k_scale = v_scale = kv_scale
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
@ -246,48 +132,206 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Decode
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4))
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
# TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
kv_layout: str,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
max_q_len, max_kv_len = max_seq_lens
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0
sm_scale = float(1.0 / (head_size**0.5))
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32)
q_lens[-1] = max_q_len
q_indptr = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
query = torch.randn(torch.sum(q_lens).item(),
num_qo_heads,
head_size,
dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Prefill
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout)
wrapper.plan(q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
causal=True,
sm_scale=scale,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = torch.empty(query.shape, dtype=dtype)
wrapper.run(query,
key_value_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
# TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=dtype)
# TRTLLM Prefill
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=query.contiguous(),
kv_cache=key_value_cache,
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens,
max_q_len=max_q_len,
max_kv_len=max_seq_len,
bmm1_scale=k_scale * scale,
bmm2_scale=v_scale,
batch_size=num_seqs,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"

View File

@ -128,11 +128,17 @@ class Attention(nn.Module):
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.num_heads = num_heads
self.head_size = head_size
@ -291,6 +297,7 @@ class Attention(nn.Module):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once

View File

@ -9,7 +9,7 @@ from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -18,23 +18,32 @@ from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionStaticQuantPattern:
"""
Fusion for Attention+StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the StaticQuant
op will be removed from the graph, and its scale will be passed into
Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer_name: str,
num_heads: int,
head_size: int,
layer: Attention,
quant_dtype: torch.dtype,
symmetric=True,
):
self.layer_name = layer_name
self.num_heads = num_heads
self.head_size = head_size
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_dtype = quant_dtype
self.quant_key = QuantKey(dtype=quant_dtype,
static=True,
@ -48,11 +57,10 @@ class AttentionStaticQuantPattern:
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def register_if_supported(self, pm_pass: PatternMatcherPass,
layer: Attention):
if layer.impl.fused_output_quant_supported(self.quant_dtype,
self.quant_key.static,
self.quant_key.group_shape):
def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(
self.quant_dtype, self.quant_key.static,
self.quant_key.group_shape):
self._register(pm_pass)
def _register(self, pm_pass: PatternMatcherPass):
@ -60,19 +68,15 @@ class AttentionStaticQuantPattern:
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_attn,
[-1, self.num_heads, self.head_size])
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
output=output_attn,
layer_name=self.layer_name,
output_scale=None)
attn_out_view = RESHAPE_OP(at1[1],
[-1, self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
result=output_quant,
input=attn_out_view,
@ -82,17 +86,19 @@ class AttentionStaticQuantPattern:
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_quant,
[-1, self.num_heads, self.head_size])
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
0.0,
dtype=self.quant_dtype,
device=q.device)
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
output=output_attn,
layer_name=self.layer_name,
output_scale=scale)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors.
@ -102,7 +108,7 @@ class AttentionStaticQuantPattern:
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads * self.head_size), # attn_output
empty_bf16(5, self.num_heads, self.head_size), # attn_output
self.empty_quant(5, self.num_heads *
self.head_size), # quant_output
empty_fp32(1, 1) # scale
@ -140,27 +146,30 @@ class AttnFusionPass(VllmInductorPass):
def __init__(self, config: VllmConfig):
super().__init__(config)
self.static_fwd_ctx = config.compilation_config.static_forward_context
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
for key, layer in self.static_fwd_ctx.items():
pattern = AttentionStaticQuantPattern(key, layer.num_heads,
layer.head_size,
current_platform.fp8_dtype())
pattern.register_if_supported(self.patterns, layer)
if len(self.static_fwd_ctx) == 0:
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE)
pattern.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(
"Attention + quant fusion is enabled, but "
"CompilationConfig.static_forward_context is empty. "
"Cannot access attention layers so no fusion "
"patterns were registered.")
"Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered.")
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.begin()
self.dump_graph(graph, "before_attn_fusion")
count = self.patterns.apply(graph)
# TODO: Move this to pass_manager.py after the fx graph broken issue
# has been resolved.
# see https://github.com/vllm-project/vllm/issues/23091
graph.eliminate_dead_code()
logger.debug("Fused quantization onto %s attention nodes", count)
self.dump_graph(graph, "after_attn_fusion")
self.end_and_log()

View File

@ -174,21 +174,30 @@ def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
def use_trtllm_attention(
num_qo_heads: int,
num_kv_heads: int,
num_tokens: int,
max_seq_len: int,
kv_cache_dtype: str,
num_qo_heads: Optional[int],
num_kv_heads: Optional[int],
attn_head_size: Optional[int],
q_dtype: torch.dtype,
is_prefill: bool,
has_sinks: bool = False,
) -> bool:
use_trtllm, env_value = supports_trtllm_attention()
if not use_trtllm:
return False
# Check if the dimensions are supported by TRTLLM decode attention
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
or num_qo_heads % num_kv_heads != 0):
if num_qo_heads % num_kv_heads != 0:
return False
# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
logger.info_once("Using TRTLLM attention (query is quantized).")
return True
# TRTLLM prefill attention does not support FP8 kv cache with
# non-quantized query
if is_prefill and kv_cache_dtype.startswith("fp8"):
return False
# If sinks are being used, we must use TRTLLM attention as it's
@ -290,6 +299,7 @@ __all__ = [
"has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory",
"supports_trtllm_attention",
"use_trtllm_attention",
"flashinfer_scaled_fp4_mm",
]

View File

@ -15,12 +15,17 @@ from flashinfer.decode import (_get_range_buf, get_seq_lens,
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_attention
from vllm.utils.flashinfer import (supports_trtllm_attention,
use_trtllm_attention)
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
# yapf conflicts with isort for this block
# yapf: disable
@ -35,6 +40,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
@ -519,22 +526,27 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
else:
kv_cache_dtype = self.kv_cache_spec.dtype
num_qo_heads = self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config)
config = self.vllm_config
num_qo_heads = config.model_config.get_num_attention_heads(
config.parallel_config)
num_kv_heads = self.kv_cache_spec.num_kv_heads
head_dim = self.kv_cache_spec.head_size
# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks = self.global_hyperparameters.has_sinks
# currently prefill trtllm attention does not support fp8 kv cache
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
and use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim, has_sinks)
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
q_dtype = config.model_config.dtype
enable_fusion = config.compilation_config.pass_config.enable_attn_fusion
if cache_dtype.startswith("fp8") and enable_fusion:
q_dtype = kv_cache_dtype
prefill_use_trtllm = use_trtllm_attention(
num_qo_heads, num_kv_heads, num_prefill_tokens, max_seq_len,
cache_dtype, q_dtype, is_prefill=True, has_sinks=has_sinks)
decode_use_trtllm = use_trtllm_attention(
num_decode_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim, has_sinks)
num_qo_heads, num_kv_heads, num_decode_tokens, max_seq_len,
cache_dtype, q_dtype, is_prefill=False, has_sinks=has_sinks)
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
@ -548,7 +560,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
head_dim=head_dim,
page_size=page_size,
kv_data_type=kv_cache_dtype,
q_data_type=self.vllm_config.model_config.dtype,
q_data_type=q_dtype,
slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len,
max_seq_len=max_seq_len,
@ -622,6 +634,8 @@ class FlashInferImpl(AttentionImpl):
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.window_left = (self.sliding_window[0]
if self.sliding_window is not None else -1)
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
@ -644,6 +658,19 @@ class FlashInferImpl(AttentionImpl):
)
self.sinks = sinks
self.support_trtllm_attn = (supports_trtllm_attention() and
num_heads % num_kv_heads == 0)
self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
supported_quant_type = (dtype == FP8_DTYPE and static and
group_shape == GroupShape.PER_TENSOR)
return (self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8")
and supported_quant_type)
def forward(
self,
layer: torch.nn.Module,
@ -672,15 +699,42 @@ class FlashInferImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
if attn_metadata is None:
# Profiling run.
return output
if self.bmm1_scale is None:
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
self.scale)
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
# The attn+quant fusion happens when output_scale is provided.
if output_scale is None:
assert attn_metadata.q_data_type != FP8_DTYPE, \
"Query can only be FP8 if output fusion happened."
else:
assert attn_metadata.q_data_type == FP8_DTYPE, \
"Query must be FP8 when attn+quant fusion happened."
assert (attn_metadata.prefill_use_trtllm and
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
assert output.dtype == FP8_DTYPE, \
"Output must be FP8 when attn+quant fusion happened."
# TRTLLM attn kernel requires o scale as a host scalar, store the
# o scale to host scalar in warmup run with cuda graph not enabled
if layer._o_scale_float is None:
layer._o_scale_float = output_scale.cpu().item()
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
# Insert FP8 quant for query
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@ -718,9 +772,6 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
window_left = (self.sliding_window[0]
if self.sliding_window is not None else -1)
# Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens]
output_padded = output
@ -748,7 +799,7 @@ class FlashInferImpl(AttentionImpl):
if not attn_metadata.prefill_use_trtllm:
assert prefill_wrapper._causal
assert prefill_wrapper._window_left == window_left
assert prefill_wrapper._window_left == self.window_left
assert prefill_wrapper._logits_soft_cap == (
self.logits_soft_cap or 0.0)
assert prefill_wrapper._sm_scale == self.scale
@ -783,12 +834,12 @@ class FlashInferImpl(AttentionImpl):
seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len,
max_kv_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
batch_size=attn_metadata.num_prefills,
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
window_left=window_left,
window_left=self.window_left,
sinks=self.sinks,
out=output[num_decode_tokens:],
)
@ -800,7 +851,7 @@ class FlashInferImpl(AttentionImpl):
assert decode_wrapper is not None
if not attn_metadata.decode_use_trtllm:
assert decode_wrapper._window_left == window_left
assert decode_wrapper._window_left == self.window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
assert decode_wrapper._sm_scale == self.scale
@ -815,8 +866,8 @@ class FlashInferImpl(AttentionImpl):
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer
block_tables_decode = attn_metadata.block_table_tensor[:
num_decode_tokens]
block_tables_decode = attn_metadata.\
block_table_tensor[:num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
@ -834,9 +885,9 @@ class FlashInferImpl(AttentionImpl):
block_tables=block_tables_decode,
seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
window_left=window_left,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
window_left=self.window_left,
sinks=self.sinks,
out=output[:num_decode_tokens],
)