[Kernel] Add FP8 support with FlashMLA backend (#22668)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
@ -19,7 +19,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
flashmla
|
flashmla
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||||
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
|
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
BUILD_COMMAND ""
|
BUILD_COMMAND ""
|
||||||
@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
|||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||||
set(FlashMLA_SOURCES
|
set(FlashMLA_SOURCES
|
||||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
|
||||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
|
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
|
||||||
|
|
||||||
set(FlashMLA_INCLUDES
|
set(FlashMLA_INCLUDES
|
||||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||||
${flashmla_SOURCE_DIR}/csrc/include)
|
${flashmla_SOURCE_DIR}/csrc)
|
||||||
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${FlashMLA_SOURCES}"
|
SRCS "${FlashMLA_SOURCES}"
|
||||||
|
@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
|||||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
const double scale, const std::string& kv_cache_dtype);
|
const double scale, const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
void gather_cache(
|
void gather_and_maybe_dequant_cache(
|
||||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||||
|
torch::Tensor const& scale,
|
||||||
|
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// grid is launched with dimensions (batch, num_splits)
|
// grid is launched with dimensions (batch, num_splits)
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void gather_cache(
|
__global__ void gather_and_maybe_dequant_cache(
|
||||||
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||||
// ENTRIES...]
|
// ENTRIES...]
|
||||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
||||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||||
@ -634,6 +634,7 @@ __global__ void gather_cache(
|
|||||||
const int32_t block_size, const int32_t entry_size,
|
const int32_t block_size, const int32_t entry_size,
|
||||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||||
|
const float* __restrict__ scale,
|
||||||
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
|
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
|
||||||
// batch
|
// batch
|
||||||
|
|
||||||
@ -675,10 +676,16 @@ __global__ void gather_cache(
|
|||||||
if (partial_block_size) full_blocks_end -= 1;
|
if (partial_block_size) full_blocks_end -= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto copy_entry = [&](const scalar_t* __restrict__ _src,
|
auto copy_entry = [&](const cache_t* __restrict__ _src,
|
||||||
scalar_t* __restrict__ _dst) {
|
scalar_t* __restrict__ _dst) {
|
||||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
|
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
|
||||||
_dst[i] = _src[i];
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||||
|
_dst[i] = static_cast<scalar_t>(_src[i]);
|
||||||
|
} else {
|
||||||
|
_dst[i] =
|
||||||
|
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for (int pid = split_start; pid < full_blocks_end; ++pid) {
|
for (int pid = split_start; pid < full_blocks_end; ++pid) {
|
||||||
@ -705,25 +712,31 @@ __global__ void gather_cache(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// Macro to dispatch the kernel based on the data type.
|
// Macro to dispatch the kernel based on the data type.
|
||||||
#define CALL_GATHER_CACHE(CPY_DTYPE) \
|
// SCALAR_T is the data type of the destination tensor.
|
||||||
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
|
// CACHE_T is the stored data type of kv-cache.
|
||||||
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
|
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
|
||||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
|
||||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
<<<grid, block, 0, stream>>>( \
|
||||||
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
|
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
|
||||||
|
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
|
||||||
|
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||||
|
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||||
|
cache_entry_stride, dst_entry_stride, \
|
||||||
|
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
|
||||||
|
|
||||||
// Gather sequences from the cache into the destination tensor.
|
// Gather sequences from the cache into the destination tensor.
|
||||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||||
// - block_table contains the cache block indices for each sequence
|
// - block_table contains the cache block indices for each sequence
|
||||||
// - Optionally, seq_starts (if provided) offsets the starting block index by
|
// - Optionally, seq_starts (if provided) offsets the starting block index by
|
||||||
// (seq_starts[bid] / page_size)
|
// (seq_starts[bid] / page_size)
|
||||||
void gather_cache(
|
void gather_and_maybe_dequant_cache(
|
||||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||||
int64_t batch_size,
|
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||||
|
torch::Tensor const& scale,
|
||||||
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
@ -761,20 +774,8 @@ void gather_cache(
|
|||||||
dim3 grid(batch_size, num_splits);
|
dim3 grid(batch_size, num_splits);
|
||||||
dim3 block(1024);
|
dim3 block(1024);
|
||||||
|
|
||||||
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
|
|
||||||
"src_cache and dst must have the same dtype");
|
|
||||||
|
|
||||||
const int dtype_bits = src_cache.element_size() * 8;
|
|
||||||
const int32_t* seq_starts_ptr =
|
const int32_t* seq_starts_ptr =
|
||||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||||
|
|
||||||
if (dtype_bits == 32) {
|
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
|
||||||
CALL_GATHER_CACHE(uint32_t);
|
|
||||||
} else if (dtype_bits == 16) {
|
|
||||||
CALL_GATHER_CACHE(uint16_t);
|
|
||||||
} else if (dtype_bits == 8) {
|
|
||||||
CALL_GATHER_CACHE(uint8_t);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -672,11 +672,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
"str kv_cache_dtype) -> ()");
|
"str kv_cache_dtype) -> ()");
|
||||||
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
||||||
|
|
||||||
// Gather cache blocks from src_cache to dst.
|
// Gather cache blocks from src_cache to dst, dequantizing from
|
||||||
|
// src_cache's dtype to dst's dtype if necessary.
|
||||||
cache_ops.def(
|
cache_ops.def(
|
||||||
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
|
||||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
" Tensor block_table, Tensor cu_seq_lens, "
|
||||||
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
|
" int batch_size, "
|
||||||
|
" str kv_cache_dtype, "
|
||||||
|
" Tensor scale, Tensor? seq_starts) -> ()");
|
||||||
|
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
|
||||||
|
&gather_and_maybe_dequant_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||||
|
@ -709,14 +709,15 @@ def test_swap_blocks_mla(
|
|||||||
@pytest.mark.parametrize("max_seq_len", [512])
|
@pytest.mark.parametrize("max_seq_len", [512])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||||
@pytest.mark.parametrize("kv_cache_dtype",
|
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||||
["auto"]) # You can also test "fp8" if needed.
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||||
num_blocks, max_seq_len, batch_size, dtype,
|
block_size, num_blocks,
|
||||||
kv_cache_dtype, device):
|
max_seq_len, batch_size, dtype,
|
||||||
|
kv_cache_dtype, device):
|
||||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||||
|
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||||
kv_cache_dtype, device)
|
kv_cache_dtype, device)
|
||||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||||
@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
|||||||
perm = torch.randperm(num_blocks, device=device)
|
perm = torch.randperm(num_blocks, device=device)
|
||||||
block_table[b, :] = perm
|
block_table[b, :] = perm
|
||||||
|
|
||||||
dst = torch.zeros((total_tokens, entry_size),
|
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
|
||||||
dtype=src_cache.dtype,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
expected_batches = []
|
expected_batches = []
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
|||||||
|
|
||||||
gathered_rows = []
|
gathered_rows = []
|
||||||
for i in range(tot - 1):
|
for i in range(tot - 1):
|
||||||
gathered_rows.append(src_cache[blocks[i]])
|
block_data = src_cache[blocks[i]]
|
||||||
|
if kv_cache_dtype == "fp8":
|
||||||
|
dequantized_block = torch.empty_like(block_data, dtype=dtype)
|
||||||
|
ops.convert_fp8(dequantized_block, block_data, scale.item())
|
||||||
|
gathered_rows.append(dequantized_block)
|
||||||
|
else:
|
||||||
|
gathered_rows.append(block_data)
|
||||||
remaining = s - (tot - 1) * block_size
|
remaining = s - (tot - 1) * block_size
|
||||||
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
|
last_block_data = src_cache[blocks[-1], :remaining, :]
|
||||||
|
if kv_cache_dtype == "fp8":
|
||||||
|
dequantized_last_block = torch.empty_like(last_block_data,
|
||||||
|
dtype=dtype)
|
||||||
|
ops.convert_fp8(dequantized_last_block, last_block_data,
|
||||||
|
scale.item())
|
||||||
|
gathered_rows.append(dequantized_last_block)
|
||||||
|
else:
|
||||||
|
gathered_rows.append(last_block_data)
|
||||||
|
|
||||||
batch_expected = torch.cat(gathered_rows, dim=0)
|
batch_expected = torch.cat(gathered_rows, dim=0)
|
||||||
expected_batches.append(batch_expected)
|
expected_batches.append(batch_expected)
|
||||||
expected = torch.cat(expected_batches, dim=0)
|
expected = torch.cat(expected_batches, dim=0)
|
||||||
|
|
||||||
opcheck(
|
opcheck(
|
||||||
torch.ops._C_cache_ops.gather_cache,
|
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
|
||||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
|
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
|
||||||
|
scale, None),
|
||||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||||
)
|
)
|
||||||
|
|
||||||
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
|
||||||
|
cu_seq_lens, batch_size, kv_cache_dtype,
|
||||||
|
scale, None)
|
||||||
torch.testing.assert_close(dst, expected)
|
torch.testing.assert_close(dst, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
|||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
|
||||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
def cal_diff(x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
name: str,
|
||||||
|
use_fp8: bool = False) -> None:
|
||||||
x, y = x.double(), y.double()
|
x, y = x.double(), y.double()
|
||||||
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||||
(x * x + y * y).sum().item(), 1e-12)
|
(x * x + y * y).sum().item(), 1e-12)
|
||||||
assert cos_diff < 1e-5
|
if (use_fp8):
|
||||||
|
assert cos_diff < 1e-4
|
||||||
|
else:
|
||||||
|
assert cos_diff < 1e-5
|
||||||
|
|
||||||
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||||
if not is_flashmla_supported()[0] else "FlashMLA is supported"
|
if not is_flashmla_supported()[0] else "FlashMLA is supported"
|
||||||
@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
|||||||
reason=FLASH_MLA_UNSUPPORTED_REASON)
|
reason=FLASH_MLA_UNSUPPORTED_REASON)
|
||||||
@pytest.mark.parametrize("b", [128])
|
@pytest.mark.parametrize("b", [128])
|
||||||
@pytest.mark.parametrize("s_q", [1, 2])
|
@pytest.mark.parametrize("s_q", [1, 2])
|
||||||
@pytest.mark.parametrize("mean_sk", [4096, 8192])
|
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||||
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
||||||
@pytest.mark.parametrize("h_kv", [1])
|
@pytest.mark.parametrize("h_kv", [1])
|
||||||
@pytest.mark.parametrize("d", [576])
|
@pytest.mark.parametrize("d", [576])
|
||||||
@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
|||||||
@pytest.mark.parametrize("block_size", [64])
|
@pytest.mark.parametrize("block_size", [64])
|
||||||
@pytest.mark.parametrize("causal", [True])
|
@pytest.mark.parametrize("causal", [True])
|
||||||
@pytest.mark.parametrize("varlen", [False, True])
|
@pytest.mark.parametrize("varlen", [False, True])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("torch_dtype",
|
||||||
|
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||||
varlen, dtype):
|
varlen, torch_dtype):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
torch.set_default_dtype(dtype)
|
if torch_dtype == torch.float8_e4m3fn:
|
||||||
|
init_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
init_dtype = torch_dtype
|
||||||
|
torch.set_default_dtype(init_dtype)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
|
|
||||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
|
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
|
||||||
|
|
||||||
|
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||||
if varlen:
|
if varlen:
|
||||||
for i in range(b):
|
for i in range(b):
|
||||||
@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
|||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||||
cache_seqlens, s_q * h_q // h_kv, h_kv)
|
cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||||
|
|
||||||
|
init_dtype = q.dtype
|
||||||
|
if use_fp8:
|
||||||
|
fp8_dtype = torch.float8_e4m3fn
|
||||||
|
descale_q = torch.ones((1), dtype=torch.float32)
|
||||||
|
descale_k = torch.ones((1), dtype=torch.float32)
|
||||||
|
|
||||||
|
q = q.to(fp8_dtype)
|
||||||
|
blocked_k = blocked_k.to(fp8_dtype)
|
||||||
|
blocked_v = blocked_v.to(fp8_dtype)
|
||||||
|
else:
|
||||||
|
descale_q = None
|
||||||
|
descale_k = None
|
||||||
|
|
||||||
def flash_mla():
|
def flash_mla():
|
||||||
return flash_mla_with_kvcache(
|
return flash_mla_with_kvcache(
|
||||||
q,
|
q,
|
||||||
@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
|||||||
tile_scheduler_metadata,
|
tile_scheduler_metadata,
|
||||||
num_splits,
|
num_splits,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
|
descale_q=descale_q,
|
||||||
|
descale_k=descale_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||||
@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
|||||||
return attn_weight @ value, lse
|
return attn_weight @ value, lse
|
||||||
|
|
||||||
def ref_mla():
|
def ref_mla():
|
||||||
|
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||||
|
blocked_k_ = (blocked_k.to(torch.float) *
|
||||||
|
descale_k).to(init_dtype) if use_fp8 else blocked_k
|
||||||
|
blocked_v_ = (blocked_v.to(torch.float) *
|
||||||
|
descale_k).to(init_dtype) if use_fp8 else blocked_v
|
||||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||||
for i in range(b):
|
for i in range(b):
|
||||||
begin = i * max_seqlen_pad
|
begin = i * max_seqlen_pad
|
||||||
end = begin + cache_seqlens[i]
|
end = begin + cache_seqlens[i]
|
||||||
ref_O, LSE = scaled_dot_product_attention(
|
out_i, lse_i = scaled_dot_product_attention(
|
||||||
q[i].transpose(0, 1),
|
q_[i].transpose(0, 1),
|
||||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||||
is_causal=causal,
|
is_causal=causal,
|
||||||
)
|
)
|
||||||
out[i] = ref_O.transpose(0, 1)
|
out[i] = out_i.transpose(0, 1)
|
||||||
lse[i] = LSE
|
lse[i] = lse_i
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
out_flash, lse_flash = flash_mla()
|
out_flash, lse_flash = flash_mla()
|
||||||
out_torch, lse_torch = ref_mla()
|
out_torch, lse_torch = ref_mla()
|
||||||
cal_diff(out_flash, out_torch, "out")
|
cal_diff(out_flash, out_torch, "out", use_fp8)
|
||||||
cal_diff(lse_flash, lse_torch, "lse")
|
cal_diff(lse_flash, lse_torch, "lse")
|
||||||
|
|
||||||
t = triton.testing.do_bench(flash_mla)
|
t = triton.testing.do_bench(flash_mla)
|
||||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
|
bytes = (total_seqlens * h_kv * d +
|
||||||
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
|
||||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
|
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||||
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
|
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
|
||||||
|
f"{bytes / 10 ** 6 / t:.0f} GB/s")
|
||||||
|
@ -1589,14 +1589,18 @@ def convert_fp8(output: torch.Tensor,
|
|||||||
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
|
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
|
||||||
|
|
||||||
|
|
||||||
def gather_cache(src_cache: torch.Tensor,
|
def gather_and_maybe_dequant_cache(
|
||||||
dst: torch.Tensor,
|
src_cache: torch.Tensor,
|
||||||
block_table: torch.Tensor,
|
dst: torch.Tensor,
|
||||||
cu_seq_lens: torch.Tensor,
|
block_table: torch.Tensor,
|
||||||
batch_size: int,
|
cu_seq_lens: torch.Tensor,
|
||||||
seq_starts: Optional[torch.Tensor] = None) -> None:
|
batch_size: int,
|
||||||
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
|
kv_cache_dtype: str,
|
||||||
cu_seq_lens, batch_size, seq_starts)
|
scale: torch.Tensor,
|
||||||
|
seq_starts: Optional[torch.Tensor] = None) -> None:
|
||||||
|
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache(
|
||||||
|
src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
|
||||||
|
scale, seq_starts)
|
||||||
|
|
||||||
|
|
||||||
def get_device_attribute(attribute: int, device: int) -> int:
|
def get_device_attribute(attribute: int, device: int) -> int:
|
||||||
|
@ -837,8 +837,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
|||||||
self.context_chunk_workspace_size // num_prefills_with_context
|
self.context_chunk_workspace_size // num_prefills_with_context
|
||||||
|
|
||||||
# align max_context_chunk to page_size by rounding down,
|
# align max_context_chunk to page_size by rounding down,
|
||||||
# currently the `gather_cache` kernel cannot handle
|
# currently the `gather_and_maybe_dequant_cache` kernel cannot
|
||||||
# `context_chunk_starts` that are not aligned to page_size
|
# handle `context_chunk_starts` that are not aligned to page_size
|
||||||
max_context_chunk = round_down(max_context_chunk, self.page_size)
|
max_context_chunk = round_down(max_context_chunk, self.page_size)
|
||||||
assert max_context_chunk > 0
|
assert max_context_chunk > 0
|
||||||
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
|
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
|
||||||
@ -1082,6 +1082,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
):
|
):
|
||||||
prefill_metadata = attn_metadata.prefill_metadata
|
prefill_metadata = attn_metadata.prefill_metadata
|
||||||
assert prefill_metadata is not None
|
assert prefill_metadata is not None
|
||||||
@ -1103,12 +1104,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
toks = prefill_metadata.context_chunk_seq_tot[i]
|
toks = prefill_metadata.context_chunk_seq_tot[i]
|
||||||
|
|
||||||
ops.gather_cache(
|
ops.gather_and_maybe_dequant_cache(
|
||||||
src_cache=kv_c_and_k_pe_cache,
|
src_cache=kv_c_and_k_pe_cache,
|
||||||
dst=workspace,
|
dst=workspace,
|
||||||
block_table=prefill_metadata.block_tables,
|
block_table=prefill_metadata.block_tables,
|
||||||
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
|
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||||
batch_size=prefill_metadata.num_prefills,
|
batch_size=prefill_metadata.num_prefills,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
scale=k_scale,
|
||||||
seq_starts=prefill_metadata.context_chunk_starts[i],
|
seq_starts=prefill_metadata.context_chunk_starts[i],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1165,6 +1168,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
prefill_metadata = attn_metadata.prefill_metadata
|
prefill_metadata = attn_metadata.prefill_metadata
|
||||||
@ -1197,7 +1201,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
|
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
|
||||||
suffix_output, suffix_lse = output
|
suffix_output, suffix_lse = output
|
||||||
context_output, context_lse = self._compute_prefill_context( \
|
context_output, context_lse = self._compute_prefill_context( \
|
||||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||||
|
|
||||||
output = torch.empty_like(suffix_output)
|
output = torch.empty_like(suffix_output)
|
||||||
merge_attn_states(
|
merge_attn_states(
|
||||||
@ -1287,7 +1291,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
if has_prefill:
|
if has_prefill:
|
||||||
output[:num_prefill_tokens] = self._forward_prefill(
|
output[:num_prefill_tokens] = self._forward_prefill(
|
||||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||||
attn_metadata)
|
attn_metadata, layer._k_scale)
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_q_nope, decode_q_pe = decode_q.split(
|
decode_q_nope, decode_q_pe = decode_q.split(
|
||||||
|
@ -67,6 +67,8 @@ def flash_mla_with_kvcache(
|
|||||||
num_splits: torch.Tensor,
|
num_splits: torch.Tensor,
|
||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
|
descale_q: Optional[torch.Tensor] = None,
|
||||||
|
descale_k: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -81,6 +83,8 @@ def flash_mla_with_kvcache(
|
|||||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||||
Default to 1 / sqrt(head_dim).
|
Default to 1 / sqrt(head_dim).
|
||||||
causal: bool. Whether to apply causal attention mask.
|
causal: bool. Whether to apply causal attention mask.
|
||||||
|
descale_q: (batch_size), torch.float32. Descaling factors for Q.
|
||||||
|
descale_k: (batch_size), torch.float32. Descaling factors for K.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||||
@ -98,6 +102,8 @@ def flash_mla_with_kvcache(
|
|||||||
causal,
|
causal,
|
||||||
tile_scheduler_metadata,
|
tile_scheduler_metadata,
|
||||||
num_splits,
|
num_splits,
|
||||||
|
descale_q,
|
||||||
|
descale_k,
|
||||||
)
|
)
|
||||||
return out, softmax_lse
|
return out, softmax_lse
|
||||||
|
|
||||||
|
@ -1445,10 +1445,9 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# No Fp8 KV cache so far.
|
|
||||||
if self.kv_cache_dtype != "auto":
|
if self.kv_cache_dtype != "auto":
|
||||||
supported = current_platform.is_kv_cache_dtype_supported(
|
supported = current_platform.is_kv_cache_dtype_supported(
|
||||||
self.kv_cache_dtype)
|
self.kv_cache_dtype, model_config)
|
||||||
if not supported:
|
if not supported:
|
||||||
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
|
@ -481,16 +481,41 @@ class CudaPlatformBase(Platform):
|
|||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||||
|
model_config: "ModelConfig") -> bool:
|
||||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||||
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
|
attention_backend = envs.VLLM_ATTENTION_BACKEND
|
||||||
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
|
||||||
supported = False
|
supported = False
|
||||||
if cls.is_device_capability(100):
|
if model_config is not None and model_config.use_mla:
|
||||||
supported = True
|
# Default to CutlassMLA for blackwell,
|
||||||
elif fp8_attention and will_use_fa:
|
# FlashMLA otherwise
|
||||||
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
if attention_backend is None:
|
||||||
supported = flash_attn_supports_fp8()
|
if cls.is_device_capability(100):
|
||||||
|
attention_backend = "CUTLASS_MLA"
|
||||||
|
else:
|
||||||
|
attention_backend = "FLASHMLA"
|
||||||
|
|
||||||
|
# Only FlashMLA supports fp8
|
||||||
|
if attention_backend == "FLASHMLA":
|
||||||
|
supported = True
|
||||||
|
else:
|
||||||
|
supported = (not fp8_attention)
|
||||||
|
else:
|
||||||
|
# Default to FlashAttention
|
||||||
|
if attention_backend is None:
|
||||||
|
attention_backend = "FLASH_ATTN_VLLM_V1"
|
||||||
|
|
||||||
|
# All Blackwell backends support fp8
|
||||||
|
if cls.is_device_capability(100):
|
||||||
|
supported = True
|
||||||
|
elif attention_backend == "FLASH_ATTN_VLLM_V1":
|
||||||
|
if fp8_attention:
|
||||||
|
from vllm.attention.utils.fa_utils import (
|
||||||
|
flash_attn_supports_fp8)
|
||||||
|
supported = flash_attn_supports_fp8()
|
||||||
|
else:
|
||||||
|
supported = True
|
||||||
return supported
|
return supported
|
||||||
|
|
||||||
|
|
||||||
|
@ -565,7 +565,8 @@ class Platform:
|
|||||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||||
|
model_config: "ModelConfig") -> bool:
|
||||||
"""
|
"""
|
||||||
Returns if the kv_cache_dtype is supported by the current platform.
|
Returns if the kv_cache_dtype is supported by the current platform.
|
||||||
"""
|
"""
|
||||||
|
@ -459,5 +459,6 @@ class RocmPlatform(Platform):
|
|||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||||
|
model_config: "ModelConfig") -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -196,7 +196,8 @@ class TpuPlatform(Platform):
|
|||||||
raise ValueError("Torch XLA does not support per-request seed.")
|
raise ValueError("Torch XLA does not support per-request seed.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||||
|
model_config: "ModelConfig") -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -631,8 +631,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
if self.aot_schedule:
|
if self.aot_schedule:
|
||||||
# align max_context_chunk to page_size by rounding down,
|
# align max_context_chunk to page_size by rounding down,
|
||||||
# currently the `gather_cache` kernel cannot handle
|
# currently the `gather_and_maybe_dequant_cache` kernel
|
||||||
# `context_chunk_starts` that are not aligned to page_size
|
# cannot handle `context_chunk_starts` that are not aligned
|
||||||
|
# to page_size
|
||||||
max_context_chunk = round_down(max_context_chunk,
|
max_context_chunk = round_down(max_context_chunk,
|
||||||
self.page_size)
|
self.page_size)
|
||||||
|
|
||||||
@ -1005,6 +1006,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
):
|
):
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
prefill_metadata = attn_metadata.prefill
|
prefill_metadata = attn_metadata.prefill
|
||||||
@ -1017,12 +1019,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
|
|
||||||
ops.gather_cache(
|
ops.gather_and_maybe_dequant_cache(
|
||||||
src_cache=kv_c_and_k_pe_cache,
|
src_cache=kv_c_and_k_pe_cache,
|
||||||
dst=workspace,
|
dst=workspace,
|
||||||
block_table=prefill_metadata.block_table,
|
block_table=prefill_metadata.block_table,
|
||||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||||
batch_size=attn_metadata.num_prefills,
|
batch_size=attn_metadata.num_prefills,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
scale=k_scale,
|
||||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1073,6 +1077,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
|
|
||||||
@ -1095,7 +1100,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
if has_context:
|
if has_context:
|
||||||
suffix_output, suffix_lse = output
|
suffix_output, suffix_lse = output
|
||||||
context_output, context_lse = self._compute_prefill_context( \
|
context_output, context_lse = self._compute_prefill_context( \
|
||||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||||
|
|
||||||
output = torch.empty_like(suffix_output)
|
output = torch.empty_like(suffix_output)
|
||||||
merge_attn_states(
|
merge_attn_states(
|
||||||
@ -1119,6 +1124,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
q_pe: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: M,
|
attn_metadata: M,
|
||||||
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -1146,6 +1152,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
# same expert outputs.
|
# same expert outputs.
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
|
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||||
|
|
||||||
num_actual_toks = attn_metadata.num_actual_tokens
|
num_actual_toks = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
@ -1180,10 +1188,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
scale=layer._k_scale,
|
scale=layer._k_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if fp8_attention:
|
||||||
|
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
output[num_decode_tokens:] = self._forward_prefill(
|
output[num_decode_tokens:] = self._forward_prefill(
|
||||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||||
attn_metadata)
|
attn_metadata, layer._k_scale)
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
@ -1196,7 +1207,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
# Convert from (N, B, L) to (B, N, L)
|
# Convert from (N, B, L) to (B, N, L)
|
||||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||||
|
|
||||||
|
if fp8_attention:
|
||||||
|
ql_nope_shape = decode_ql_nope.shape
|
||||||
|
decode_ql_nope, _ = ops.scaled_fp8_quant(
|
||||||
|
decode_ql_nope.reshape([
|
||||||
|
ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]
|
||||||
|
]), layer._q_scale)
|
||||||
|
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
|
||||||
|
q_pe_shape = decode_q_pe.shape
|
||||||
|
decode_q_pe, _ = ops.scaled_fp8_quant(
|
||||||
|
decode_q_pe.reshape(
|
||||||
|
[q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
|
||||||
|
layer._q_scale)
|
||||||
|
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
|
||||||
|
|
||||||
output[:num_decode_tokens] = self._forward_decode(
|
output[:num_decode_tokens] = self._forward_decode(
|
||||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer)
|
||||||
|
|
||||||
return output_padded
|
return output_padded
|
||||||
|
@ -7,7 +7,7 @@ from typing import ClassVar, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionType,
|
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||||
@ -278,6 +278,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
q_pe: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self._use_old_cutlass_mla:
|
if self._use_old_cutlass_mla:
|
||||||
# TODO: Remove the old cutlass MLA kernel after more extensive
|
# TODO: Remove the old cutlass MLA kernel after more extensive
|
||||||
|
@ -6,8 +6,7 @@ from typing import ClassVar, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionType,
|
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||||
is_quantized_kv_cache)
|
|
||||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||||
get_mla_metadata,
|
get_mla_metadata,
|
||||||
is_flashmla_supported)
|
is_flashmla_supported)
|
||||||
@ -166,16 +165,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"FlashMLAImpl")
|
"FlashMLAImpl")
|
||||||
|
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FlashMLA V1 with FP8 KV cache not yet supported")
|
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
q_pe: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: FlashMLAMetadata,
|
attn_metadata: FlashMLAMetadata,
|
||||||
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
@ -194,6 +190,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
num_splits=attn_metadata.decode.num_splits,
|
num_splits=attn_metadata.decode.num_splits,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
descale_q=layer._q_scale.reshape(1),
|
||||||
|
descale_k=layer._k_scale.reshape(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._v_up_proj(o)
|
return self._v_up_proj(o)
|
||||||
|
@ -7,6 +7,7 @@ from typing import ClassVar, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.abstract import AttentionLayer
|
||||||
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
@ -221,6 +222,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
q_pe: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: AiterMLAMetadata,
|
attn_metadata: AiterMLAMetadata,
|
||||||
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (AttentionType,
|
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||||
@ -127,6 +127,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
q_pe: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
|
Reference in New Issue
Block a user