[Kernel] Add FP8 support with FlashMLA backend (#22668)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
Matthew Bonanni
2025-08-21 22:26:32 -04:00
committed by GitHub
parent 480bdf5a7b
commit 19fe1a0510
19 changed files with 235 additions and 109 deletions

View File

@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc/include)
${flashmla_SOURCE_DIR}/csrc)
set_gencode_flags_for_srcs(
SRCS "${FlashMLA_SOURCES}"

View File

@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
void gather_cache(
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
int64_t batch_size, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt);

View File

@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
template <typename scalar_t>
__global__ void gather_cache(
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void gather_and_maybe_dequant_cache(
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
@ -634,6 +634,7 @@ __global__ void gather_cache(
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
@ -675,10 +676,16 @@ __global__ void gather_cache(
if (partial_block_size) full_blocks_end -= 1;
}
auto copy_entry = [&](const scalar_t* __restrict__ _src,
auto copy_entry = [&](const cache_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
_dst[i] = _src[i];
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
_dst[i] = static_cast<scalar_t>(_src[i]);
} else {
_dst[i] =
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
}
}
};
for (int pid = split_start; pid < full_blocks_end; ++pid) {
@ -705,25 +712,31 @@ __global__ void gather_cache(
} // namespace vllm
// Macro to dispatch the kernel based on the data type.
#define CALL_GATHER_CACHE(CPY_DTYPE) \
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, \
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_cache(
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
int64_t batch_size, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -761,20 +774,8 @@ void gather_cache(
dim3 grid(batch_size, num_splits);
dim3 block(1024);
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
"src_cache and dst must have the same dtype");
const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
if (dtype_bits == 32) {
CALL_GATHER_CACHE(uint32_t);
} else if (dtype_bits == 16) {
CALL_GATHER_CACHE(uint16_t);
} else if (dtype_bits == 8) {
CALL_GATHER_CACHE(uint8_t);
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
}

View File

@ -672,11 +672,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
// Gather cache blocks from src_cache to dst.
// Gather cache blocks from src_cache to dst, dequantizing from
// src_cache's dtype to dst's dtype if necessary.
cache_ops.def(
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
" int batch_size, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
&gather_and_maybe_dequant_cache);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {

View File

@ -709,14 +709,15 @@ def test_swap_blocks_mla(
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
block_size, num_blocks,
max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm
dst = torch.zeros((total_tokens, entry_size),
dtype=src_cache.dtype,
device=device)
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
expected_batches = []
for b in range(batch_size):
@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
gathered_rows = []
for i in range(tot - 1):
gathered_rows.append(src_cache[blocks[i]])
block_data = src_cache[blocks[i]]
if kv_cache_dtype == "fp8":
dequantized_block = torch.empty_like(block_data, dtype=dtype)
ops.convert_fp8(dequantized_block, block_data, scale.item())
gathered_rows.append(dequantized_block)
else:
gathered_rows.append(block_data)
remaining = s - (tot - 1) * block_size
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
last_block_data = src_cache[blocks[-1], :remaining, :]
if kv_cache_dtype == "fp8":
dequantized_last_block = torch.empty_like(last_block_data,
dtype=dtype)
ops.convert_fp8(dequantized_last_block, last_block_data,
scale.item())
gathered_rows.append(dequantized_last_block)
else:
gathered_rows.append(last_block_data)
batch_expected = torch.cat(gathered_rows, dim=0)
expected_batches.append(batch_expected)
expected = torch.cat(expected_batches, dim=0)
opcheck(
torch.ops._C_cache_ops.gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, kv_cache_dtype,
scale, None)
torch.testing.assert_close(dst, expected)

View File

@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
from vllm.triton_utils import triton
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
def cal_diff(x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12)
assert cos_diff < 1e-5
if (use_fp8):
assert cos_diff < 1e-4
else:
assert cos_diff < 1e-5
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
if not is_flashmla_supported()[0] else "FlashMLA is supported"
@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
reason=FLASH_MLA_UNSUPPORTED_REASON)
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1, 2])
@pytest.mark.parametrize("mean_sk", [4096, 8192])
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
@pytest.mark.parametrize("h_kv", [1])
@pytest.mark.parametrize("d", [576])
@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("torch_dtype",
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen, dtype):
varlen, torch_dtype):
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
use_fp8 = torch_dtype == torch.float8_e4m3fn
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv)
init_dtype = q.dtype
if use_fp8:
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
q = q.to(fp8_dtype)
blocked_k = blocked_k.to(fp8_dtype)
blocked_v = blocked_v.to(fp8_dtype)
else:
descale_q = None
descale_k = None
def flash_mla():
return flash_mla_with_kvcache(
q,
@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k,
)
def scaled_dot_product_attention(query, key, value, is_causal=False):
@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
return attn_weight @ value, lse
def ref_mla():
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_v
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
ref_O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
out_i, lse_i = scaled_dot_product_attention(
q_[i].transpose(0, 1),
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
is_causal=causal,
)
out[i] = ref_O.transpose(0, 1)
lse[i] = LSE
out[i] = out_i.transpose(0, 1)
lse[i] = lse_i
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out")
cal_diff(out_flash, out_torch, "out", use_fp8)
cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
bytes = (total_seqlens * h_kv * d +
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
f"{bytes / 10 ** 6 / t:.0f} GB/s")

View File

@ -1589,14 +1589,18 @@ def convert_fp8(output: torch.Tensor,
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
def gather_cache(src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
def gather_and_maybe_dequant_cache(
src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
kv_cache_dtype: str,
scale: torch.Tensor,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache(
src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, seq_starts)
def get_device_attribute(attribute: int, device: int) -> int:

View File

@ -837,8 +837,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
self.context_chunk_workspace_size // num_prefills_with_context
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
# currently the `gather_and_maybe_dequant_cache` kernel cannot
# handle `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk, self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
@ -1082,6 +1082,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
):
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
@ -1103,12 +1104,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
for i in range(iters):
toks = prefill_metadata.context_chunk_seq_tot[i]
ops.gather_cache(
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_tables,
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
batch_size=prefill_metadata.num_prefills,
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.context_chunk_starts[i],
)
@ -1165,6 +1168,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
@ -1197,7 +1201,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
output = torch.empty_like(suffix_output)
merge_attn_states(
@ -1287,7 +1291,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if has_prefill:
output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
attn_metadata, layer._k_scale)
if has_decode:
decode_q_nope, decode_q_pe = decode_q.split(

View File

@ -67,6 +67,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
@ -81,6 +83,8 @@ def flash_mla_with_kvcache(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q.
descale_k: (batch_size), torch.float32. Descaling factors for K.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
@ -98,6 +102,8 @@ def flash_mla_with_kvcache(
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse

View File

@ -1445,10 +1445,9 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto":
supported = current_platform.is_kv_cache_dtype_supported(
self.kv_cache_dtype)
self.kv_cache_dtype, model_config)
if not supported:
_raise_or_fallback(feature_name="--kv-cache-dtype",
recommend_to_remove=False)

View File

@ -481,16 +481,41 @@ class CudaPlatformBase(Platform):
return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
model_config: "ModelConfig") -> bool:
fp8_attention = kv_cache_dtype.startswith("fp8")
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
attention_backend = envs.VLLM_ATTENTION_BACKEND
supported = False
if cls.is_device_capability(100):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
supported = flash_attn_supports_fp8()
if model_config is not None and model_config.use_mla:
# Default to CutlassMLA for blackwell,
# FlashMLA otherwise
if attention_backend is None:
if cls.is_device_capability(100):
attention_backend = "CUTLASS_MLA"
else:
attention_backend = "FLASHMLA"
# Only FlashMLA supports fp8
if attention_backend == "FLASHMLA":
supported = True
else:
supported = (not fp8_attention)
else:
# Default to FlashAttention
if attention_backend is None:
attention_backend = "FLASH_ATTN_VLLM_V1"
# All Blackwell backends support fp8
if cls.is_device_capability(100):
supported = True
elif attention_backend == "FLASH_ATTN_VLLM_V1":
if fp8_attention:
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
else:
supported = True
return supported

View File

@ -565,7 +565,8 @@ class Platform:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
model_config: "ModelConfig") -> bool:
"""
Returns if the kv_cache_dtype is supported by the current platform.
"""

View File

@ -459,5 +459,6 @@ class RocmPlatform(Platform):
return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
model_config: "ModelConfig") -> bool:
return True

View File

@ -196,7 +196,8 @@ class TpuPlatform(Platform):
raise ValueError("Torch XLA does not support per-request seed.")
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
model_config: "ModelConfig") -> bool:
return True

View File

@ -631,8 +631,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if self.aot_schedule:
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
# currently the `gather_and_maybe_dequant_cache` kernel
# cannot handle `context_chunk_starts` that are not aligned
# to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)
@ -1005,6 +1006,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
@ -1017,12 +1019,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.gather_cache(
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
@ -1073,6 +1077,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
@ -1095,7 +1100,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_context:
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
output = torch.empty_like(suffix_output)
merge_attn_states(
@ -1119,6 +1124,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M,
layer: AttentionLayer,
) -> torch.Tensor:
raise NotImplementedError
@ -1146,6 +1152,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# same expert outputs.
return output.fill_(0)
fp8_attention = self.kv_cache_dtype.startswith("fp8")
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
@ -1180,10 +1188,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=layer._k_scale,
)
if fp8_attention:
kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
attn_metadata, layer._k_scale)
if has_decode:
assert attn_metadata.decode is not None
@ -1196,7 +1207,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
if fp8_attention:
ql_nope_shape = decode_ql_nope.shape
decode_ql_nope, _ = ops.scaled_fp8_quant(
decode_ql_nope.reshape([
ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]
]), layer._q_scale)
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
q_pe_shape = decode_q_pe.shape
decode_q_pe, _ = ops.scaled_fp8_quant(
decode_q_pe.reshape(
[q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
layer._q_scale)
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer)
return output_padded

View File

@ -7,7 +7,7 @@ from typing import ClassVar, Optional
import torch
import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (AttentionType,
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
@ -278,6 +278,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
if self._use_old_cutlass_mla:
# TODO: Remove the old cutlass MLA kernel after more extensive

View File

@ -6,8 +6,7 @@ from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
@ -166,16 +165,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"are not implemented for "
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
@ -194,6 +190,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
return self._v_up_proj(o)

View File

@ -7,6 +7,7 @@ from typing import ClassVar, Optional
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
from vllm.config import VllmConfig
from vllm.utils import cdiv
@ -221,6 +222,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

View File

@ -6,7 +6,7 @@ from typing import Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionType,
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention
@ -127,6 +127,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None