[Kernel] Add FP8 support with FlashMLA backend (#22668)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
@ -19,7 +19,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
|
||||
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/include)
|
||||
${flashmla_SOURCE_DIR}/csrc)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
|
@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
void gather_cache(
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
namespace vllm {
|
||||
|
||||
// grid is launched with dimensions (batch, num_splits)
|
||||
template <typename scalar_t>
|
||||
__global__ void gather_cache(
|
||||
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void gather_and_maybe_dequant_cache(
|
||||
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
// ENTRIES...]
|
||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
@ -634,6 +634,7 @@ __global__ void gather_cache(
|
||||
const int32_t block_size, const int32_t entry_size,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||
const float* __restrict__ scale,
|
||||
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
|
||||
// batch
|
||||
|
||||
@ -675,10 +676,16 @@ __global__ void gather_cache(
|
||||
if (partial_block_size) full_blocks_end -= 1;
|
||||
}
|
||||
|
||||
auto copy_entry = [&](const scalar_t* __restrict__ _src,
|
||||
auto copy_entry = [&](const cache_t* __restrict__ _src,
|
||||
scalar_t* __restrict__ _dst) {
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
|
||||
_dst[i] = _src[i];
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
_dst[i] = static_cast<scalar_t>(_src[i]);
|
||||
} else {
|
||||
_dst[i] =
|
||||
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int pid = split_start; pid < full_blocks_end; ++pid) {
|
||||
@ -705,25 +712,31 @@ __global__ void gather_cache(
|
||||
} // namespace vllm
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_GATHER_CACHE(CPY_DTYPE) \
|
||||
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
|
||||
// SCALAR_T is the data type of the destination tensor.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
|
||||
|
||||
// Gather sequences from the cache into the destination tensor.
|
||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||
// - block_table contains the cache block indices for each sequence
|
||||
// - Optionally, seq_starts (if provided) offsets the starting block index by
|
||||
// (seq_starts[bid] / page_size)
|
||||
void gather_cache(
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size,
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -761,20 +774,8 @@ void gather_cache(
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(1024);
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
|
||||
"src_cache and dst must have the same dtype");
|
||||
|
||||
const int dtype_bits = src_cache.element_size() * 8;
|
||||
const int32_t* seq_starts_ptr =
|
||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||
|
||||
if (dtype_bits == 32) {
|
||||
CALL_GATHER_CACHE(uint32_t);
|
||||
} else if (dtype_bits == 16) {
|
||||
CALL_GATHER_CACHE(uint16_t);
|
||||
} else if (dtype_bits == 8) {
|
||||
CALL_GATHER_CACHE(uint8_t);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
|
||||
}
|
||||
|
@ -672,11 +672,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"str kv_cache_dtype) -> ()");
|
||||
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
||||
|
||||
// Gather cache blocks from src_cache to dst.
|
||||
// Gather cache blocks from src_cache to dst, dequantizing from
|
||||
// src_cache's dtype to dst's dtype if necessary.
|
||||
cache_ops.def(
|
||||
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
|
||||
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
|
||||
" Tensor block_table, Tensor cu_seq_lens, "
|
||||
" int batch_size, "
|
||||
" str kv_cache_dtype, "
|
||||
" Tensor scale, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
|
||||
&gather_and_maybe_dequant_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
|
@ -709,14 +709,15 @@ def test_swap_blocks_mla(
|
||||
@pytest.mark.parametrize("max_seq_len", [512])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype",
|
||||
["auto"]) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
num_blocks, max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||
block_size, num_blocks,
|
||||
max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
perm = torch.randperm(num_blocks, device=device)
|
||||
block_table[b, :] = perm
|
||||
|
||||
dst = torch.zeros((total_tokens, entry_size),
|
||||
dtype=src_cache.dtype,
|
||||
device=device)
|
||||
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
|
||||
|
||||
expected_batches = []
|
||||
for b in range(batch_size):
|
||||
@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
|
||||
gathered_rows = []
|
||||
for i in range(tot - 1):
|
||||
gathered_rows.append(src_cache[blocks[i]])
|
||||
block_data = src_cache[blocks[i]]
|
||||
if kv_cache_dtype == "fp8":
|
||||
dequantized_block = torch.empty_like(block_data, dtype=dtype)
|
||||
ops.convert_fp8(dequantized_block, block_data, scale.item())
|
||||
gathered_rows.append(dequantized_block)
|
||||
else:
|
||||
gathered_rows.append(block_data)
|
||||
remaining = s - (tot - 1) * block_size
|
||||
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
|
||||
last_block_data = src_cache[blocks[-1], :remaining, :]
|
||||
if kv_cache_dtype == "fp8":
|
||||
dequantized_last_block = torch.empty_like(last_block_data,
|
||||
dtype=dtype)
|
||||
ops.convert_fp8(dequantized_last_block, last_block_data,
|
||||
scale.item())
|
||||
gathered_rows.append(dequantized_last_block)
|
||||
else:
|
||||
gathered_rows.append(last_block_data)
|
||||
|
||||
batch_expected = torch.cat(gathered_rows, dim=0)
|
||||
expected_batches.append(batch_expected)
|
||||
expected = torch.cat(expected_batches, dim=0)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.gather_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
|
||||
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
|
||||
scale, None),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
|
||||
cu_seq_lens, batch_size, kv_cache_dtype,
|
||||
scale, None)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
|
@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||
def cal_diff(x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
name: str,
|
||||
use_fp8: bool = False) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||
(x * x + y * y).sum().item(), 1e-12)
|
||||
assert cos_diff < 1e-5
|
||||
if (use_fp8):
|
||||
assert cos_diff < 1e-4
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
if not is_flashmla_supported()[0] else "FlashMLA is supported"
|
||||
@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
reason=FLASH_MLA_UNSUPPORTED_REASON)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1, 2])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("h_kv", [1])
|
||||
@pytest.mark.parametrize("d", [576])
|
||||
@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("torch_dtype",
|
||||
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
varlen, dtype):
|
||||
varlen, torch_dtype):
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
init_dtype = torch.bfloat16
|
||||
else:
|
||||
init_dtype = torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
|
||||
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||
|
||||
init_dtype = q.dtype
|
||||
if use_fp8:
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
descale_q = torch.ones((1), dtype=torch.float32)
|
||||
descale_k = torch.ones((1), dtype=torch.float32)
|
||||
|
||||
q = q.to(fp8_dtype)
|
||||
blocked_k = blocked_k.to(fp8_dtype)
|
||||
blocked_v = blocked_v.to(fp8_dtype)
|
||||
else:
|
||||
descale_q = None
|
||||
descale_k = None
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
return attn_weight @ value, lse
|
||||
|
||||
def ref_mla():
|
||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||
blocked_k_ = (blocked_k.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_k
|
||||
blocked_v_ = (blocked_v.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_v
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
ref_O, LSE = scaled_dot_product_attention(
|
||||
q[i].transpose(0, 1),
|
||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
out_i, lse_i = scaled_dot_product_attention(
|
||||
q_[i].transpose(0, 1),
|
||||
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = ref_O.transpose(0, 1)
|
||||
lse[i] = LSE
|
||||
out[i] = out_i.transpose(0, 1)
|
||||
lse[i] = lse_i
|
||||
return out, lse
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
cal_diff(out_flash, out_torch, "out")
|
||||
cal_diff(out_flash, out_torch, "out", use_fp8)
|
||||
cal_diff(lse_flash, lse_torch, "lse")
|
||||
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
|
||||
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
|
||||
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
|
||||
bytes = (total_seqlens * h_kv * d +
|
||||
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
|
||||
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
|
||||
f"{bytes / 10 ** 6 / t:.0f} GB/s")
|
||||
|
@ -1589,14 +1589,18 @@ def convert_fp8(output: torch.Tensor,
|
||||
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
|
||||
|
||||
|
||||
def gather_cache(src_cache: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cu_seq_lens: torch.Tensor,
|
||||
batch_size: int,
|
||||
seq_starts: Optional[torch.Tensor] = None) -> None:
|
||||
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
|
||||
cu_seq_lens, batch_size, seq_starts)
|
||||
def gather_and_maybe_dequant_cache(
|
||||
src_cache: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cu_seq_lens: torch.Tensor,
|
||||
batch_size: int,
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor,
|
||||
seq_starts: Optional[torch.Tensor] = None) -> None:
|
||||
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache(
|
||||
src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
|
||||
scale, seq_starts)
|
||||
|
||||
|
||||
def get_device_attribute(attribute: int, device: int) -> int:
|
||||
|
@ -837,8 +837,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
self.context_chunk_workspace_size // num_prefills_with_context
|
||||
|
||||
# align max_context_chunk to page_size by rounding down,
|
||||
# currently the `gather_cache` kernel cannot handle
|
||||
# `context_chunk_starts` that are not aligned to page_size
|
||||
# currently the `gather_and_maybe_dequant_cache` kernel cannot
|
||||
# handle `context_chunk_starts` that are not aligned to page_size
|
||||
max_context_chunk = round_down(max_context_chunk, self.page_size)
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
|
||||
@ -1082,6 +1082,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
):
|
||||
prefill_metadata = attn_metadata.prefill_metadata
|
||||
assert prefill_metadata is not None
|
||||
@ -1103,12 +1104,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.context_chunk_seq_tot[i]
|
||||
|
||||
ops.gather_cache(
|
||||
ops.gather_and_maybe_dequant_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_tables,
|
||||
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||
batch_size=prefill_metadata.num_prefills,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
seq_starts=prefill_metadata.context_chunk_starts[i],
|
||||
)
|
||||
|
||||
@ -1165,6 +1168,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
prefill_metadata = attn_metadata.prefill_metadata
|
||||
@ -1197,7 +1201,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
|
||||
suffix_output, suffix_lse = output
|
||||
context_output, context_lse = self._compute_prefill_context( \
|
||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||
|
||||
output = torch.empty_like(suffix_output)
|
||||
merge_attn_states(
|
||||
@ -1287,7 +1291,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
if has_prefill:
|
||||
output[:num_prefill_tokens] = self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
attn_metadata, layer._k_scale)
|
||||
|
||||
if has_decode:
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
|
@ -67,6 +67,8 @@ def flash_mla_with_kvcache(
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
descale_q: Optional[torch.Tensor] = None,
|
||||
descale_k: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
@ -81,6 +83,8 @@ def flash_mla_with_kvcache(
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
descale_q: (batch_size), torch.float32. Descaling factors for Q.
|
||||
descale_k: (batch_size), torch.float32. Descaling factors for K.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
@ -98,6 +102,8 @@ def flash_mla_with_kvcache(
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
descale_q,
|
||||
descale_k,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
||||
|
@ -1445,10 +1445,9 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Fp8 KV cache so far.
|
||||
if self.kv_cache_dtype != "auto":
|
||||
supported = current_platform.is_kv_cache_dtype_supported(
|
||||
self.kv_cache_dtype)
|
||||
self.kv_cache_dtype, model_config)
|
||||
if not supported:
|
||||
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
||||
recommend_to_remove=False)
|
||||
|
@ -481,16 +481,41 @@ class CudaPlatformBase(Platform):
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
||||
attention_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
|
||||
supported = False
|
||||
if cls.is_device_capability(100):
|
||||
supported = True
|
||||
elif fp8_attention and will_use_fa:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
||||
supported = flash_attn_supports_fp8()
|
||||
if model_config is not None and model_config.use_mla:
|
||||
# Default to CutlassMLA for blackwell,
|
||||
# FlashMLA otherwise
|
||||
if attention_backend is None:
|
||||
if cls.is_device_capability(100):
|
||||
attention_backend = "CUTLASS_MLA"
|
||||
else:
|
||||
attention_backend = "FLASHMLA"
|
||||
|
||||
# Only FlashMLA supports fp8
|
||||
if attention_backend == "FLASHMLA":
|
||||
supported = True
|
||||
else:
|
||||
supported = (not fp8_attention)
|
||||
else:
|
||||
# Default to FlashAttention
|
||||
if attention_backend is None:
|
||||
attention_backend = "FLASH_ATTN_VLLM_V1"
|
||||
|
||||
# All Blackwell backends support fp8
|
||||
if cls.is_device_capability(100):
|
||||
supported = True
|
||||
elif attention_backend == "FLASH_ATTN_VLLM_V1":
|
||||
if fp8_attention:
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_fp8)
|
||||
supported = flash_attn_supports_fp8()
|
||||
else:
|
||||
supported = True
|
||||
return supported
|
||||
|
||||
|
||||
|
@ -565,7 +565,8 @@ class Platform:
|
||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Returns if the kv_cache_dtype is supported by the current platform.
|
||||
"""
|
||||
|
@ -459,5 +459,6 @@ class RocmPlatform(Platform):
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
return True
|
||||
|
@ -196,7 +196,8 @@ class TpuPlatform(Platform):
|
||||
raise ValueError("Torch XLA does not support per-request seed.")
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
return True
|
||||
|
||||
|
||||
|
@ -631,8 +631,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
if self.aot_schedule:
|
||||
# align max_context_chunk to page_size by rounding down,
|
||||
# currently the `gather_cache` kernel cannot handle
|
||||
# `context_chunk_starts` that are not aligned to page_size
|
||||
# currently the `gather_and_maybe_dequant_cache` kernel
|
||||
# cannot handle `context_chunk_starts` that are not aligned
|
||||
# to page_size
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.page_size)
|
||||
|
||||
@ -1005,6 +1006,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
):
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
@ -1017,12 +1019,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
|
||||
ops.gather_cache(
|
||||
ops.gather_and_maybe_dequant_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
|
||||
@ -1073,6 +1077,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.prefill is not None
|
||||
|
||||
@ -1095,7 +1100,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
if has_context:
|
||||
suffix_output, suffix_lse = output
|
||||
context_output, context_lse = self._compute_prefill_context( \
|
||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||
|
||||
output = torch.empty_like(suffix_output)
|
||||
merge_attn_states(
|
||||
@ -1119,6 +1124,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1146,6 +1152,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
@ -1180,10 +1188,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if fp8_attention:
|
||||
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
if has_prefill:
|
||||
output[num_decode_tokens:] = self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
attn_metadata, layer._k_scale)
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
@ -1196,7 +1207,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
|
||||
if fp8_attention:
|
||||
ql_nope_shape = decode_ql_nope.shape
|
||||
decode_ql_nope, _ = ops.scaled_fp8_quant(
|
||||
decode_ql_nope.reshape([
|
||||
ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]
|
||||
]), layer._q_scale)
|
||||
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
|
||||
q_pe_shape = decode_q_pe.shape
|
||||
decode_q_pe, _ = ops.scaled_fp8_quant(
|
||||
decode_q_pe.reshape(
|
||||
[q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
|
||||
layer._q_scale)
|
||||
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
|
||||
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer)
|
||||
|
||||
return output_padded
|
||||
|
@ -7,7 +7,7 @@ from typing import ClassVar, Optional
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
@ -278,6 +278,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
if self._use_old_cutlass_mla:
|
||||
# TODO: Remove the old cutlass MLA kernel after more extensive
|
||||
|
@ -6,8 +6,7 @@ from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
@ -166,16 +165,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
@ -194,6 +190,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
|
@ -7,6 +7,7 @@ from typing import ClassVar, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv
|
||||
@ -221,6 +222,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
@ -6,7 +6,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
@ -127,6 +127,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
Reference in New Issue
Block a user