mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[NVIDIA] Support Flashinfer TRTLLM FP8-q/kv/out Attention Kernel (#21716)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@ -631,6 +631,7 @@ steps:
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/fusion.py
|
||||
- vllm/compilation/fusion_attn.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
@ -647,6 +648,7 @@ steps:
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
@ -3,16 +3,14 @@
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -26,65 +24,107 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: tuple[int, int] = (64, 8),
|
||||
head_size: int = 128,
|
||||
kv_layout: str = "HND",
|
||||
block_size: int = 16,
|
||||
warmup: int = 10,
|
||||
trials: int = 20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
NUM_BLOCKS = int(256000 / block_size)
|
||||
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
# For decode, batch_size is num_decode_token
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
|
||||
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
|
||||
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
|
||||
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_seq_len
|
||||
|
||||
seq_lens = kv_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
output_trtllm = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
# Benchmark TRT decode
|
||||
def trt_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
q,
|
||||
kv_cache,
|
||||
workspace_buffer,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
max_kv_len,
|
||||
bmm1_scale=k_scale * sm_scale,
|
||||
bmm2_scale=v_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
@ -101,74 +141,51 @@ def benchmark_decode(
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
# TRT Decode
|
||||
trt_mean, trt_std = time_fn(trt_decode)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % page_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
output_baseline = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
|
||||
)
|
||||
o_scale = 1.0
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
|
||||
def trtllm_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
trtllm_mean, trtllm_std = time_fn(trtllm_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
|
||||
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"batch_size": batch_size,
|
||||
"trtllm_mean": trtllm_mean,
|
||||
"trtllm_std": trtllm_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"q_dtype": str(q_quant_dtype),
|
||||
"kv_cache_dtype": str(kv_quant_dtype),
|
||||
"output_dtype": str(o_quant_dtype),
|
||||
"block_size": block_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"head_size": head_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None):
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"batch_size",
|
||||
"trtllm_mean",
|
||||
"trtllm_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"page_size",
|
||||
"output_dtype",
|
||||
"block_size",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"head_size",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
print(
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
all_results.append(result)
|
||||
dtype = torch.bfloat16
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
]
|
||||
|
||||
print(
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="fp8",
|
||||
)
|
||||
all_results.append(result)
|
||||
for quant_dtype in quant_dtypes:
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
print(
|
||||
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||
f"output_dtype: {o_quant_dtype}"
|
||||
)
|
||||
print(
|
||||
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in batch_sizes:
|
||||
result = benchmark_decode(
|
||||
dtype=dtype,
|
||||
quant_dtypes=quant_dtype,
|
||||
batch_size=bs,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
|
@ -3,16 +3,14 @@
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -26,84 +24,99 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_prefill(
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: tuple[int, int] = (64, 8),
|
||||
head_size: int = 128,
|
||||
kv_layout: str = "HND",
|
||||
block_size: int = 16,
|
||||
warmup: int = 10,
|
||||
trials: int = 20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
max_q_len = max_kv_len = max_seq_len
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
NUM_BLOCKS = int(256000 / block_size)
|
||||
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
|
||||
q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
q_lens[-1] = MAX_SEQ_LEN
|
||||
max_q_len = max(q_lens)
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
|
||||
q_lens[-1] = max_q_len
|
||||
q_indptr = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(
|
||||
torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32
|
||||
),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype)
|
||||
|
||||
kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
kv_lens[-1] = MAX_SEQ_LEN
|
||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)]
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
output_trtllm = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % page_size
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
output_baseline = torch.empty(q.shape, dtype=dtype)
|
||||
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout
|
||||
@ -115,12 +128,12 @@ def benchmark_prefill(
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
head_size,
|
||||
block_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_cache.dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
@ -138,52 +151,55 @@ def benchmark_prefill(
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
def baseline_prefill():
|
||||
return wrapper.run(
|
||||
q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline
|
||||
)
|
||||
o_scale = 1.0
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def trt_prefill():
|
||||
def baseline_prefill():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
|
||||
def trtllm_prefill():
|
||||
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
query=q,
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens_tensor,
|
||||
seq_lens=seq_lens,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_seq_len,
|
||||
bmm1_scale=k_scale * sm_scale,
|
||||
bmm2_scale=v_scale,
|
||||
batch_size=num_seqs,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
trt_mean, trt_std = time_fn(trt_prefill)
|
||||
baseline_mean, baseline_std = time_fn(baseline_prefill)
|
||||
trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}"
|
||||
f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}"
|
||||
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
|
||||
f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"batch_size": batch_size,
|
||||
"trtllm_mean": trtllm_mean,
|
||||
"trtllm_std": trtllm_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"q_dtype": str(q_quant_dtype),
|
||||
"kv_cache_dtype": str(kv_quant_dtype),
|
||||
"output_dtype": str(o_quant_dtype),
|
||||
"block_size": block_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"head_size": head_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
@ -195,17 +211,18 @@ def write_results_to_csv(results, filename=None):
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"batch_size",
|
||||
"trtllm_mean",
|
||||
"trtllm_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"page_size",
|
||||
"output_dtype",
|
||||
"block_size",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"head_size",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
@ -224,27 +241,41 @@ def write_results_to_csv(results, filename=None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
print(
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_prefill(
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
all_results.append(result)
|
||||
dtype = torch.bfloat16
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
print(
|
||||
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||
f"output_dtype: {o_quant_dtype}"
|
||||
)
|
||||
print(
|
||||
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in batch_sizes:
|
||||
result = benchmark_prefill(
|
||||
dtype=dtype,
|
||||
quant_dtypes=quant_dtype,
|
||||
batch_size=bs,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -7,13 +8,27 @@ import torch._dynamo
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.models.utils import check_outputs_equal
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
# globals needed for string-import custom Dynamo backend field
|
||||
backend: Optional[TestBackend] = None
|
||||
@ -132,3 +147,235 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
|
||||
# Reset backend to make sure llm2 gets released
|
||||
backend = None
|
||||
|
||||
|
||||
class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
||||
"""Test model for AttentionStaticQuantPattern fusion."""
|
||||
|
||||
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
|
||||
kv_cache_dtype: torch.dtype, device: torch.device,
|
||||
vllm_config: VllmConfig):
|
||||
super().__init__()
|
||||
self.num_qo_heads = num_qo_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_size = head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_qo_heads,
|
||||
head_size=self.head_size,
|
||||
scale=1.0 / (self.head_size**0.5),
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=vllm_config.cache_config,
|
||||
prefix="model.layers.0.self_attn.attn",
|
||||
)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
|
||||
self.wscale = torch.tensor([1.0], dtype=torch.float32)
|
||||
self.scale = torch.tensor([1.0], dtype=torch.float32)
|
||||
|
||||
self.block_size = 16
|
||||
|
||||
# Initialize attn MetadataBuilder
|
||||
self.builder = self.attn.attn_backend.get_builder_cls()(
|
||||
kv_cache_spec=AttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
),
|
||||
layer_names=[self.attn.layer_name],
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def build_attn_metadata(self, batch_size: int):
|
||||
"""Initialize attention metadata."""
|
||||
|
||||
# Create common attn metadata
|
||||
batch_spec = BatchSpec(seq_lens=[1] * batch_size,
|
||||
query_lens=[1] * batch_size)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
self.block_size,
|
||||
self.device,
|
||||
arange_block_indices=True)
|
||||
|
||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size -
|
||||
1) // self.block_size
|
||||
num_blocks = batch_size * max_blocks
|
||||
|
||||
# Create dummy KV cache for FlashInfer TRTLLM
|
||||
# - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
# - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
|
||||
# Create kv_cache in HND layout and permute to NHD layout
|
||||
# (later will be permuted back to HND layout in forward pass)
|
||||
kv_cache = torch.zeros(num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device)
|
||||
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
|
||||
self.attn.kv_cache = [kv_cache]
|
||||
|
||||
# Build attn metadata
|
||||
self.attn_metadata = self.builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata)
|
||||
|
||||
return self.attn_metadata
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
w: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
return self.fp8_linear.apply(input=attn_output,
|
||||
weight=w,
|
||||
weight_scale=self.wscale,
|
||||
input_scale=self.scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("batch_size", [7, 256, 533])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, quant_key",
|
||||
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)])
|
||||
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)),
|
||||
reason="Only test on SM100(Blackwell)")
|
||||
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
head_size: int, batch_size: int,
|
||||
dtype: torch.dtype, model_name: str,
|
||||
quant_key: QuantKey, backend: _Backend,
|
||||
monkeypatch, dist_init):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
torch.manual_seed(42)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
),
|
||||
scheduler_config=SchedulerConfig(max_num_seqs=1024),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
),
|
||||
cache_config=CacheConfig(cache_dtype="fp8"))
|
||||
|
||||
# Create test inputs
|
||||
hidden_size = num_qo_heads * head_size
|
||||
q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
k = torch.randn(batch_size,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
v = torch.randn(batch_size,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t()
|
||||
|
||||
# Mark first dimension as dynamic for realistic testing
|
||||
torch._dynamo.mark_dynamic(q, 0)
|
||||
torch._dynamo.mark_dynamic(k, 0)
|
||||
torch._dynamo.mark_dynamic(v, 0)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
vllm_config_unfused = copy.deepcopy(vllm_config)
|
||||
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
|
||||
attn_metadata=None, vllm_config=vllm_config_unfused
|
||||
), global_force_attn_backend_context_manager(backend):
|
||||
model_unfused = TestAttentionStaticQuantPatternModel(
|
||||
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
|
||||
vllm_config_unfused)
|
||||
model_unfused = model_unfused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
|
||||
batch_size)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
result_unfused = model_unfused(q, k, v, linear_w)
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_attn_fusion=True, enable_noop=True)
|
||||
with set_current_vllm_config(vllm_config), set_forward_context(
|
||||
attn_metadata=None, vllm_config=vllm_config
|
||||
), global_force_attn_backend_context_manager(backend):
|
||||
model_fused = TestAttentionStaticQuantPatternModel(
|
||||
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
|
||||
vllm_config)
|
||||
model_fused = model_fused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
|
||||
|
||||
# Create test backend with fusion passes enabled
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
|
||||
)
|
||||
test_backend = TestBackend(noop_pass, attn_pass)
|
||||
|
||||
# Compile model with fusion enabled
|
||||
model_compiled = torch.compile(model_fused,
|
||||
backend=test_backend,
|
||||
fullgraph=True)
|
||||
assert model_compiled.attn._o_scale_float is None
|
||||
result_fused_1 = model_compiled(q, k, v, linear_w)
|
||||
|
||||
# After the 1st round of the forward pass, output quant scale should be
|
||||
# loaded into the attn layer's _o_scale_float, the 2nd round should
|
||||
# reuse the loaded _o_scale_float
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
result_fused_2 = model_compiled(q, k, v, linear_w)
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
|
||||
# Check attn fusion support
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
||||
quant_key.static,
|
||||
quant_key.group_shape) for key,
|
||||
layer in vllm_config.compilation_config.static_forward_context.items()
|
||||
]
|
||||
if any(attn_fusion_supported):
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
test_backend.check_before_ops([QUANT_OPS[quant_key]],
|
||||
fully_replaced=True)
|
||||
|
||||
# Check attention ops in the graph before and after fusion
|
||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
||||
attn_nodes_post = list(find_op_nodes(ATTN_OP,
|
||||
test_backend.graph_post_pass))
|
||||
|
||||
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
|
||||
assert len(attn_nodes_pre) == len(attn_nodes_post), \
|
||||
"Should have same number of attention nodes before and after fusion"
|
||||
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \
|
||||
"Attention should not have output_scale before fusion"
|
||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
|
||||
"Attention should have output_scale after fusion"
|
||||
|
||||
# Check that results are closed
|
||||
torch.testing.assert_close(result_unfused,
|
||||
result_fused_1,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(result_unfused,
|
||||
result_fused_2,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
|
@ -13,21 +13,7 @@ if not current_platform.is_device_capability(100):
|
||||
allow_module_level=True)
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
MAX_Q_LEN = 1024
|
||||
MAX_KV_LEN = 4096
|
||||
BATCH_SIZES = [4, 12]
|
||||
NUM_HEADS = [(16, 16), (40, 8)]
|
||||
HEAD_SIZES = [128]
|
||||
BLOCK_SIZES = [16]
|
||||
KV_LAYOUTS = ["HND"]
|
||||
DTYPES = [torch.bfloat16]
|
||||
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
SOFT_CAPS = [None, 50.0]
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -39,42 +25,59 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
DTYPE = [torch.bfloat16]
|
||||
QUANT_DTYPES = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
]
|
||||
BATCH_SIZE = [4, 12]
|
||||
MAX_SEQ_LENS = [(1024, 4096)]
|
||||
NUM_HEADS = [(64, 8), (40, 8)]
|
||||
HEAD_SIZE = [128]
|
||||
KV_LAYOUT = ["HND"] # currently only HND is supported
|
||||
BLOCK_SIZE = [16]
|
||||
SOFT_CAP = [None, 50.0]
|
||||
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
|
||||
Optional[torch.dtype]],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
kv_layout: str,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[torch.dtype],
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
|
||||
kv_lens[-1] = MAX_KV_LEN
|
||||
max_kv_len = torch.max(kv_lens).item()
|
||||
num_seqs = len(kv_lens)
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
_, max_kv_len = max_seq_lens
|
||||
|
||||
scale = head_size**-0.5
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
@ -83,156 +86,39 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
kv_scale = 1.0
|
||||
if kv_cache_dtype is current_platform.fp8_dtype():
|
||||
key_value_cache, kv_scale = to_float8(key_value_cache,
|
||||
current_platform.fp8_dtype())
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
k_scale = v_scale = kv_scale
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_query_heads // num_kv_heads) > 4))
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = torch.empty(query.shape, dtype=dtype)
|
||||
wrapper.run(query,
|
||||
key_value_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output)
|
||||
|
||||
# TRTLLM Decode
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
output_trtllm = torch.empty(query.shape, dtype=dtype)
|
||||
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=query.contiguous(),
|
||||
kv_cache=key_value_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=kv_lens_tensor,
|
||||
max_seq_len=max_kv_len,
|
||||
bmm1_scale=k_scale * scale,
|
||||
bmm2_scale=v_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
batch_size: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
kv_layout: str,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[torch.dtype],
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
if dtype != kv_cache_dtype:
|
||||
pytest.skip(f"Not supported dtype({dtype}) with "
|
||||
"kv_cache_dtype({kv_cache_dtype})")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32)
|
||||
q_lens[-1] = MAX_Q_LEN
|
||||
max_q_len = torch.max(q_lens).item()
|
||||
q_indptr = torch.cat([
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
])
|
||||
|
||||
kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
|
||||
kv_lens[-1] = MAX_KV_LEN
|
||||
|
||||
seq_lens = kv_lens + q_lens
|
||||
seq_lens = kv_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
num_seqs = len(seq_lens)
|
||||
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(),
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
kv_scale = 1.0
|
||||
if kv_cache_dtype is current_platform.fp8_dtype():
|
||||
key_value_cache, kv_scale = to_float8(key_value_cache,
|
||||
current_platform.fp8_dtype())
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
(batch_size, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
k_scale = v_scale = kv_scale
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
@ -246,48 +132,206 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
# Baseline Decode
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4))
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
o_scale = 1.0
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
|
||||
# TRTLLM Decode
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 5e-2, 7e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
|
||||
Optional[torch.dtype]],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
max_q_len, max_kv_len = max_seq_lens
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32)
|
||||
q_lens[-1] = max_q_len
|
||||
q_indptr = torch.cat([
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
])
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(),
|
||||
num_qo_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(batch_size, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
# Baseline Prefill
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout)
|
||||
wrapper.plan(q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
causal=True,
|
||||
sm_scale=scale,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
kv_data_type=dtype,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = torch.empty(query.shape, dtype=dtype)
|
||||
wrapper.run(query,
|
||||
key_value_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output)
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
o_scale = 1.0
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
|
||||
# TRTLLM Decode
|
||||
output_trtllm = torch.empty(query.shape, dtype=dtype)
|
||||
# TRTLLM Prefill
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
query=query.contiguous(),
|
||||
kv_cache=key_value_cache,
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_seq_len,
|
||||
bmm1_scale=k_scale * scale,
|
||||
bmm2_scale=v_scale,
|
||||
batch_size=num_seqs,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
out=output_trtllm,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 5e-2, 7e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
|
@ -128,11 +128,17 @@ class Attention(nn.Module):
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep the float32 versions of k/v_scale for attention
|
||||
# backends that don't support tensors (Flashinfer)
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
self._o_scale_float: Optional[float] = None
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -291,6 +297,7 @@ class Attention(nn.Module):
|
||||
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||
self._q_scale_float = self._q_scale.item()
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
self._v_scale_float = self._v_scale.item()
|
||||
# We only calculate the scales once
|
||||
|
@ -9,7 +9,7 @@ from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||
unset_fake_temporarily)
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -18,23 +18,32 @@ from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionStaticQuantPattern:
|
||||
"""
|
||||
Fusion for Attention+StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the StaticQuant
|
||||
op will be removed from the graph, and its scale will be passed into
|
||||
Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_name: str,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
layer: Attention,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True,
|
||||
):
|
||||
self.layer_name = layer_name
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.head_size = layer.head_size
|
||||
self.quant_dtype = quant_dtype
|
||||
self.quant_key = QuantKey(dtype=quant_dtype,
|
||||
static=True,
|
||||
@ -48,11 +57,10 @@ class AttentionStaticQuantPattern:
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass,
|
||||
layer: Attention):
|
||||
if layer.impl.fused_output_quant_supported(self.quant_dtype,
|
||||
self.quant_key.static,
|
||||
self.quant_key.group_shape):
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||
if self.layer.impl.fused_output_quant_supported(
|
||||
self.quant_dtype, self.quant_key.static,
|
||||
self.quant_key.group_shape):
|
||||
self._register(pm_pass)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
@ -60,19 +68,15 @@ class AttentionStaticQuantPattern:
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
view_7 = RESHAPE_OP(output_attn,
|
||||
[-1, self.num_heads, self.head_size])
|
||||
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=view_7,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None)
|
||||
attn_out_view = RESHAPE_OP(at1[1],
|
||||
[-1, self.num_heads * self.head_size])
|
||||
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=output_quant,
|
||||
input=attn_out_view,
|
||||
@ -82,17 +86,19 @@ class AttentionStaticQuantPattern:
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
view_7 = RESHAPE_OP(output_quant,
|
||||
[-1, self.num_heads, self.head_size])
|
||||
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device)
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=view_7,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale)
|
||||
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
||||
@ -102,7 +108,7 @@ class AttentionStaticQuantPattern:
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads * self.head_size), # attn_output
|
||||
empty_bf16(5, self.num_heads, self.head_size), # attn_output
|
||||
self.empty_quant(5, self.num_heads *
|
||||
self.head_size), # quant_output
|
||||
empty_fp32(1, 1) # scale
|
||||
@ -140,27 +146,30 @@ class AttnFusionPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.static_fwd_ctx = config.compilation_config.static_forward_context
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
for key, layer in self.static_fwd_ctx.items():
|
||||
pattern = AttentionStaticQuantPattern(key, layer.num_heads,
|
||||
layer.head_size,
|
||||
current_platform.fp8_dtype())
|
||||
pattern.register_if_supported(self.patterns, layer)
|
||||
if len(self.static_fwd_ctx) == 0:
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE)
|
||||
pattern.register_if_supported(self.patterns)
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but "
|
||||
"CompilationConfig.static_forward_context is empty. "
|
||||
"Cannot access attention layers so no fusion "
|
||||
"patterns were registered.")
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered.")
|
||||
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_attn_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
|
||||
# TODO: Move this to pass_manager.py after the fx graph broken issue
|
||||
# has been resolved.
|
||||
# see https://github.com/vllm-project/vllm/issues/23091
|
||||
graph.eliminate_dead_code()
|
||||
|
||||
logger.debug("Fused quantization onto %s attention nodes", count)
|
||||
self.dump_graph(graph, "after_attn_fusion")
|
||||
self.end_and_log()
|
||||
|
@ -174,21 +174,30 @@ def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
|
||||
|
||||
|
||||
def use_trtllm_attention(
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
num_tokens: int,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_qo_heads: Optional[int],
|
||||
num_kv_heads: Optional[int],
|
||||
attn_head_size: Optional[int],
|
||||
q_dtype: torch.dtype,
|
||||
is_prefill: bool,
|
||||
has_sinks: bool = False,
|
||||
) -> bool:
|
||||
use_trtllm, env_value = supports_trtllm_attention()
|
||||
if not use_trtllm:
|
||||
return False
|
||||
|
||||
# Check if the dimensions are supported by TRTLLM decode attention
|
||||
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
|
||||
or num_qo_heads % num_kv_heads != 0):
|
||||
if num_qo_heads % num_kv_heads != 0:
|
||||
return False
|
||||
|
||||
# Must use TRTLLM attention if query is FP8 quantized
|
||||
if q_dtype == current_platform.fp8_dtype():
|
||||
logger.info_once("Using TRTLLM attention (query is quantized).")
|
||||
return True
|
||||
|
||||
# TRTLLM prefill attention does not support FP8 kv cache with
|
||||
# non-quantized query
|
||||
if is_prefill and kv_cache_dtype.startswith("fp8"):
|
||||
return False
|
||||
|
||||
# If sinks are being used, we must use TRTLLM attention as it's
|
||||
@ -290,6 +299,7 @@ __all__ = [
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
]
|
||||
|
@ -15,12 +15,17 @@ from flashinfer.decode import (_get_range_buf, get_seq_lens,
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import use_trtllm_attention
|
||||
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
||||
use_trtllm_attention)
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -35,6 +40,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -519,22 +526,27 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
else:
|
||||
kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
|
||||
num_qo_heads = self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config)
|
||||
config = self.vllm_config
|
||||
num_qo_heads = config.model_config.get_num_attention_heads(
|
||||
config.parallel_config)
|
||||
num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||
head_dim = self.kv_cache_spec.head_size
|
||||
|
||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||
has_sinks = self.global_hyperparameters.has_sinks
|
||||
|
||||
# currently prefill trtllm attention does not support fp8 kv cache
|
||||
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
|
||||
and use_trtllm_attention(
|
||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
|
||||
q_dtype = config.model_config.dtype
|
||||
enable_fusion = config.compilation_config.pass_config.enable_attn_fusion
|
||||
if cache_dtype.startswith("fp8") and enable_fusion:
|
||||
q_dtype = kv_cache_dtype
|
||||
|
||||
prefill_use_trtllm = use_trtllm_attention(
|
||||
num_qo_heads, num_kv_heads, num_prefill_tokens, max_seq_len,
|
||||
cache_dtype, q_dtype, is_prefill=True, has_sinks=has_sinks)
|
||||
decode_use_trtllm = use_trtllm_attention(
|
||||
num_decode_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||
num_qo_heads, num_kv_heads, num_decode_tokens, max_seq_len,
|
||||
cache_dtype, q_dtype, is_prefill=False, has_sinks=has_sinks)
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@ -548,7 +560,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
q_data_type=self.vllm_config.model_config.dtype,
|
||||
q_data_type=q_dtype,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
max_q_len=max_q_len,
|
||||
max_seq_len=max_seq_len,
|
||||
@ -622,6 +634,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.window_left = (self.sliding_window[0]
|
||||
if self.sliding_window is not None else -1)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
@ -644,6 +658,19 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
self.sinks = sinks
|
||||
|
||||
self.support_trtllm_attn = (supports_trtllm_attention() and
|
||||
num_heads % num_kv_heads == 0)
|
||||
self.bmm1_scale: Optional[float] = None
|
||||
self.bmm2_scale: Optional[float] = None
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: GroupShape):
|
||||
supported_quant_type = (dtype == FP8_DTYPE and static and
|
||||
group_shape == GroupShape.PER_TENSOR)
|
||||
return (self.support_trtllm_attn
|
||||
and self.kv_cache_dtype.startswith("fp8")
|
||||
and supported_quant_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -672,15 +699,42 @@ class FlashInferImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashInferImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
|
||||
self.scale)
|
||||
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
# The attn+quant fusion happens when output_scale is provided.
|
||||
if output_scale is None:
|
||||
assert attn_metadata.q_data_type != FP8_DTYPE, \
|
||||
"Query can only be FP8 if output fusion happened."
|
||||
else:
|
||||
assert attn_metadata.q_data_type == FP8_DTYPE, \
|
||||
"Query must be FP8 when attn+quant fusion happened."
|
||||
assert (attn_metadata.prefill_use_trtllm and
|
||||
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
|
||||
assert output.dtype == FP8_DTYPE, \
|
||||
"Output must be FP8 when attn+quant fusion happened."
|
||||
|
||||
# TRTLLM attn kernel requires o scale as a host scalar, store the
|
||||
# o scale to host scalar in warmup run with cuda graph not enabled
|
||||
if layer._o_scale_float is None:
|
||||
layer._o_scale_float = output_scale.cpu().item()
|
||||
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
|
||||
|
||||
# Insert FP8 quant for query
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
@ -718,9 +772,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.kv_cache_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
window_left = (self.sliding_window[0]
|
||||
if self.sliding_window is not None else -1)
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
query = query[:num_actual_tokens]
|
||||
output_padded = output
|
||||
@ -748,7 +799,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
assert prefill_wrapper._causal
|
||||
assert prefill_wrapper._window_left == window_left
|
||||
assert prefill_wrapper._window_left == self.window_left
|
||||
assert prefill_wrapper._logits_soft_cap == (
|
||||
self.logits_soft_cap or 0.0)
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
@ -783,12 +834,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
seq_lens=seq_lens_prefill,
|
||||
max_q_len=attn_metadata.max_q_len,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
|
||||
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
||||
window_left=window_left,
|
||||
window_left=self.window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
@ -800,7 +851,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert decode_wrapper is not None
|
||||
|
||||
if not attn_metadata.decode_use_trtllm:
|
||||
assert decode_wrapper._window_left == window_left
|
||||
assert decode_wrapper._window_left == self.window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
@ -815,8 +866,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
# decode_query may be non-contiguous
|
||||
decode_query = decode_query.contiguous()
|
||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||
block_tables_decode = attn_metadata.block_table_tensor[:
|
||||
num_decode_tokens]
|
||||
block_tables_decode = attn_metadata.\
|
||||
block_table_tensor[:num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
@ -834,9 +885,9 @@ class FlashInferImpl(AttentionImpl):
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
window_left=window_left,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
window_left=self.window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
|
Reference in New Issue
Block a user